RTX 4090が制限されている時代に、大規模モデルにRLHFを使用するより効率的な方法が登場

RTX 4090が制限されている時代に、大規模モデルにRLHFを使用するより効率的な方法が登場

  • 論文リンク: https://arxiv.org/abs/2310.10505
  • 著者: Li Ziniu、Xu Tian、​​Zhang Yushun、Yu Yang、Sun Ruoyu、Luo Zhiquan
  • 機関: 香港中文大学、深圳、深圳ビッグデータ研究所、南京大学、南京仙学院
  • オープンソースコード: https://github.com/liziniu/ReMax

特に記載がない限り、すべての画像は新聞からのものです。

背景

今年は、ChatGPT が主導する大規模言語モデル (LLM) があらゆる面で注目を集め、学術界やビジネス界で GPU などのコンピューティング リソースの需要が急増しました。

左の写真はDALL・E3、右の写真はDALL・E3

たとえば、Llama2-7B モデルの教師あり微調整 (SFT) には 80 GB を超えるメモリが必要です。しかし、多くの場合、これだけでは十分ではありません。人間と一致するためには、大規模な言語モデルも RLHF (人間からのフィードバックによる強化学習) でトレーニングする必要があります。 RLHF の GPU 消費量は SFT の 2 倍以上になることが多く、トレーニング時間は 6 倍以上になることがあります。

最近、米国政府は、H100やH800などのNvidia GPU製品の中国市場への参入を制限すると発表しました。この規定は間違いなく、中国の大規模言語モデル(LLM)と人工知能の開発に大きな抵抗を加えることになるだろう。 RLHF のトレーニング コスト (GPU 消費量とトレーニング時間) を削減することは、LLM の開発にとって非常に重要です。

モチベーション

RLHF は次の 3 つの段階から構成されます。

1. 教師あり微調整 (SFT)

2. 比較データから報酬モデルを学習します。

3. 強化学習 (RL) アルゴリズムを使用して報酬を最大化します。

画像出典: InstructGPT 論文

RLHF の主な計算オーバーヘッドは、第 3 段階 (報酬の最大化) から発生することがわかります。 DeepSpeed-Chat レポートから、第 3 ステージのトレーニング時間は最初の 2 つのステージの合計時間の 4 倍以上であることがわかります。さらに、私たちの経験によれば、第 3 ステージの GPU 消費量は、最初の 2 ステージの 2 倍以上になります。

DeepSpeed-Chat テクニカルレポートからの画像

現在、RLHF フェーズ 3 の主な計算上のボトルネックは何ですか?

この段階での計算ボトルネックの主な原因は、現在使用されている RL アルゴリズム、つまり PPO アルゴリズムであることがわかりました。 PPO アルゴリズムは、普遍的な RL 問題を解決するための最も人気のあるアルゴリズムの 1 つであり、成功例も数多くあります。ここでは PPO の技術的な詳細は省略し、PPO の主要コンポーネントである価値モデルに焦点を当てます。価値モデルは、特定の戦略の期待される長期リターンを効果的に推定するためにトレーニングする必要があるニューラル ネットワークです。価値モデルは PPO に優れたパフォーマンスをもたらしますが、RLHF タスクに大きな計算オーバーヘッドも生じます。たとえば、人間の好みに合わせるために、PPO の価値モデルは通常 LLM とサイズが似ており、ストレージ要件が 2 倍になります。さらに、価値モデルをトレーニングするには、その勾配、アクティベーション、およびオプティマイザーの状態を保存する必要があり、これにより GPU ストレージ要件がさらに 4 倍近く増加します。要約すると、PPO とその価値モデル (およびそのトレーニング関連部分) は、RLHF の報酬最大化段階における主な計算上の障害となっています。

PPOと比較すると、ReMaxは軽量なアルゴリズムである。

アイデア

PPO よりも RLHF に適したアルゴリズムを見つけることは可能ですか?

私たちが出した答えは「はい」です。これは、PPO と価値モデルが、RLHF のような特定の問題ではなく、一般的な RL 問題向けに設計されているためです (RLHF は RL 問題のサブクラスにすぎません)。興味深いことに、RLHF には PPO では使用されていない 3 つの重要な構造があることがわかりました。

1. 高速シミュレーション: 軌跡 (つまり、LLM での応答全体) は、時間のオーバーヘッドをほとんどかけずに、非常に短時間 (1 秒未満) で実行できます。

2. 決定論的遷移: コンテキストは過去のトークンと現在生成されているトークンに決定論的に依存します。

3. 軌道レベルの報酬: 報酬モデルは、応答が完了した場合にのみ報酬値を提供します。

これら 3 つの観察から、RLHF 問題において価値モデルが「冗長」であることは容易にわかります。これは、価値モデル設計の本来の意図が、ランダム環境でのサンプル効率と、低速シミュレーション環境での計算効率を達成することにあるためです。ただし、RLHF ではこれは必要ありません。

ReMax は RLHF 用に設計されたアルゴリズムですが、PPO は一般的な RL 用に設計されたアルゴリズムです。

方法

リマックス

ReMax アルゴリズムは、古いポリシー勾配アルゴリズム REINFORCE に基づいています。REINFORCE で使用されるポリシー勾配推定器を次の図に示します。

勾配推定器の強化

REINFORCE は、最適化に応答報酬を直接使用し、一般的な RL アルゴリズムのように中間ステップの報酬と価値関数を知る必要がないため、計算レベルで RLHF タスクの 3 つの特性を活用できます。ただし、戦略のランダム性により、REINFORCE 勾配推定器には高分散の問題があり (Richard Sutton の RL 書籍で指摘されています)、モデルトレーニングの有効性に影響します。そのため、以下の 2 つの図に示すように、REINFORCE は RLHF タスクでパフォーマンスが低下します。

REINFORCEは計算コストは​​低いがパフォーマンスは低い


REINFORCEの(ランダムな)勾配はReMaxの勾配よりもはるかに大きい。

この問題を解決するために、ReMax は貪欲応答の報酬をベースライン値として使用して勾配推定器を構築します。具体的な式は次のとおりです。

ReMax勾配推定器

貪欲な応答に対する報酬は、期待される報酬の良い近似値として見ることができることに注意してください。理想的なケース ( ) では、ランダム変数 に対してとなるため、推定値の分散は小さくなることが期待できます

下の図はReMaxのアルゴリズムフローを示しており、赤いボックスはコアアルゴリズムの変更を示しています。

ReMaxアルゴリズムプロセス

理論上の保証

ReMax で使用される勾配推定量は、依然として真のポリシー勾配の不偏推定量であることを示します。

詳細な理論的紹介については論文を参照してください。

アルゴリズムの利点

  • ReMax のコアは 6 行のコードで実装できます。対照的に、PPO では、重要度サンプリング、一般化利点推定 (GAE)、価値モデル学習などの追加モジュールが導入されています。
  • ReMax にはハイパーパラメータがほとんどありません。対照的に、PPO には、重要度サンプリング クリッピング比、GAE 係数、価値モデル学習率、オフポリシー トレーニング エポックなどの追加のハイパーパラメータがあります。これらのハイパーパラメータの調整には多くの時間が必要です。
  • ReMax は理論的にはメモリを約 50% 節約できます。 PPO と比較すると、ReMax は価値モデルに関連するすべてのコンポーネントを正常に削除し、メモリのオーバーヘッドを大幅に削減します。計算により、ReMax は PPO と比較して約 50% のメモリを節約できることがわかりました。

効果

効果

  • ReMaxはPPOと同様に効果的に報酬を最大化できます

OPT-1.3Bでは、ReMaxは効果的に報酬を最大化することができます

OPT-1.3BではReMaxトレーニングは非常に安定しています

  • GPT-4評価(LIMAテスト問題)では、ReMaxによって得られた戦略はSFTやPPOよりも優れている。

GPT4スコアリングでは、ReMaxによって得られたモデルの方が優れていることが示されています。

効率

  • ReMax は GPU メモリを約 50% 節約できます。 ReMax は、価値モデルとそのトレーニング部分 (勾配、オプティマイザー、アクティベーション) を削除するため、GPU メモリ要件が大幅に削減されます。 Llama2-7B を考慮すると、PPO は 8xA100-40GB マシンでは実行できませんが、ReMax は実行できます。

Llama2-7Bでは、ReMaxはGPUメモリを約50%節約できる

  • ReMax はトレーニングを 2 倍高速化できます。各ラウンドで、ReMax は 2 世代と 1 回のバックプロパゲーションを呼び出しますが、PPO は 1 世代と 2 回のバックプロパゲーションを使用します。大規模なモデルの場合、生成時間はバックプロパゲーション時間よりも短くなるため、ReMax は理論的にはトレーニングの約 2 倍の高速化を実現できます。

汎用性

RLHF タスクに加えて、RL アルゴリズムとしての ReMax は、従来の NLP タスクにも適用できます。この論文では、報酬モデルが比較データから学習されない GPT-2 上の映画レビュー継続タスクを検討します。実験的観察によると、ReMax は 2.2 倍のトレーニング高速化と 60% の GPU メモリ節約を実現できます。

従来の NLP タスク (テキスト継続) では、ReMax は PPO と比較して 2.2 倍の高速化を達成しました。

要約する

最後に、私たちの実験から得た PPO に対する ReMax の主な利点を簡単にまとめます。

  • よりシンプルな実装: ReMax のコアは 6 行のコードで実装できます。これは、PPO の多くの複雑なコード構成要素とはまったく対照的です。
  • メモリ オーバーヘッドの削減: 価値モデルとそのトレーニング コンポーネント全体が削除されたため、ReMax は PPO と比較して GPU メモリを約 50% 節約します。
  • ハイパーパラメータの削減: ReMax は、GAE 係数、価値モデルの学習率、重要度サンプリング エポック、ミニバッチ サイズなど、価値モデルのトレーニングに関連するすべてのハイパーパラメータを正常に削除します。これらのハイパーパラメータは、多くの場合、問題に敏感であり、調整が困難です。 ReMax は RLHF 研究者にとってより親しみやすいものであると考えています。
  • より高速なトレーニング速度: GPT2 (137M) の実験では、実際の実行時間に関して、ReMax は PPO と比較して 2.2 倍高速であることが確認されました。高速化は、各反復における ReMax の計算オーバーヘッドが低いことから生まれます。私たちの計算によると、この高速化の利点は、より大きなモデルでも維持されます (PPO が十分に大きなメモリに正常に展開できると仮定)。
  • 優れたパフォーマンス: 上記のように、ReMax は中規模の実験で PPO と同等のパフォーマンスを達成し、場合によっては PPO を上回るパフォーマンスを発揮します (おそらく、ReMax に適切なハイパーパラメータを見つけるのが簡単なためです)。この優れたパフォーマンスは、より大きなモデルにも拡張できると推測されます。

<<: 

>>:  OpenAIがついにオープン:DALL-E 3の論文が発表され、ChatGPTが開始、著者の半数が中国人

ブログ    
ブログ    
ブログ    
ブログ    

推薦する

...

...

...

【ディープラーニング連載】畳み込みニューラルネットワークの徹底解説(第2回)~畳み込みニューラルネットワークを手書きで書いてみる~

前回の記事では、畳み込みニューラルネットワークの基本原理について、いくつかの基本層の定義、動作ルール...

中国チームがボストン・ダイナミクスに対抗する四足歩行ロボットを発表

本日、Yushu Technology は、中国で正式に一般に公開される初の四足歩行ロボットとなる四...

...

不妊治療の新たな夜明け:AI

世界初の試験管ベビーは1978年に英国で誕生した。それ以来、人工生殖技術は継続的に改良されてきました...

今後10年間で、人工知能とロボットは雇用に7つの影響を与える

[[202532]]編集者注: この記事はNetEase Intelligenceからのもので、著者...

マスク氏、さらに 4 人の「民間」宇宙飛行士を宇宙に送り出す!スペースXは12回の有人ミッションを成功させた

北京時間の今朝早く、SpaceXは再び人類を宇宙に送り出すことに成功した。これは、米国の民間航空宇宙...

2024 年のコンテナ技術予測: パフォーマンス、AI、セキュリティの採用

パフォーマンス重視のコンテナ技術向けのツールとサービスを提供する Sylabs は、2024 年まで...

fBox アルゴリズムを使用して、高度に隠蔽された詐欺ユーザーを検出する方法

[51CTO.com クイック翻訳] インターネットの活発な発展とインターネットユーザーの継続的な増...

科学者は人工知能を使って新素材を発見する

米国の科学者チームは、人工知能を利用して非常に短期間で新たな鉄鋼の代替品を発見したいと考えている。そ...

2018 年の人工知能の商業化に関する 5 つの洞察

[[252389]]人工知能囲碁プログラム「AlphaGo」が囲碁の世界チャンピオンを破って以来、人...

C# 暗号化におけるハッシュ アルゴリズムの適用に関する簡単な分析

ハッシュ アルゴリズムは C# 暗号化でよく使用される方法ですが、ハッシュ アルゴリズムとは何でしょ...

Googleの「AIが写真を推測」アプリがWeChat Momentsで人気:ユーザーの参加でよりスマートに

Google 初の WeChat ミニプログラム「絵を当てよう」アプリは、リリースから 1 日で、一...