ディープラーニングにおけるバッチ正規化の落とし穴

ディープラーニングにおけるバッチ正規化の落とし穴

[[191744]]

バッチ正規化は、ディープラーニングにおいて最近登場した効果的な手法です。その有効性は広く実証されており、研究やアプリケーションに急速に応用されています。この投稿は、読者がバッチ正規化とは何かを知っており、その仕組みについてある程度理解していることを前提としています。この概念を初めて知ったり、復習する必要がある場合は、次のリンク (http://blog.csdn.net/malefactor/article/details/51476961) でバッチ正規化の簡単な概要を参照してください。

この論文では、2 つの異なる方法を使用してニューラル ネットワークを実装します。各ステップで同じデータが入力されます。ネットワークには、まったく同じ損失関数、まったく同じハイパーパラメータ、まったく同じオプティマイザーがあります。その後、まったく同じ数の GPU でトレーニングが実行されます。結果として、一方のバージョンの分類精度は他方のバージョンよりも 2% 低く、このパフォーマンスの低下は非常に安定しているように見えます。

単純な MNIST と SVHN の分類問題を例に挙げてみましょう。

最初の実装では、MNIST データのバッチと SVHN データのバッチが抽出され、結合されてからネットワークに送られます。

2 番目の実装では、ネットワークのコピーが 2 つ作成され、重みが共有されます。 1 つのコピーには MNIST データが入力され、もう 1 つのコピーには SVHN データが入力されます。

どちらの実装でも、データの半分は MNIST で、残りの半分は SVHN であることに注意してください。さらに、2 番目の実装では重みを共有するため、2 つのモデルのパラメーターの数は同じになり、同じ方法で更新されます。

単純に考えると、これら 2 つのモデルのトレーニング中の勾配は同じになるはずです。これも事実です。しかし、バッチ正規化を追加すると状況は変わります。最初の実装では、同じデータ バッチに MNIST データと SVHN データの両方が含まれています。 2 番目の方法では、モデルは 2 つのバッチでトレーニングされます。1 つのバッチは MNIST データのみでトレーニングされ、もう 1 つのバッチは SVHN データのみでトレーニングされます。

この問題の原因は、トレーニング中に 2 つのネットワークがパラメータを共有する一方で、データ セットの平均と分散の移動平均も共有されるためです。このパラメータの更新は、両方のデータセットにも適用されます。 2 番目のアプローチでは、上部のネットワークは MNIST データからの平均と分散の推定値を使用してトレーニングされ、下部のネットワークは SVHN データからの平均と分散の推定値を使用してトレーニングされます。しかし、移動平均は 2 つのネットワーク間で共有されるため、移動平均は MNIST データと SVHN データの平均に収束します。

したがって、テスト時に、テスト セットで使用されるバッチ正規化のスケールと変換 (1 つのデータセットの平均) は、モデルが期待するもの (両方のデータセットの平均) とは異なります。テスト用の正規化がトレーニング用の正規化と異なる場合、モデルは次の結果を取得します。

このグラフは、5 つのランダム シードを使用した別の類似データセット (この例では MNIST または SVHN ではありません) での最高、中央値、最低のモデル パフォーマンスを示しています。重みを共有する 2 つのネットワークを使用すると、パフォーマンスが大幅に低下するだけでなく、出力の分散も増加します。

この問題は、単一のデータ ミニバッチがデータ分布全体を代表していない場合に発生します。つまり、入力をシャッフルすることを忘れずにバッチ正規化を使用するのは危険です。これは、最近人気の高い敵対的生成ネットワーク (GAN) でも非常に重要です。 GAN の識別器は通常、偽のデータと実際のデータの混合でトレーニングされます。識別器でバッチ正規化が使用されている場合、純粋に偽のデータのバッチと純粋に実際のデータのバッチを交互に使用するのは誤りです。各小バッチには、両方が均等に混合されている必要があります (それぞれ 50%)。

実際には、バッチ正規化変数を分離し、他の変数を共有するネットワーク構造を使用すると、最良の結果が得られることに注意してください。これは実装が複雑ですが、他の方法よりも確かに効果的です (下の図を参照)。

バッチ正規化:諸悪の根源

上記の問題を考慮して、著者は、可能であればバッチ正規化を使用しないという結論に達しました。

この結論はエンジニアリングの観点から分析されます。

一般的に、コードに問題がある場合、その理由は次の 2 つに限ります。

  1. 明らかに間違いです。たとえば、間違った変数名を入力したり、関数を呼び出すのを忘れたりした可能性があります。
  2. コードには、相互作用する他のコードの動作に対する暗黙の依存関係があり、それらの依存関係の一部が満たされていません。これらのエラーは、コードがどの条件に依存しているかを把握するのに通常長い時間がかかるため、より有害になることがよくあります。

これら両方の間違いは避けられません。 2 番目のタイプのエラーは、より単純な方法を使用し、既存のコードを再利用することで軽減できます。

バッチ正規化方法には、次の 2 つの基本的なプロパティがあります。

  1. トレーニング中、単一の入力 xi の出力はミニバッチ内の他の xj の影響を受けます。
  2. テスト時に、モデルの計算パスが変更されます。正規化にはミニバッチ平均ではなく移動平均が使用されるようになったためです。

これらの特性を持つ最適化方法は他にほとんどありません。これにより、バッチ正規化コードを実装する人は、入力ミニバッチが無相関であるか、トレーニング操作とテスト操作が同じであると想定しやすくなります。このアプローチに疑問を抱く人は誰もいないだろう。

もちろん、バッチ正規化は Java 正規化のブラック ボックス バージョンと考えることができますが、これは非常にうまく機能します。しかし、実際には抽象化には常に漏れがあり、バッチ正規化も例外ではなく、その特性により漏れがさらに生じやすくなります。

なぜ人々はバッチ正規化をあきらめないのでしょうか?

コンピュータ サイエンス コミュニティには、ダイクストラの「GoTo ステートメントは有害である」という有名な記事があります。この中で、ダイクストラは、goto 文はコードを読みにくくするので避けるべきであり、goto を使用するプログラムは goto 文なしで書き直すことができると主張しています。

著者は「バッチ正規化は有害である」という見解を述べたいと思っていますが、十分な理由が見つかりません。結局のところ、バッチ正規化は非常に便利です。

はい、バッチ正規化には問題があります。しかし、すべてを正しく行えば、モデルのトレーニングははるかに速くなります。バッチ正規化の論文が 1400 回以上引用されているのには、十分な理由があります。

バッチ正規化には多くの代替手段がありますが、それらにも独自の欠点があります。レイヤー正規化は、RNN で使用するとより効果的ですが、畳み込みレイヤーで使用すると問題が発生することがあります。重み正規化とコサイン正規化はどちらも比較的新しい正規化方法です。重み正規化の記事では、バッチ正規化が機能しないいくつかの問題に重み正規化を適用できると述べられています。しかし、これらの方法は今のところあまり使われておらず、おそらく時間の問題でしょう。レイヤー正規化、重み正規化、コサイン正規化はすべて、上記のバッチ正規化の問題に対処します。新しい問題に取り組んでいてリスクを負いたい場合には、これらの正規化方法を試してみることをお勧めします。結局、どの方法を使用する場合でも、ハイパーパラメータの調整が必要になります。一度調整すると、さまざまな方法間の違いは小さくなるはずです。

(勇気があれば、バッチ再正規化を試すこともできますが、テスト時には移動平均のみが使用されます。)

バッチ正規化の使用は、ディープラーニングにおける「悪魔の契約」と見なすことができます。得られるものは効率的なトレーニングですが、失うものは異常な結果(狂気)の可能性です。全員がこの契約書に署名します。

翻訳者メモ

「バッチ正規化は有害である」および「バッチ正規化の使用はできるだけ避ける」という著者の見解は、やや極端です。しかし、この記事で言及されているバッチ正規化の罠には注意する必要があります。バッチ正規化の有効性のため、多くのディープラーニング研究者はそれを「魔法のブラックボックス」として扱い、あらゆる可能な場所に適用しています。この単純で大雑把な方法は、トレーニング速度の向上に非常に効果的だからです。しかし、精度の低下をバッチ正規化に帰することは困難です。結局のところ、バッチ正規化によってトレーニングの精度が低下するとは誰も言及していません。

しかし、トレーニング中とテスト中にデータ セットが矛盾する状況は、実際には非常に一般的です。この問題は、翻訳者がトレーニング データ セットを人工的にシミュレートするときに発生します。バッチ正規化を使用する前に、次の問題を慎重に検討することをお勧めします。

  1. トレーニング データセットの各バッチのサンプルは平均化されていますか?
  2. トレーニング データセットのバッチ平均は、テスト中の移動平均と一致していますか?

それ以外の場合は、この記事で説明されている問題を回避するために、次の方法の 1 つ以上を使用する必要があります。

  1. バッチ平均化を確実にするためにトレーニング データセットをランダムにサンプリングします。
  2. 上記の問題を回避するには、記事の例のようにモデルを変更します。
  3. バッチ正規化の代わりにレイヤー正規化、重み正規化、またはコサイン正規化を使用します。
  4. 正規化方法は使用されません。

<<:  CNN の弱点を見つけ、MNIST の「ルーチン」に注意する

>>:  時空間アルゴリズム研究に基づくビジネス意思決定分析

ブログ    
ブログ    
ブログ    

推薦する

70億のオープンソース数学モデルがGPT-4に勝利、中国チーム

7B オープンソースモデル、その数学的能力は数千億規模の GPT-4 を超えます。その性能はオープン...

...

...

機械学習トランスフォーマーアーキテクチャの謎を解く

翻訳者|朱 仙中レビュー | Chonglou Transformers は 2017 年の発売以来...

Appleの会話型AI予算は1日あたり数百万ドルに拡大

海外メディアは9月7日、事情に詳しい関係者の話として、アップルが人工知能の構築に必要なコンピューティ...

2021 年のトップ 10 機械学習ライブラリ

今は人工知能爆発の時代です。AIと機械学習は広く普及しています。もちろん、機械学習の分野で最も人気の...

OpenAI の「地震」の中心人物である Ilya を見てみましょう。彼は AI についてどう考えているのでしょうか?

OpenAIのCEOサム・アルトマン氏は先週金曜日に解雇され、もはや同社を率いていない。投資家たち...

5 年以内に、8,000 万の仕事が機械に置き換えられるでしょう。インダストリアル インターネットは治療薬でしょうか、それとも毒でしょうか?

時代の発展は常に要求と矛盾の中で発展しています。あらゆる産業革命は発展の力をもたらすだけでなく、大き...

Google BardとChatGPT、どちらを選ぶべきでしょうか?

こんにちは、ルガです。今日は、人工知能 (AI) エコシステムに関連するテクノロジーについて、Goo...

Google がオールラウンドな音楽転写 AI を発表: 曲を一度聴くだけでピアノとバイオリンの楽譜がすべて手に入る

この記事はAI新メディアQuantum Bit(公開アカウントID:QbitAI)より許可を得て転載...

...

ソラは人間の代わりにはなれない!アマゾンのエンジニアは主張:実際の仕事上の対立はAIでは解決できない

今週、OpenAIのビデオAIツール「Sora」が登場するや否や大きな話題を呼んだ。 「仕事を失う」...

Yann LeCun 氏は衝撃的な発言をしました。「ディープラーニングは死んだ、微分可能プログラミング万歳!」

ディープラーニングの分野で最も有名な学者の一人であるヤン・ルカン氏が本日、自身のFacebookに投...

顔認識と指紋認識のどちらがより定量化しやすいでしょうか?

顔認証と指紋認証は、携帯電話のロックを解除する主な 2 つの方法です。私たちは、日常の仕事でも公共の...