みずりゅの自由帳

主に参加したイベントやソフトウェア技術/開発について記録しています

テンソルライブラリNxをちょっと試してみた

先日、Elixirのテンソルライブラリ「Nx」が公開されました。1

存在は知ってはいたのですが、ちょっとゴタゴタしていて手を出せていませんでした。
しかし、2021/2/25に参加した 「NervesJP #15 Nxを触ってみる回」での話をきいて、「触ってみよう」って気持ちが高まったので、少しだけ触ってみました。

ちなみに、Nerves JPについては、こちらを参照。 nerves-jp.connpass.com

※2021.02.26追記:「NervesJP #15 Nxを触ってみる回」自体について、若干追記しました。

Nxとは

冒頭でも触れましたが、NxはElixir製のテンソル用ライブラリです。
「Numerical Elixir」からNxとつけられたのでしょうか。

github.com

GitHubelixir-nxのプロジェクトには、「Nx」と「EXLA」の2つのライブラリがあります。 それぞれの説明は以下の通り。

  • Nx: Nx is a multi-dimensional tensors library for Elixir with multi-staged compilation to the CPU/GPU.
  • EXLA: Elixir client for Google's XLA (Accelerated Linear Algebra).

とりあえず、Nxがテンソル用ライブラリ、EXLAがElixir用のXLA(線形代数の演算に特化したコンパイラ2 あたりでしょうか。
この辺りは、おいおい知識のアップデートをしていきます。

Nxについては、作者のJosé Valim氏のデモ動画があります。

www.youtube.com

実際にNxを動かしてみる

HelloWorld的な位置付けとして、ひとまずNxライブラリを動かしてみることだけやってみます。

Nxの「Installation」と「Examples」に書いてあることをやります。

https://github.com/elixir-nx/nx/tree/main/nx#installation

https://github.com/elixir-nx/nx/tree/main/nx#examples

自分の環境面
Nxの導入

Nxを利用するためには、Nxのライブラリをインストールする必要があります。
説明では、mixで新しいプロジェクトを作成して、mix.exsでNxライブラリの設定を行うように書かれています。これに従って、実施していきます。

1: 任意のディレクトリ上で、mix newで新規プロジェクトを作成。

$ mix new my_app

結果、配下にプロジェクトのディレクトリができるので、my_appディレクトリへ移動する。

* creating README.md
* creating .formatter.exs
* creating .gitignore
* creating mix.exs
* creating lib
* creating lib/my_app.ex
defmodule MyApp.MixProject do
* creating test
* creating test/test_helper.exs
* creating test/my_app_test.exs

Your Mix project was created successfully.
You can use "mix" to compile it, test it, and more:

    cd my_app
    mix test

Run "mix help" for more commands.
$ cd my_app

2: mix.exsにNxのライブラリを追加する。

利用するバージョンは今後どんどん変わっていくと思いますが、現在のReadmeでは「main」ブランチから取得するように設定されています。

$ vim mix.exs
  defp deps do
    [
      # {:dep_from_hexpm, "~> 0.3.0"},
      # {:dep_from_git, git: "https://github.com/elixir-lang/my_dep.git", tag: "0.1.0"}
      {:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", branch: "main", sparse: "nx"}   #<- この一文を追記
    ]
  end

Nxの追記をセーブしたら、mix deps.getを実行して、Nxライブラリをインストールしましょう。

$ mix deps.get 

実行結果がこちら。
私はasdfでElixirをインストールしているので、asdfディレクトリ配下にインストールされています。

* Getting nx (https://github.com/elixir-nx/nx.git - origin/main)
remote: Enumerating objects: 231, done.
remote: Counting objects: 100% (231/231), done.
remote: Compressing objects: 100% (150/150), done.
remote: Total 6921 (delta 95), reused 179 (delta 69), pack-reused 6690
Receiving objects: 100% (6921/6921), 1.70 MiB | 1.21 MiB/s, done.
Resolving deltas: 100% (4569/4569), done.
==> nx
Could not find Hex, which is needed to build dependency :ex_doc
Shall I install Hex? (if running non-interactively, use "mix local.hex --force") [Yn] Y
* creating /Users/<ユーザ名>/.asdf/installs/elixir/1.11.3-otp-23/.mix/archives/hex-0.21.1
$

インストールがうまくいっているか、mix testで確認しときます。

$mix test
==> nx
Compiling 17 files (.ex)
Generated nx app
==> my_app
Compiling 1 file (.ex)
Generated my_app app
..

Finished in 0.05 seconds
1 doctest, 1 test, 0 failures

Randomized with seed 34230
$
Nxの動作確認

それではNxライブラリの動作確認に進みます。 iex内で作成したプロジェクトを読み込ませるので、iex -S mixでiexを起動します。

$ iex -S mix

こんな感じで起動しました。

Erlang/OTP 23 [erts-11.1.7] [source] [64-bit] [smp:4:4] [ds:4:4:10] [async-threads:1] [hipe]

==> nx
Compiling 17 files (.ex)
Generated nx app
==> my_app
Compiling 1 file (.ex)
Generated my_app app
Interactive Elixir (1.11.3) - press Ctrl+C to exit (type h() ENTER for help)
iex(1)>

Examplesの例の通りに入力。

iex(1)> t = Nx.tensor([[1, 2], [3, 4]])
#Nx.Tensor<
  s64[2][2]
  [
    [1, 2],
    [3, 4]
  ]
>
iex(2)> Nx.shape(t)
{2, 2}

2x2(2行2列)となって出力されてますね。
ちょっとデータを修正。

「3行2列」と「2行3列」を試してみる

iex(3)> t = Nx.tensor([ [1, 2], [3, 4], [5,6] ])
#Nx.Tensor<
  s64[3][2]
  [
    [1, 2],
    [3, 4],
    [5, 6]
  ]
>
iex(4)> Nx.shape(t)
{3, 2}
iex(5)> t = Nx.tensor([ [1, 2, 3], [4, 5, 6] ])
#Nx.Tensor<
  s64[2][3]
  [
    [1, 2, 3],
    [4, 5, 6]
  ]
>
iex(6)> Nx.shape(t)
{2, 3}

続いて、ソフトマックス関数(Softmax function)3の例。

ソフトマックス関数の式はこれ。

 \displaystyle
  y_i = \frac{e^{z_i}}{\sum_{j=1}^{N} e^{z_j}}

ソフトマックス関数を「Nx.divide」「Nx.exp」「Nx.sum」を利用して実施。

iex(7)> t = Nx.tensor([[1, 2], [3, 4]])
#Nx.Tensor<
  s64[2][2]
  [
    [1, 2],
    [3, 4]
  ]
>
iex(8)> Nx.divide(Nx.exp(t), Nx.sum(Nx.exp(t)))
#Nx.Tensor<
  f64[2][2]
  [
    [0.03205860328008499, 0.08714431874203257],
    [0.23688281808991013, 0.6439142598879722]
  ]
>

全部を合計すると、1になった。

iex(9)> 0.03205860328008499 + 0.08714431874203257 + 0.23688281808991013 + 0.6439142598879722
1.0

ちなみに、この例はSoftmax functionのExamplesの「Here is an example of Elixir code」としても載っていました。

https://en.wikipedia.org/wiki/Softmax_function

参考情報

既に日本のアルケミスト(=Elixir使い)達によって、Nxの実行結果の情報がQiitaに載っています。 その一部をご紹介。

今後はこちらも参考にしつつ、自分もNxについての知見を向上していきたいです。

余談:「NervesJP #15 Nxを触ってみる回」の話

冒頭で挙げた「NervesJP #15 Nxを触ってみる回」では、主にあんちぽ氏こと栗山健太郎さん(@kentaro)と、piacereさん(@piacere_ex)がお話をされていました。

栗山さんは、MNISTの手書き数字画像分類をされたお話をしてくださいました。

資料のリンクも貼っておきます。

speakerdeck.com

piacereさんについては、Elixirとマシンラーニングについての話を始めてとして、所属されているEDI (Elixir Digitalization Implementors)関連のお話など、多様な話題を提供してくださいました。

なお、Twitterのリンク先にも記載されていますが、EDIの中ではOSSとしていくつかのElixirライブラリを公開しようとしています。興味がある方は、ひとまずEDIに参加されてみてはいかがでしょうかね。
connpassでfukuoka.exに所属するか、Twitterハッシュタグ「 #ElixirDI 」で検索してみると良いでしょう。

fukuokaex.connpass.com

余談:Nxのマスコットについて

Nxのマスコットとして描かれている動物。

一見するとリスのようですが、これはNumbatMyrmecobius fasciatus)です。
和名だと、「フクロアリクイ」ですね。

マスコットについてはNxのReadmeにも、以下のように書かれています。

github.com

Nx's mascot is the Numbat, a marsupial native to southern Australia.

Nxのマスコットは、オーストラリア南部原産の有袋類Numbatです。

Unfortunately the Numbat are endangered and it is estimated to be fewer than 1000 left. If you enjoy this project, consider donating to Numbat conservation efforts, such as Project Numbat and Australian Wildlife Conservancy.

残念ながらヌンバットは絶滅の危機に瀕しており、その数は1000頭以下と推定されています。
このプロジェクトをお楽しみになられた方は、
Project NumbatやAustralian Wildlife Conservancyなどの
Numbatの保護活動への寄付をご検討ください。

Nxの恩恵に預かるなら寄付してみるか、と思い「Project Numbat」経由で寄付しようとしたら...、なんと、寄付できませんでした。
PayPal経由で実施するからなのかがいまいちわかりませんが、別の方法でも試してみます。

まとめ

Nxが活用されていくのはまだこれからだとは思われますが、期待に胸が震えるライブラリであることは間違い無いでしょう。
本格的に活用される前に、少しでも予備知識を仕入れていこうと考えています。


  1. https://github.com/elixir-nx/nx

  2. 「XLAを使うことでTensorFlowの演算を最適化し、メモリ使用量、性能、サーバやモバイル環境での移植性の面での改善が期待できる」とのこと。参考URL: https://www.tensorflow.org/xla/architecture?hl=ja

  3. 複数の出力値の合計が1.0(=100%)になるように変換して出力する関数。正規化指数関数: Normalized exponential functionともいう。