RTX 4090が制限されている時代に、大規模モデルにRLHFを使用するより効率的な方法が登場

RTX 4090が制限されている時代に、大規模モデルにRLHFを使用するより効率的な方法が登場

  • 論文リンク: https://arxiv.org/abs/2310.10505
  • 著者: Li Ziniu、Xu Tian、​​Zhang Yushun、Yu Yang、Sun Ruoyu、Luo Zhiquan
  • 機関: 香港中文大学、深圳、深圳ビッグデータ研究所、南京大学、南京仙学院
  • オープンソースコード: https://github.com/liziniu/ReMax

特に記載がない限り、すべての画像は新聞からのものです。

背景

今年は、ChatGPT が主導する大規模言語モデル (LLM) があらゆる面で注目を集め、学術界やビジネス界で GPU などのコンピューティング リソースの需要が急増しました。

左の写真はDALL・E3、右の写真はDALL・E3

たとえば、Llama2-7B モデルの教師あり微調整 (SFT) には 80 GB を超えるメモリが必要です。しかし、多くの場合、これだけでは十分ではありません。人間と一致するためには、大規模な言語モデルも RLHF (人間からのフィードバックによる強化学習) でトレーニングする必要があります。 RLHF の GPU 消費量は SFT の 2 倍以上になることが多く、トレーニング時間は 6 倍以上になることがあります。

最近、米国政府は、H100やH800などのNvidia GPU製品の中国市場への参入を制限すると発表しました。この規定は間違いなく、中国の大規模言語モデル(LLM)と人工知能の開発に大きな抵抗を加えることになるだろう。 RLHF のトレーニング コスト (GPU 消費量とトレーニング時間) を削減することは、LLM の開発にとって非常に重要です。

モチベーション

RLHF は次の 3 つの段階から構成されます。

1. 教師あり微調整 (SFT)

2. 比較データから報酬モデルを学習します。

3. 強化学習 (RL) アルゴリズムを使用して報酬を最大化します。

画像出典: InstructGPT 論文

RLHF の主な計算オーバーヘッドは、第 3 段階 (報酬の最大化) から発生することがわかります。 DeepSpeed-Chat レポートから、第 3 ステージのトレーニング時間は最初の 2 つのステージの合計時間の 4 倍以上であることがわかります。さらに、私たちの経験によれば、第 3 ステージの GPU 消費量は、最初の 2 ステージの 2 倍以上になります。

DeepSpeed-Chat テクニカルレポートからの画像

現在、RLHF フェーズ 3 の主な計算上のボトルネックは何ですか?

この段階での計算ボトルネックの主な原因は、現在使用されている RL アルゴリズム、つまり PPO アルゴリズムであることがわかりました。 PPO アルゴリズムは、普遍的な RL 問題を解決するための最も人気のあるアルゴリズムの 1 つであり、成功例も数多くあります。ここでは PPO の技術的な詳細は省略し、PPO の主要コンポーネントである価値モデルに焦点を当てます。価値モデルは、特定の戦略の期待される長期リターンを効果的に推定するためにトレーニングする必要があるニューラル ネットワークです。価値モデルは PPO に優れたパフォーマンスをもたらしますが、RLHF タスクに大きな計算オーバーヘッドも生じます。たとえば、人間の好みに合わせるために、PPO の価値モデルは通常 LLM とサイズが似ており、ストレージ要件が 2 倍になります。さらに、価値モデルをトレーニングするには、その勾配、アクティベーション、およびオプティマイザーの状態を保存する必要があり、これにより GPU ストレージ要件がさらに 4 倍近く増加します。要約すると、PPO とその価値モデル (およびそのトレーニング関連部分) は、RLHF の報酬最大化段階における主な計算上の障害となっています。

PPOと比較すると、ReMaxは軽量なアルゴリズムである。

アイデア

PPO よりも RLHF に適したアルゴリズムを見つけることは可能ですか?

私たちが出した答えは「はい」です。これは、PPO と価値モデルが、RLHF のような特定の問題ではなく、一般的な RL 問題向けに設計されているためです (RLHF は RL 問題のサブクラスにすぎません)。興味深いことに、RLHF には PPO では使用されていない 3 つの重要な構造があることがわかりました。

1. 高速シミュレーション: 軌跡 (つまり、LLM での応答全体) は、時間のオーバーヘッドをほとんどかけずに、非常に短時間 (1 秒未満) で実行できます。

2. 決定論的遷移: コンテキストは過去のトークンと現在生成されているトークンに決定論的に依存します。

3. 軌道レベルの報酬: 報酬モデルは、応答が完了した場合にのみ報酬値を提供します。

これら 3 つの観察から、RLHF 問題において価値モデルが「冗長」であることは容易にわかります。これは、価値モデル設計の本来の意図が、ランダム環境でのサンプル効率と、低速シミュレーション環境での計算効率を達成することにあるためです。ただし、RLHF ではこれは必要ありません。

ReMax は RLHF 用に設計されたアルゴリズムですが、PPO は一般的な RL 用に設計されたアルゴリズムです。

方法

リマックス

ReMax アルゴリズムは、古いポリシー勾配アルゴリズム REINFORCE に基づいています。REINFORCE で使用されるポリシー勾配推定器を次の図に示します。

勾配推定器の強化

REINFORCE は、最適化に応答報酬を直接使用し、一般的な RL アルゴリズムのように中間ステップの報酬と価値関数を知る必要がないため、計算レベルで RLHF タスクの 3 つの特性を活用できます。ただし、戦略のランダム性により、REINFORCE 勾配推定器には高分散の問題があり (Richard Sutton の RL 書籍で指摘されています)、モデルトレーニングの有効性に影響します。そのため、以下の 2 つの図に示すように、REINFORCE は RLHF タスクでパフォーマンスが低下します。

REINFORCEは計算コストは​​低いがパフォーマンスは低い


REINFORCEの(ランダムな)勾配はReMaxの勾配よりもはるかに大きい。

この問題を解決するために、ReMax は貪欲応答の報酬をベースライン値として使用して勾配推定器を構築します。具体的な式は次のとおりです。

ReMax勾配推定器

貪欲な応答に対する報酬は、期待される報酬の良い近似値として見ることができることに注意してください。理想的なケース ( ) では、ランダム変数 に対してとなるため、推定値の分散は小さくなることが期待できます

下の図はReMaxのアルゴリズムフローを示しており、赤いボックスはコアアルゴリズムの変更を示しています。

ReMaxアルゴリズムプロセス

理論上の保証

ReMax で使用される勾配推定量は、依然として真のポリシー勾配の不偏推定量であることを示します。

詳細な理論的紹介については論文を参照してください。

アルゴリズムの利点

  • ReMax のコアは 6 行のコードで実装できます。対照的に、PPO では、重要度サンプリング、一般化利点推定 (GAE)、価値モデル学習などの追加モジュールが導入されています。
  • ReMax にはハイパーパラメータがほとんどありません。対照的に、PPO には、重要度サンプリング クリッピング比、GAE 係数、価値モデル学習率、オフポリシー トレーニング エポックなどの追加のハイパーパラメータがあります。これらのハイパーパラメータの調整には多くの時間が必要です。
  • ReMax は理論的にはメモリを約 50% 節約できます。 PPO と比較すると、ReMax は価値モデルに関連するすべてのコンポーネントを正常に削除し、メモリのオーバーヘッドを大幅に削減します。計算により、ReMax は PPO と比較して約 50% のメモリを節約できることがわかりました。

効果

効果

  • ReMaxはPPOと同様に効果的に報酬を最大化できます

OPT-1.3Bでは、ReMaxは効果的に報酬を最大化することができます

OPT-1.3BではReMaxトレーニングは非常に安定しています

  • GPT-4評価(LIMAテスト問題)では、ReMaxによって得られた戦略はSFTやPPOよりも優れている。

GPT4スコアリングでは、ReMaxによって得られたモデルの方が優れていることが示されています。

効率

  • ReMax は GPU メモリを約 50% 節約できます。 ReMax は、価値モデルとそのトレーニング部分 (勾配、オプティマイザー、アクティベーション) を削除するため、GPU メモリ要件が大幅に削減されます。 Llama2-7B を考慮すると、PPO は 8xA100-40GB マシンでは実行できませんが、ReMax は実行できます。

Llama2-7Bでは、ReMaxはGPUメモリを約50%節約できる

  • ReMax はトレーニングを 2 倍高速化できます。各ラウンドで、ReMax は 2 世代と 1 回のバックプロパゲーションを呼び出しますが、PPO は 1 世代と 2 回のバックプロパゲーションを使用します。大規模なモデルの場合、生成時間はバックプロパゲーション時間よりも短くなるため、ReMax は理論的にはトレーニングの約 2 倍の高速化を実現できます。

汎用性

RLHF タスクに加えて、RL アルゴリズムとしての ReMax は、従来の NLP タスクにも適用できます。この論文では、報酬モデルが比較データから学習されない GPT-2 上の映画レビュー継続タスクを検討します。実験的観察によると、ReMax は 2.2 倍のトレーニング高速化と 60% の GPU メモリ節約を実現できます。

従来の NLP タスク (テキスト継続) では、ReMax は PPO と比較して 2.2 倍の高速化を達成しました。

要約する

最後に、私たちの実験から得た PPO に対する ReMax の主な利点を簡単にまとめます。

  • よりシンプルな実装: ReMax のコアは 6 行のコードで実装できます。これは、PPO の多くの複雑なコード構成要素とはまったく対照的です。
  • メモリ オーバーヘッドの削減: 価値モデルとそのトレーニング コンポーネント全体が削除されたため、ReMax は PPO と比較して GPU メモリを約 50% 節約します。
  • ハイパーパラメータの削減: ReMax は、GAE 係数、価値モデルの学習率、重要度サンプリング エポック、ミニバッチ サイズなど、価値モデルのトレーニングに関連するすべてのハイパーパラメータを正常に削除します。これらのハイパーパラメータは、多くの場合、問題に敏感であり、調整が困難です。 ReMax は RLHF 研究者にとってより親しみやすいものであると考えています。
  • より高速なトレーニング速度: GPT2 (137M) の実験では、実際の実行時間に関して、ReMax は PPO と比較して 2.2 倍高速であることが確認されました。高速化は、各反復における ReMax の計算オーバーヘッドが低いことから生まれます。私たちの計算によると、この高速化の利点は、より大きなモデルでも維持されます (PPO が十分に大きなメモリに正常に展開できると仮定)。
  • 優れたパフォーマンス: 上記のように、ReMax は中規模の実験で PPO と同等のパフォーマンスを達成し、場合によっては PPO を上回るパフォーマンスを発揮します (おそらく、ReMax に適切なハイパーパラメータを見つけるのが簡単なためです)。この優れたパフォーマンスは、より大きなモデルにも拡張できると推測されます。

<<: 

>>:  OpenAIがついにオープン:DALL-E 3の論文が発表され、ChatGPTが開始、著者の半数が中国人

ブログ    
ブログ    

推薦する

...

2017 年の機械学習開発に関するトップ 10 の予測: 悲観的か現実的か?

「分析の時代」はまだ始まったばかりですが、私たちには多くの刺激的なアイデアと期待がもたらされていま...

ガートナー、2022年の銀行・投資サービスにおける3つの注目のテクノロジートレンドを発表

ガートナーは、2022年の銀行および投資サービス業界における3つの注目の技術トレンドとして、生成型人...

...

数十人の国内NLP専門家が協力し、事前学習済みモデルの過去、現在、未来を検討した。

[[422361]] BERT や GPT などの大規模な事前トレーニング済みモデル (PTM) ...

通信ネットワーク運用イベントのナレッジグラフの構築

1. 通信ネットワーク運用シナリオまず、通信ネットワーク運用の背景についてご紹介します。通信ネットワ...

個人情報保護における人工知能データの役割

世界中で人工知能の大規模な構築と応用の発展が加速する中、近年、人工知能ガバナンスの問題が社会の関心を...

最も孤独なニューラル ネットワーク: たった 1 つのニューロンですが、「クローンをシャドウ」することができます

世界で最も先進的なニューラルネットワークモデルは何ですか?それは人間の脳に違いない。人間の脳には86...

お金は人を幸せにできるのでしょうか?機械学習を使って答えを見つける方法を教えます

機械学習システムを分類する 1 つの方法は、一般化の程度によって分類することです。ほとんどの機械学習...

速報、AI専門家のJing Kun氏がBaiduを退社! CIOの李英がXiaoduのCEOに就任

この記事はAI新メディアQuantum Bit(公開アカウントID:QbitAI)より許可を得て転載...

GPT-4 MATHの精度は84.3%まで上昇しました!香港中文大学や清華大学を含むトップ7大学が新たなCSV方式を提案

大規模言語モデル (LLM) は常識理解やコード生成などのタスクでは大きな進歩を遂げていますが、数学...

Java プログラミング スキル - データ構造とアルゴリズム「ハフマン ツリー」

[[389315]]基本的な紹介n 個のリーフ ノードとして n 個の重みが与えられ、バイナリ ツ...

初の科学ニュース執筆ロボット「小科」が発売

[[272541]] 8月1日、初の科学ニュース執筆ロボット「小科」が正式に就任し、その最初の一連の...

AV-TESTに再び認定されました! Sangfor EDRは中国で初めて満点を獲得したエンタープライズレベルのエンドポイントセキュリティ製品となる

検出能力6点!パフォーマンス消費6ポイント!使いやすさ6点!先日、国際的に権威のある評価機関 AV-...

AIはイノベーションを通じて気候への影響を補うことができるでしょうか?

最も熱心な気候変動監視者でさえ希望を抱いている。なぜなら、人類の革新と技術が私たちをこの混乱に陥れた...