Tenstorrentにおけるfused kernel実装と性能評価

この記事は、NTT docomo Business Advent Calendar 2025 19日目の記事です。

こんにちは、イノベーションセンターの鈴ヶ嶺です。普段はAIアクセラレータの検証に関する業務に従事しています。

本記事では、まずTenstorrentのAIアクセラレータアーキテクチャを紹介し、その特徴について説明します。次に、複数の演算を1つのkernelに統合するfused kernelによる最適化に注目し、標準正規乱数(randn)を例にTenstorrentのアクセラレータにおける具体的な実装方法と性能評価を共有します。その結果、従来の演算の組み合わせの標準正規乱数の実装と比較して、fused kernel実装により約4倍の高速化を確認しました。

Tenstorrentとは

Tenstorrent Inc. は次世代AIアクセラレータを製造する半導体メーカーです。 オープン戦略を掲げており、アクセラレータにはRISC-Vを採用し、ソフトウェアに関してはOSS (https://github.com/tenstorrent) として積極的に公開されています。 2025年12月現在ではDEC、AMD、Apple、Teslaを歴任した半導体業界の著名なJim Keller氏がCEOを務めています。

TenstorrentのAIアクセラレータのアーキテクチャについて紹介します。

引用: https://speakerdeck.com/tenstorrent_japan/tensix-core-akitekutiyajie-shuo?slide=7

アクセラレータはTensix Coreと呼ばれる5つのBaby RISC-V、2つのNetwork-on-Chip(NoC)、SRAMで構成されるものが複数搭載されています。 一般的なハードウェア管理キャッシュを持たない構成となっており、明示的にコア付近のSRAMを操作する分散メモリ型のNear Memory Computing(NMC)な設計です。 5つのRISC-Vコアは独立な動作が可能なMIMD(Multiple Instruction、 Multiple Data)アーキテクチャです。 多くの処理は典型的にはデータ読み出しを行うReader kernel(RISC-V 1)、 計算をするCompute kernel(RISC-V 2、 3、 4)、 データ書き込みを行うWriter kernel(RISC-V 5)に分けて実行されます。 後述する標準正規乱数のfused kernel実装ではデータ読み込みが不要のためCompute、Writer kernelのみの実装となっており、処理に合わせて自由度を高く調整できます。 16x16を基本としてtileベースの演算エンジンを積んでおり、Compute kernelはこのエンジンを呼び出します。 kernel間のデータはCircular Buffer (CB)と呼ばれるSRAM上のFIFOキューでやり取りをします。 ホストとのデータ交換は外側のDRAM(GDDR)を介して行われます。

その他の技術詳細は日本法人のTenstorrent Japanから以下にさまざまな資料が公開されているためご参照ください。

https://speakerdeck.com/tenstorrent_japan

オンチップ計算を活かしたFlash Attention

アクセラレータの特徴として、低コスト化のためにHBM(High Bandwidth Memory)などの高コストなメモリを使わない設計となっています。 そのためできるだけDRAM往復によるオーバーヘッドを避けるために、オンチップのSRAM上で計算する工夫がされます。 ここではLLMのAttention計算の事例を取り上げて、どのようにTenstorrentのAIアクセラレータで効率化されるのかを説明します。

https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/FlashAttention/FlashAttention.md

LLMのAttentionはそのまま計算すると、巨大な中間行列によりHBM、 DRAMへのデータ移動がオーバーヘッドとなることが知られております。 FlashAttention 1 2 は、その課題に対して行列をチャンクに分割し、より高速なSRAM上で計算しデータ移動のオーバーヘッドを削減し、高速化する手法です。

TenstorrentのAIアクセラレータでも、このFlashAttentionを適用可能です。 大容量のSRAMを利用して実装され中間データがDRAMに書き込まれないため高速化されます。 以下の図のようにベースライン実装と比較して平均して20倍高速に動作します。

引用: https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/FlashAttention/images/image3.png

fused kernelの実装と評価

AIアクセラレータの実行は複数のkernelの実行による、中間計算結果のメモリアクセスや起動オーバーヘッドが課題となります。 そこで複数の計算処理を1つのkernelに統合するfused kernelにより性能を向上させる処理がよく用いられます。

例えばLLMのAttentionなどは計算を最適化するために1つのfused kernelとして実装されています。

ttnn.transformer.scaled_dot_product_attention(input_tensor_q: ttnn.Tensor, 
    input_tensor_k: ttnn.Tensor, input_tensor_v: ttnn.Tensor, *,
    attn_mask: ttnn.Tensor = None, is_causal: bool = true, scale: float = None,
    sliding_window_size: int = None, memory_config: ttnn.MemoryConfig = None,
    program_config: SDPAProgramConfig = None,
    compute_kernel_config: ttnn.DeviceComputeKernelConfig = None,
    attention_sink: ttnn.Tensor = None) → ttnn.Tensor

https://docs.tenstorrent.com/tt-metal/latest/ttnn/ttnn/api/ttnn.transformer.scaled_dot_product_attention.html#ttnn.transformer.scaled_dot_product_attention

ここではttnnに実装されていない標準正規乱数を生成するrandnを実装します。 randnは一般的な PyTorchのtorch.randn や Numpyのnp.random.randn などではサポートされています。 標準正規乱数には、Box-Muller法 3 を用います。

実装

新規のOperation追加は、次のように手順で行います。

https://docs.tenstorrent.com/tt-metal/latest/ttnn/ttnn/adding_new_ttnn_operation.html

まず、ホスト側での処理を抜粋すると以下のように実装します。

ttnn/cpp/ttnn/operations/randn/device/randn_device_operation.[cpp|hpp] ではOperationの引数やバリデーションを実装します。

struct RandnDeviceOperation {
    struct operation_attributes_t {
        const ttnn::Shape shape; // テンソルの形状
        DataType dtype;
        Layout layout;
        const MemoryConfig memory_config;
        MeshDevice* device;
        const DeviceComputeKernelConfig compute_kernel_config;
        uint32_t seed; // 乱数seed
    };

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

void RandnDeviceOperation::validate_inputs(
    const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
    TT_FATAL(
        operation_attributes.dtype == DataType::FLOAT32 || operation_attributes.dtype == DataType::BFLOAT16,
        "Randn: Output tensor must be Float32 or Bfloat16"); // dtypeによるバリデーション
    TT_FATAL(operation_attributes.layout == Layout::TILE, "Randn: Not currently supporting row major layout"); // メモリレイアウトのバリデーション
}

アクセラレータ上のkernel実行の詳細は ttnn/cpp/ttnn/operations/randn/device/randn_program_factory.cpp に記述します。

  1. ユーティリティ関数 tt::tt_metal::split_work_to_cores 4 によるコアごとの処理を均等に分散
  2. CreateCircularBuffer によるCB(FIFOキュー)の作成
  3. CreateKernel によるCompute、 Writer kernelの作成
  4. SetRuntimeArgs kernel実行の引数の設定
// split_work_to_coresにより、それぞれのコアに処理を割り振る
auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] =
    split_work_to_cores(grid, units_to_divide);

// CBの作成(2tile分の出力ができるように確保する)
constexpr uint32_t dst_cb_id = CBIndex::c_0;
CircularBufferConfig cb_output_config =
    CircularBufferConfig(in_out_num_tiles * dtype_tile_size, {{dst_cb_id, out_data_format}})
        .set_page_size(dst_cb_id, dtype_tile_size);
tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config);

// Writer kernelの設定
const std::string kernels_dir_path = "ttnn/cpp/ttnn/operations/randn/device/kernels/";
std::vector<uint32_t> writer_compile_time_args{dst_cb_id};
tt::tt_metal::TensorAccessorArgs(output.buffer()).append_to(writer_compile_time_args);
const std::string writer_file_path = kernels_dir_path + "writer_standard_normal.cpp";
KernelHandle writer_kernel_id = tt_metal::CreateKernel(
    program, writer_file_path, all_cores, WriterDataMovementConfig(writer_compile_time_args));

// Compute kernelの設定
const std::vector<uint32_t> compute_compile_time_args{dst_cb_id};
const std::string compute_file_path = kernels_dir_path + "compute_standard_normal.cpp";
auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] =
    get_compute_kernel_config_args(device->arch(), operation_attributes.compute_kernel_config);
KernelHandle compute_kernel_id = CreateKernel(
    program,
    compute_file_path,
    all_cores,
    ComputeConfig{
        .math_fidelity = math_fidelity, // 計算の精度 ref: https://speakerdeck.com/tenstorrent_japan/tensix-core-akitekutiyajie-shuo?slide=26
        .fp32_dest_acc_en = true,
        .dst_full_sync_en = dst_full_sync_en,
        .math_approx_mode = math_approx_mode,
        .compile_args = compute_compile_time_args,
        .defines = compute_defines,
    });

// foreach in split_work_to_coresによる割り振り
  // kernel引数(1コアあたりの乱数生成のtile数、出力のアドレス)の設定
  std::vector<uint32_t> compute_runtime_args = {seed, tile_offset, units_per_core};
  SetRuntimeArgs(program, compute_kernel_id, core, compute_runtime_args);
  std::vector<uint32_t> writer_runtime_args = {output.buffer()->address(), tile_offset, units_per_core};
  SetRuntimeArgs(program, writer_kernel_id, core, writer_runtime_args);
// end

ここからはkernelの実装を説明します。kernel内で利用可能なAPIは以下になります。

https://docs.tenstorrent.com/tt-metal/latest/tt-metalium/tt_metal/apis/kernel_apis.html

Compute kernel ttnn/cpp/ttnn/operations/randn/device/kernels/compute_standard_normal.cpp の抜粋を記述します。 tileベースの命令を用いて処理します。 ここで実際にBox-Muller法で標準正規乱数が生成されます。

// Box-Muller法で標準正規乱数 (Z1, Z2) を生成
// Z1 = sqrt(ln(U1) * -2) * cos(U2 * 2pi)
// Z2 = sqrt(ln(U1) * -2) * sin(U2 * 2pi)

// 出力CBの末尾に2tile確保
cb_reserve_back(dst_cb_id, 2);

// タイルレジスタを確保
tile_regs_acquire();

// U1、 U2の一様乱数(0, 1)をレジスタ0, 1に生成
rand_tile(0, flt_min, one_minus);
rand_tile(1, flt_min, one_minus);

// sqrt(ln(U1) * -2)を計算し、レジスタ0に格納
log_tile(0);
mul_unary_tile(0, neg_two);
sqrt_tile(0);

// レジスタ2に2piを詰める
fill_tile_bitcast(2, two_pi);

// U2 * 2piを計算し、レジスタ3, 1に格納
mul_binary_tile(1, 2, 3);
mul_binary_tile(1, 2, 1);

// cos(U2 * 2pi)を計算し、レジスタ3に格納
cos_tile(3);
// sin(U2 * 2pi)を計算し、レジスタ1に格納
sin_tile(1);

// Z1 = sqrt(ln(U1) * -2) * cos(U2 * 2pi)を計算し、レジスタ3に格納
mul_binary_tile(0, 3, 3);
// Z2 = sqrt(ln(U1) * -2) * sin(U2 * 2pi)を計算し、レジスタ1に格納
mul_binary_tile(0, 1, 1);

// 出力dtypeが BFLOAT16 の場合は型変換
#ifdef OUTPUT_DTYPE_BFLOAT16
typecast_tile<0, 5>(3);
typecast_tile<0, 5>(1);
#endif

// レジスタ計算の確定、完了待ち
tile_regs_commit();
tile_regs_wait();

// レジスタ3, 1のZ1、 Z2をCBへ書き込み
pack_tile(3, dst_cb_id);
pack_tile(1, dst_cb_id);

// レジスタ解放
tile_regs_release();

// CBの末尾に2タイル追加したことを通知
cb_push_back(dst_cb_id, 2);

次にWriter kernel ttnn/cpp/ttnn/operations/randn/device/kernels/writer_standard_normal.cpp を抜粋します。 基本的にはCompute kernelからデータを受け取り、そのままNOC経由で書き込みます。

// CBの先頭に2tileがCompute kernelからpushされるまで待つ
cb_wait_front(dst_cb_id, 2);

// CBの読み取りポインタ取得
uint32_t dst_cb_read_base = get_read_ptr(dst_cb_id);
uint32_t dst_cb_read0_ptr = dst_cb_read_base;
uint32_t dst_cb_read1_ptr = dst_cb_read_base + dst_tile_bytes;

// NOCでタイル単位に非同期書き込み
noc_async_write_tile(i, output_addrg, dst_cb_read0_ptr);
noc_async_write_tile(i + 1, output_addrg, dst_cb_read1_ptr);

// 書き込み完了までバリア
noc_async_write_barrier();

// CBから2tile pop
cb_pop_front(dst_cb_id, 2);

最後にC++やPythonから呼び出すための実装を追加します。

ttnn/cpp/ttnn/operations/randn/device/[randn|randn_pybind].[cpp|hpp]

Tensor Randn::invoke(
    const ttnn::Shape& shape,
    MeshDevice& device,
    const DataType dtype,
    const Layout layout,
    const MemoryConfig& memory_config,
    const std::optional<DeviceComputeKernelConfig>& compute_kernel_config,
    uint32_t seed) {
    auto tensor = ttnn::prim::randn(shape, dtype, layout, memory_config, device, compute_kernel_config, seed);
    if (layout != Layout::TILE) {
        tensor = ttnn::to_layout(tensor, layout);
    }
    return tensor;
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

void bind_randn_operation(py::module& pymodule) {
    bind_registered_operation(
        pymodule,
        ttnn::randn,
        doc,
        ttnn::pybind_overload_t{
            [](const OperationType& self,
               const ttnn::Shape& shape,
               MeshDevice& device,
               const DataType dtype,
               const Layout layout,
               const MemoryConfig& memory_config,
               const std::optional<DeviceComputeKernelConfig>& compute_kernel_config,
               uint32_t seed) {
                return self(shape, device, dtype, layout, memory_config, compute_kernel_config, seed);
            },
            py::arg("shape"),
            py::arg("device"),
            py::kw_only(),
            py::arg("dtype") = DataType::BFLOAT16,
            py::arg("layout") = Layout::TILE,
            py::arg("memory_config") = ttnn::DRAM_MEMORY_CONFIG,
            py::arg("compute_kernel_config") = std::nullopt,
            py::arg("seed") = 0});
}

今回実装したより詳しい全体コードは、以下のPull Requestを参照ください。

https://github.com/tenstorrent/tt-metal/pull/34508

性能評価

次のスクリプトで実装したttnn.randnと従来のopを組み合わせたttnn.rand + Box-Muller変換の実装と比較します。 補足としてCPUによる実装も計測します。

import math, time, ttnn, torch, numpy as np

def rand_box_muller(shape, *, device, dtype, layout, mem, seed):
    half = (*shape[:-1], shape[-1] // 2)
    u1 = ttnn.rand(half, device=device, dtype=dtype, layout=layout, memory_config=mem, seed=seed + 1234)
    u2 = ttnn.rand(half, device=device, dtype=dtype, layout=layout, memory_config=mem, seed=seed + 4321)
    r = ttnn.sqrt(ttnn.multiply(ttnn.log(u1), -2.0))
    th = ttnn.multiply(u2, 2.0 * math.pi)
    z0 = ttnn.multiply(r, ttnn.cos(th))
    z1 = ttnn.multiply(r, ttnn.sin(th))
    return ttnn.concat([z0, z1], dim=-1)

def fused(shape, *, device, dtype, layout, mem, seed):
    return ttnn.randn(shape, device=device, dtype=dtype, layout=layout, memory_config=mem, seed=seed + 1234)

def torch_randn(shape, *, dtype, seed):
    torch.manual_seed(seed+1234)
    return torch.randn(shape, dtype=dtype)

def bench(name, fn, *, iters, warmup):
    for i in range(warmup): fn(i)
    t0 = time.perf_counter_ns()
    for i in range(iters): fn(i)
    mean_ms = (time.perf_counter_ns() - t0) / 1e6 / iters
    print(f"{name}: {mean_ms:.6f} ms/iter")
    return mean_ms

DEVICE_ID = 0
SHAPE = (1, 1, 1024, 1024)
ITERS, WARMUP = 10000, 1000
LAYOUT, MEM, DTYPE = ttnn.TILE_LAYOUT, ttnn.DRAM_MEMORY_CONFIG, ttnn.float32

device = ttnn.open_device(device_id=DEVICE_ID)

res_rand_box = bench("ttnn.rand + Box-Muller", lambda i: rand_box_muller(SHAPE, device=device, dtype=DTYPE, layout=LAYOUT, mem=MEM, seed=i), iters=ITERS, warmup=WARMUP)
res_randn = bench("ttnn.randn", lambda i: fused(SHAPE, device=device, dtype=DTYPE, layout=LAYOUT, mem=MEM, seed=i), iters=ITERS, warmup=WARMUP)
print(f"Speedup: {res_rand_box / res_randn:.3f}x")
ttnn.close_device(device)

print("\nappendix")
res_torch = bench("torch.randn", lambda i: torch_randn(SHAPE, dtype=torch.float32, seed=i), iters=ITERS, warmup=WARMUP)

4つの Tenstorrent Wormhole™ n300s カードを搭載したTT-LoudBoxサーバで実行した結果が次のようになります。 従来のop組み合わせ(rand×2 + log/sqrt/sin/cos/mul + concat)の実装に比べて、今回fused kernelを実装して約4倍の高速化が達成しました。 ちなみに、CPU(Intel® Xeon® Silver 4309Y)の torch.randn で実行したものと比べるとアクセラレータによる並列実行の恩恵を感じることができると思います。

ttnn.rand + Box-Muller: 0.344376 ms/iter
ttnn.randn: 0.085173 ms/iter
Speedup: 4.043x

appendix
torch.randn: 4.509201 ms/iter

また、出力されたサンプルの分布を可視化しても標準正規分布として問題ないことが次のように確認できました。

import ttnn, matplotlib.pyplot as plt, numpy as np

device = ttnn.open_device(device_id=0)
x = ttnn.randn(
    (1, 1, 1024, 1024),
    device=device,
    dtype=ttnn.float32,
    layout=ttnn.TILE_LAYOUT,
    memory_config=ttnn.DRAM_MEMORY_CONFIG,
    seed=1234,
)

x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT)
x = ttnn.from_device(x)
x = ttnn.to_torch(x).cpu().numpy().ravel()

mean = np.mean(x)
var = np.var(x)

plt.figure(figsize=(6, 4))
plt.hist(x, bins=100, density=True, alpha=0.7)
plt.axvline(mean, linewidth=2, label=f"mean = {mean:.6f}")
plt.axvspan(mean - np.sqrt(var), mean + np.sqrt(var), alpha=0.2, label=f"var = {var:.6f}")
plt.title("Histogram of ttnn.randn()")
plt.xlabel("Value")
plt.ylabel("Probability Density")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig("fig.png")

ttnn.close_device(device)

まとめ

本記事では、TenstorrentのAIアクセラレータアーキテクチャとその特徴を紹介しました。また、fused kernelによる具体的な最適化の実装方法と従来手法と比較して約4倍の高速化を達成する性能評価結果を共有しました。

明日のアドベントカレンダーもお楽しみに。


  1. Dao, Tri. "Flashattention-2: Faster attention with better parallelism and work partitioning." arXiv preprint arXiv:2307.08691 (2023).
  2. Shah, Jay, et al. "Flashattention-3: Fast and accurate attention with asynchrony and low-precision." Advances in Neural Information Processing Systems 37 (2024): 68658-68685.
  3. Box, George E. P. and Mervin E. Muller. “A Note on the Generation of Random Normal Deviates.” Annals of Mathematical Statistics 29 (1958): 610-611.
  4. https://github.com/tenstorrent/tt-metal/blob/main/METALIUM_GUIDE.md#spmd-in-metalium