Rustで実装する乱数生成のベクトル化

この記事は、 NTT Communications Advent Calendar 2022 22日目の記事です。

はじめに

こんにちは、イノベーションセンターの鈴ヶ嶺(@suzu_3_14159265)です。普段は、クラウド・ハイブリッドクラウド・エッジデバイスなどを利用したAI/MLシステムに関する業務に従事しています。

本日は、Rustでベクトル化された乱数生成器を実装する方法を紹介します。乱数生成器にはPermuted congruential generator(PCG)という高速でシンプルな実装を取り扱います。ベクトル化には1つの命令で複数のデータを適用するSingle Instruction, Multiple Data(SIMD)を活用します。

また、以下のように毎年Rustネタのアドベントカレンダーを書いているのでぜひ見ていただけると嬉しいです!

PCGとは

Permuted congruential generator(PCG)1は、シンプルな実装でメモリ消費も低く高速な乱数生成アルゴリズムです。次式のような線形合同法の出力に対してXorshift2のような排他的論理和とビットシフト操作( x ^= x >> N)を加えることで、線形合同法で見られる偶数と奇数が交互にでるような下位ビットの低いランダム性を改善しています。

 \displaystyle{
X_{n+1} = (A \times X_{n} + B) \mod M
}

また、PCGはTestU013と呼ばれる乱数生成器をテストするツールのBig Crushと呼ばれる最も大きいテストを突破しています。

PCGの活用事例をみると、例えばNumpyの乱数生成器はデフォルトでPCGが採用されています。

The default BitGenerator used by Generator is PCG64. https://numpy.org/doc/stable/reference/random/generator.html

ちなみに、他の乱数生成器との比較は次のようにPCGの公式ページに表としてまとめられており分かりやすいので気になる方は参照してください。

https://www.pcg-random.org/

Rustでは、乱数生成ライブラリの rand_pcgクレートから利用が可能です。次に利用方法のサンプルコードを示します。

use rand::prelude::*;
use rand_pcg::Pcg32;

fn main() {
    let mut rng = Pcg32::from_entropy();
    let ru32: u32 = rng.gen();
    println!("{}", ru32);
}

PCGの実装方法は、Rustのrand 0.8.5のソースコードの一部を抜粋してコメントアウトを用いて次にインラインで説明します。

https://github.com/rust-random/rand/blob/0.8.5/rand_pcg/src/pcg64.rs

const MULTIPLIER: u64 = 6364136223846793005; // 乗数は固定値

pub struct Lcg64Xsh32 {
    state: u64, // 出力の2倍の状態をもつ
    increment: u64, // 任意の奇数
}

pub type Pcg32 = Lcg64Xsh32; // LCG(線形合同法)64bitにXSH(xorshift操作)をするという意味

impl Lcg64Xsh32 {
    // 線形合同法の1ステップ
    fn step(&mut self) {
        self.state = self
            .state
            .wrapping_mul(MULTIPLIER) // 積
            .wrapping_add(self.increment); // 和
    }
}

impl RngCore for Lcg64Xsh32 {
    fn next_u32(&mut self) -> u32 {
        let state = self.state;
        self.step(); // 線形合同法で現在の状態から次の状態を計算

        // xorshift操作のbit数パラメータ
        const ROTATE: u32 = 59; // 64 - 5
        const XSHIFT: u32 = 18; // (5 + 32) / 2
        const SPARE: u32 = 27; // 64 - 32 - 5

        let rot = (state >> ROTATE) as u32;
        let xsh = (((state >> XSHIFT) ^ state) >> SPARE) as u32; // xorshift操作
        xsh.rotate_right(rot) // ビット回転により出力は状態のビット数の半分(32bit)
    }
}

ベクトル化とは

ベクトル化はfor文などで繰り返し1つずつ計算している処理を高速に処理する手法です。今回はその中のSingle Instruction, Multiple Data(SIMD)という技術を活用します。SIMDは1つの命令で複数のデータを処理する方法です。例えば256bitのSIMD命令がある場合は64bit x 4個=256bitのように64bitのデータを同時に4個処理可能です。

Rustでは次のように core::arch::x86_64(https://doc.rust-lang.org/core/arch/x86_64/index.html) 以下の _mm128_**, _mm256_** などのunsafeな関数を利用することでSIMDを利用することが可能になります。 例えば64bitごとに計4個の値を同時に加算するコード例を以下に示します。 今回はIntel AVX2環境で実装します。

#[cfg(any(target_arch = "x86_64"))]
use core::arch::x86_64::*;

fn main() {
    unsafe {
        sample();
    }
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn sample() {
    let a = _mm256_set_epi64x(1, 2, 3, 4);
    let b = _mm256_set_epi64x(5, 6, 7, 8);
    let c = _mm256_add_epi64(a, b);
    println!(" a : {:?}", a);
    println!(" b : {:?}", b);
    println!("a+b: {:?}", c);
}

以下は、出力結果です。

 a : __m256i(4, 3, 2, 1)
 b : __m256i(8, 7, 6, 5)
a+b: __m256i(12, 10, 8, 6)

このように2つの配列 [1, 2, 3, 4], [5, 6, 7, 8] 各要素の加算された結果 [6, 8, 10, 12] が返ってきたことが分かると思います。このように1つの命令で複数のデータを処理します。これをPCGの実装の中の線形合同法のステップ、出力のxorshift操作とビット回転に適用していきます。

その他にも応用的なSIMDの活用方法を知りたい場合はsimdjsonなどの高速JSON Parserの作者で著名なDaniel Lemire先生のgithubのrepoを覗いてみるのをお勧めします。次に説明する実装も、C言語で実装されたsimdpcgSIMDxorshiftを参考にして作成しました。

実装

ここでは、SIMDを用いてベクトル化したPCG実装を示します。基本方針としてユーザが直接利用する関数に関してはsafeな関数として実装していきます。 64bitごとの積を計算する _mm256_mullo_epi64 はavx2にはないため別途実装していることなどに注意してください。 また、portable_simdを利用するためnightlyを利用します。

#![feature(portable_simd)]
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
use std::simd::{u32x4};

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub struct Avx2Pcg {
    state: __m256i,
    inc: __m256i,
    mul_l: __m256i,
    mul_h: __m256i,
}

const MULTIPLIER: i64 = 6364136223846793005; // 乗数は固定

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
#[inline]
// 64bitごとの積
unsafe fn _mm256_mullo_epi64(x: __m256i, ml: __m256i, mh: __m256i) -> __m256i {
    let xl = _mm256_and_si256(x, _mm256_set1_epi64x(0x00000000ffffffff));
    let xh = _mm256_srli_epi64(x, 32);
    let hl = _mm256_slli_epi64(_mm256_mul_epu32(xh, ml), 32);
    let lh = _mm256_slli_epi64(_mm256_mul_epu32(xl, mh), 32);
    let ll = _mm256_mul_epu32(xl, ml);
    let ret = _mm256_add_epi64(ll, _mm256_add_epi64(hl, lh));
    return ret;
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn _mm256_rorv_epi32(x: __m256i, r: __m256i) -> __m256i {
    let ret = _mm256_or_si256(
        _mm256_sllv_epi32(x, _mm256_sub_epi32(_mm256_set1_epi32(32), r)),
        _mm256_srlv_epi32(x, r),
    );
    return ret;
}

impl Avx2Pcg {
    #[inline]
    fn next(&mut self) -> __m128i {
        unsafe { self.next_() }
    }

    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
    #[target_feature(enable = "avx2")]
    #[inline]
    unsafe fn next_(&mut self) -> __m128i {
        let old_state = self.state;

        // 積和計算
        self.state = _mm256_add_epi64(
            _mm256_mullo_epi64(self.state, self.mul_l, self.mul_h),
            self.inc,
        );

        // xorshift
        let xorshifted = _mm256_srli_epi64(
            _mm256_xor_si256(_mm256_srli_epi64(old_state, 18), old_state),
            27,
        );
        let rot = _mm256_srli_epi64(old_state, 59);

        // ビット回転
        let ret = _mm256_castsi256_si128(_mm256_permutevar8x32_epi32(
            _mm256_rorv_epi32(xorshifted, rot),
            _mm256_set_epi32(7, 7, 7, 7, 6, 4, 2, 0),
        ));

        return ret;
    }

    #[inline]
    pub fn next_u32x4(&mut self) -> [u32; 4] {
        let m128: u32x4 = self.next().into();
        return *m128.as_array();
    }

    #[inline]
    pub fn from_state_inc(state: [i64; 4], inc: [i64; 4]) -> Avx2Pcg {
        unsafe {
            Avx2Pcg {
                state: _mm256_set_epi64x(state[0], state[1], state[2], state[3]),
                inc: _mm256_set_epi64x(inc[0] | 1, inc[1] | 1, inc[2] | 1, inc[3] | 1),  // 奇数のため "| 1" を追加
                mul_l: _mm256_set1_epi64x(MULTIPLIER & 0x00000000ffffffff),  // 乗数は固定
                mul_h: _mm256_set1_epi64x(MULTIPLIER >> 32),  // 乗数は固定
            }
        }
    }

    #[inline]
    pub fn from_entropy() -> Avx2Pcg {
        let mut state_buf: [u8; 32] = [0u8; 32];
        let mut inc_buf: [u8; 32] = [0u8; 32];
        let _ = getrandom::getrandom(&mut state_buf);
        let _ = getrandom::getrandom(&mut inc_buf);
        let state_buf64: [i64; 4] = unsafe { std::mem::transmute(state_buf) };
        let inc_buf64: [i64; 4] = unsafe { std::mem::transmute(inc_buf) };
        Avx2Pcg::from_state_inc(state_buf64, inc_buf64)
    }
}

実験

実際に、32bit x 10000000個の合計40Mbの乱数生成時間でRustのPCG実装のベースラインと今回のベクトル化されたPCG実装の比較します。 実験環境はGoogle Cloudのインスタンスサイズn2-standard-2, Ubuntu20.04LTSです。 Rust versionはrustc 1.68.0-nightly (d0dc9efff 2022-12-18)を利用します。 また、ループアンローリングでループ内に複数の命令数(1, 2, 4)を展開したケースもわけて計測します。 ちなみに、計測ツールにはCriterion.rsを利用して試行回数100回で計測しました。

use criterion::{criterion_group, criterion_main, Criterion};
use rand::prelude::*;
use rand_pcg::Pcg32;
use simd_pcg::Avx2Pcg;

const N: usize = 10000000;

pub fn avx_one(c: &mut Criterion) {
    c.bench_function("avx one", |b| {
        b.iter(|| {
            let mut arr = vec![0u32; N];

            let mut rng1 = Avx2Pcg::from_entropy();

            for i in (0..N).step_by(4) {
                let r1_arr = rng1.next_u32x4();
                arr[i..(i + 4)].copy_from_slice(&r1_arr);
            }
        })
    });
}

pub fn avx_two(c: &mut Criterion) {
    c.bench_function("avx two", |b| {
        b.iter(|| {
            let mut arr = vec![0u32; N];

            let mut rng1 = Avx2Pcg::from_entropy();
            let mut rng2 = Avx2Pcg::from_entropy();

            for i in (0..N).step_by(8) {
                let r1_arr = rng1.next_u32x4();
                let r2_arr = rng2.next_u32x4();
                arr[i..(i + 4)].copy_from_slice(&r1_arr);
                arr[(i + 4)..(i + 8)].copy_from_slice(&r2_arr);
            }
        })
    });
}

pub fn avx_four(c: &mut Criterion) {
    c.bench_function("avx four", |b| {
        b.iter(|| {
            let mut arr = vec![0u32; N];

            let mut rng1 = Avx2Pcg::from_entropy();
            let mut rng2 = Avx2Pcg::from_entropy();
            let mut rng3 = Avx2Pcg::from_entropy();
            let mut rng4 = Avx2Pcg::from_entropy();

            for i in (0..N).step_by(16) {
                let r1_arr = rng1.next_u32x4();
                let r2_arr = rng2.next_u32x4();
                let r3_arr = rng3.next_u32x4();
                let r4_arr = rng4.next_u32x4();
                arr[i..(i + 4)].copy_from_slice(&r1_arr);
                arr[(i + 4)..(i + 8)].copy_from_slice(&r2_arr);
                arr[(i + 8)..(i + 12)].copy_from_slice(&r3_arr);
                arr[(i + 12)..(i + 16)].copy_from_slice(&r4_arr);
            }
        })
    });
}

pub fn baseline_one(c: &mut Criterion) {
    c.bench_function("baseline one", |b| {
        b.iter(|| {
            let mut arr = vec![0u32; N];
            let mut rng1 = Pcg32::from_entropy();
            for i in (0..N).step_by(1) {
                arr[i] = rng1.gen();
            }
        })
    });
}

pub fn baseline_two(c: &mut Criterion) {
    c.bench_function("baseline two", |b| {
        b.iter(|| {
            let mut arr = vec![0u32; N];
            let mut rng1 = Pcg32::from_entropy();
            let mut rng2 = Pcg32::from_entropy();
            for i in (0..N).step_by(2) {
                arr[i] = rng1.gen();
                arr[i + 1] = rng2.gen();
            }
        })
    });
}

pub fn baseline_four(c: &mut Criterion) {
    c.bench_function("baseline four", |b| {
        b.iter(|| {
            let mut arr = vec![0u32; N];
            let mut rng1 = Pcg32::from_entropy();
            let mut rng2 = Pcg32::from_entropy();
            let mut rng3 = Pcg32::from_entropy();
            let mut rng4 = Pcg32::from_entropy();
            for i in (0..N).step_by(4) {
                arr[i] = rng1.gen();
                arr[i + 1] = rng2.gen();
                arr[i + 2] = rng3.gen();
                arr[i + 3] = rng4.gen();
            }
        })
    });
}
ループアンローリング 命令数 baseline avx2(今回の実装)
1 27.601 ms 24.447 ms
2 23.357 ms 19.803 ms
4 22.839 ms 19.526 ms

上の表が試行回数100回の計測時間の平均結果、図が計測時間の分布を示しています。

この結果からベクトル化された今回の実装のavx2は特にループアンローリングの命令数が多い場合baselineよりも有意に高速であることが分かります。

また、こちらの実験コードは以下に置いておきます。

https://github.com/suzusuzu/simd-pcg

次のコマンドで実行可能です。

git clone git@github.com:suzusuzu/simd-pcg.git
cd simd-pcg
rustup run nightly cargo bench

まとめ

本記事では、PCGを対象にSIMDによるベクトル化をRustで実装する方法を紹介しました。また、比較実験の結果からもSIMD化による高速化が有効であることを示しました。

それでは、明日の記事もお楽しみに!

参考


  1. O’Neill, Melissa E. "PCG: A family of simple fast space-efficient statistically good algorithms for random number generation." ACM Transactions on Mathematical Software (2014).
  2. Marsaglia, George. "Xorshift rngs." Journal of Statistical Software 8 (2003): 1-6.
  3. L'ecuyer, Pierre, and Richard Simard. "TestU01: AC library for empirical testing of random number generators." ACM Transactions on Mathematical Software (TOMS) 33.4 (2007): 1-40.
© NTT Communications Corporation 2014