34 個の事前トレーニング済みモデルを比較して再現します。PyTorch と Keras のどちらを選択しますか?

34 個の事前トレーニング済みモデルを比較して再現します。PyTorch と Keras のどちらを選択しますか?

Keras と PyTorch は確かに最も初心者に優しいディープラーニング フレームワークであり、アーキテクチャを記述するシンプルな言語のように機能し、どのレイヤーで何を使用するかをフレームワークに指示します。これにより、静的計算グラフの設計、各テンソルの次元と内容の個別の定義など、多くの抽象的な作業が削減されます。

しかし、どのフレームワークが優れているのでしょうか? もちろん、開発者や研究者によって好みや意見は異なります。この記事では、抽象化とパフォーマンスの観点から PyTorch と Keras を比較し、2 つのフレームワークのすべての事前トレーニング済みモデルを再現して比較する新しいベンチマークを紹介します。

Keras と PyTorch のベンチマーク プロジェクトでは、MIT の博士課程の学生 Curtis G. Northcutt が 34 個の事前トレーニング済みモデルを再現しました。このベンチマークは、Keras と PyTorch を組み合わせて 1 つのフレームワークに統合することで、2 つのフレームワークを比較し、さまざまなモデルにどちらのフレームワークが適しているかを知ることができます。たとえば、プロジェクト作成者は、ResNet アーキテクチャ モデルでは Keras よりも PyTorch の方が適しており、Inception アーキテクチャ モデルでは PyTorch よりも Keras の方が適していると述べています。

Keras と PyTorch ベンチマーク プロジェクト: https://github.com/cgnorthcutt/benchmarking-keras-pytorch

1. 2つの主要フレームワークのパフォーマンスと使いやすさ

Keras は TensorFlow の高度にカプセル化されたバージョンであるため、抽象化のレベルが非常に高く、多くの API の詳細が隠されています。 PyTorch は TensorFlow の静的計算グラフよりも使いやすいですが、Keras では全体的に詳細が隠されています。パフォーマンスに関しては、実際には各フレームワークで多くの最適化が行われており、その違いはあまり明白ではなく、主な選択基準にはなりません。

1. 使いやすさ

Keras は、一般的に使用されるディープラーニングのレイヤーと操作を便利なビルディング ブロックにカプセル化し、ビルディング ブロックのように複雑なモデルを構築する、より高レベルのフレームワークです。開発者や研究者は、ディープラーニングの複雑さを考慮する必要がありません。

PyTorch は比較的低レベルの実験環境を提供し、ユーザーがカスタム レイヤーを記述したり、数値最適化タスクを探索したりする自由度を高めます。たとえば、PyTorch 1.0 では、コンパイル ツール torch.jit に、Python のサブ言語である Torch Script という言語が含まれています。開発者はこれを使用して、モデルをさらに最適化できます。

単純な畳み込みネットワークを定義すると、両方の使いやすさがわかります。

  1. モデル=シーケンシャル()
  2. model.add(Conv2D(32, (3, 3), activation = 'relu' , input_shape =(32, 32, 3)))
  3. モデルを追加します(MaxPool2D())
  4. model.add(Conv2D(16, (3, 3),アクティベーション= 'relu' ))
  5. モデルを追加します(MaxPool2D())
  6. モデルを追加します(フラット化())
  7. model.add(Dense(10, activation = 'softmax' ))

上記のように、Keras はこのように定義されています。多くの場合、操作はパラメータとして API に埋め込まれているため、コードは非常に簡潔になります。以下はPyTorchの定義方法です。一般的にはクラスとインスタンスを通じて定義され、特定の操作の多くの次元のパラメータを定義する必要があります。

  1. クラスNet(nn.Module):
  2. __init__(self)を定義します。
  3. super(Net, self).__init__()
  4.  
  5. 自己.conv1 = nn.Conv2d (3, 32, 3)
  6. 自己.conv2 = nn.Conv2d (32, 16, 3)
  7. 自己.fc1 = nn.線形(16 * 6 * 6, 10)
  8. 自己プール= nn.MaxPool2d (2, 2)
  9.  
  10. def forward(self, x):
  11. x = self .pool(F.relu(self.conv1(x)))
  12. x = self .pool(F.relu(self.conv2(x)))
  13. x x = x.view(-1, 16 * 6 * 6)
  14. x = F .log_softmax(self.fc1(x),次元=-1)
  15.  
  16. xを返す
  17.  
  18. モデル=ネット()

Keras は PyTorch よりも使いやすいように感じますが、両者の違いは大きくなく、どちらもモデルがより便利に記述されることが期待されます。

2. パフォーマンス

さまざまなフレームワークのパフォーマンスを比較する実験は数多くあり、PyTorch のトレーニング速度が Keras よりも速いことが示されています。次の 2 つのグラフは、さまざまなハードウェアおよびモデル タイプでのさまざまなフレームワークのパフォーマンスを示しています。

次の 2 つの図も、PyTorch および Keras フレームワークでのさまざまなモデルのパフォーマンスを示しています。これら 18 年間のテストはどちらも、PyTorch が Keras よりもわずかに高速であることを示しています。

これら 2 つの比較の詳細については、以下を参照してください。

  • https://github.com/ilkarman/DeepLearningFrameworks/
  • https://wrosinski.github.io/deep-learning-frameworks/

2. Keras と PyTorch ベンチマーク

さて、事前学習済みモデルの観点から見ると、異なるフレームワーク上での同じモデルの検証セットの精度はどの程度でしょうか。このプロジェクトでは、著者は 2 つのフレームワークを使用して合計 34 の事前学習済みモデルを再現し、すべての事前学習済みモデルの検証精度を示しました。したがって、このプロジェクトは比較の基盤としてだけでなく、学習リソースとしても使用できます。クラシック モデル コードを直接学習するより良い方法があるでしょうか?

1. 事前トレーニング済みのモデルはすでに再現可能ではないのですか?

PyTorch ではこのように動作します。しかし、Keras ユーザーの中には、再現が非常に難しいと感じる人もいます。遭遇する問題は、次の 3 つのカテゴリに分けられます。

  • サンプルコードを正確にコピーしたとしても、Keras の公開されたベンチマーク結果を再現することはできません。実際、報告されている精度(2019 年 2 月現在)は、実際の精度よりもわずかに高いことがよくあります。
  • 事前トレーニング済みの Keras モデルの中には、サーバーにデプロイしたり、他の Keras モデルと連続して実行したりすると、一貫性がなかったり、精度が低くなったりするものもあります。
  • バッチ正規化 (BN) を使用する Keras モデルは信頼できない可能性があります。一部のモデルでは、順方向伝播評価によって推論フェーズ中に重みが変更されることがあります。

これらの問題は実際に存在し、元の​​ GitHub プロジェクトでは各問題へのリンクが提供されています。プロジェクト作成者の目標の 1 つは、Keras の事前トレーニング済みモデルの再現可能なベンチマークを作成することで、上記の問題の一部に対処することです。解決策は、Keras で実行する必要がある次の 3 つの側面に分けることができます。推論中にバッチを回避する。

これは非常に遅く、一度に 1 つの例を実行しますが、各モデルに対して再現可能な出力が得られます。

次のモデルがロードされるときに前のモデルから何もメモリに保持されないようにするには、モデルをローカル関数またはステートメントでのみ実行します。

2. 事前学習済みモデルの再現結果

以下は、Keras と PyTorch の「実際の」検証セットの精度の表です (macOS 10.11.6、Linux Debian 9、Ubuntu 18.04 で検証済み)。

3. 複製方法

まず、50,000 枚の画像を含む ImageNet 2012 検証セットをダウンロードする必要があります。 ILSVRC2012_img_val.tar のダウンロードが完了したら、次のコマンド ラインを実行して検証セットを前処理/抽出します。

  1. # Soumith の功績: https://github.com/soumith/imagenet-multiGPU.torch
  2. $ cd ../ && mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar
  3. $ wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | バッシュ

ImageNet 検証セットの各例の上位 5 つの予測が推定されました。次のコマンドラインを実行すると、これらの事前計算された結果が直接使用され、Keras および PyTorch ベンチマークが数秒で再現されます。

  1. $ git クローン https://github.com:cgnorthcutt/imagenet-benchmarking.git
  2. $ cd ベンチマーク-keras-pytorch
  3. $ python imagenet_benchmarking.py /path/to/imagenet_val_data

Keras と PyTorch のそれぞれの推論出力は、事前に計算されたデータを使用せずに再現できます。 Keras での推論には長い時間 (5 ~ 10 時間) がかかります。これは、フォワード パスが一度に 1 つの例ごとに計算され、ベクトル計算が回避されるためです。同じ精度を確実に再現したい場合、これがこれまでに見つかった最良の方法です。 PyTorch での推論は非常に高速です (1 時間未満)。再現するコードは次のとおりです。

  1. $ git クローン https://github.com:cgnorthcutt/imagenet-benchmarking.git
  2. $ cd ベンチマーク-keras-pytorch
  3. $ # PyTorch モデルの出力を計算する (1 時間)
  4. $ ./imagenet_pytorch_get_predictions.py /path/to/imagenet_val_data
  5. $ # Keras モデルの出力を計算する (5~10 時間)
  6. $ ./imagenet_keras_get_predictions.py /path/to/imagenet_val_data
  7. $ # ベンチマーク結果を表示
  8. $ ./imagenet_benchmarking.py /path/to/imagenet_val_data

GPU の使用状況、バッチ サイズ、出力ストレージ ディレクトリなどを制御できます。コマンドライン オプションを表示するには、-h フラグを付けて実行します。

記事を読んで、どちらが好きになりましたか?

オリジナルリンク: http://l7.curtisnorthcutt.com/towards-reproducibility-benchmarking-keras-pytorch

[この記事は51CTOコラム「Machine Heart」、WeChatパブリックアカウント「Machine Heart(id:almosthuman2014)」によるオリジナル翻訳です]

この著者の他の記事を読むにはここをクリックしてください

<<:  AIエンジニアの年収はわずか50万元程度で、年間100万元を稼ぐには長年の経験が必要です。

>>:  自分に最適なオープンソース フレームワークを選択するにはどうすればよいでしょうか?

ブログ    

推薦する

テラデータ、Vantage Customer ExperienceとVantage Analystを発表

ユビキタス データ インテリジェンス テクノロジーを提供する世界唯一のプロバイダーである Terad...

...

...

この記事では、さまざまな教師なしクラスタリングアルゴリズムのPython実装について簡単に説明します。

教師なし学習は、データ内のパターンを見つけるために使用される機械学習技術の一種です。教師なし学習アル...

クロスモーダルトランスフォーマー: 高速かつ堅牢な 3D オブジェクト検出に向けて

この記事は、Heart of Autonomous Driving の公開アカウントから許可を得て転...

...

...

目から涙が溢れてきました!ビクーニャのデジタルツインは10年前の自分を再現し、10年間の対話は数え切れないほどの人々に影響を与えた

Reddit のネットユーザーが何か新しいことをやっている。彼は、自身のオンラインフットプリントデー...

...

Verdict、2020年第1四半期のTwitterにおけるIoTトレンドトップ5を発表

私たちは、企業や専門家が IoT についてどう考えているかを知りたいと思っていますが、一般の人々はど...

...

人工知能分野で急成長を遂げている企業の主な問題点

AI 分野で急成長しているビジネスを運営し、成長させるには、プロセスの構築、顧客の成功、人材の獲得、...

K8S向け機械学習ツール「Kubeflow」の詳しい解説

[51CTO.com オリジナル記事] Kubeflowには多くのコンポーネントがあり、各コンポーネ...

人工知能が台頭しています。インテリジェントセキュリティの開発はどのように進んでいますか?

セキュリティ業界は、人工知能の市場を長く有する業界として、人工知能の発展に対する理解がより明確で、そ...

...