ストリーム処理を活用してLLMベース音声対話システムのレイテンシを短縮する

この記事は、 NTT Communications Advent Calendar 2024 1日目の記事です。

こんにちは、イノベーションセンターの加藤です。普段はコンピュータビジョンの技術開発やAI/機械学習(ML: Machine Learning)システムの検証に取り組んでいます。一方で、兼務で生成AIチームに参加し、大規模言語モデル(LLM: Large Language Model)に関する技術の調査を行なっています。

音声アシスタントをLLMベースで作成する際、ユーザーの入力音声を一旦テキストに変換し、LLMに応答させた後、その応答文から読み上げ音声を生成するというカスケード方式がこれまで取られてきています。 一方最近ではMini-Omni1など、音声を入力として音声を出力するLLMを一貫して学習可能なエンドツーエンド方式も登場してきています。音声アシスタントのようにユーザーとやりとりするシステムにおいて、ユーザーの入力が終わってからアシスタントが応答を返し始めるまでの時間はレイテンシと呼ばれ、ユーザーの体験に直結する大事な指標ですが、エンドツーエンド方式ではこのレイテンシが短くなる傾向にあります。 しかしながら、利用するLLMモデルや読み上げ音声のカスタマイズの柔軟性という面で既存のカスケード方式にも長所があります。 今回はストリーム処理を活用することで、カスケード方式の欠点であるレイテンシの短縮に取り組んだ結果を紹介します。

目次

音声対話システムの仕組み

今回作成する音声対話システムは次のモジュールで構成されています。

  • VAD(Voice Activity Detection, 音声区間検出)
  • ASR(Automatic Speech Recognition, 自動音声認識)
  • LLM(Large Language Model, 大規模言語モデル)
  • TTS(Text to Speech, 音声合成)

これらのモジュールを使い、以下の流れで処理を行います。

VADを用いてユーザーの会話終わりを検出する

ユーザーがシステムに音声を送信する際、話し終わったタイミングで送信ボタンを押すというUIも考えられますが、今回はスムーズな会話を実現するために自動で発話の終わりを検出します。推論モデルはSilero VADを利用し、マイクから入力された音声を32msの粒度で発言中かどうかを判定します。そして発話の終わりを検知し、無音時間が閾値の300msを超えたら、音声バッファから喋っている部分を切り出し次の処理に回します。

from silero_vad import load_silero_vad, get_speech_timestamps, VADIterator
import gradio as gr
import numpy as np
import torch
import librosa
from typing import Tuple, Optional, List

AudioType = Tuple[int, np.ndarray]

class VAD:
    def __init__(self):
        self.SAMPLING_RATE = 16000
        self.vad_iter = VADIterator(
            load_silero_vad(onnx=True), 
            sampling_rate=self.SAMPLING_RATE,
            min_silence_duration_ms=300,
        )
        self.reset()

    def reset(self):
        self.vad_iter.reset_states()
        self.buffer = np.empty(0, dtype=np.float32)
        self.ptr = 0
        self.orig_sr = None
        self.start = None
        self.end = None

    def __call__(self, audio: AudioType):
        """ストリーミングされる音声を受け取り、発話が終わった時のみ音声バッファを返す"""
        sr, wave = audio
        self.orig_sr = sr
        window = int(512 * sr / self.SAMPLING_RATE)
        if wave.dtype == np.int16:
            wave = wave.astype(np.float32) / 32768.0
        # Convert to mono if stereo
        if wave.ndim > 1:
            wave = wave.mean(axis=1)
        self.buffer = np.concatenate([self.buffer, wave])
        while len(self.buffer) - self.ptr > window:
            chunk = self.buffer[self.ptr:self.ptr + window]
            chunk = librosa.resample(chunk, orig_sr=sr, target_sr=self.SAMPLING_RATE)
            chunk = torch.tensor(chunk)
            speech_dict = self.vad_iter(chunk, return_seconds=True)
            self.ptr += window
            if speech_dict:
                if "start" in speech_dict:
                    self.start = speech_dict["start"]
                elif "end" in speech_dict:
                    self.end = speech_dict["end"]
                    speech = self.buffer[int(self.start * self.orig_sr) : int(self.end * self.orig_sr)]
                    return (sr, speech)
        return None

ASRを用いてユーザーの発話内容を文字列に変換する

VADにより会話の終了を検知したら、今までのバッファを音声認識モデルに入力し、LLMに入れるための文字列に変換します。モデルはWhisperを利用しました。

import whisper

class ASR:
    def __init__(self):
        self.whisper_sr = whisper.audio.SAMPLE_RATE
        self.whisper_model = whisper.load_model("large-v3")
    def __call__(self, value: AudioType):
        sr, wave = value
        if wave.dtype == np.int16:
            wave = wave.astype(np.float32) / 32768.0
        wave = librosa.resample(wave, orig_sr=sr, target_sr=self.whisper_sr)
        result = self.whisper_model.transcribe(wave, language="ja", temperature=0.0)
        return result["text"]

LLMを用いてチャットボットの応答を文字列として取得する

チャット用にチューニングされたLLMであるCALM3-22B-Chatを利用し、ユーザーの発話内容に対して応答をさせます。 今回は複数回のやり取りを想定し、会話記録としてhistoryを受け取り応答文を返す形にします。

from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, TextIteratorStreamer

class LLM:
    def __init__(self):
        self.llm_model = AutoModelForCausalLM.from_pretrained("cyberagent/calm3-22b-chat", device_map="auto", torch_dtype="auto")
        self.llm_tokenizer = AutoTokenizer.from_pretrained("cyberagent/calm3-22b-chat")
        self.prompts = [
            {"role": "system", "content": "あなたは親切なAIアシスタントです。"}
        ]
    def __call__(self, history: List[dict]) -> str:
        token_ids = self.llm_tokenizer.apply_chat_template(
            self.prompts + history,
            add_generation_prompt=True,
            return_tensors="pt")
        gen = self.llm_model.generate(
            input_ids=token_ids.to(self.llm_model.device),
            return_dict_in_generate=True,
            max_new_tokens=300,
            do_sample=False)
        seq = gen.sequences
        seq = seq[:,token_ids.shape[1]:]  # skip prompt
        response = self.llm_tokenizer.batch_decode(seq, skip_special_tokens=True)[0]
        return response

TTSを用いて応答文を読み上げさせる

最後にTTSとしてESPnetを利用し、LLMからの応答文を音声に変換します。

from espnet2.bin.tts_inference import Text2Speech
from espnet_model_zoo.downloader import ModelDownloader
class TTS:
    def __init__(self):
        model_tag = "kan-bayashi/jsut_full_band_vits_prosody"
        downloader = ModelDownloader(cachedir=Path.home() / ".cache" / "espnet_model_zoo")
        self.text2speech = Text2Speech(
            **downloader.download_and_unpack(model_tag),
            # No vocoder
            device="cuda",
            noise_scale=0.333,
            noise_scale_dur=0.333,
            always_fix_seed=True,
        )
    def __call__(self, text: str):
        return self.forward(text)
    @torch.no_grad()
    def forward(self, inpt) -> AudioType:
        wav = self.text2speech(inpt)["wav"].cpu().numpy()
        return self.text2speech.fs, wav

各処理を順に行うGUIを作成

この記事ではGUIとしてGradioを利用します。 このプログラムではVADの判定により会話部分の区間が確定した後、ASR, LLM, TTSが順に実行されます。 VADで保存されている記録中の音声バッファなど、ユーザーに依存する状態は本来Gradioを使い管理する必要がありますが、本記事では簡単のため実装を省略しています。

作成したGUIは以下の画像のようになりました。

各モジュールのレイテンシを測る

一通りシステムが完成したので、次はどのモジュールがどれだけ時間をかけているか測定します。 人力で音声を入力すると話し終わりから応答までのレイテンシを測りにくいので、入力音声はTTSで作成し、TTSによる入力文の再生終了を話し終わりとみなして実行時間を計測しました。 GPU環境はNVIDIA H100 GPUを1枚利用しました。

入力文:AIによって私たちの暮らしはどのように変わりますか?簡潔に説明してください。 * 入力音声長:5.050 sec * ASR処理時間:0.787 sec * LLM処理時間:4.730 sec (0.033 sec / token) * 応答文:AIは私たちの暮らしに多岐にわたる影響を与えます。まず、日常生活の効率化が挙げられます。スマートホームデバイスや音声アシスタントが家事や日常のタスクを自動化し、時間を節約します。また、医療分野では診断や治療の精度が向上し、早期発見や個別化医療が進みます。さらに、交通や物流の分野では、自動運転車やドローン配送が普及し、移動や配送が迅速かつ効率的になります。教育分野では、AIを活用した個別学習プログラムが学習者の理解度に応じた指導を提供し、教育の質が向上します。最後に、エンターテインメントやクリエイティブ産業でも、AIが新しいコンテンツの生成やパーソナライズされた体験を提供し、私たちの楽しみ方を一変させます。 * 応答音声長:48.228 sec * TTS処理時間:0.198 sec * 全体レイテンシ(ASR + LLM + TTS + その他処理時間):5.895 sec

以上のように5.895秒が計測された全体レイテンシとなります。実際にユーザーが喋り終わってからアシスタントが喋り始めるまでには、喋り終わりが判定されるまでの300ミリ秒や、生成音声をブラウザーに渡す際のエンコード処理などの時間が追加でかかります。 しかし全体を通して一番のボトルネックはLLMの処理時間で、レイテンシの約80%を占めていることが分かります。 次の章では、LLMのストリーム出力を活用することでこのレイテンシを低減する手法や、それに伴って新しく可能になる機能を紹介します。

TTSのストリーム処理

文で区切る

LLMの文章生成は逐次的に単語を出力するため、応答全体が揃うまで待たなくても音声を再生し始めることができます。 例えば、LLMからの出力に句点が入り次第音声を出力することで、返答の音声が始まるまでのレイテンシを短縮できます。本来はテキスト全体からイントネーションや発音の長さなどを決定し音声合成するため、文章全体を入力した場合とTTSの結果が少し異なってしまいますが、文を跨いで音声の調子が影響することは考えにくいため、レイテンシの短縮を優先して採用しました。実際に生成した音声も、文章全体から生成した場合とほぼ同じ結果が得られ、文単位で区切る手法は有用であることがわかりました。

# ただのTextIteratorStreamerは漢字が来るまで内部でバッファしてしまうのでバッファリングを無効化する
class EagerTextIteratorStreamer(TextIteratorStreamer):
    def put(self, value):
        if len(value.shape) > 1 and value.shape[0] > 1:
            raise ValueError("TextStreamer only supports batch size 1")
        elif len(value.shape) > 1:
            value = value[0]
        if self.skip_prompt and self.next_tokens_are_prompt:
            self.next_tokens_are_prompt = False
            return
        self.token_cache.extend(value.tolist())
        text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
        printable_text = text[self.print_len :]
        self.token_cache = []
        self.print_len = 0
        self.on_finalized_text(printable_text)

class LLMStream(LLM):
    def __call__(self, history: List[dict]) -> str:
        token_ids = self.llm_tokenizer.apply_chat_template(
            self.prompts + history,
            add_generation_prompt=True,
            return_tensors="pt")
        streamer = EagerTextIteratorStreamer(self.llm_tokenizer, skip_prompt=True, skip_special_tokens=True)
        thread = Thread(
            target=self.llm_model.generate,
            kwargs=dict(
                input_ids=token_ids.to(self.llm_model.device),
                max_new_tokens=300,
                do_sample=False,
                streamer=streamer))
        thread.start()
        yield from streamer

class TTSStream(TTS):
    def __call__(self, text_streamer):
        pattern = re.compile(r".+[.!。\n]")  # 句点・改行のタイミングでTTSに渡す
        result = ""
        ptr = 0
        for text in text_streamer:  # LLMからの出力をストリームで受け取る
            result += text
            while ptr < len(result):
                m = pattern.search(result, ptr)
                if m:
                    proc_text = result[ptr:m.end()]
                    audio = self.forward(proc_text)
                    yield proc_text, audio  # 処理した文章と音声をストリームで返す
                    ptr = m.end()
                else:
                    break
        # process remaining text
        if len(result[ptr:]) > 0:
            audio = self.forward(result[ptr:])
            yield result[ptr:], audio

先ほどの例ならば「AIは私たちの暮らしに多岐にわたる影響を与えます。」まで出力された時点でTTSの実行や再生バッファへの入力を開始することで、これ以降のLLM推論によるレイテンシを隠蔽し、5.895秒から1.580秒まで短縮できました。

また、今回の実験ではH100を使っているというのもあり、基本的に生成される音声の長さよりも処理時間のほうが十分に短いため、LLMやTTSの処理が間に合わず再生バッファを枯渇させるというようなことは滅多に起こらないことがわかります。 先ほどの入力ケースでは下図のように、約50秒の応答音声を生成し切るのにかかった時間は5秒弱であり、その後はバッファされた音声を再生するだけとなっていることがわかります。

文節で区切る

前節では文単位で区切ってTTSに入力していましたが、どれくらい細かい区切りまでなら違和感のない音声を生成できるのでしょうか。今回は文節までなら別々に生成しても大丈夫だろうという仮定を置き実験してみました。

今回利用しているTTSモデルはVITSと呼ばれるものですが、前処理としてOpen JTalkによる音素推定が行われています。これは入力したテキストから漢字の読み・単語の分割・アクセント位置などを解析し、音素と呼ばれる発音表記に変換する処理のことで、例えば「東京都に住む」という入力に対しては t o [ o ky o o ] t o n i # s u ] m u というような列に変換されます。ここで[はピッチの上昇、]はピッチの下降、#は文節の区切りを指し、読みやすく書き換えると ト↑オキョオ↓トニ/スム となります。この前処理結果を活用して、以下のような処理を行います。

LLMから新しく文章が生成されるたびに推定音素列を更新し、文節の切れ目が分かっているところまでの音素列をTTSに入力します。 この処理はOpenJTalkの前処理結果を抽出することで実現できます。

class TTSStream2(TTS):
    def __call__(self, text_streamer):
        clean = self.text2speech.preprocess_fn.text_cleaner
        to_tok = self.text2speech.preprocess_fn.tokenizer.text2tokens
        to_ids = self.text2speech.preprocess_fn.token_id_converter.tokens2ids

        fulltext = ""
        ptr = 0
        token_buffer = []
        token_ptr = 0

        tts_input = []
        for text in text_streamer:
            fulltext += text
            fulltoken = to_tok(clean(fulltext))
            while True:
                token_buffer = fulltoken[token_ptr:]
                shift = len(token_buffer)
                if "#" in token_buffer[token_ptr:]:
                    shift = min(token_buffer.index("#"), shift)
                elif "_" in token_buffer[token_ptr:]:
                    shift = min(token_buffer.index("_"), shift)
                else:  # not found
                    break
                tts_input += token_buffer[:shift+1]
                token_ptr += shift+1
                if len(tts_input) > 0:
                    ints_input = np.array(to_ids(tts_input), dtype=np.int64)
                    audio = self.forward(ints_input)
                    yield fulltext[ptr:], audio
                    tts_input = []  # flush
                    ptr = len(fulltext)
        # process remaining text
        tts_input = to_tok(clean(fulltext))[token_ptr:]
        if len(tts_input) > 0:
            ints_input = np.array(to_ids(tts_input), dtype=np.int64)
            audio = self.forward(ints_input)
            yield fulltext[ptr:], audio

この結果レイテンシを1.284秒まで短縮できました。

しかし、文節に区切ってTTSを行うと、音声の調子が少し不自然になってしまいました。以下の音声は「こんにちは」という入力に対する文単位TTSと文節単位TTSの結果です。

文単位TTS

文節単位TTS

文節間に不自然な間があったり、「どの/ような」のアクセントがおかしくなっていたりしています。 これは入力を文節ごとに制限したせいで、文節を跨ぐときの発音間隔や、文節が連なることによるイントネーションの変化が推定できなかったためと考えられます。

そこで次はTTS入力時に前後1文節をマージンとして一緒に推定し、目的の文節の音声のみを切り出すという方法をとってみました。

この処理を行うためには文章上の位置と読み上げ音声の位置を対応づける必要がありますが、 VITSの内部で生成される各音素の発音長を取り出すことで、生成した音声から目的の文節が読み上げられている区間を抽出できます。

これを行うと、レイテンシが1.284秒から1.290秒に少し悪化しますが、イントネーションが以下のように改善しました。

しかしながら、文節の継ぎ目でノイズが混入しており、生成音声を綺麗に切り貼りするのは難しいようです。

まとめると以下の表になります。

手法 レイテンシ(秒) 品質
シーケンシャル 5.895 ⚪︎
ストリーム処理(文単位) 1.580 ⚪︎
ストリーム処理(文節単位) 1.284 ×
ストリーム処理(文節単位+前後1文節マージン) 1.290

音声の自然さという点では文単位で区切るのが限界のようです。システムプロンプトを工夫することで、応答の一文の長さ自体を短くすることも大事でしょう。

ユーザーからの割り込み

文章生成から音声出力までがストリーム処理になった場合、出力をいつでも中断できるという利点が生まれます。これとVADを組み合わせることで、応答中にユーザーが発言したときにLLM推論や音声出力を中断し、新しい応答を生成できます。ただし、アシスタントの応答音声をなんらかのスピーカーで出力させている場合は、応答の音声がそのままマイクに入ってしまってユーザーの発言と誤認してしまう問題を解決する必要があり、実用上はなかなか困難を伴うと思います。

まとめ

今回はGradioとLLMを用いて、ユーザーの音声を受け取り応答を音声で返す対話システムを作ってみました。このようなシステムではユーザーの発言が終わってからアシスタントが話し始めるまでのレイテンシがユーザーの体験に大きく影響しますが、LLMやTTSをストリーム処理することでレイテンシを短縮できました。また、ストリーム処理を実装することで、途中で推論をキャンセルできるという利点があることも紹介しました。

明日のアドベントカレンダーもお楽しみに。

参考資料

© NTT Communications Corporation 2014