ディープラーニングにおけるバッチ正規化の落とし穴

ディープラーニングにおけるバッチ正規化の落とし穴

[[191744]]

バッチ正規化は、ディープラーニングにおいて最近登場した効果的な手法です。その有効性は広く実証されており、研究やアプリケーションに急速に応用されています。この投稿は、読者がバッチ正規化とは何かを知っており、その仕組みについてある程度理解していることを前提としています。この概念を初めて知ったり、復習する必要がある場合は、次のリンク (http://blog.csdn.net/malefactor/article/details/51476961) でバッチ正規化の簡単な概要を参照してください。

この論文では、2 つの異なる方法を使用してニューラル ネットワークを実装します。各ステップで同じデータが入力されます。ネットワークには、まったく同じ損失関数、まったく同じハイパーパラメータ、まったく同じオプティマイザーがあります。その後、まったく同じ数の GPU でトレーニングが実行されます。結果として、一方のバージョンの分類精度は他方のバージョンよりも 2% 低く、このパフォーマンスの低下は非常に安定しているように見えます。

単純な MNIST と SVHN の分類問題を例に挙げてみましょう。

最初の実装では、MNIST データのバッチと SVHN データのバッチが抽出され、結合されてからネットワークに送られます。

2 番目の実装では、ネットワークのコピーが 2 つ作成され、重みが共有されます。 1 つのコピーには MNIST データが入力され、もう 1 つのコピーには SVHN データが入力されます。

どちらの実装でも、データの半分は MNIST で、残りの半分は SVHN であることに注意してください。さらに、2 番目の実装では重みを共有するため、2 つのモデルのパラメーターの数は同じになり、同じ方法で更新されます。

単純に考えると、これら 2 つのモデルのトレーニング中の勾配は同じになるはずです。これも事実です。しかし、バッチ正規化を追加すると状況は変わります。最初の実装では、同じデータ バッチに MNIST データと SVHN データの両方が含まれています。 2 番目の方法では、モデルは 2 つのバッチでトレーニングされます。1 つのバッチは MNIST データのみでトレーニングされ、もう 1 つのバッチは SVHN データのみでトレーニングされます。

この問題の原因は、トレーニング中に 2 つのネットワークがパラメータを共有する一方で、データ セットの平均と分散の移動平均も共有されるためです。このパラメータの更新は、両方のデータセットにも適用されます。 2 番目のアプローチでは、上部のネットワークは MNIST データからの平均と分散の推定値を使用してトレーニングされ、下部のネットワークは SVHN データからの平均と分散の推定値を使用してトレーニングされます。しかし、移動平均は 2 つのネットワーク間で共有されるため、移動平均は MNIST データと SVHN データの平均に収束します。

したがって、テスト時に、テスト セットで使用されるバッチ正規化のスケールと変換 (1 つのデータセットの平均) は、モデルが期待するもの (両方のデータセットの平均) とは異なります。テスト用の正規化がトレーニング用の正規化と異なる場合、モデルは次の結果を取得します。

このグラフは、5 つのランダム シードを使用した別の類似データセット (この例では MNIST または SVHN ではありません) での最高、中央値、最低のモデル パフォーマンスを示しています。重みを共有する 2 つのネットワークを使用すると、パフォーマンスが大幅に低下するだけでなく、出力の分散も増加します。

この問題は、単一のデータ ミニバッチがデータ分布全体を代表していない場合に発生します。つまり、入力をシャッフルすることを忘れずにバッチ正規化を使用するのは危険です。これは、最近人気の高い敵対的生成ネットワーク (GAN) でも非常に重要です。 GAN の識別器は通常、偽のデータと実際のデータの混合でトレーニングされます。識別器でバッチ正規化が使用されている場合、純粋に偽のデータのバッチと純粋に実際のデータのバッチを交互に使用するのは誤りです。各小バッチには、両方が均等に混合されている必要があります (それぞれ 50%)。

実際には、バッチ正規化変数を分離し、他の変数を共有するネットワーク構造を使用すると、最良の結果が得られることに注意してください。これは実装が複雑ですが、他の方法よりも確かに効果的です (下の図を参照)。

バッチ正規化:諸悪の根源

上記の問題を考慮して、著者は、可能であればバッチ正規化を使用しないという結論に達しました。

この結論はエンジニアリングの観点から分析されます。

一般的に、コードに問題がある場合、その理由は次の 2 つに限ります。

  1. 明らかに間違いです。たとえば、間違った変数名を入力したり、関数を呼び出すのを忘れたりした可能性があります。
  2. コードには、相互作用する他のコードの動作に対する暗黙の依存関係があり、それらの依存関係の一部が満たされていません。これらのエラーは、コードがどの条件に依存しているかを把握するのに通常長い時間がかかるため、より有害になることがよくあります。

これら両方の間違いは避けられません。 2 番目のタイプのエラーは、より単純な方法を使用し、既存のコードを再利用することで軽減できます。

バッチ正規化方法には、次の 2 つの基本的なプロパティがあります。

  1. トレーニング中、単一の入力 xi の出力はミニバッチ内の他の xj の影響を受けます。
  2. テスト時に、モデルの計算パスが変更されます。正規化にはミニバッチ平均ではなく移動平均が使用されるようになったためです。

これらの特性を持つ最適化方法は他にほとんどありません。これにより、バッチ正規化コードを実装する人は、入力ミニバッチが無相関であるか、トレーニング操作とテスト操作が同じであると想定しやすくなります。このアプローチに疑問を抱く人は誰もいないだろう。

もちろん、バッチ正規化は Java 正規化のブラック ボックス バージョンと考えることができますが、これは非常にうまく機能します。しかし、実際には抽象化には常に漏れがあり、バッチ正規化も例外ではなく、その特性により漏れがさらに生じやすくなります。

なぜ人々はバッチ正規化をあきらめないのでしょうか?

コンピュータ サイエンス コミュニティには、ダイクストラの「GoTo ステートメントは有害である」という有名な記事があります。この中で、ダイクストラは、goto 文はコードを読みにくくするので避けるべきであり、goto を使用するプログラムは goto 文なしで書き直すことができると主張しています。

著者は「バッチ正規化は有害である」という見解を述べたいと思っていますが、十分な理由が見つかりません。結局のところ、バッチ正規化は非常に便利です。

はい、バッチ正規化には問題があります。しかし、すべてを正しく行えば、モデルのトレーニングははるかに速くなります。バッチ正規化の論文が 1400 回以上引用されているのには、十分な理由があります。

バッチ正規化には多くの代替手段がありますが、それらにも独自の欠点があります。レイヤー正規化は、RNN で使用するとより効果的ですが、畳み込みレイヤーで使用すると問題が発生することがあります。重み正規化とコサイン正規化はどちらも比較的新しい正規化方法です。重み正規化の記事では、バッチ正規化が機能しないいくつかの問題に重み正規化を適用できると述べられています。しかし、これらの方法は今のところあまり使われておらず、おそらく時間の問題でしょう。レイヤー正規化、重み正規化、コサイン正規化はすべて、上記のバッチ正規化の問題に対処します。新しい問題に取り組んでいてリスクを負いたい場合には、これらの正規化方法を試してみることをお勧めします。結局、どの方法を使用する場合でも、ハイパーパラメータの調整が必要になります。一度調整すると、さまざまな方法間の違いは小さくなるはずです。

(勇気があれば、バッチ再正規化を試すこともできますが、テスト時には移動平均のみが使用されます。)

バッチ正規化の使用は、ディープラーニングにおける「悪魔の契約」と見なすことができます。得られるものは効率的なトレーニングですが、失うものは異常な結果(狂気)の可能性です。全員がこの契約書に署名します。

翻訳者メモ

「バッチ正規化は有害である」および「バッチ正規化の使用はできるだけ避ける」という著者の見解は、やや極端です。しかし、この記事で言及されているバッチ正規化の罠には注意する必要があります。バッチ正規化の有効性のため、多くのディープラーニング研究者はそれを「魔法のブラックボックス」として扱い、あらゆる可能な場所に適用しています。この単純で大雑把な方法は、トレーニング速度の向上に非常に効果的だからです。しかし、精度の低下をバッチ正規化に帰することは困難です。結局のところ、バッチ正規化によってトレーニングの精度が低下するとは誰も言及していません。

しかし、トレーニング中とテスト中にデータ セットが矛盾する状況は、実際には非常に一般的です。この問題は、翻訳者がトレーニング データ セットを人工的にシミュレートするときに発生します。バッチ正規化を使用する前に、次の問題を慎重に検討することをお勧めします。

  1. トレーニング データセットの各バッチのサンプルは平均化されていますか?
  2. トレーニング データセットのバッチ平均は、テスト中の移動平均と一致していますか?

それ以外の場合は、この記事で説明されている問題を回避するために、次の方法の 1 つ以上を使用する必要があります。

  1. バッチ平均化を確実にするためにトレーニング データセットをランダムにサンプリングします。
  2. 上記の問題を回避するには、記事の例のようにモデルを変更します。
  3. バッチ正規化の代わりにレイヤー正規化、重み正規化、またはコサイン正規化を使用します。
  4. 正規化方法は使用されません。

<<:  CNN の弱点を見つけ、MNIST の「ルーチン」に注意する

>>:  時空間アルゴリズム研究に基づくビジネス意思決定分析

ブログ    
ブログ    
ブログ    

推薦する

AIの負担を軽減する時が来た。Python AIライブラリ5選のおすすめ

機械学習は興味深いものですが、作業範囲が広く複雑で困難です。開発者として学ぶべきツールはたくさんあり...

...

Java プログラミング スキル - データ構造とアルゴリズム「スレッド バイナリ ツリー」

[[388829]]まず質問を見てみましょうシーケンス{1,3,6,8,10,14}を二分木に構築...

Nature: 光コンピューティングと AI 推論を統合して高速かつ高帯域幅の AI コンピューティングを実現

電子コンピューティングと比較すると、光コンピューティングは高速、高帯域幅、低消費電力という利点があり...

本当に知っておくべき 10 の AI テクノロジートレンド

人工知能技術のトレンドは人類を前進させています。デジタル変革はあらゆる業界に広がり、人工知能は科学者...

安定性、効率性、俊敏性:適応型AIの利点

人工知能にはさまざまなものがあります。コンピューターを使って知的なことを行うこともあれば、コンピュー...

...

GPT-5 は 50,000 個の H100 で停止しています。アルトマンは、NVIDIAに代わるAIチップ帝国を築くために、緊急に数十億ドルを調達している。

サム・アルトマンは半導体ファウンドリの世界的なネットワークを構築するために数十億ドルを調達しています...

ウェブページを出力できるAIアプリが登場、早速評価してみよう

みなさんこんにちは、カソンです。最近、ウェブページ作成ツールframer[1]は、プロンプトワードに...

20 分で回路基板の組み立て方を学びましょう!オープンソースのSERLフレームワークは、精密制御において100%の成功率を誇り、人間の3倍の速さです。

近年、四足歩行、把持、器用な操作など、ロボットの強化学習技術の分野では大きな進歩が遂げられていますが...

Google AI はすべてを食べています!すべての公開コンテンツはAIトレーニングのためにクロールされ、プライバシーポリシーが更新されました

今後、インターネット上で公に話すすべての言葉が、Google によって AI のトレーニングに使用さ...

OpenAIはGPT-4が怠惰になったことを認める:当面修正することはできない

OpenAI は、ますます深刻化する GPT-4 の遅延問題に正式に対応しました。私は今でもChat...

製造および自動化アプリケーション向けの人工知能技術の選び方

人工知能 (AI) の定義は、産業オートメーションにおける生産と、研究室外の日常生活では大きく異なり...

Hive でサポートされているファイル形式と圧縮アルゴリズム

[[194194]]概要正しいファイル タイプと圧縮タイプ (Textfile+Gzip、Seque...