進化的モデルマージで日本語がわかるソースコード生成LLMを作ってみる

こんにちは、イノベーションセンターの加藤です。普段はコンピュータビジョンの技術開発やAI/機械学習(ML: Machine Learning)システムの検証に取り組んでいます。一方で、兼務で生成AIチームに参加し、大規模言語モデル(LLM: Large Language Model)に関する技術の調査を行なっています。

この記事では、日本語のコード生成のデータセットが無い条件下で、進化的モデルマージを活用することで日本語とソースコード生成に特化した大規模言語モデル(LLM)を合成した試みについて紹介します。

目次

モデルマージとは

学習済みモデルの重みを組み合わせることで、元のモデルの持つ能力を組み合わせたり精度を向上させたりできる手法です。例えばベースモデルの重み \theta_0に対し、さまざまなハイパーパラメータでファインチューニングした複数の派生モデルの重み \theta_1, \theta_2, ..., \theta_nがあるとき、これらを平均した \theta_S=\sum_{i=1}^n\theta_i/nはmodel soupと呼ばれ元の派生モデルよりも性能が良くなり、かつ入力の変化にも頑健になると言われています1

一方で、ベースとなるLLMの重み \theta_0に対し数学の問題データでファインチューニングした派生モデル \theta_Mと、ソースコード生成のためのデータセットでファインチューニングした派生モデル \theta_Cがあったとします。このとき、派生モデルの重みとベースモデルの重みの差分 \tau_B = \theta_M - \theta_0, \tau_C = \theta_C - \theta_0はtask vectorと呼ばれ、ファインチューニングによって得られた新しい能力の重みと見なすことができ、これの和を取ることで数学とプログラミング両方に強い新たなLLMの重み \theta_{MC}=\theta_0 + \tau_B + \tau_Cを作ることができます。このような差分を線型結合するテクニックはtask arithmetic2と呼ばれ、足すだけでなく引くことで特定の性質(例えば攻撃性など)を取り除くという操作も可能とされています。

これらの手法は複数のモデルの出力を合成するアンサンブル手法とは違い、モデルのサイズが変化しないので推論速度が元のモデルから変化しないという利点があります。非線形操作を多量に含んでいるモデルの重みを単純に足し引きする操作は直観的に機能しなさそうですが、実験的にはうまくいく3ようで言語モデルや画像生成などさまざまな分野でいろいろなマージモデルが作成されています。

また、重みの合成方法についても単純な線形補間にとどまらないさまざまな手法が提案されています。今回の記事では、task vectorからランダムにパラメータを0にして残りのパラメータをスケーリングするDARE4という手法と、マージするパラメータの符号ができるだけ一致するようにいくつかのパラメータを0にするTIES5という手法を組み合わせたものを利用します。いずれの手法も、それぞれのtask vectorには能力の発現に実際に寄与するパラメータは少数であるという仮定をおき、そのようなパラメータがもう一方のtask vectorに乱されないように工夫をするというアプローチによって、より性能の良いマージモデルの作成を可能にしています。

進化的モデルマージとは

モデルのパラメータを合成する際の重み付け係数やDAREなどにおける非ゼロの密度などは、マージ後の精度を左右する重要なハイパーパラメータです。このハイパーパラメータを探索するために進化的アルゴリズムを活用したのがSakana AIの提案した進化的モデルマージです6。この手法は指定されたデータセットの評価指標を最適化するようなパラメータを繰り返し探索するもので、データセット全体に対する評価を何度も実行する必要がある一方で勾配を必要としないという利点があります。Sakana AIが提案した手法はDAREとTIESを併用したマージ手法とCMA-ESと呼ばれる進化的アルゴリズムを組み合わせたもので、モデルマージの際に広く使われているライブラリであるmergekit7にも実装されています。この記事では、mergekitに実装された進化的モデルマージを用いて、日本語LLMモデルとコード生成モデルを合成してみました。

利用したモデル

マージに利用したモデルを紹介します。モデルマージの制約上、全てのモデルは同じベースモデルから派生している必要があるため、今回はMeta社の基盤モデルであるLlama2 (70億パラメータ)をベースとした派生モデルを合成しました。利用モデルの派生関係は以下の図に示しています。

日本語LLM

日本語を理解できるLLMとして今回はELYZA Japanese Llama Instruct 7Bを利用しました。Llama2をベースに約180億トークンの日本語データセットで追加事前学習したモデルであり8、日本語に対する高い処理能力が期待できます。

コード生成特化

マージ用のモデルとして、大部分をソースコードが占める計5000億トークンのデータセットで追加学習し質問応答用にファインチューニングしたCodeLlama Instructと、ソースコードに関する質問応答データセットであるEvolved CodeAlpacaを用いてLlama2をファインチューニングしたLlama-2-7b-evolcodealpacaの2つを用意しました。 また、比較対象としてCodeLlama Instructに日本語データセットを追加事前学習させたELYZA Japanese CodeLlama Instruct 7Bを用意し、マージモデルがどれくらいこの追加事前学習モデルに匹敵するかを調査しました。

MergeKitによる実験

利用モデル

mergekitに実装された進化的モデルマージを用いて、以下の組み合わせでモデルマージを行いました。

  • ELYZA Japanese Llama Instruct 7B + CodeLlama Instruct
  • ELYZA Japanese Llama Instruct 7B + Llama-2-7b-evolcodealpaca

マージモデルと元モデルとの関係は次の図のようになります。

マージ用データセット

ハイパーパラメータ探索の際の評価用データセットは以下の2つを利用しました。

JSQuAD

日本語向けの言語理解ベンチマークJGLUE9の一部であり、質問応答データセットSQuADの日本語版です。日本語の文章読解能力を評価するために使用しました。回答は文章生成で行いますが、正しい形式で答えさせるために2つの例題をプロンプトに加えています。この工夫によって、LLMに「はい、お答えします。」のような無意味な回答をしないように誘導できます。 指標はexact_matchを使いました。これはLLMの出力が想定回答と完全一致した場合に正解とみなすものです。

プロンプトと想定回答の例

入力

以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
### 指示:
与えられた文脈から、質問に対する答えを抜き出してください。

### 入力:
{例題1}
### 応答:
{回答1}

### 指示:
与えられた文脈から、質問に対する答えを抜き出してください。

### 入力:
{例題2}
### 応答:
{回答2}

### 指示:
与えられた文脈から、質問に対する答えを抜き出してください。

### 入力:
文脈:梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国の南部から長江流域にかけての沿海部、および台湾など、東アジアの広範囲においてみられる特有の気象現象で、5月から7月にかけて来る曇りや雨の多い期間のこと。雨季の一種である。
質問:日本で梅雨がないのは北海道とどこか。

### 応答:

出力

小笠原諸島

CoNaLa

CoNaLa10はStack Overflowから収集されたプログラミング言語のデータセットであり、質問文と短いPythonプログラムのペアからできています。ソースコード生成の能力を評価するためにこれを利用しました。こちらも2つの例題をプロンプトに加えています。 指標はBLEU11が使われます。

プロンプトと想定回答の例

入力

Answer the following instructions in one line of Python code:
Instruction:
{例題1}
Solution:
{回答1}
Instruction:
{例題2}
Solution:
{回答2}
Instruction:
send a signal `signal.SIGUSR1` to the current process
Solution:

出力

os.kill(os.getpid(), signal.SIGUSR1)

評価用データセット

マージ後のモデルの性能を測るために、上記のデータセットに加えて、以下の日本語データセットとプログラミング言語データセットを利用しました。

JCommonsenseQA

JSQuADと同様に日本語向けの言語理解ベンチマークJGLUEの一部であり、常識を問うタスクCommonsenseQAの日本語版です。日本語の文章読解能力を評価するためにJSQuADと共に使用しました。3つの例題をプロンプトに加え、選択式で答えさせています。指標はaccuracyを使いました。

入出力例

入力

以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
### 指示:
与えられた選択肢の中から、最適な答えを選んでください。出力は以下から選択してください:
{選択リスト1}
### 入力:
{例題1}

### 応答:
{回答1}

### 指示:
与えられた選択肢の中から、最適な答えを選んでください。出力は以下から選択してください:
{選択リスト2}
### 入力:
{例題2}

### 応答:
{回答2}

### 指示:
与えられた選択肢の中から、最適な答えを選んでください。出力は以下から選択してください:
{選択リスト3}
### 入力:
{例題3}

### 応答:
{回答3}

### 指示:
与えられた選択肢の中から、最適な答えを選んでください。出力は以下から選択してください:
- 掲示板
- パソコン
- マザーボード
- ハードディスク
- まな板

### 入力:
電子機器で使用される最も主要な電子回路基板の事をなんと言う?

### 応答:

出力

マザーボード

HumanEval

HumanEval12は関数名とそのドキュメント(説明文)から内部のPython実装を生成させるデータセットです。一般的な文書生成の評価手法とは異なり、実際に実装を動作させてテストが通るかどうかで生成結果を評価しています。変数名が正解と異なっていてもプログラムが正しければ正解とみなされるため、より実用に向いた評価指標といえます。また、人手で作成されたため、少量ながら高品質なデータであることも特徴です。 指標はpass@kが使われ、LLMが k個の答えを生成し1つでも合っていれば正解とみなします。

入出力例

入力

from typing import List
def has_close_elements(numbers: List[float], threshold: float) -> bool:
  """
  Check if in given list of numbers, are any two numbers closer to each other than given threshold.
  >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
  False
  >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
  True
  """

出力

for idx, elem in enumerate(numbers):
 for idx2, elem2 in enumerate(numbers):
  if idx != idx2:
   distance = abs(elem - elem2)
if distance < threshold:
  return True
return False

テスト

def check(candidate):
  assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True
  assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False
  assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True
  assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False
  assert candidate([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True
  assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True
  assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False

JHumanEval

JHumanEval13はHumanEvalのドキュメント部分を日本語化したもので、日本語を理解しかつソースコードを生成する能力を測ることができます。

入力例

from typing import List
def has_close_elements(numbers: List[float], threshold: float) -> bool:
  """リストnumbersの中に、与えられたthresholdより近い2つの数値が存在するか判定する
  >>> has_close_elements([1.0, 2.0, 3.0], 0.5)
  False
  >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
  True
  """

実験結果

マージ前、マージ後、比較用モデルそれぞれの評価結果は次のようになりました。

model JSQuAD (2shot, exact match) JCommonsenseQA (3shot) CoNaLa (2shot, bleu) HumanEval (pass@1) JHumanEval (pass@1)
(J) ELYZA Japanese Llama Instruct 0.6673 0.7194 0.2743 0.1280 0.0976
(C1) CodeLlama Instruct 0.6855 0.6452 0.3582 0.3293 0.2744
(C2) Llama-2-7b-evolcodealpaca 0.6171 0.58 0.2435 0.3354 0.2683
--- --- --- --- --- ---
ELYZA Japanese CodeLlama Instruct 0.6828 0.7373 0.3358 0.3232 0.2378
--- --- --- --- --- ---
J + C1 0.3165 0.5505 0.2236 0.0732 0.0610
J + C2 0.6517 0.7051 0.2412 0.3415 0.2134

作成した2つのマージモデルのうち、Llama-2-7b-evolcodealpacaを用いた方(J + C2)は日本語能力(JSQuAD,JCommonsenseQA)を保ったままコード生成の性能を元の日本語LLMから向上させることができ、追加事前学習を行なったELYZA Japanese CodeLlama Instructとほぼ同等のHumanEval, JHumanEval性能を持たせることができました。

考察

JHumanEvalの回答に必要な日本語能力

結果を見ると、JHumanEvalの性能はJSQuADやJCommonsenseQAよりもHumanEvalの性能に相関しているようです。関数ドキュメント程度の文章量ではあまり読解能力や常識を必要としなくてもうまくコード生成ができるようです。HumanEvalでは網羅できないような、長く複雑な仕様文書からのコード生成能力を測る必要がありそうです。

CodeLlamaのマージ困難性

日本語LLMとCodeLLamaをマージすると(J + C1)全ての性能が大きく劣化してしまいました。DAREの論文でも挙げられている通りCodeLlamaは追加学習量が多く、ベースモデルからパラメータが大きくずれているためか上手くマージできないようです。この問題は進化的アルゴリズムによるハイパーパラメータ探索でも解決が難しそうだということがわかりました。

まとめ

進化的モデルマージを利用して日本語LLMとコード生成LLMを合成し、両方の能力を獲得できるか実験しました。結果として、日本語のタスクとコード生成のタスクの両方で性能の良いモデルは作成できたものの、日本語からソースコードを生成するというタスクは設計が難しそうなこと、ベースラインからの差分が大きな派生モデルはハイパラ探索を活用してもマージ困難であることがわかりました。


  1. Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time. https://arxiv.org/abs/2203.05482
  2. Editing Models with Task Arithmetic. https://arxiv.org/abs/2212.04089
  3. モデルマージの理論的な背景について研究している論文に次のようなものがあります。 Task Arithmetic in the Tangent Space: Improved Editing of Pre-Trained Models. https://arxiv.org/abs/2305.12827
  4. Language Models are Super Mario: Absorbing Abilities from Homologous Models as a Free Lunch. https://arxiv.org/abs/2311.03099
  5. TIES-Merging: Resolving Interference When Merging Models. https://arxiv.org/abs/2306.01708
  6. https://sakana.ai/evolutionary-model-merge-jp/
  7. https://github.com/arcee-ai/mergekit
  8. https://note.com/elyza/n/na405acaca130
  9. https://techblog.yahoo.co.jp/entry/2022122030379907/
  10. https://conala-corpus.github.io
  11. https://huggingface.co/spaces/evaluate-metric/bleu
  12. Evaluating Large Language Models Trained on Code. https://arxiv.org/abs/2107.03374v2
  13. https://huggingface.co/datasets/kogi-jwu/jhumaneval
© NTT Communications Corporation 2014