Githubには13,000個のスターがある。JAXの急速な発展はTensorFlowやPyTorchに匹敵する

Githubには13,000個のスターがある。JAXの急速な発展はTensorFlowやPyTorchに匹敵する

  [[416349]]

機械学習の分野では、TensorFlow と PyTorch は誰もがよく知っているかもしれませんが、これら 2 つのフレームワークに加えて、Google が立ち上げた JAX という新たな勢力も見逃せません。多くの研究者は、TensorFlow などの多くの機械学習フレームワークを置き換えることができると期待し、大きな期待を寄せています。

JAX はもともと、Google Brain チームの Matt Johnson、Roy Frostig、Dougal Maclaurin、Chris Leary によって開始されました。

現在、JAX は GitHub で 13.7K 個のスターを獲得しています。

プロジェクトアドレス: https://github.com/google/jax

JAXの急速な発展

JAX の前身は Autograd です。Autograd の更新版の助けを借りて、XLA と組み合わせることで、Python プログラムと NumPy 操作の自動微分を実行し、ループ、分岐、再帰、クロージャ関数の導出、および 3 次導関数をサポートできます。XLA に依存することで、JAX は GPU と TPU で NumPy プログラムをコンパイルして実行できます。grad を通じて、自動モードのバックプロパゲーションとフォワードプロパゲーションをサポートでき、2 つを任意の順序で組み合わせることができます。

JAX 開発の出発点は何でしたか?これについて言えば、NumPy について触れなければなりません。 NumPy は Python の基本的な数値計算ライブラリであり、広く使用されています。ただし、NumPy は GPU やその他のハードウェア アクセラレータをサポートしておらず、バックプロパゲーションのサポートも組み込まれていません。さらに、Python 自体の速度制限により NumPy の使用が妨げられるため、NumPy を直接使用してディープラーニング モデルを実稼働環境でトレーニングまたは展開する研究者はほとんどいません。

このような状況の中で、PyTorch、TensorFlow など、数多くのディープラーニング フレームワークが登場しました。ただし、numpy には柔軟性、デバッグの容易さ、安定した API などの独自の利点があります。 JAX の主な出発点は、numpy の上記の利点とハードウェア アクセラレーションを組み合わせることです。

現在、JAX をベースにした優れたオープンソース プロジェクトが数多く存在します。たとえば、Google のニューラル ネットワーク ライブラリ チームは、Jax 用のディープラーニング コード ライブラリである Haiku を開発しました。Haiku を通じて、ユーザーは Jax 上でオブジェクト指向開発を行うことができます。もう 1 つの例は、Jax をベースにした強化学習ライブラリである RLax です。ユーザーは RLax を使用して Q 学習モデルを構築およびトレーニングできます。さらに、1 行のコードで計算グラフを定義し、GPU アクセラレーションを実行できる JAX ベースのディープラーニング ライブラリ JAXnet もあります。ここ数年、JAXはディープラーニング研究に旋風を巻き起こし、科学研究の急速な発展を促進してきたと言えます。

JAX のインストール

JAX の使い方は?まず、Python 環境または Google Colab に JAX をインストールする必要があります。pip を使用してインストールします。

  1. $ pip インストール --upgrade jax jaxlib

上記のインストール方法は、CPU 上での実行のみをサポートしていることに注意してください。プログラムを GPU 上で実行する場合は、まず CUDA と cuDNN が必要で、その後次のコマンドを実行します (jaxlib バージョンを CUDA バージョンにマッピングするようにしてください)。

  1. $ pip インストール --upgrade jax jaxlib == 0.1.61 +cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

次に、Numpy とともに JAX をインポートします。

  1. jaxをインポートする
  2. jax.numpyをjnpとしてインポートする
  3. numpyをnpとしてインポートする

JAXの機能

grad() 関数を使用した自動微分: これはバックプロパゲーションの実行を容易にするため、ディープラーニング アプリケーションに非常に役立ちます。以下は、単純な 2 次関数の例で、ポイント 1.0 で導関数を取得します。

  1. jaxインポートgrad から
  2. 定義f(x):
  3. 3 *x** 2 + 2 *x + 5を返す
  4. f_prime(x)を定義します:
  5. 6 *x + 2を返す
  6. 卒業率( 1.0 )
  7. # デバイス配列( 8. , dtype=float32)
  8. f_prime( 1.0 )
  9. # 8.0

jit (ジャストインタイム): XLA のパワーを活用するには、コードを XLA カーネルにコンパイルする必要があります。ここで JIT が役立ちます。 XLA と jit を使用するには、ユーザーは jit() 関数または @jit アノテーションを使用できます。

  1. jaxからjitをインポート
  2. x = np.random.rand( 1000 , 1000 )
  3. y = jnp.array(x)
  4. 定義f(x):
  5. _ が範囲( 10 )内にある場合:
  6. x = 0.5 * x + 0.1 * jnp.sin(x)
  7. xを返す
  8. g = jit(f)
  9. %timeit -n 5 -r 5 f(y).block_until_ready()
  10. # 5ループ、ベスト5 : ループあたり10.8ミリ秒
  11. %timeit -n 5 -r 5 g(y).block_until_ready()
  12. # 5ループ、ベスト5 : ループあたり341 µs

pmap: 現在のすべてのデバイスに計算を自動的に分散し、それらの間のすべての通信を処理します。 JAX は pmap 変換を通じて大規模なデータ並列処理をサポートし、単一のプロセッサでは処理できない大規模なデータを処理します。利用可能なデバイスを確認するには、jax.devices() を実行します。

  1. jaxからpmapをインポート
  2. 定義f(x):
  3. jnp.sin(x) + x** 2を返す
  4. f(np.arange( 4 )) は、
  5. #デバイス配列([ 0 . , 1.841471 , 4.9092975 , 9.14112 ], dtype=float32)
  6. pmap(f)(np.arange( 4 ))
  7. #ShardedDeviceArray([ 0 . , 1.841471 , 4.9092975 , 9.14112 ], dtype=float32)

vmap: 関数変換です。JAX は vmap 変換による自動ベクトル化アルゴリズムを提供します。これにより、このタイプの計算が大幅に簡素化され、研究者はバッチの問題に悩まされることなく新しいアルゴリズムを扱えるようになります。次に例を示します。

  1. jaxからvmapをインポート
  2. 定義f(x):
  3. jnp.square(x)を返す
  4. f(jnp.arange( 10 ))
  5. #デバイス配列([ 0 , 1 , 4 , 9 , 16 , 25 , 36 , 49 , 64 , 81 ], dtype=int32)
  6. vmap(f)(jnp.arange( 10 ))
  7. #デバイス配列([ 0 , 1 , 4 , 9 , 16 , 25 , 36 , 49 , 64 , 81 ], dtype=int32)

TensorFlow 対 PyTorch 対 Jax

ディープラーニングの分野には巨大企業がいくつもあり、彼らが提案するフレームワークは多くの研究者に利用されています。たとえば、Google の TensorFlow、Facebook の PyTorch、Microsoft の CNTK、Amazon AWS の MXnet などです。

各フレームワークには長所と短所があり、自分のニーズに応じて選択する必要があります。

Python の 3 つの主要なディープラーニング フレームワーク (TensorFlow、PyTorch、Jax) を比較します。これらのフレームワークは異なりますが、共通点が 2 つあります。

  • それらはオープンソースです。つまり、ライブラリにバグがある場合、ユーザーは GitHub で問題を報告して修正してもらうことができ、また、独自の機能をライブラリに追加することもできます。
  • Python は、グローバル インタープリタ ロックが原因で内部的に遅く実行されます。したがって、これらのフレームワークは、すべての計算と並列プロセスを処理するために、バックエンドとして C/C++ を使用します。

では、どのような点が異なるのでしょうか?次の表は、TensorFlow、PyTorch、JAX の 3 つのフレームワークの比較を示しています。

テンソルフロー

TensorFlow は Google によって開発され、その最初のバージョンは 2015 年のオープンソースの TensorFlow0.1 にまで遡ります。それ以来、着実に発展し、強力なユーザーベースを持ち、最も人気のあるディープラーニング フレームワークになりました。しかし、使用してみると、API の安定性が不十分であったり、静的計算グラフ プログラミングが複雑であったりするなど、TensorFlow の欠点も明らかになりました。そのため、TensorFlow 2.0 バージョンでは、Google が Keras を組み込み、tf.keras になりました。

TensorFlow の主な機能は次のとおりです。

  • これは非常にユーザーフレンドリーなフレームワークです。高レベル API-Keras が利用できるため、モデル レイヤーの定義、損失関数、モデルの作成が非常に簡単になります。
  • TensorFlow 2.0 には Eager Execution が付属しており、これによりライブラリがよりユーザーフレンドリーになり、以前のバージョンから大幅にアップグレードされています。
  • この高レベル インターフェースには、いくつかの欠点があります。TensorFlow は、エンド ユーザーの利便性のためだけに、多くの基礎となるメカニズムを抽象化しているため、研究者はモデルを処理する自由度が低くなります。
  • Tensorflow は、実際には Tensorflow 視覚化ツールキットである TensorBoard を提供します。これにより、研究者は損失関数、モデルグラフ、モデル分析などを視覚化できます。

パイトーチ

PyTorch (Python-Torch) は、Facebook の機械学習ライブラリです。 TensorFlow か PyTorch か? 1 年前、この質問には異論はなく、ほとんどの研究者が TensorFlow を選択しました。しかし、今では状況は大きく変わり、PyTorch を使用する研究者が増えています。 PyTorch の最も重要な機能には次のようなものがあります。

  • TensorFlow とは異なり、PyTorch は動的型グラフを使用します。つまり、実行グラフはオンザフライで作成されます。いつでもグラフの内部構造を変更したり検査したりすることができます。
  • PyTorch には、ユーザーフレンドリーな高レベル API に加えて、機械学習モデルをより細かく制御できるように慎重に構築された低レベル API も含まれています。トレーニング中に、モデルの前方パスと後方パスの両方の出力を検査および変更できます。これは、グラデーション クリッピングとニューラル スタイル転送に非常に効果的であることが示されています。
  • PyTorch を使用すると、ユーザーはコードを拡張して、新しい損失関数やユーザー定義のレイヤーを簡単に追加できます。 PyTorch の Autograd モジュールは、ディープラーニング アルゴリズムにバックプロパゲーション微分を実装します。Tensor クラスのすべての操作に対して、Autograd は微分を自動的に提供し、手動で微分を計算する複雑なプロセスを簡素化します。
  • PyTorch は、データ並列処理と GPU の使用を幅広くサポートしています。
  • PyTorch は TensorFlow よりも Python 的です。 PyTorch は Python エコシステムにうまく適合し、Python のようなデバッガー ツールを使用して PyTorch コードをデバッグできます。

ジャックス

JAX は、Google の比較的新しい機械学習ライブラリです。これは、ネイティブ Python と NumPy コードを区別できる autograd ライブラリのようなものです。 JAX の主な機能は次のとおりです。

  • 公式ウェブサイトに記載されているように、JAX は Python + NumPy プログラムの構成可能な変換 (ベクトル化、JIT から GPU/TPU など) を実行できます。
  • PyTorch と比較した JAX の最も重要な側面は、勾配の計算方法です。 Torch では、グラフはフォワード パス中に作成され、勾配はバックワード パス中に計算されますが、一方、JAX では計算は関数として表現されます。関数に grad() を使用すると、指定された入力に対する関数の勾配を直接計算する勾配関数が返されます。
  • JAX は自動グレード ツールであり、単独での使用は推奨されません。 JAX ベースの機械学習ライブラリはさまざまありますが、その中でも注目すべきものとしては ObJax、Flax、Elegy などがあります。これらはすべて同じコアを使用し、インターフェースは JAX ライブラリのラッパーにすぎないため、同じ括弧内に配置できます。
  • Flax はもともと PyTorch エコシステムの下で開発され、使用の柔軟性に重点を置いていました。一方、Elegy は Keras からインスピレーションを受けています。 ObJAX は、シンプルさとわかりやすさを重視し、主に研究指向の目的で設計されています。

<<:  人工知能の「想像力」を実現する

>>:  人工知能に関する世界インターネット会議の8つの視点のレビュー

ブログ    

推薦する

...

音声認識データベースが人工知能の中核となる

音声認識データベースと音声合成データベースは、人工知能の重要な技術です。機械が人間のように聞き、話し...

デジタル経済は新たな時代へ:インターネットが主導権を握り、ビッグデータと人工知能が注目の的

[[208505]]強固な経済基盤がなければ、豊かな国と強い国民は実現できません。中国共産党第19回...

...

ジェネレーティブ AI と自動化: 未来のデータ センターを加速

自動化と生成型人工知能 (GenAI) の時代において、「データセンター」の本当の意味を再考する時が...

ニューロモルフィックコンピューティングを理解する: 基本原理から実験的検証まで

人間の脳は、効率的な生体エネルギーによって計算能力を部分的にサポートし、ニューロンを基本的な発火単位...

OpenAI憲章中国語版

この文書は、OpenAI 内外の多くの人々からのフィードバックを含め、過去 2 年間にわたって改良し...

AIが脳波を80%以上の精度で解読!あなたの目の中で最も美しいtaを高度に復元します

千人の人々の目には千のハムレットがいる。主観的な違いにより、人間には何千万通りもの異なる美的嗜好が存...

...

エンジニアリングパフォーマンスを分析してデータ駆動型チームを構築

Gigster の副社長 Cory Hymel 氏は、2024 年にさらなる適応力と成功を実現するた...

インターネットと自動車の大手企業が「自動運転」に賭けているのはなぜでしょうか?

米国現地時間の水曜日、マスク氏はソーシャルメディア上で、同社が今週、一部の選ばれた顧客に対して初の「...

エッジAIの進歩が次世代ドローンのイノベーションをどう推進するか

ここ数年、ドローンをめぐる革新は数多くありました。 いくつかの企業はすでに、荷物や食品の配達のほか、...

...

CAPとPaxosコンセンサスアルゴリズムについての簡単な説明

CAPとはCAP理論についてはすでに多くの背景情報が語られているので、ここでは詳しくは触れません。ど...

ロボットが家庭に入り、人工知能の夢はもはや高価ではない

[[221538]]人工知能とは何ですか? 「第一次産業革命における蒸気機関、第二次産業革命における...