Githubには13,000個のスターがある。JAXの急速な発展はTensorFlowやPyTorchに匹敵する

Githubには13,000個のスターがある。JAXの急速な発展はTensorFlowやPyTorchに匹敵する

  [[416349]]

機械学習の分野では、TensorFlow と PyTorch は誰もがよく知っているかもしれませんが、これら 2 つのフレームワークに加えて、Google が立ち上げた JAX という新たな勢力も見逃せません。多くの研究者は、TensorFlow などの多くの機械学習フレームワークを置き換えることができると期待し、大きな期待を寄せています。

JAX はもともと、Google Brain チームの Matt Johnson、Roy Frostig、Dougal Maclaurin、Chris Leary によって開始されました。

現在、JAX は GitHub で 13.7K 個のスターを獲得しています。

プロジェクトアドレス: https://github.com/google/jax

JAXの急速な発展

JAX の前身は Autograd です。Autograd の更新版の助けを借りて、XLA と組み合わせることで、Python プログラムと NumPy 操作の自動微分を実行し、ループ、分岐、再帰、クロージャ関数の導出、および 3 次導関数をサポートできます。XLA に依存することで、JAX は GPU と TPU で NumPy プログラムをコンパイルして実行できます。grad を通じて、自動モードのバックプロパゲーションとフォワードプロパゲーションをサポートでき、2 つを任意の順序で組み合わせることができます。

JAX 開発の出発点は何でしたか?これについて言えば、NumPy について触れなければなりません。 NumPy は Python の基本的な数値計算ライブラリであり、広く使用されています。ただし、NumPy は GPU やその他のハードウェア アクセラレータをサポートしておらず、バックプロパゲーションのサポートも組み込まれていません。さらに、Python 自体の速度制限により NumPy の使用が妨げられるため、NumPy を直接使用してディープラーニング モデルを実稼働環境でトレーニングまたは展開する研究者はほとんどいません。

このような状況の中で、PyTorch、TensorFlow など、数多くのディープラーニング フレームワークが登場しました。ただし、numpy には柔軟性、デバッグの容易さ、安定した API などの独自の利点があります。 JAX の主な出発点は、numpy の上記の利点とハードウェア アクセラレーションを組み合わせることです。

現在、JAX をベースにした優れたオープンソース プロジェクトが数多く存在します。たとえば、Google のニューラル ネットワーク ライブラリ チームは、Jax 用のディープラーニング コード ライブラリである Haiku を開発しました。Haiku を通じて、ユーザーは Jax 上でオブジェクト指向開発を行うことができます。もう 1 つの例は、Jax をベースにした強化学習ライブラリである RLax です。ユーザーは RLax を使用して Q 学習モデルを構築およびトレーニングできます。さらに、1 行のコードで計算グラフを定義し、GPU アクセラレーションを実行できる JAX ベースのディープラーニング ライブラリ JAXnet もあります。ここ数年、JAXはディープラーニング研究に旋風を巻き起こし、科学研究の急速な発展を促進してきたと言えます。

JAX のインストール

JAX の使い方は?まず、Python 環境または Google Colab に JAX をインストールする必要があります。pip を使用してインストールします。

  1. $ pip インストール --upgrade jax jaxlib

上記のインストール方法は、CPU 上での実行のみをサポートしていることに注意してください。プログラムを GPU 上で実行する場合は、まず CUDA と cuDNN が必要で、その後次のコマンドを実行します (jaxlib バージョンを CUDA バージョンにマッピングするようにしてください)。

  1. $ pip インストール --upgrade jax jaxlib == 0.1.61 +cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

次に、Numpy とともに JAX をインポートします。

  1. jaxをインポートする
  2. jax.numpyをjnpとしてインポートする
  3. numpyをnpとしてインポートする

JAXの機能

grad() 関数を使用した自動微分: これはバックプロパゲーションの実行を容易にするため、ディープラーニング アプリケーションに非常に役立ちます。以下は、単純な 2 次関数の例で、ポイント 1.0 で導関数を取得します。

  1. jaxインポートgrad から
  2. 定義f(x):
  3. 3 *x** 2 + 2 *x + 5を返す
  4. f_prime(x)を定義します:
  5. 6 *x + 2を返す
  6. 卒業率( 1.0 )
  7. # デバイス配列( 8. , dtype=float32)
  8. f_prime( 1.0 )
  9. # 8.0

jit (ジャストインタイム): XLA のパワーを活用するには、コードを XLA カーネルにコンパイルする必要があります。ここで JIT が役立ちます。 XLA と jit を使用するには、ユーザーは jit() 関数または @jit アノテーションを使用できます。

  1. jaxからjitをインポート
  2. x = np.random.rand( 1000 , 1000 )
  3. y = jnp.array(x)
  4. 定義f(x):
  5. _ が範囲( 10 )内にある場合:
  6. x = 0.5 * x + 0.1 * jnp.sin(x)
  7. xを返す
  8. g = jit(f)
  9. %timeit -n 5 -r 5 f(y).block_until_ready()
  10. # 5ループ、ベスト5 : ループあたり10.8ミリ秒
  11. %timeit -n 5 -r 5 g(y).block_until_ready()
  12. # 5ループ、ベスト5 : ループあたり341 µs

pmap: 現在のすべてのデバイスに計算を自動的に分散し、それらの間のすべての通信を処理します。 JAX は pmap 変換を通じて大規模なデータ並列処理をサポートし、単一のプロセッサでは処理できない大規模なデータを処理します。利用可能なデバイスを確認するには、jax.devices() を実行します。

  1. jaxからpmapをインポート
  2. 定義f(x):
  3. jnp.sin(x) + x** 2を返す
  4. f(np.arange( 4 )) は、
  5. #デバイス配列([ 0 . , 1.841471 , 4.9092975 , 9.14112 ], dtype=float32)
  6. pmap(f)(np.arange( 4 ))
  7. #ShardedDeviceArray([ 0 . , 1.841471 , 4.9092975 , 9.14112 ], dtype=float32)

vmap: 関数変換です。JAX は vmap 変換による自動ベクトル化アルゴリズムを提供します。これにより、このタイプの計算が大幅に簡素化され、研究者はバッチの問題に悩まされることなく新しいアルゴリズムを扱えるようになります。次に例を示します。

  1. jaxからvmapをインポート
  2. 定義f(x):
  3. jnp.square(x)を返す
  4. f(jnp.arange( 10 ))
  5. #デバイス配列([ 0 , 1 , 4 , 9 , 16 , 25 , 36 , 49 , 64 , 81 ], dtype=int32)
  6. vmap(f)(jnp.arange( 10 ))
  7. #デバイス配列([ 0 , 1 , 4 , 9 , 16 , 25 , 36 , 49 , 64 , 81 ], dtype=int32)

TensorFlow 対 PyTorch 対 Jax

ディープラーニングの分野には巨大企業がいくつもあり、彼らが提案するフレームワークは多くの研究者に利用されています。たとえば、Google の TensorFlow、Facebook の PyTorch、Microsoft の CNTK、Amazon AWS の MXnet などです。

各フレームワークには長所と短所があり、自分のニーズに応じて選択する必要があります。

Python の 3 つの主要なディープラーニング フレームワーク (TensorFlow、PyTorch、Jax) を比較します。これらのフレームワークは異なりますが、共通点が 2 つあります。

  • それらはオープンソースです。つまり、ライブラリにバグがある場合、ユーザーは GitHub で問題を報告して修正してもらうことができ、また、独自の機能をライブラリに追加することもできます。
  • Python は、グローバル インタープリタ ロックが原因で内部的に遅く実行されます。したがって、これらのフレームワークは、すべての計算と並列プロセスを処理するために、バックエンドとして C/C++ を使用します。

では、どのような点が異なるのでしょうか?次の表は、TensorFlow、PyTorch、JAX の 3 つのフレームワークの比較を示しています。

テンソルフロー

TensorFlow は Google によって開発され、その最初のバージョンは 2015 年のオープンソースの TensorFlow0.1 にまで遡ります。それ以来、着実に発展し、強力なユーザーベースを持ち、最も人気のあるディープラーニング フレームワークになりました。しかし、使用してみると、API の安定性が不十分であったり、静的計算グラフ プログラミングが複雑であったりするなど、TensorFlow の欠点も明らかになりました。そのため、TensorFlow 2.0 バージョンでは、Google が Keras を組み込み、tf.keras になりました。

TensorFlow の主な機能は次のとおりです。

  • これは非常にユーザーフレンドリーなフレームワークです。高レベル API-Keras が利用できるため、モデル レイヤーの定義、損失関数、モデルの作成が非常に簡単になります。
  • TensorFlow 2.0 には Eager Execution が付属しており、これによりライブラリがよりユーザーフレンドリーになり、以前のバージョンから大幅にアップグレードされています。
  • この高レベル インターフェースには、いくつかの欠点があります。TensorFlow は、エンド ユーザーの利便性のためだけに、多くの基礎となるメカニズムを抽象化しているため、研究者はモデルを処理する自由度が低くなります。
  • Tensorflow は、実際には Tensorflow 視覚化ツールキットである TensorBoard を提供します。これにより、研究者は損失関数、モデルグラフ、モデル分析などを視覚化できます。

パイトーチ

PyTorch (Python-Torch) は、Facebook の機械学習ライブラリです。 TensorFlow か PyTorch か? 1 年前、この質問には異論はなく、ほとんどの研究者が TensorFlow を選択しました。しかし、今では状況は大きく変わり、PyTorch を使用する研究者が増えています。 PyTorch の最も重要な機能には次のようなものがあります。

  • TensorFlow とは異なり、PyTorch は動的型グラフを使用します。つまり、実行グラフはオンザフライで作成されます。いつでもグラフの内部構造を変更したり検査したりすることができます。
  • PyTorch には、ユーザーフレンドリーな高レベル API に加えて、機械学習モデルをより細かく制御できるように慎重に構築された低レベル API も含まれています。トレーニング中に、モデルの前方パスと後方パスの両方の出力を検査および変更できます。これは、グラデーション クリッピングとニューラル スタイル転送に非常に効果的であることが示されています。
  • PyTorch を使用すると、ユーザーはコードを拡張して、新しい損失関数やユーザー定義のレイヤーを簡単に追加できます。 PyTorch の Autograd モジュールは、ディープラーニング アルゴリズムにバックプロパゲーション微分を実装します。Tensor クラスのすべての操作に対して、Autograd は微分を自動的に提供し、手動で微分を計算する複雑なプロセスを簡素化します。
  • PyTorch は、データ並列処理と GPU の使用を幅広くサポートしています。
  • PyTorch は TensorFlow よりも Python 的です。 PyTorch は Python エコシステムにうまく適合し、Python のようなデバッガー ツールを使用して PyTorch コードをデバッグできます。

ジャックス

JAX は、Google の比較的新しい機械学習ライブラリです。これは、ネイティブ Python と NumPy コードを区別できる autograd ライブラリのようなものです。 JAX の主な機能は次のとおりです。

  • 公式ウェブサイトに記載されているように、JAX は Python + NumPy プログラムの構成可能な変換 (ベクトル化、JIT から GPU/TPU など) を実行できます。
  • PyTorch と比較した JAX の最も重要な側面は、勾配の計算方法です。 Torch では、グラフはフォワード パス中に作成され、勾配はバックワード パス中に計算されますが、一方、JAX では計算は関数として表現されます。関数に grad() を使用すると、指定された入力に対する関数の勾配を直接計算する勾配関数が返されます。
  • JAX は自動グレード ツールであり、単独での使用は推奨されません。 JAX ベースの機械学習ライブラリはさまざまありますが、その中でも注目すべきものとしては ObJax、Flax、Elegy などがあります。これらはすべて同じコアを使用し、インターフェースは JAX ライブラリのラッパーにすぎないため、同じ括弧内に配置できます。
  • Flax はもともと PyTorch エコシステムの下で開発され、使用の柔軟性に重点を置いていました。一方、Elegy は Keras からインスピレーションを受けています。 ObJAX は、シンプルさとわかりやすさを重視し、主に研究指向の目的で設計されています。

<<:  人工知能の「想像力」を実現する

>>:  人工知能に関する世界インターネット会議の8つの視点のレビュー

ブログ    

推薦する

近年の機械学習の奇妙な状況

翻訳者注:人工知能分野の発展は学者の貢献と切り離せないものです。しかし、研究が進むにつれて、「クリッ...

AI開発シンポジウム:機械学習を農家に役立てる方法について議論

この記事は、公開アカウント「Reading the Core」(ID: AI_Discovery)か...

ロボットは銀行業務を破壊するのか?

[[223220]]世界経済フォーラムの最近のレポートでは、2020年までに先進国で500万の雇用...

Java スパニングツリー構造 ポイント間の最短経路アルゴリズム

まずは二分木についてお話しましょう。二分木は、各ポイントが 2 つのポイントに接続されているツリー構...

...

AIがネットワークゴミを生み出す:古いインターネットは死につつあり、新しいインターネットは困難の中で生まれる

網易科技は6月27日、ここ数カ月、インターネットの方向性が変化したことを示すさまざまな兆候があると報...

...

シングルチッププロセッサの終焉?アップルとNVIDIAはマルチチップパッケージングに興味を持っており、相互接続技術が鍵となる

3月10日、Appleは2022年春のカンファレンスで、M1 Maxチップのアップグレード版であるM...

表形式データでの機械学習に特徴抽出を使用する方法

データ準備の最も一般的なアプローチは、データセットを調査し、機械学習アルゴリズムの期待値を確認し、最...

人工知能が銀行業界の変革を加速します!ビッグデータにより各ユーザーの信用格付けが提供されます!

[[221188]]将来、人工知能が 380 万人以上の銀行員の仕事を全て置き換える日が来るのでし...

自動化によってセキュリティアナリストがいなくなる可能性はありますか?

否定できない現実として、私たちは自動化の時代に入り、それに伴い人工知能 (AI)、機械学習 (ML)...

機械学習で人気のアルゴリズムトップ10

現在、機械学習のためのアルゴリズムは数多く存在します。初心者にとってはかなり圧倒されるかもしれません...

最新の! 2018年中国プログラマーの給与と生活に関する調査レポート

中国インターネット情報センター(CNNIC)が発表した第41回中国インターネット発展統計報告によると...

データ構造とアルゴリズム - グラフ理論: 連結成分と強連結成分の検出

無向グラフの連結成分を見つける深さ優先探索を使用すると、グラフのすべての接続コンポーネントを簡単に見...