最新のPyTorchで軽量OCRモデルPARSeqをTensorRT化する

こんにちは。イノベーションセンターの加藤です。普段はコンピュータビジョンの技術開発やAIシステムの検証に取り組んでいます。 今回は最新版のPyTorchを使って軽量なTransformerベースOCRモデルであるPARSeq(Permuted Autoregressive Sequence)をTensorRTモデルに変換して高速化した取り組みについて紹介します。

PARSeqとは

PARSeq1はVision Transformer(ViT)を特徴抽出器として用いる文字認識モデルであり、以下の画像のような文章生成の形をとっています。 このような文章生成モデルでは、まず画像をトークンに分割したものをTransformer Encoderで特徴抽出し、これをもとにTransformer Decoderで次の文字トークンの予測を繰り返します。PARSeqの場合は文字トークンの予測方法にオプションがあり、以前の予測を参照しながら1文字ずつ予測するもの(Autoregressive)、一度に全部の文字を予測するもの(Non-autoregressive)、一度予測した文字を入力し直して洗練するもの(Iterative refinement)の三通りのデコード戦略があります。

PARSeqの特徴はTransformerベースでありながら非常に軽量である点です。 Encoder部分は一般的なViTと同様に12層のTransformerレイヤーで構成されていますが、Decoder部分はたった1層しかなく、 一般的なVision Language Modelが数十億のパラメータを抱えている一方でPARSeqは数千万パラメータに留まっています。

PARSeqのTensorRT化

このPARSeqモデルをさらに高速化するために、今回はTensorRTモデルに変換します。 TensorRT2は、NVIDIAが提供しているディープラーニングモデルの推論を高速化するためのツールで、さまざまなAIフレームワークが対応している共通フォーマットのONNX3からの変換や、PyTorchモデルからの直接変換が可能です。 実はNVIDIAが公式ブログでPARSeqをTensorRT化する記事を公開している4のですが、 PARSeqやその依存先のPyTorchのバージョンが古くそのままでは動作しないため、本稿では最新版(PyTorch 2.10, PARSeq 2024年2月版)を使ったTensorRT化の流れを紹介します。

PyTorch Lightningによるモデル変換

PARSeqはPyTorchによって実装されたモデルをPyTorch-Lightningで制御しており、ONNXやTensorRTへの変換はPyTorch-Lightningが提供する関数を利用できます。 NVIDIAのブログでもto_onnx()を利用して一度ONNX化したのち、trtexecと呼ばれるツールを使ってONNXからTensorRTへ変換しています。 今回はto_tensorrt()を利用して、モデルを直接TensorRTに変換してみます。

import torch

parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
parseq.model.refine_iters = 0  # Iterative refinementを無効化
parseq.model.decode_ar = False  # Non-autoregressive mode

output_path = "engine.pt2"
img = torch.randn(1,3,32,128)
parseq.to_tensorrt(output_path, img, ir="dynamo")

これで無事TensorRTモデルengine.pt2に変換できました。このモデルは以下のように呼び出すことができます。

import torch
import torch_tensorrt  # <- 必須

parseq = torch.export.load("engine.pt2").module()
img = torch.randn(1,3,32,128).cuda()
parseq(img)
# torch.Size([1, 26, 95])  26は一度に推測可能な文字数、95は対応文字種

AutoregressiveとIterative refinementがTensorRT化できない問題

しかしながら、この方法ではAutoregressive(decode_ar=True)またはIterative refinement(refine_iters>0)に対応したモデルを作ろうとするとエラーになってしまいます。 論文ではNon-autoregressiveよりAutoregressiveの方が高精度5とされており、またIterative refinementも1回適用するだけでそれなりに精度が向上するため、ぜひこれらのモードもTensorRTで活用したいです。 そこでPARSeqの実装を改造しTensorRT化に挑戦しました。

Autoregressive modeのTensorRT化

まず先ほどと同じ方法ではどこで落ちるかをみてみます。

import torch

parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
parseq.model.refine_iters = 0
parseq.model.decode_ar = True  # AR mode

output_path = "engine.pt2"
img = torch.randn(1,3,32,128)
parseq.to_tensorrt(output_path, img, ir="dynamo")

表示されるエラーは以下のとおりです。

  File "/root/.cache/torch/hub/baudm_parseq_main/strhub/models/parseq/model.py", line 144, in forward
    if testing and (tgt_in == tokenizer.eos_id).any(dim=-1).all():
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)

これは「文章の終了を示すEOSトークンが出たら生成を停止する」処理の部分であり、どうもif文による分岐はTensorRTと相性が悪いようです。 しかしこれは1文字生成を繰り返すAutoregressive modeでは必須の処理であるため、1文字生成する実装のみをTensorRT化し、繰り返し部分はモデルの外側でやるように変えてみます。

import pytorch_lightning as pl                                                                            
from torch import Tensor                                                                                  
from typing import Optional                                                                               
                                                                                                          
class PARSeqEncoder(pl.LightningModule):                                                                  
    def __init__(self, model):                                                                            
        super().__init__()                                                                                
        self.encoder = model.encoder                                                                      
    def forward(self, images: Tensor) -> Tensor:                                                          
        memory = self.encoder(images)                                                                     
        return memory 

class PARSeqDecoder(pl.LightningModule):                                                                 
    def __init__(self, tokenizer, model):                                                                
        super().__init__()                                                                               
        self.tokenizer = tokenizer                                                                       
        self.max_label_length = model.max_label_length                                                   
        self.text_embed = model.text_embed                                                               
        self.pos_queries = model.pos_queries                                                             
        self.decoder = model.decoder
        self.head = model.head                                                                           
    def forward(self, memory: Tensor, input_ids: Tensor) -> Tensor:                                      
        B, S = input_ids.size(0), input_ids.size(1)                                                      
        null_ctx = self.text_embed(input_ids[:, :1])                                                     
        tgt_emb = self.pos_queries[:, :S-1] + self.text_embed(input_ids[:, 1:])                          
        tgt_emb = torch.cat([null_ctx, tgt_emb], dim=1)                                                  
        tgt_query = self.pos_queries[:, S-1:S].expand(B, -1, -1)
        tgt_mask = torch.triu(torch.ones((S, S), dtype=torch.bool), 1).to(tgt_emb.device)                
        decoder_outputs = self.decoder(tgt_query, tgt_emb, memory, content_mask=tgt_mask)
        return self.head(decoder_outputs) 

ここでエンコーダとデコーダが切り離されています。これはエンコードを一度実行したのち、デコードをEOSトークンが出るまで繰り返す必要があるためです。 推論は以下のようになります。

img_transform = T.Compose([
    T.Resize((32, 128), T.InterpolationMode.BICUBIC),
    T.ToTensor(),
    T.Normalize(0.5, 0.5),
])

_parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
bos_id = _parseq.tokenizer.bos_id
pad_id = _parseq.tokenizer.pad_id
eos_id = _parseq.tokenizer.eos_id
parseq_encoder = PARSeqEncoder(_parseq.model)
parseq_decoder = PARSeqDecoder(_parseq.tokenizer, _parseq.model)

img = Image.open("world.png").convert("RGB")
img = img_transform(img).unsqueeze(0)

with torch.no_grad():
    num_steps = _parseq.model.max_label_length + 1
    input_ids = torch.full((1, num_steps), pad_id, dtype=torch.long)
    input_ids[:,0] = bos_id
    memory = parseq_encoder(img)
    preds = []
    for i in range(num_steps-1):
        j = i + 1
        logit = parseq_decoder(memory, input_ids[:, :j])
        preds.append(logit.softmax(-1))
        input_ids[:, j:j+1] = logit.argmax(-1)
        if (input_ids == eos_id).any(dim=-1).all():
            break
    label, confidence = _parseq.tokenizer.decode(torch.cat(preds, dim=1))
print(f"AR result: {label[0]}")

そして変換は次のように行います。input_idsの長さは伸び縮みするため最短・最長を指定しておく必要があります。

parseq_encoder.to_tensorrt("encoder.pt2", img, ir="dynamo")
decoder_input_ids = torch_tensorrt.Input(
    min_shape=[1, 1],
    opt_shape=[1, num_steps],
    max_shape=[1, num_steps],
    dtype=torch.int64)
encoder_outputs = torch_tensorrt.Input(
    min_shape=[1, 128, 384],
    opt_shape=[1, 128, 384],
    max_shape=[1, 128, 384],
    dtype=torch.float32)
parseq_decoder.to_tensorrt("decoder.pt2", (encoder_outputs, decoder_input_ids), ir="dynamo")

TorchDynamoの機嫌をとる

しかしながら、なぜかこれはデコーダ(PARSeqDecoder)の変換に失敗します。本来入力するinput_idsのトークン長は1以上あれば動作するはずですが、以下のように3以上に限定しなさいというエラーが出てきます。

 - Not all values of _1 = L['input_ids'].size()[1] in the specified range _1 <= 26 satisfy the generated guard 3 <= L['input_ids'].size()[1] and L['input_ids'].size()[1] <= 26
Suggested fixes:
  _1 = Dim('_1', min=3, max=26)

これはTensorRT化よりも前の、TorchDynamoがソースコードを解析するときに発生しているエラーなのですが、どこが原因なのかをTorchDynamoを使って探ってみます。

from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device
from torch_tensorrt.dynamo._tracer import get_dynamic_shapes_args
from torch.export import Dim, export, draft_export

arg_inputs = (encoder_outputs, decoder_input_ids)
parseq_decoder.to("cuda")

device = to_torch_device("cuda")
torch_arg_inputs = get_torch_inputs(arg_inputs, device)
dynamic_shapes = get_dynamic_shapes_args(parseq_decoder, arg_inputs)
ep = draft_export(  # エラーが起きても最後まで解析させることで全てのエラーを収集する
    parseq_decoder,
    tuple(torch_arg_inputs),
    dynamic_shapes=dynamic_shapes,
)
print(ep._report)

すると以下のような警告が確認できます。

###################################################################################################
WARNING: 2 issue(s) found during export, and it was not able to soundly produce a graph.
Please follow the instructions to fix the errors.
###################################################################################################

1. Guard Added.
    A guard was added during tracing, which might've resulted in some incorrect
    tracing or constraint violation error.
    Specifically, this guard was added: Ne(s70 - 1, 1), where {'s70': "L['input_ids'].size()[1]"}.
    This occurred at the following stacktrace:
        File /opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py, lineno 1776, in _wrapped_call_impl
        File /opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py, lineno 1787, in _call_impl
        File /workspace/src/ar_deploy_decoder.py, lineno 31, in forward
            tgt_emb = self.pos_queries[:, :S-1] + self.text_embed(input_ids[:, 1:]):

        Locals:
            self: [None]
            S: ['s70']
            input_ids: ['Tensor(shape: torch.Size([1, s70]), stride: (s70, 1), storage_offset: 0)']

        Symbols:
           s70: L['input_ids'].size()[1]

    And the following framework stacktrace:
        File /opt/venv/lib/python3.12/site-packages/torch/_prims_common/__init__.py, lineno 404, in is_contiguous_for_memory_format
        File /opt/venv/lib/python3.12/site-packages/torch/_prims_common/__init__.py, lineno 317, in is_contiguous
        File /opt/venv/lib/python3.12/site-packages/torch/_prims_common/__init__.py, lineno 277, in check_contiguous_sizes_strides
            if maybe_guard_or_false(x == 1):
(以下省略)

テンソルをS-1の長さにスライスするところでS-1 != 1という制約がDynamoによって導入されています。 どうやらスライスをした時長さ1になりうる可変長テンソルは問題があるようです。(おそらく0/1-specialization6と呼ばれる処理と関係があるのですが、なぜこうなっているのかはよく分かりません...)

そこでスライスを行わない形に実装を直しておきます。

class PARSeqDecoder(pl.LightningModule):
    def __init__(self, tokenizer, model):
        super().__init__()
        self.tokenizer = tokenizer
        self.max_label_length = model.max_label_length
        self.text_embed = model.text_embed
        # self.pos_queries = model.pos_queries                                                             
        self.prefixed_pos_queries = torch.nn.Parameter(torch.cat([torch.zeros_like(model.pos_queries)[:,:1], model.pos_queries], dim=1))
        self.decoder = model.decoder
        self.head = model.head
    def forward(self, memory: Tensor, input_ids: Tensor) -> Tensor:
        B, S = input_ids.size(0), input_ids.size(1)
        tgt_emb = self.prefixed_pos_queries[:, :S] + self.text_embed(input_ids)
        tgt_query = self.prefixed_pos_queries[:, S:S+1].expand(B, -1, -1)
        tgt_mask = torch.triu(torch.ones((S, S), dtype=torch.bool), 1).to(tgt_emb.device)
        decoder_outputs = self.decoder(tgt_query, tgt_emb, memory, content_mask=tgt_mask)
        return self.head(decoder_outputs)

これで無事変換が通るようになりました。

Iterative refinementのTensorRT化

次にIterative refinementを行うデコーダのTensorRT化を行います。 元のPARSeq実装からrefinementを行う箇所を切り出しPyTorch Lightningでラップします。

class PARSeqRefiner(pl.LightningModule):
    def __init__(self, tokenizer, model):
        super().__init__()
        self.tokenizer = tokenizer
        self.max_label_length = model.max_label_length
        self.text_embed = model.text_embed
        self.prefixed_pos_queries = torch.nn.Parameter(torch.cat([torch.zeros_like(model.pos_queries)[:,:1], model.pos_queries], dim=1))
        self.pos_queries = model.pos_queries
        self.decoder = model.decoder
        self.head = model.head
    def forward(self, memory: Tensor, input_ids: Tensor) -> Tensor:
        B, S = input_ids.size(0), input_ids.size(1)
        tgt_emb = self.prefixed_pos_queries[:, :S] + self.text_embed(input_ids)
        tgt_query = self.pos_queries
        tgt_mask = torch.triu(torch.ones((S, S), dtype=torch.bool), 1).to(tgt_emb.device)
        tgt_mask[torch.triu(torch.ones((S, S), dtype=torch.bool, device=tgt_emb.device), 2)] = 0
        tgt_padding_mask = (input_ids == self.tokenizer.eos_id).int().cumsum(-1) > 0
        decoder_outputs = self.decoder(tgt_query, tgt_emb, memory, query_mask=tgt_mask, content_mask=tgt_mask, content_key_padding_mask=tgt_padding_mask)
        return self.head(decoder_outputs)

refiner_input_ids = torch_tensorrt.Input(                                                                 
    min_shape=[1, num_steps],                                                                             
    opt_shape=[1, num_steps],
    max_shape=[1, num_steps],
    dtype=torch.int64)
print("==== export refiner ====")
parseq_refiner.to_tensorrt("refiner.pt2", (encoder_outputs, refiner_input_ids), ir="dynamo")

こちらは入力トークンが伸び縮みしないのもあり素直に変換できました。

評価

最後にTensorRT化によってどれくらい速くなったかをみてみます。 OCRのベンチマークであるIIIT-5Kに対してさまざまな設定で推論し、1枚あたりのレイテンシをH200 GPU 1台で計測しました。

結果は次の図のようになりました。

例えばAutoregressive(AR)モード・iterative refinement無しではTensorRT変換によって2.58倍の高速化、 Non-Autoregressive(NAR)モードでは3.07倍の高速化を達成しました。 グラフの傾きからiterative refinementも軽量になっていることが分かります。

まとめ

今回の実験では軽量で高性能なOCRモデルであるPARSeqを最新の環境でTensorRT化してみました。 その際、文章生成などでよく用いられるデコーダは入力サイズが動的に変化するため変換に一癖あり、ライブラリが処理しやすいようなプログラムに書き換える必要があることを紹介しました。