wav2vec 2.0 を使って 手軽に音声認識モデルを触れるようになろう

この記事は NTTコミュニケーションズ Advent Calendar 2021 の20日目の記事です。

はじめに

こんにちは。プラットフォームサービス本部アプリケーションサービス部の是松です。 NTTコミュニケーションズでは自然言語処理、機械翻訳、音声認識・合成、要約、映像解析などのAI関連技術を活用した法人向けサービスを提供しています。(COTOHA シリーズ

NTTコミュニケーションズがこのようなAI関連技術を活用したサービスを展開する強みとして、

  • NTT研究所の研究成果が利用可能であること
  • 自社の他サービスを利用しているお客様に対してシナジーのあるサービスを提案できること

この2点が挙げられると思います。
実際に、私が担当している COTOHA Voice Insight は 通話音声テキスト化によってコンタクトセンターの業務効率化・高度化を実現するサービスなのですが、 NTT研究所の音声認識技術を活用しており、自社サービスとの連携も積極的に行っています。

ターゲットとしているコンタクトセンターのDX市場は変化が激しい業界でありながら、 私たちのサービスはその変化についていく体制が整っていないことが課題だと感じています。
様々な事情から、実際にサービスで使用している音声認識モデルを気軽に試すことは難しいのですが、オープンソースの音声認識技術を活用することでサービスの品質向上につながる知見を集めることが可能だと考え、技術調査をしてきました。

本記事では、そのような取り組みの1つとして、wav2vec 2.0 というオープンソースを用いて、事前学習された音声認識モデルを少量のデータセットでチューニングする方法をご紹介します。 手元で簡単に音声認識を試すことができるようになれば、PoC(Proof of Concept)などを通して知見を集めることができ、 なかなか小回りが効かないサービス開発の効率向上が期待できると考えています。

wav2vec 2.0 とは

2020.6 に wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations という論文で提案された音声認識フレームワークです。

wav2vec 2.0 では、ラベル付き(書き起こし文がある)の音声データだけでなくラベルなし音声データも学習に活用する、自己教師あり学習の手法を採用しています。 少量のラベル付きデータ + たくさんのラベルなしデータでも、それまでの手法に匹敵する音声認識精度を達成し、 ラベル付きデータの量を増やしていくことで音声認識精度がさらに向上していくことを示しました。

現在でも wav2vec 2.0 を基にした手法が複数の音声認識タスクで最高性能を記録しています。 https://paperswithcode.com/task/speech-recognition

wav2vec 2.0 の詳細は原論文や、解説記事(こちらが分かりやすかったです)を参照ください。

wav2vec 2.0 の素晴らしい点として、大規模音声データを用いて事前学習されたモデルに対して、少量のデータセットを用いてパラメータを再調整(Fine Tuning)することによって、後から追加した音声データにうまく適合したモデルを作れることが挙げられます。 しかも、wav2vec 2.0 は、ソースコードと事前学習されたモデルが公開されているため(こちら)、誰でも簡単にモデルをチューニングできます。

日本語データセットを使って音声認識モデルをチューニングする

wav2vec 2.0 の効果的な使い方として、事前学習された言語非依存のモデルを少量の特定言語データセットで Fine Tuning する手法があります。 この手法は、話者が少ないなどの理由でデータセットが十分存在しない言語の音声認識モデルを作成する際に効果的です。 他にも、方言や特定の話者に特化したモデルや電話音声など特定のドメインのみで使用するモデルなど、これまでは十分なデータを用意できていなかった分野での活用が期待されます。

今回は、言語非依存の事前学習モデルに対して、少量の日本語音声データを使って Fine Tuning を行います。
大まかな流れは、英語データセットを使用してチューニングを行なっているこちらの記事を参考にしています。解説もわかりやすいので興味のある方は元の記事をご参照ください。

環境

Google Colaboraoty Pro を使用しました。
無料版でやる場合は、おそらく学習時にGPUのメモリが足らなくなるので、バッチサイズを小さくするなどして対応してください。

事前に必要なライブラリをインストールします。

%%capture
# 日本語データセットのダウンロード
!pip install datasets==1.13.3
# 機械学習ライブラリ
!pip install transformers==4.11.3
!pip install torchaudio
# 音声データ処理用ライブラリ
!pip install librosa
# 形態素解析
!pip install mecab-python3
!pip install unidic
# かなローマ字変換
!pip install romkan

最新の日本語辞書をダウンロードしておきます。

!python -m unidic download

学習データの準備

まずは、Fine Tuning に使用する日本語データセットを用意しましょう。 common voice というオープンソースの音声データセットを構築するプロジェクトがあり、そこでは様々な言語の音声データが収録・公開されています。

下記のライブラリを使うことで、common voice データセットを簡単に使うことができます。
https://pypi.org/project/datasets/

余談ですが、common voice のwebサイトでは、表示されるテキストを読み上げて録音したり、録音された音声の品質を評価したりすることでプロジェクトに貢献できます。興味がある方はぜひ試してみてください。

datasets.load_dataset() を使って様々な言語のデータセットをダウンロード可能です。 日本語の場合は第2引数に "ja" と入力します。 データセットは、train, validation, test の3種類に分かれており、今回は train データと test データの2つに分ければ十分なので、test と validation を train データとします。

from datasets import load_dataset, load_metric, Audio

common_voice_train = load_dataset("common_voice", "ja", split="train+validation")
common_voice_test = load_dataset("common_voice", "ja", split="test")

データセットは、train, validation, test の3種類に分かれていますが、今回は train と test データの2つに分けて使用します。

今回は、trainデータが1308組、testデータが632組でした。

len(common_voice_train), len(common_voice_test)

(1308, 632)

データセットの中身を見てみましょう。

common_voice_train[0]

{'accent': '',
'age': 'twenties',
'audio': {'array': array([ 0. , 0. , 0. , ..., 0.00085527,
-0.00014246, -0.00077921], dtype=float32),
'path': '/root/.cache/huggingface/datasets/downloads/extracted/07b0a73f2103df267f566548f7597fe8d75f8d4bdd37b7f556478ae85378bd6a/cv-corpus-6.1-2020-12-11/ja/clips/common_voice_ja_19817895.mp3',
'sampling_rate': 48000},
'client_id': 'b067e4a64d0c78c7c24b8eb93f9efc165121f9281fa6c31386d872529c2951a5a1f144ee8e5679c1bd41003695583b1c341de7d67fea995d584b5724b91ce984',
'down_votes': 1,
'gender': 'male',
'locale': 'ja',
'path': '/root/.cache/huggingface/datasets/downloads/extracted/07b0a73f2103df267f566548f7597fe8d75f8d4bdd37b7f556478ae85378bd6a/cv-corpus-6.1-2020-12-11/ja/clips/common_voice_ja_19817895.mp3',
'segment': "''",
'sentence': '予想外の事態に、電力会社も、ちょっぴり困惑気味だ。',
'up_votes': 2}

様々な情報が含まれていますが、今回は音声データとテキスト情報だけがあればいいので、 不要な情報を除きます。

common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])

音声データの確認

次に音声データを確認してみましょう。 "path" に実際の音声データが格納されているファイルの場所が記載されています。 しかし、既に"audio" に音声ファイルの中身のバイナリデータが格納されているので、 下記の通り "audio" の中身を適切に処理することで音声データを再生できます。

import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(common_voice_train)-1)

ipd.Audio(data=common_voice_train[rand_int]["audio"]["array"], autoplay=True, rate=16000)

テキストデータの確認

音声データがどんな内容を話しているかは、"sentence" の中に記載されています。 今回は句読点を除いた状態でチューニングを行うので、句読点などの不要な記号を除去する関数を作成し、テキストを整形します。

import re
chars_to_ignore_regex = '[、,。]'

def remove_special_characters(batch):
    batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower() + " "
    return batch

common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)

ここで、データセットの中身をランダムに10文出力する関数を作成して、表示してみます。

from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

show_random_elements(common_voice_test.remove_columns(["path","audio"]))

sentence
0 このカレーはとても辛いです
1 ボーハンはイーストマンらギャングのスピークイージーの上がりから賄賂を取っていたとも噂された
2 母はいつもわたしに買い物を頼みます
3 危ないのでそちらへ行かないでください
4 それはたいてい一時間にも及ぶ
5 不満やいらだちはもっぱら受験や身のまわりに向けられている
6 来月の初め国へ帰ります
7 この箱はとても重いです
8 娘のフィアンセでこいつだけにはどうしても負けられない
9 うちの中学は弁当制で持って行けない場合は五十円の学校販売のパンを買う

句読点が除去されていることが確認できました。 これで、必要な音声データとテキスト情報のセットの準備ができました。

学習データの変換(かな漢字・カナ・ローマ字)

次に、3パターンのデータセットを用意したいと思います。 1つ目は前節で用意したデータセット(かな漢字文)をそのまま使用するもの。
2つ目はかな漢字文をカナ文に変換したもの。
3つ目はカナ文をローマ字に変換したものです。

かな漢字文をカナ文に変換するために、形態素解析器を使用します。 今回は MeCab という有名な形態素解析ライブラリを使用します。

まずはMeCabを使用してみましょう。

import MeCab
import unidic
import romkan
mecab = MeCab.Tagger()

sentence = common_voice_train[0]['sentence']
print(sentence)
print(mecab.parse(sentence))

予想外の事態に電力会社もちょっぴり困惑気味だ
予想 名詞,普通名詞,サ変可能,,,,ヨソウ,予想,予想,ヨソー,予想,ヨソー,漢,"","","","","","",体,ヨソウ,ヨソウ,ヨソウ,ヨソウ,"0","C2","",10819203040944640,39360
外 接尾辞,名詞的,一般,,,,ガイ,外,外,ガイ,外,ガイ,漢,"","","","","","",接尾体,ガイ,ガイ,ガイ,ガイ,"","C3","",2169894821044736,7894
の 助詞,格助詞,,,,,ノ,の,の,ノ,の,ノ,和,"","","","","","",格助,ノ,ノ,ノ,ノ,"","名詞%F1","",7968444268028416,28989
事態 名詞,普通名詞,一般,,,,ジタイ,事態,事態,ジタイ,事態,ジタイ,漢,"","","","","","",体,ジタイ,ジタイ,ジタイ,ジタイ,"1","C1","",4922247303275008,17907
に 助詞,格助詞,,,,,ニ,に,に,ニ,に,ニ,和,"","","","","","",格助,ニ,ニ,ニ,ニ,"","名詞%F1","",7745518285496832,28178
電力 名詞,普通名詞,一般,,,,デンリョク,電力,電力,デンリョク,電力,デンリョク,漢,"","","","","","",体,デンリョク,デンリョク,デンリョク,デンリョク,"0,1","C2","",7095706913481216,25814
会社 名詞,普通名詞,一般,,,,カイシャ,会社,会社,カイシャ,会社,カイシャ,漢,"カ濁","基本形","","","","",体,カイシャ,カイシャ,カイシャ,カイシャ,"0","C2","",1577258053673472,5738
も 助詞,係助詞,,,,,モ,も,も,モ,も,モ,和,"","","","","","",係助,モ,モ,モ,モ,"","動詞%F2@-1,形容詞%F4@-2,名詞%F1","",10324972564259328,37562
ちょっぴり 副詞,,,,,,チョッピリ,ちょっぴり,ちょっぴり,チョッピリ,ちょっぴり,チョッピリ,和,"","","","","","",相,チョッピリ,チョッピリ,チョッピリ,チョッピリ,"3","","",6652053971673600,24200
困惑 名詞,普通名詞,サ変可能,,,,コンワク,困惑,困惑,コンワク,困惑,コンワク,漢,"","","","","","",体,コンワク,コンワク,コンワク,コンワク,"0","C2","",3654785274356224,13296
気味 名詞,普通名詞,一般,,,,キミ,気味,気味,ギミ,気味,ギミ,漢,"キ濁","濁音形","","","","",体,ギミ,ギミ,ギミ,キミ,"2","C3","",2424706640790016,8821
だ 助動詞,,,,助動詞-ダ,終止形-一般,ダ,だ,だ,ダ,だ,ダ,和,"","","","","","",助動,ダ,ダ,ダ,ダ,"","名詞%F1","",6299110739157675,22916
EOS

かな漢字文からカナ文への変換

上記より得られる出力のうち、カナの情報だけを集めて1つの文に繋げ直すための関数を作成します。

def convert_sentence_to_kana(batch):
    s = mecab.parse(batch["sentence"])
    kana = ""
    for line in s.split("\n"):
      if line.find("\t")<=0: continue
      columns = line.split(',')
      if len(columns) < 10:
        kana += line.split('\t')[0]
      else:
        kana += columns[9]
    batch["kana"] = kana
    return batch

不格好な関数ですが、これで今回使用する日本語データのテキストを全てカナ文に変換できます。 注意した点として、もともとカタカナだった単語などは読みの情報が出力されないので、 場合分けをして元の表記をそのまま使用するようにしています。

では、正しく処理ができたかどうか確認してみましょう。

show_random_elements(common_voice_test.remove_columns(["path","audio"]))

sentence kana
0 人々は花の苗や種を焼却し畑の花を全部抜きとってしまう ヒトビトワハナノナエヤタネオショーキャクシハタケノハナオゼンブヌキトッテシマウ
1 毎日忙しいのであまり休むことができません マイニチイソガシーノデアマリヤスムコトガデキマセン
2 女性とは逆で何とか常識を破ってめだってやろうと意気込む人がほとんどだ ジョセートワギャクデナントカジョーシキオヤブッテメダッテヤロートイキゴムヒトガホトンドダ
3 細長い指先で激しく鍵を叩く ホソナガイユビサキデハゲシクカギオタタク
4 クィーンズアベニューアルファに所属している クィーンズアベニューアルファニショゾクシテイル
5 山田さんは来月東京へ行くそうです ヤマダサンワライゲツトーキョーエイクソーデス
6 野球の後のビールぐらいうまいものはない ヤキューノアトノビールグライウマイモノワナイ
7 山田さんはおもしろい人です ヤマダサンワオモシロイヒトデス
8 わたしは歌がへたです ワタシワウタガヘタデス
9 熱いお茶が飲みたいです アツイオチャガノミタイデス

カナ文からローマ字文への変換

カナ文への変換ができたら、ローマ字文への変換は簡単です。 カナとローマ字を変換してくれる romkan ライブラリを使います。

def convert_sentence_to_roman(batch):
    s = mecab.parse(batch["sentence"])
    kana = ""
    for line in s.split("\n"):
      if line.find("\t")<=0: continue
      columns = line.split(',')
      if len(columns) < 10:
        kana += line.split('\t')[0]
      else:
        kana += columns[9]
    roman = romkan.to_roma(kana)
    batch["roman"] = kana
    return batch

こちらも正しく変換ができたか確認してみましょう。

# "kana" は邪魔なので除いておく
show_random_elements(common_voice_test.remove_columns(["path","audio", "kana"])) 

sentence roman
0 イさんはかぜをひいているので元気じゃありません isanwakazeohi-teirunodegenkijaarimasen
1 ツュレンハルト領はヴュルテンベルク領に編入された tsuxyurenharutoryo-wabyurutenberukuryo-nihen'nyu-sareta
2 航空事故を限りなくゼロに近づけるにはそれほどなり振りかまわぬ努力がいる ko-ku-jikookagirinakuzeronichikazukeruniwasorehodonarifurikamawanudoryokugairu
3 お偉方がぞくぞくと登場し恐縮する oerakatagazokuzokutoto-jo-shikyo-shukusuru
4 小林さんは青い傘を持っています kobayashisanwaaoikasaomotteimasu
5 冷蔵庫に卵や野菜や果物などがあります re-zo-konitamagoyayasaiyakudamononadogaarimasu
6 この絵は色がきれいです konoewairogakire-desu
7 ペンシルベニア州フィラデルフィアの郊外ウィンウッドのランケナウ病院で生まれた penshirubeniashu-firaderufianoko-gaiwin'uddonorankenaubyo-indeumareta
8 危ないのであそこの窓を開けてはいけません abunainodeasokonomadooaketewaikemasen
9 先月わたしは会社をやめました sengetsuwatashiwakaishaoyamemashita

Tokenizerの作成

ここからは用意した日本語データセットを使って、Fine Tuning するための準備を進めていきます。

事前学習された wav2vec 2.0 のモデルは、言語非依存のモデルなので出力が日本語に対応していません。 出力を日本語データセットの形式に合わせるために、専用の Tokenizer というものを作成する必要があります。

その準備のために、テキストデータを1文字ずつに分割して、重複を除去したものを vocab_dict_{sentence|kana|roman} に辞書形式で格納します。

3パターンの辞書(vocab_dict)を作成しているコード(クリックすると開きます)

def extract_all_chars_sentence(batch):
  all_text = " ".join(batch["sentence"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

def extract_all_chars_kana(batch):
  all_text = " ".join(batch["kana"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

def extract_all_chars_kana(batch):
  all_text = " ".join(batch["roman"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

vocab_train_sentence = common_voice_train.map(extract_all_chars_sentence, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test_sentence = common_voice_test.map(extract_all_chars_sentence, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)
vocab_list_sentence = list(set(vocab_train_sentence["vocab"][0]) | set(vocab_test_sentence["vocab"][0]))
vocab_dict_sentence = {v: k for k, v in enumerate(vocab_list_sentence)}

vocab_train_kana = common_voice_train.map(extract_all_chars_kana, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test_kana = common_voice_test.map(extract_all_chars_kana, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)
vocab_list_kana = list(set(vocab_train_kana["vocab"][0]) | set(vocab_test_kana["vocab"][0]))
vocab_dict_kana = {v: k for k, v in enumerate(vocab_list_kana)}

vocab_train_roman = common_voice_train.map(extract_all_chars_roman, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test_roman = common_voice_test.map(extract_all_chars_roman, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)
vocab_list_roman = list(set(vocab_train_roman["vocab"][0]) | set(vocab_test_roman["vocab"][0]))
vocab_dict_roman = {v: k for k, v in enumerate(vocab_list_roman)}

それぞれの辞書の数を見てみましょう。

len(vocab_dict_sentence), len(vocab_dict_kana), len(vocab_dict_roman)

(1432, 82, 30)

かな漢字文の場合は、漢字がたくさん含まれているので、かなりの数になります。 カナ文やローマ字の場合は漢字を読みに変換しているので、 情報量が落ちている代わりに vocab_dict の数を抑えられています。

ここからは、3種類のデータセットで行う処理が全く同じなので、vocab_dict_{sentence|kana|roman} を vocab_dict に統一して記載します。実際に動かす場合は、それぞれ読み替えてください。

vocab_dict にこのあとの処理で必要な、未知語を意味する"[UNK]"と、空白を意味する"[PAD]"を追加して、jsonファイルに出力します。 CTC(Connectionist Temporal Classification)というアルゴリズムで、音声の時系列とテキストの対応を計算する際に必要な処理です。(https://distill.pub/2017/ctc/

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

出力したjsonファイルを使用して、Tokenizer を作成します。

from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

Feature Extractor の作成

次に、音声データを事前学習モデルに入力できる形に変換するための Feature Extractorを作成します。

from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

これから使う事前学習モデルはサンプリングレート16000Hzで学習されているので、sampling_rate は 16000 に設定する必要があります。

この後の処理を簡単にするために、Feature Extractor と Tokenizer をまとめた Processor を作成します。

from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

データの前処理

common_voice データセットの音声データは16000Hzではないかもしれないので、全部16000Hzにリサンプリングします。

common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))

ここで先ほど tokenizer と feature Extractor をまとめた processor を使って、音声データと書き起こし文をこの後処理しやすい形に変換します。

def prepare_dataset(batch):
    audio = batch["audio"]

    # batched output is "un-batched"
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["sentence"]).input_ids
    return batch

common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names, num_proc=4)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, num_proc=4)

学習と推論

data collector を定義します。これ以降の学習処理は基本的に参考にした記事と同じにしています。

data collector の詳細

import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

事前学習されたモデルを用意します。 今回は、wav2vec2-large-xlsr-53 という53言語のデータを用いて事前学習された大規模モデルを使用しています。

モデルの詳細

from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-xlsr-53", 
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.1,
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

# Fine Tuning で Feature Extractor が変化しないように設定を入れる
model.freeze_feature_extractor()

学習時の設定値を決めます。

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./wav2vec2-large-xlsr-japanese-demo",
  group_by_length=True,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=30,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=3e-4,
  warmup_steps=500,
  save_total_limit=2,
)

今まで作成してきたものを使って、Trainer を作成します。

from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

ここまでお疲れ様でした。
後は実際に学習を実行して、想定したものが正しくできているかを確認するだけです。
私の環境では学習に2時間ほどかかりました。

trainer.train()

学習時の出力例(カナ文の場合)

Step Training Loss Validation Loss
100 11.832900 9.277085
200 4.031800 4.175367
300 3.988200 4.157576
400 3.896900 4.027470
500 3.612100 3.709400
600 1.262500 1.241253
700 0.588300 0.966568
800 0.411600 0.989113
900 0.343900 0.975543
1000 0.280400 1.019831
1100 0.268400 0.964782
1200 0.282000 0.991331

それでは、学習した3種類のモデルについて、それぞれテストデータを入力してどのような文字列が出力されるかを確認してみましょう。

カナ漢字文

pred_str text
先んのはいがはかませんてでした 先生の話は意味が分かりませんでした
先日のにが日とがいあります 先生の机の上に辞書が二冊と雑誌が一冊あります
そなでまどがきます 少し暑いので窓を開けます
わんのにとがあります かばんの中にノートがあります
こ前とてをしました 五年前妻と結婚しました
スはあるかられどとさんれた テロがあるからやめろとさんざんいわれた
先あのけ本ありました 先週姉の結婚パーティーがありました
テもりるのここはの学かにいきををします 友情思いやり協力の心は将来の社会生活に強い影響を及ぼします

認識精度はお世辞にも良いとは言えません。 その理由としては、vocab_dict の数が多く、同じ漢字が出現することが少ないために学習がうまくいかなかったことが挙げられそうです。 また学習データの中に出現しなかった漢字が、テストデータの中に出現していそうですね。

カナ文

pred_str text
センセーノハシワイーミカワカネマセンテシタ センセーノハナシワイミガワカリマセンデシタ
センセーノツクエノウエニチショガニサツトナッシガイッサツアリマス センセーノツクエノウエニジショガニサツトザッシガイッサツアリマス
スコシアツイナデマドオガアキマス スコシアツイノデマドオアケマス
コバンノナカニノートガアリマス カバンノナカニノートガアリマス
ゴネンマエキマトケッコンシマシタ ゴネンゼンサイトケッコンシマシタ
テロガアルカラリャトートサンザンリワレタ テロガアルカラヤメロトサンザンイワレタ
センキューアネノケッコンパーテイガアリマシタ センシューアネノケッコンパーティーガアリマシタ
ユージョーオモアリキョールクノココロワショーライノシャカイセーカツニズヨイエーキョーオオーボシマス ユージョーオモイヤリキョーリョクノココロワショーライノシャカイセーカツニツヨイエーキョーオオヨボシマス
カレノソキロホンルリダッテケキケキナサラナナガオサメタ カレモショキューオホンルイダシテゲキテキナサヨナラガチオオサメタ
ニッポンデワタミタゼガツヨイデス ニッポンデワシュンプーガツヨイデス

一見わかりにくいですが、カナ漢字文と比較して音が綺麗に推測できていると言えそうです。

一方で、うまくいかなかったケースもありました。 例えば最後の文は、かな漢字文だと「日本では春風が強いです」であり、 音声は「シュンプー」ではなく「ハルカゼ」と読んでいると思われます。 形態素解析のミスによってデータにノイズが含まれていることが示唆されます。

ニッポンデワタミタゼガツヨイデス ニッポンデワシュンプーガツヨイデス

ローマ字

pred_str text

sensu-noharashiwai-mikawakaremasenteshita sensu-nohekuta-runazuhiwaimigawakarimasendeshita
sense-notsukuenoueniiichishogan'isatsutoyasshigaissatsuarimasu sensu-nottosukuenouenijizuhogan'izuattosutozasshigaissatsuarimasu
sok-shiwatsunademadogaakemasu sukoshiatsuinodemadooakemasu
gabannonakanio-togaarimasu kabannonakanino-togaarimasu
gonenmaeimatokekonshimashita go-nenzensaitokekkonshimashita
terogaaa-rukararyajho-tosanzaniwareta terogaarukarayamerotosanzaniwareta
senkyu-anenokekkonpa-teigaarimashita senshu-anenokekkonpa-ti-gaarimashita
yu-jo-omoiarikyo-rukunokokorowasho-rainoshakaise-katshunizuyoie-kyo-oooboshimasu yu-jo-omoiyarikyo-ryokunokokorowasho-rainoshakaise-katsunitsuyoie-kyo-ooyoboshimasu
tarnoosokyrohponru-ridat-tsute-kekikekunasarananakakuo-sameta karemoshokyu-ohon'a-ruyu-aidi-azuhittoegekittoekinasayonaragachioosameta
nippondewatabikazegatsugyoidesu nippondewashunpu-gatsuyoidesu

カナ文からさらに読みづらくはありますが、比較的うまくいっていそうです。 ローマ字の中にシングルクオート「'」が含まれていますが、これは例えば「んい」と「に」を区別するための記号です。

おわりに

この記事では、手軽に自分だけの音声認識モデルを構築する方法についてご紹介しました。

日本語データセットを整備して、かな漢字・カナ・ローマ字のそれぞれでどのようなモデルが生成できるのかを確認しました。 カナ漢字文をそのまま使う場合は今回のようなシンプルなやり方だとうまくいきませんでした。 カナ文とローマ字文の場合は、認識精度が比較的高いモデルができました。

ただし、出力される文もカナ・ローマ字のままになってしまうので、 実用にあたっては、後処理でカナ漢字文に変換するなどの対応が必要です。

また、今回は言語モデルを使用せずにシンプルに音声データとテキスト情報の対応をとりました。 事前学習済みの言語モデルを使用することで単語の前後関係や文脈を考慮したより高い性能の音声認識モデルを構築できると考えられます。

大規模データセットで学習された音声認識モデルが無償で誰でも使えるなんて、素晴らしい時代に生まれたと感じます。 自前で音声認識モデルをチューニングできることで、音声認識サービスのトライアンドエラーをどんどん回すことが可能になると期待しています。 今回ご紹介した手法はまだまだ改善の余地があるので、引き続き技術検討を進めたいと思います。

それでは、明日の記事もお楽しみに!

参考にしたもの

https://maelfabien.github.io/machinelearning/wav2vec/ https://huggingface.co/blog/fine-tune-wav2vec2-english/ https://tech.fusic.co.jp/posts/2021-03-30-wav2vec-fune-tune/

© NTT Communications Corporation 2014