先読みを用いたLLMの文章生成の高速化

こんにちは、イノベーションセンターの加藤です。普段はコンピュータビジョンの技術開発やAI/機械学習(ML: Machine Learning)システムの検証に取り組んでいます。一方で、兼務1で大規模言語モデル(LLM: Large Language Model)について調査を行なっており、特にLLMの推論や学習の高速化に関心を持っています。

今回は、小さな言語モデルによる先読みを活用してLLMの文章生成を高速化する手法(Assisted Generation2, Speculative Sampling3などと呼ばれています)についてご紹介します。 LLMの推論は計算コストが高く、文章生成の遅さが課題としてよく挙げられています。特に日本語はトークンあたりの文字数が少なく、ChatGPTのようなストリーム出力でもかなり生成が遅く感じるかと思います。

これに対して、いくらか余分にメモリを利用して、元々のLLMと全く同じものをより高速に出力できるAssisted Generationが提案されています。AI開発・機械学習のためのプラットフォームを運営するHugging Faceが公開している機械学習ライブラリであるTransformersでも実装されています。この記事ではその手法についてまとめ、より発展的な手法とともにHugging Faceのモデルで実験しました。

LLMの文章生成メカニズム

文章生成に使われるLLMの多くはCausal Language Modelで作られています。Causal Language Modelとは単語列(厳密にはトークン列)が与えられた時にのちに続く単語(トークン)を予測するモデルで、与えられたトークン列  (t_1, t_2, \cdots, t_n) に対し、各  t_i の次に来るべきトークン y_{i+1} の分布  p(y_{i+1}|t_1,\cdots,t_i) を出力します。 以下の例では、入力文"The man worked as a"の各トークンに対して、facebook/opt-125mモデルを用いて最も次に来そうなトークンを予測した結果を図示しています。一度の推論で全トークンが並列に予測できていることや、Causal Language Modelでは後ろの入力トークンが前の予測トークンに影響しないことが今回紹介する技術のポイントです。

これを用いると1回の推論で新しいトークンを1つ(ここではsecurity)生成できるため、何度も繰り返すことで文章を生成します。 これをAutoregressive Generationと呼び、Transformersでは generate() という関数が提供されています。特にこの例のように最も確率の高いものを次のトークンに選択する方法をGreedy Decodingと呼びます。 (結局最終トークンの予測 "security" しか使い道が無いように見えますが、残りのトークンの予測も後述の手法で活用できます。)

import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
print(torch.__version__)  # 2.1.0+cu121
print(transformers.__version__)  # 4.34.1

model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m").cuda()
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m", use_fast=False)
prompt = "The man worked as a"

input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
print(tokenizer.decode(model.generate(input_ids, max_new_tokens=10)[0], skip_special_tokens=True))
# The man worked as a security guard at a hotel in the city of K

また、LLMが出力した確率に沿って出力をサンプリング sample() することで、多様性のある文章を生成することもできます。

print(tokenizer.decode(model.generate(input_ids, max_new_tokens=10, do_sample=True)[0], skip_special_tokens=True))
print(tokenizer.decode(model.generate(input_ids, max_new_tokens=10, do_sample=True)[0], skip_special_tokens=True))
# The man worked as a clerk to a bakery. He worked for a bakery
# The man worked as a janitor on a cleaning staff in our house.

モデルの推論速度とAssisted Generation

最近公開されているLLMは10億~1000億規模のパラメータを持っており、推論に非常に時間がかかることが知られています。 A100 GPUでfacebook/optモデルを動作させ、ランダムな文章(C4データセットの冒頭部を利用)の続きを推論させた場合は以下の図のようになり、13Bモデルは最軽量の125Mモデルの30倍の生成時間がかかります。13B, 125Mはモデルのパラメータ数を表しており、それぞれ13 billion、125 millionを示しています。

そこで1回の推論で各位置の次トークンを並列に予測できるという性質を利用し、小さなモデルでいくつか「先読み」してからLLMでそれを検証するAssisted Generationという手法で高速に文章を生成できます。以下の節ではその手法を説明します。

generate()の高速化

まず小さなモデル(以降ではドラフトモデルと呼びます)で新しい単語を複数生成します。この時使用するモデルはLLMよりも軽量でかつ同じトークンIDを使っている必要があります。以下の例では1.3Bモデルに対し、同じトークンIDを使っていてより軽量な125Mモデルをドラフトモデルにしています。

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch

assist = AutoModelForCausalLM.from_pretrained("facebook/opt-125m").cuda()
model = AutoModelForCausalLM.from_pretrained("facebook/opt-1.3B", torch_dtype=torch.float16).cuda()
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3B", use_fast=False)

prompt = "The man worked as a"

input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
initial_len = input_ids.size(1)
for i in range(10):
    with torch.no_grad():
        y = assist(input_ids)
        next_id = y.logits[:, -1, :].argmax(-1, keepdims=True)
        input_ids = torch.cat([input_ids, next_id], dim=-1)

# 入力文章 + [ドラフトモデルが生成した文章]
print(prompt + f"[{tokenizer.decode(input_ids[0,initial_len:], skip_special_tokens=True)}]")
# The man worked as a[ security guard at a hotel in the city of K]

次に生成した文章 "The man worked as a security guard at a hotel..." をLLMに入れ、各位置の次トークンを予測します。

with torch.no_grad():
    y = model(input_ids)
    next_ids = y.logits.argmax(-1)

for i in range(10):
    input_words = tokenizer.decode(input_ids[0, :initial_len+i], skip_special_tokens=True)
    next_words = tokenizer.decode(next_ids[0, initial_len+i-1:initial_len+i], skip_special_tokens=True)
    print(f"model[{i}]: {input_words}[{next_words}]")
    if next_ids[0, initial_len+i-1] != input_ids[0, initial_len+i]:
        assist_input = tokenizer.decode(input_ids[0, :initial_len+i], skip_special_tokens=True)
        assist_word = tokenizer.decode(input_ids[0, initial_len+i:initial_len+i+1], skip_special_tokens=True)
        print(f"assistant: {assist_input}[{assist_word}]")
        break

"""
model[0]: The man worked as a[ security]
model[1]: The man worked as a security[ guard]
model[2]: The man worked as a security guard[ at]
model[3]: The man worked as a security guard at[ the]
assistant: The man worked as a security guard at[ a]
"""

LLMの予測トークンが "security guard at" まではドラフトモデルの出力と一致していることがわかります。そこでこの一致した分とその次のLLMの予測 "the" を合わせた "security guard at the" をAssisted Generationの出力として一気に採用してしまいます。 Greedy Decodingの性質上、LLMで初期プロンプトから1つずつトークンを生成しても、初めて予測がずれたこの"the"まで一致することが以下のように確かめられます。

prompt = "The man worked as a"

input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
initial_len = input_ids.size(1)
for i in range(10):
    with torch.no_grad():
        model_inputs = model.prepare_inputs_for_generation(input_ids)
        y = model(**model_inputs)
        next_id = y.logits[:, -1, :].argmax(-1, keepdims=True)
        input_ids = torch.cat([input_ids, next_id], dim=-1)
print(tokenizer.decode(input_ids[0], skip_special_tokens=True))
# The man worked as a security guard at the University of California, Berkeley,
# "The man worked as a security guard at the" が model[3] と一致している

つまり1回のLLM推論と10回のドラフトモデル推論で4トークン生成できました。これを繰り返すことで巨大なモデルの推論回数を節約しながらそのモデルの出力を完全に再現できます。これはTransformersでも assistant_model を引数に指定することで利用できます。

facebook/optのさまざまなパラメータ数のモデルに対し125Mモデル(facebook/opt-125m)を用いてAssisted Generationを行った際の実行時間は以下のようになります。

モデル規模の差が大きければ大きいほどAssisted Generationの恩恵を受けられることが分かります。

sample()の高速化

サンプリングにドラフトモデルを活用する際は、先読みで得られた確率分布を用いて棄却サンプリングを行い、LLMの出力する確率分布を再現します。棄却サンプリングとは、確率分布が非常に複雑でそこからのサンプリングが困難である時に使われる手法で、目的の確率分布  p(x) に対してサンプリングしやすい確率分布  q(x) を用意して次のように行います。

  1. (前準備) 任意の  x に対し  k \geq p(x) / q(x) が成り立つような定数  k を用意する。(用意できない場合は棄却サンプリングを使えない。すなわち  p(x) が定義されている  x では常に  q(x) も定義されている必要があり、さらに確率比  p(x) / q(x) は有界である必要がある。)
  2.  q(x) からサンプリングする  x \sim q(x)
  3. 一様乱数からのサンプリング  r \sim \mathrm{Unif}[0, 1] が  r \lt p(x)/(kq(x)) を満たせば  x p(x) からのサンプリング結果として採用、そうでなければ棄却し 2. からやり直す。

これにより p(x)から直接サンプリングすることなく、その確率分布に従う xを生成できます。この手法の弱点の1つに、用意した  k があまりに大きいと棄却の割合が増えサンプリングの効率が落ちることが挙げられます。 ただし一般的に使われている棄却サンプリングの前提とは異なり、LLMの確率分布からのサンプリングも先読みが終わる度に計算できるため、次のような工夫がなされています。

  1. 入力トークン  (x_1, \cdots x_t) に対し、ドラフトモデルで  K 個先までの確率分布  p_{t+1}, \cdots, p_{t+K} を先読みする
  2. 先読みしたトークンを用いてLLMの確率分布  q_{t+1},\cdots,q_{t+K} を計算する(一度の推論で可能)
  3.  i=1,\cdots, K に対して、一様乱数  r\sim\mathrm{Unif}[0,1] が  r \lt \min(1, q_{t+i}/p_{t+i}) を満たせば採用、そうでなければ棄却する
  4. 初めて棄却された  i のトークンを  \max(0, q_{t+i} - p_{t+i}) を正規化した分布からサンプリングする
  5.  t iを加え, 1.に戻る

この手法を用いることによって、1ループあたり1単語を必ず生成でき、また棄却サンプリングの仮定である p, qの定義域や確率比の上界などを気にする必要がなくなります。

実装のソースコードはこちら

from typing import Dict
import torch
from transformers import TemperatureLogitsWarper, LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList

class AssistMixin:
    def draft(
        self,
        eos_token_id: int,
        input_ids: torch.LongTensor,
        max_assistant_tokens: int,
        do_sample: bool,
    ) -> torch.LongTensor:
        draft_ids = input_ids
        self.cache["assistant_prob_list"] = []
        for idx in range(max_assistant_tokens):
            if "assistant_past_key_values" in self.cache:
                prev_seq_len = self.cache["assistant_past_key_values"][0][0].shape[-2]
                # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
                new_token_len = draft_ids.shape[1] - prev_seq_len
                assist_inputs = draft_ids[:, -new_token_len:]
                assist_attn = torch.ones_like(draft_ids)
                assistant_model_outputs = self.assist_model(
                    assist_inputs,
                    attention_mask=assist_attn,
                    past_key_values=self.cache["assistant_past_key_values"])
            else:
                assistant_model_outputs = self.assist_model(draft_ids)
            self.cache["assistant_past_key_values"] = assistant_model_outputs.past_key_values
            assist_new_logits = assistant_model_outputs.logits[:, -1, :]
            assist_new_logits = self.logits_processor(draft_ids, assist_new_logits)
            assist_new_logits = self.logits_warper(draft_ids, assist_new_logits)
            if do_sample:
                assist_new_probs = assist_new_logits.softmax(-1)
                self.cache["assistant_prob_list"].append(assist_new_probs)
                new_token = torch.multinomial(assist_new_probs, num_samples=1).squeeze(1)
            else:
                new_token = assist_new_logits.argmax(-1)
            draft_ids = torch.cat((draft_ids, new_token[:, None]), dim=-1)
            if new_token[0] == eos_token_id:
                break
        if do_sample:
            self.cache["assistant_prob_list"] = torch.stack(self.cache["assistant_prob_list"], dim=1)
        return draft_ids

    def verify(
        self,
        eos_token_id: int,
        input_ids: torch.LongTensor,
        candidate_input_ids: torch.LongTensor,
        max_len: int,
        do_sample: bool,
    ) -> torch.LongTensor:
        candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
        cur_len = input_ids.shape[1]
        if "past_key_values" in self.cache:
            model_attn = torch.ones_like(candidate_input_ids)
            model_input_ids = candidate_input_ids[:, -candidate_length-1:]
            outputs = self.model(
                model_input_ids,
                attention_mask=model_attn,
                past_key_values=self.cache["past_key_values"],
            )
        else:
            outputs = self.model(candidate_input_ids)
        logits = outputs.logits
        for i in range(candidate_length):
            logits[:, i, :] = self.logits_processor(candidate_input_ids[:, :cur_len + i], logits[:, i, :])
        for i in range(candidate_length):
            logits[:, i, :] = self.logits_warper(candidate_input_ids[:, :cur_len + i], logits[:, i, :])
        
        speculative_ids = candidate_input_ids[:, -candidate_length:]
        if do_sample:
            probs = logits[:, -candidate_length-1:, :].float().softmax(-1)
            speculative_probs = self.cache["assistant_prob_list"].gather(dim=-1, index=speculative_ids[:,:,None]).squeeze(-1)
            speculative_actual_probs = probs[:, :-1, :].gather(dim=-1, index=speculative_ids[:,:,None]).squeeze(-1)
            resample_probs = probs.clone()
            resample_probs[:, :-1, :] = torch.clamp(resample_probs[:, :-1, :] - self.cache["assistant_prob_list"].float(), min=0)
            resample_probs /= resample_probs.sum(-1, keepdim=True)
            acceptance_thresholds = speculative_actual_probs / speculative_probs
            unif = torch.rand_like(acceptance_thresholds)
            n_matches = ((~(unif <= acceptance_thresholds)).cumsum(dim=-1) < 1).sum()
        else:
            selected_tokens = logits[:, -candidate_length-1:, :].argmax(-1)
            n_matches = ((~(speculative_ids == selected_tokens[:,:-1])).cumsum(dim=-1) < 1).sum().cpu().item()
        
        n_matches = min(max_len - cur_len, n_matches)
        self.cache["matches"].append(n_matches)
        self.cache["past_key_values"] = outputs.past_key_values
        input_ids = torch.cat((input_ids, speculative_ids[:, :n_matches]), dim=-1)
        
        if input_ids[0, -1] == eos_token_id or input_ids.shape[1] == max_len:
            # if EOS or max_len, STOP
            return input_ids

        # add one more token
        if do_sample:
            return torch.cat((input_ids, torch.multinomial(resample_probs[:, n_matches, :], num_samples=1)), dim=-1)
        else:
            return torch.cat((input_ids, selected_tokens[:, n_matches:n_matches+1]), dim=-1)

    def crop_cache(self, assist_input_ids, large_input_ids):
        # Discard past key values relative to unused assistant tokens
        self.cache["past_key_values"] = tuple([(
            kv[0][:, :, :large_input_ids.shape[1]-1, :],
            kv[1][:, :, :large_input_ids.shape[1]-1, :],
        ) for kv in self.cache["past_key_values"]])
        self.cache["assistant_past_key_values"] = tuple([(
            kv[0][:, :, :assist_input_ids.shape[1]-1, :],
            kv[1][:, :, :assist_input_ids.shape[1]-1, :],
        ) for kv in self.cache["assistant_past_key_values"]])


class SpecSampler(AssistMixin):
    def __init__(self, tokenizer, large_model, assist_model):
        self.tokenizer = tokenizer
        self.model = large_model
        self.assist_model = assist_model
    @torch.no_grad()
    def generate(self, input_ids: torch.LongTensor, max_new_len: int, temperature: float):
        max_len = input_ids.shape[1] + max_new_len
        self.cache = {}
        self.cache["matches"] = []
        self.max_assistant_tokens = 5
        self.logits_processor = LogitsProcessorList()
        self.logits_warper = LogitsProcessorList([TemperatureLogitsWarper(temperature)])
        while True:
            draft_ids = self.draft(self.tokenizer.eos_token_id, input_ids,
                                   max_assistant_tokens=self.max_assistant_tokens, do_sample=True)
            new_input_ids = self.verify(self.tokenizer.eos_token_id, input_ids, draft_ids,
                                    max_len=max_len, do_sample=True)
            n_matches = new_input_ids.shape[1] - input_ids.shape[1] - 1
            if n_matches == self.max_assistant_tokens:
                self.max_assistant_tokens += 2
            else:
                self.max_assistant_tokens = max(1, self.max_assistant_tokens - 1)
            input_ids = new_input_ids

            self.crop_cache(input_ids[:, :-1], input_ids)
            if input_ids.shape[1] >= max_len or input_ids[0, -1] == self.tokenizer.eos_token_id:
                break
        return input_ids

こちらもTransformersでは assistant_model を引数に指定することで利用できますが、バージョン4.34.1時点では若干効率が悪い実装になっています。 というのもTransformers実装では先読みの確率分布のSoftmax温度を常に0にしているのと等価なものになっており、LLMの出力する確率分布からかなり離れてしまっています。これを上で挙げたソースコードのようにLLM側で利用する温度と同一のものを利用するように実装し直すと、以下の図のようにほとんどの温度パラメータで効率化できます。特に高い温度においても、Transformers実装ではドラフトモデルなしでの生成よりも遅くなっているのに対し、私の実装では速くなっていることがわかります。

assistは独自実装、assist(hf)はTransformers実装、6.7B(float16)はドラフトモデル無しの速度、125mはドラフトモデルのみで生成した時の速度

トークナイザの互換性

トークナイザ(Tokenizer)は入力文章を言語モデルが扱えるようにトークンIDの列に変換するモジュールであり、モデルによってしばしば異なるトークナイザが使われています。 しかしAssisted Generationを使うにあたり、LLMとドラフトモデルのトークナイザは同じものを使う必要があります。これは内部ではトークンIDを用いてモデル間をやり取りしており、IDが共通であればトークン列をそのまま渡せて効率が良いためです。これに対し、軽量版モデルが存在しないなどでどうしても異なるトークナイザを使いたい場合は2つのトークナイザでトークン列を相互に変換することでAssisted Generationを正しく動かすことができます。ただしあまり高速化は期待できないかもしれません。以下の例ではllm-jp-13Bとopen-calm-small (160Mパラメータ)を用いて生成していますが、生成速度にあまり差が生まれませんでした。

実装のソースコードはこちら

import torch
import numpy as np
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers import TemperatureLogitsWarper, LogitsProcessorList
import time

class JpCALM(AssistMixin):
    def __init__(self, large_tokenizer, assist_tokenizer, large_model, assist_model):
        self.large_tokenizer = large_tokenizer
        self.assist_tokenizer = assist_tokenizer
        self.model = large_model
        self.assist_model = assist_model
    def generate(self, text: str, max_len: int):
        assist_input_ids = self.assist_tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").to(self.assist_model.device)
        large_input_ids = self.large_tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").to(self.model.device)
        self.cache = {}
        self.cache["matches"] = []
        self.logits_processor = LogitsProcessorList()
        self.logits_warper = LogitsProcessorList()
        while True:
            # 5つ先読み
            assist_draft_ids = self.draft(self.assist_tokenizer.eos_token_id, assist_input_ids, max_assistant_tokens=5, do_sample=False)
            # 先読みしたassistantのtokenを変換
            candidate_words = self.assist_tokenizer.decode(assist_draft_ids[0, assist_input_ids.shape[1]:], skip_special_tokens=True)
            large_candidate_ids = self.large_tokenizer.encode(candidate_words, add_special_tokens=False, return_tensors="pt").to(self.model.device)[:,1:]
            large_draft_ids = torch.cat((large_input_ids, large_candidate_ids), dim=1)
            # 検証
            large_next_input_ids = self.verify(self.large_tokenizer.eos_token_id, large_input_ids, large_draft_ids, max_len, do_sample=False)
            # attentionのキャッシュを整理
            self.crop_cache(assist_input_ids, large_next_input_ids)
            # 検証したLLMのtokenを変換
            selected_tokens = large_next_input_ids[:, large_input_ids.shape[1]:]
            large_input_ids = large_next_input_ids
            large_input_words = self.large_tokenizer.decode(large_input_ids[0], skip_special_tokens=True)
            valid_words = self.large_tokenizer.decode(selected_tokens[0], skip_special_tokens=True)
            assist_valid_ids = self.assist_tokenizer.encode(valid_words, add_special_tokens=False, return_tensors="pt").to(self.assist_model.device)
            assist_input_ids = torch.cat((assist_input_ids, assist_valid_ids.long()), dim=-1)
            
            if large_input_ids.shape[1] >= max_len or large_input_ids[0, -1] == self.large_tokenizer.eos_token_id:
                break
        return large_input_words

large_tokenizer = AutoTokenizer.from_pretrained("llm-jp/llm-jp-13b-v1.0")
assist_tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-small")
large_model = AutoModelForCausalLM.from_pretrained("llm-jp/llm-jp-13b-v1.0", device_map="auto", torch_dtype=torch.float16, load_in_8bit=True)
assist_model = AutoModelForCausalLM.from_pretrained("cyberagent/open-calm-small", device_map="auto", torch_dtype=torch.float16)
assisted_jp = JpCALM(large_tokenizer, assist_tokenizer, large_model, assist_model)

print("=========assisted===========")
tic = time.perf_counter()
with torch.no_grad():
    out1 = assisted_jp.generate("自然言語処理とは何か", 128)
tac = time.perf_counter()
print(out1)
print(tac - tic, "sec")
matches = np.array(assisted_jp.cache["matches"])
print("generated tokens per cycle:", matches.mean() + 1)
print("acceptance rate:", (matches > 0).sum() / len(matches))

print("=========baseline===========")
tic = time.perf_counter()
with torch.no_grad():
    tokenized_input = assisted_jp.large_tokenizer.encode("自然言語処理とは何か", add_special_tokens=False, return_tensors="pt").to(assisted_jp.model.device)
    output = assisted_jp.model.generate(
        tokenized_input,
        max_length=128,
    )[0]
    out2 = assisted_jp.large_tokenizer.decode(output)
tac = time.perf_counter()
print(out2)
print(tac - tic, "sec")

assert out1 == out2

"""
=========assisted===========
自然言語処理とは何か (岩波新書) | 西垣 通 |本 | 通販 | AmazonKindleストアでは、 自然言語処理とは何か (岩波新書)を、Kindle無料アプリまたはKindle電子書籍リーダーで今すぐお読みいただけます。Kindle電子書籍リーダーの 詳細はこちら自然言語処理とは何か (岩波新書) がカートに入りました自然言語処理とは何か (岩波新書) 新書 – 2016/11/21¥ 886 ¥ 
13.10530449775979 sec
generated tokens per cycle: 1.7910447761194028
acceptance rate: 0.23880597014925373
=========baseline===========
自然言語処理とは何か (岩波新書) | 西垣 通 |本 | 通販 | AmazonKindleストアでは、 自然言語処理とは何か (岩波新書)を、Kindle無料アプリまたはKindle電子書籍リーダーで今すぐお読みいただけます。Kindle電子書籍リーダーの 詳細はこちら自然言語処理とは何か (岩波新書) がカートに入りました自然言語処理とは何か (岩波新書) 新書 – 2016/11/21¥ 886 ¥ 
14.313117458950728 sec
"""

更なる高速化: Token Tree Verification

Assisted Generationの速度は先読みがどれだけ当たるかにかかっています。しかし連続して当てられる確率は指数関数的に減少していくため、ドラフトモデルがLLMの出力分布をうまく再現しかなりの精度で先読みを当てないと生成速度が伸び悩んでしまいます。例えば前節で挙げたllm-jp-13Bとopen-calm-smallの組み合わせでは平均で1.79個しか生成できていませんでした。これに対し、ドラフトモデルを変えずに先読みの正解率を向上させる工夫としてToken Tree Verification4という手法を活用できます。

LLMの文章生成メカニズム② Causal Language ModelにおけるAttention Maskについて

文章生成に使われているCausal Language Modelでは後ろの入力トークンが前の予測トークンに影響を与えないという話がありました。これはTransformerモデル内でトークン同士が影響し合うSelf-Attention部分にマスクを施すことで実現しています。

このマスクを改造することでより複雑な制約を課すことができます。

例えば "The quick brown fox jumps over" と "The quick brown fox runs around" を同時にLLMに入力したい場合、 "The quick brown fox jumps over runs around" と1行にまとめたのち、"jumps"と"runs", "over"と"around"などといったトークン間のattentionをマスクし取り除くことで、それぞれの文章を独立に推論した時と同じ結果を得ることができます。この手法は、前方で一致しているトークン列が長ければ長いほど、複数文をバッチでまとめるよりメモリ量や計算量的な面で有利になります。

これを一般化すると、分岐した先読みトークン木を1回のLLM推論で検証することができます。 元論文では複数のドラフトモデルからそれぞれ先読みを生成し、それらを1つのトークン木に集約していましたが、今回は1つのドラフトモデルから複数の先読みを生成してみます。

Token Tree Verificationを用いたgenerate()

facebook/optモデルの例に戻って、先読みがどれくらい当たっているのかを可視化してみます。6.7B, 13Bに対して125Mのモデルをドラフトモデルとして利用し、LLMが出力したトークンがドラフトモデルにおいて何番目までの候補に入っていたかを以下の図に示しました。

これまでの手法ではドラフトモデルが採用するのは一番確率の高いトークン、すなわちtop-1のトークンなので約70%の割合で先読みに成功していることがわかります。さらにtop-2まで採用したとすると正解率は80%に上昇します。先読みに複数選択肢を用意することには効果がありそうです。

ではどのように選択肢を用意すれば良いでしょうか。あまり当たらなそうな先読みトークンに対しては、そのさらに先を読むことによる利益は少なそうです。そこでドラフトモデルの出力確率をそのまま先読みが当たる確率と見做してしまい、先読みが当たる数の期待値が最大になるように木を生成します。

実装のソースコードはこちら

from dataclasses import dataclass, field
import heapq
from typing import Dict, List
import torch
import torch.nn.functional as F
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList

@dataclass(order=True)
class Node:
    nll: float  # 確率のnegative logが小さい順で幅優先探索を行う
    gain: int
    tokens: torch.LongTensor = field(compare=False)
    attention_mask: torch.LongTensor = field(compare=False)
    all_draft_tokens: torch.LongTensor = field(compare=False)

class AssistTreeMixin:
    def draft(
        self,
        eos_token_id: int,
        input_ids: torch.LongTensor,
        max_assistant_tokens: int,
    ) -> torch.LongTensor:
        self.cache["assistant_prob_list"] = []
        seq_len = input_ids.shape[1]
        max_len = seq_len + max_assistant_tokens
        pq = []
        device = self.assist_model.device
        heapq.heappush(pq, Node(0, 0, torch.LongTensor([]).to(device), torch.ones_like(input_ids), torch.LongTensor([]).to(device)))
        draft_ids = input_ids
        draft_masks = []
        draft_nodes = []
        for idx in range(max_assistant_tokens+1):
            top = heapq.heappop(pq)    
            if len(top.tokens) > 0:
                draft_ids = torch.cat((draft_ids, top.tokens[:,None]), dim=-1)
            draft_len = draft_ids.shape[1]
            attention_mask = F.pad(top.attention_mask, (0, draft_len - top.attention_mask.shape[1]))
            attention_mask[:, -1] = 1
            if idx > 0:
                # トークン木に採用
                draft_masks.append(F.pad(attention_mask, (0, max_len - attention_mask.shape[1])))
                draft_nodes.append(top)
            if "assistant_past_key_values" in self.cache:
                prev_seq_len = self.cache["assistant_past_key_values"][0][0].shape[-2]
                new_token_len = draft_ids.shape[1] - prev_seq_len
                assist_inputs = draft_ids[:, -new_token_len:]
                assistant_model_outputs = self.assist_model(
                    assist_inputs,
                    attention_mask=attention_mask,
                    past_key_values=self.cache["assistant_past_key_values"])
            else:
                assistant_model_outputs = self.assist_model(draft_ids, attention_mask=attention_mask)
            self.cache["assistant_past_key_values"] = assistant_model_outputs.past_key_values
            assist_new_logits = assistant_model_outputs.logits[:, -1, :]  # (batch, vocab)
            assist_new_logits = self.logits_processor(draft_ids, assist_new_logits)
            assist_new_logits = self.logits_warper(draft_ids, assist_new_logits)
            assist_new_logprobs = F.log_softmax(assist_new_logits, dim=-1)  # (batch, vocab)
            # 計算量節約のために探索数を5つに制限
            assist_new_topk = torch.topk(assist_new_logprobs, k=5, dim=-1)  # (batch, k)
            for k in range(5):
                new_token = assist_new_topk.indices[:, k]
                new_nll = -assist_new_topk.values[0, k]
                if new_token != eos_token_id:
                    heapq.heappush(pq, Node(top.nll + new_nll, top.gain+1, new_token, attention_mask, torch.cat((top.all_draft_tokens, new_token[:]))))
        return draft_ids, torch.cat(draft_masks, dim=0), draft_nodes

    def verify(
        self,
        eos_token_id: int,
        input_ids: torch.LongTensor,
        candidate_input_ids: torch.LongTensor,
        draft_masks: torch.LongTensor,
        draft_nodes: List[Node],
        max_len: int,
    ) -> torch.LongTensor:
        candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
        tgt_len = candidate_input_ids.shape[1]
        def make_tree_mask(_attention_mask, _input_shape, inputs_embeds, past_key_values_length):
            # Causal Maskを上書きする
            tree_mask = torch.tril(torch.ones(tgt_len, tgt_len))
            tree_mask[-candidate_length:, -candidate_length:] = draft_masks[:, input_ids.shape[1]:]
            tree_mask = torch.full((tgt_len, tgt_len), torch.finfo(inputs_embeds.dtype).min).masked_fill(tree_mask > 0, 0)
            if past_key_values_length > 0:
                tree_mask = torch.cat((torch.zeros(tgt_len-past_key_values_length, past_key_values_length), tree_mask[past_key_values_length:, past_key_values_length:]), dim=-1)
            return tree_mask[None,None].to(inputs_embeds.dtype).to(inputs_embeds.device)
        self.model.model.decoder._prepare_decoder_attention_mask = make_tree_mask
        cur_len = input_ids.shape[1]
        if "past_key_values" in self.cache:
            prev_seq_len = self.cache["past_key_values"][0][0].shape[-2]
            new_token_len = candidate_input_ids.shape[1] - prev_seq_len
            model_attn = torch.ones_like(candidate_input_ids)
            model_input_ids = candidate_input_ids[:, -new_token_len:]
            outputs = self.model(
                model_input_ids,
                attention_mask=model_attn,
                past_key_values=self.cache["past_key_values"],
            )
        else:
            outputs = self.model(candidate_input_ids)
        logits = outputs.logits
        for i in range(candidate_length):
            logits[:, i, :] = self.logits_processor(candidate_input_ids[:, :cur_len + i], logits[:, i, :])
        for i in range(candidate_length):
            logits[:, i, :] = self.logits_warper(candidate_input_ids[:, :cur_len + i], logits[:, i, :])

        speculative_ids = candidate_input_ids[:, -candidate_length:]
        selected_tokens = logits[:, -candidate_length-1:, :].argmax(-1)
        best_sele = torch.LongTensor([]).to(input_ids.device)
        best_draft_mask = torch.cat((torch.ones(input_ids.shape[1]), torch.zeros(candidate_length)))
        n_matches = -1
        longest_tokens = 0
        for i, node in enumerate(draft_nodes):
            selected_tokens_i = torch.cat((selected_tokens[0, 0:1], selected_tokens[0, 1:][draft_masks[i,-candidate_length:]>0]))
            streak = (~(node.all_draft_tokens == selected_tokens_i[:-1])).cumsum(0) < 1
            n_matches_i = streak.sum().cpu().item()
            longest_tokens = max(longest_tokens, len(node.all_draft_tokens))
            if n_matches_i > n_matches:
                n_matches = n_matches_i
                best_sele = selected_tokens_i
                best_draft_mask = draft_masks[i]
        self.cache["best"] = longest_tokens == n_matches
        self.cache["mask_to_cache"] = best_draft_mask > 0
        
        n_matches = min(max_len - cur_len, n_matches)
        self.cache["matches"].append(n_matches)
        self.cache["past_key_values"] = outputs.past_key_values
        verified = torch.cat((input_ids, best_sele[None, :n_matches]), dim=-1)
        
        if verified[0, -1] == eos_token_id or verified.shape[1] == max_len:
            return verified

        # add one more token
        verified = torch.cat((verified, best_sele[None, n_matches:n_matches+1]), dim=-1)
        return verified
        
    def crop_cache(self, input_ids):
        # Discard past key values relative to unused assistant tokens
        mask = self.cache["mask_to_cache"]
        length = input_ids.shape[1] - 2
        self.cache["past_key_values"] = tuple([(
            kv[0][:, :, mask, :][:,:,:length,:],
            kv[1][:, :, mask, :][:,:,:length,:],
        ) for kv in self.cache["past_key_values"]])
        self.cache["assistant_past_key_values"] = tuple([(
            kv[0][:, :, mask, :][:,:,:length,:],
            kv[1][:, :, mask, :][:,:,:length,:],
        ) for kv in self.cache["assistant_past_key_values"]])


class SpecDecoder(AssistTreeMixin):
    def __init__(self, tokenizer, large_model, assist_model):
        self.tokenizer = tokenizer
        self.model = large_model
        self.assist_model = assist_model
    @torch.no_grad()
    def generate(self, input_ids: torch.LongTensor, max_new_len: int, only_draft=False):
        max_len = input_ids.shape[1] + max_new_len
        self.cache = {}
        self.cache["matches"] = []
        self.cache["first_assistant_prob"] = []
        self.cache["verified"] = []
        self.max_assistant_tokens = 5
        while True:
            draft_ids, draft_masks, draft_nodes = self.draft(self.tokenizer.eos_token_id, input_ids, max_assistant_tokens=self.max_assistant_tokens)
            self.cache["draft"] = (draft_ids, draft_masks, draft_nodes)
            if only_draft:
                break
            new_input_ids = self.verify(self.tokenizer.eos_token_id, input_ids, draft_ids, draft_masks, draft_nodes, max_len=max_len)
            n_matches = new_input_ids.shape[1] - input_ids.shape[1] - 1
            if self.cache["best"]:
                self.max_assistant_tokens += 2
            else:
                self.max_assistant_tokens = max(1, self.max_assistant_tokens - 1)
            self.crop_cache(new_input_ids)
            
            input_ids = new_input_ids
            if input_ids.shape[1] >= max_len or input_ids[0, -1] == self.tokenizer.eos_token_id:
                break
        return input_ids

facebook/opt-125mをドラフトモデルとして、"The man worked as a"の続きを5つ先読みすると次のようになりました。

矢印の中身は遷移確率です。"worked as a"の次の単語は自信がないものの、"security"と来たら"guard", "truck"と来たら"driver"というのはそれなりに自信がありそうです。 このように、"security guard"と先読みするだけではこころもとない時に"truck driver"という先読みも含めることができるというのがこのトークン木を活用する利点です。

これを用いてAssisted Generationがどれだけ速くなるか実験しました。

...残念ながら速くはなりませんでした。トークン木の生成アルゴリズムや先読み数などのチューニングが必要そうです。

元論文の実装

元論文では複数のドラフトモデルを使って先読みをすることを想定しており、複数のGPUにドラフトモデルを分散させるなどの効率化手法についても触れています。それに加えて、この章ではGreedy Decodingにしか触れませんでしたが、サンプリングを行った際の速度についても実験されています。また、公式実装はFlexFlowで利用できます。

おわりに

本記事では、Assisted GenerationというLLMの推論高速化手法についてご紹介しました。

NTT Com では、大規模言語モデルおよび生成AIを活用したプロダクト・ソリューション開発、社内DXに挑戦するソフトウェアエンジニアを募集しています!詳細は以下のリンクをご覧ください。

hrmos.co


  1. イノベーションセンターでは、本人の意志に応じて複数のプロジェクトへ参画できる兼務制度が用意されています。詳しくは イノベーションセンター テクノロジー部門 紹介デッキ をご覧ください。
  2. https://huggingface.co/blog/assisted-generation
  3. https://arxiv.org/abs/2302.01318
  4. https://arxiv.org/abs/2305.09781v2
© NTT Communications Corporation All Rights Reserved.