JetsonでJAXが使えるようになりたい

この記事は NTTコミュニケーションズ Advent Calendar 2021 18日目の記事です。

はじめに

こんにちは、イノベーションセンターの齋藤 暁です。普段はコンピュータビジョンの技術開発やAI/MLシステムの検証に取り組んでいます。今回は、JetsonでJAXが使えるように環境の構築をしていくのですが、時間の関係上ビルドまでたどりつくことができませんでした。ただ、せっかくなので奮闘記として、どのような方法でエラーハンドリングしたかを残しておきたいと思いますので。自身への戒めという側面もありますが、次にJetsonでJAXを使いたい人が、この方法はクリティカルなエラーハンドリングではないんだということを知っていただくことが嬉しいです。

(JAXは勉強したいと思っているので、ビルドがうまく行った際には、情報を更新したいと考えておりますので、適宜のぞいてみていただけるとビルドできた例を見ることができるかもしれないです。)

JAXとは

f:id:NTTCom:20211217122246p:plain
google/jaxから引用

まずJAXについて軽く説明をしたいと思います。 JAXとはAutogradとXLAを組み合わせた機械学習のライブラリです。 特徴的な点としては、JAXがXLAを使ってGPUやTPU上でNumPyをコンパイルして実行できるところです。jitのデコレータによるコンパイルとgradによる自動微分を組み合わせて使用できます。 今後も新しい機能がでることを明言していることから、今後ますます発展するライブラリであると思っています。

モチベーション

現在私は、Jetsonなどのエッジデバイスを用いた映像解析アプリケーションの開発をしています。この開発では、エッジデバイスを含めたモデルの最適な配置についても開発の課題としております。 高速に計算をできれば、よりリアルタイムに近い推論をできる嬉しさがあると考えられます。ただ、JAXは2018年にローンチされたライブラリであるため、エッジデバイスで利用することに適しているのかという議論についてもあまり数がないように思いました。 そのため、今回はJAXをJetsonにそもそもインストールし、使用できるのかという部分から検証したいとモチベーションがあります。

環境構築

※注意: まだJetsonでJAXを使えるようにできておりません。できた際には、この記事を更新したいと思っていますが、下記の方法では私の環境ではビルドできないことを確認しています。

最初は、Jetson Xavier NXで環境構築を試みたのですが、私の環境ではbuildの途中でメモリが溢れてしまいました。もしかしたらSwap領域を増やせば、落ちないかもしれないですが、今回はJetson AGX Xavierでビルドを行いました。(Jetson Nanoでできたという記事がありますが、Jetson Xavier NXでのビルドで落ちていたので、こちらもどこかでやりたいですね)

Jetsonの環境

まず、Jetson環境の構築から始めます。 今回私が使用するJetsonのOSであるL4Tのバージョンは、32.6.1。Jetpackのバージョンは4.6を選択しました。内部ストレージは、32GBしかないため、NvMeのSSD 256GBを増設しました。

bootFromExternalStorage の手順に従ってJetPack4.6をNvMeから起動します。以下に簡単にまとめます。詳しい説明については、githubに載っているのでそちらの参照をお願いいたします。

# Ubuntu18.04 or Ubuntu20.04を入れたx86_64ベースのホストPCを用意します
# git cloneでbootFromExternalStorageをインストールします
git clone https://github.com/jetsonhacks/bootFromExternalStorage.git
cd bootFromExternalStorage
. install_dependencies.sh
. get_jetson_files.sh

# JetsonとホストPCをUSBケーブルでつなぐ
# このスクリプトがうまくいくと、Jetson AGX Xavierが起動する
. flash_jetson_external_storage.sh

# Jetson AGX Xavierに、bootFromExternalStorageをインストールします
# 以下のスクリプトを実行することによってcuda、cudnnなどをインストールできます
. install_jetson_default_packages.sh

JAXの環境構築

.bashrcの一番下に以下を追加します。

sudo vim ~/.bashrc
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}

次に以下のスクリプトを実行してインストールするとビルドが始まります。

source ~/.bashrc
#PPAの追加
sudo add-apt-repository ppa:deadsnakes/ppa
sudo apt update

# 現在(2021.12.15)のJAXは、python3.6を実行環境から外しているため、
# 今後アップデートの際にサポートから外れないように、python3.9を選択しました
sudo apt install python3.9 python3.9-dev
sudo apt install python3-pip
sudo pip3 install virtualenv
virtualenv -p /usr/bin/python3.9 jax
source ./jax/bin/activate
sudo apt-get install python3.9-distutils

#必要なライブラリのインストール
python -m pip install numpy scipy six wheel
sudo apt install g++

#jaxのインストール
git clone https://github.com/google/jax
cd jax
python ./build/build.py

おやおや...? 何かエラーが出てますね長いエラーがでているので、とりあえず怪しい部分を抜粋してみます。

#エラー文
[2,118 / 5,712] Compiling llvm/lib/Target/X86/X86ISelLowering.cpp; 50s local ... (8 actions running)
[2,155 / 5,712] Compiling llvm/lib/Target/X86/X86ISelLowering.cpp; 216s local ... (8 actions running)

ERROR: /home/hoge/.cache/bazel/_bazel_nvidia/a5643b5cc286b9b13a96818003a4a7dd/external/org_tensorflow/tensorflow/python/lib/core/BUILD:49:11: Compiling tensorflow/python/lib/core/bfloat16.cc failed: (Exit 1): gcc failed: error executing command
  (cd /home/hoge/.cache/bazel/_bazel_nvidia/a5643b5cc286b9b13a96818003a4a7dd/execroot/__main__ ...

JetsonのCPUのアーキテクチャはARM64ですが、X86のllvmをコンパイルしていますね。ですので、とりあえずarmのllvm一式をインストールします。

sudo apt-get -y install libllvm-7-ocaml-dev libllvm7 llvm-7 llvm-7-dev llvm-7-doc llvm-7-examples llvm-7-runtime

エラー文は変わってますね、今度は /.cache/bazel のパーミッションがないのか?

#エラー文
ERROR: /home/nvidia/.cache/bazel/_bazel_nvidia/a5643b5cc286b9b13a96818003a4a7dd/external/llvm-project/llvm/BUILD.bazel:2068:11: Compiling llvm/lib/Passes/PassBuilder.cpp failed: (Exit 4): gcc failed: error executing command

とりあえず、パーミッションをこのコマンドで変えてみます。

sudo chmod -R u+rw ~/.cache/bazel/

エラーが止まらない..。

次にできることとして、JetPackを使うと感じることが、ライブラリのバージョンによって動く動かないがあるので、bazelのバージョンによる問題ではないかと仮定しました。 現在のJAXのbranchのbazelのバージョンは、4.2.1です。しかし、bazelのpreinstallができるバージョンを見ているとbazel4.2.1は、JetPack4.6.1に対応していないのかなと思われます。そのため、今回はJAXの最新バージョンを使うのではなく、bazel3.7.2がビルドされるバージョンを使用してみます。

#bazel3.7.2 且つarmに対応しているtagの選択
git checkout -b jax-v0.2.17
# .cache/bazel/ のパーミッション系の問題が起きないように
sudo chmod -R u+rw ~/.cache/bazel/
python ./build/build.py

うーん。さっきとエラーが変わっていないような気がします。

numpyのバージョンという線もあるので、一応numpy1.21.4から下げてビルドをしてみましたが、エラーは変わらず..。

#エラー文
ERROR: /home/nvidia/.cache/bazel/_bazel_nvidia/f136147ae544c503f5fd3870723c0471/external/org_tensorflow/tensorflow/compiler/xla/service/cpu/BUILD:393:11: C++ compilation of rule '@org_tensorflow//tensorflow/compiler/xla/service/cpu:ir_emitter' failed (Exit 4): gcc failed: error executing command

f:id:NTTCom:20211217132352p:plain

ここで、タイムアップとなってしまいました。他にも細々とやったのですが、主にやったことが以上となっています。

所感

JetsonでJAXの環境構築にチャレンジしたのですが、冬休みの課題となりました。JAXで遊ぶことは今回達成できませんでしたが、中で何がコンパイルされているのか見ることはできたので、良い勉強になったかと思われます。また、この方法でやれば良いのではという情報を共有してくださる方がおりましたら str.saito@ntt.com まで教えていただけると嬉しいです。

おわりに

今回は、JAXをJetson AGX Xavierで環境構築をしたかったのですが、まだこれをすれば確実にビルドできる!という方法がなく1週間ほどでは難しかったです。そのため、とりあえずビルドをすることが今後の課題となりそうです。 そして今後は、JAXを使って物体検出のモデルを書いてみたいと思っているので、その際にベンチマークを測れればいいなと思います。(年末年始にチャレンジしてみたい) それでは、明日の記事もお楽しみに!

参考

@software{jax2018github,
  author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
  title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
  url = {http://github.com/google/jax},
  version = {0.2.5},
  year = {2018},
}
© NTT Communications Corporation All Rights Reserved.