モデルの並列処理により、ビジョンタスクのパフォーマンスが向上します。しかし、現在のところ、混合精度などの他の SOTA 手法と同じくらい簡単にモデル並列処理を採用できる標準ライブラリはありません。 最近、メリーランド大学カレッジパーク校のコンピュータサイエンス学部の研究者である Kaiyu Yue 氏が、PyTorch テンソルを並列シャードにスライスするための軽量エンジンである TorchShard ツールをオープンソース化しました。 TorchShard は、モデルに多数の線形レイヤー (BERT、GPT など) や多数のクラス (数百万) がある場合に、GPU メモリを削減し、トレーニングをスケールできます。PyTorch と同じ API 設計になっています。 プロジェクトアドレス: https://github.com/KaiyuYue/torchshard BERT や GPT などの非常に大規模なモデルは、NLP 分野のアプリケーションでトレンドになりつつあります。しかし、このような大規模なモデルをトレーニングするには、メモリ制限の問題に直面します。この問題を解決するために、研究者は Megatron-LM と PyTorch-Lightning モデルの並列処理を使用してトレーニングを拡張しました。このうち、Megatron-LM は大規模なトレーニング言語モデルにのみ焦点を当てていますが、PyTorch-Lightning は DeepSpeed などのシャード化されたオプティマイザー状態と勾配のみに基づいています。 コンピューター ビジョン タスクでは、Transformer ベースのモデル、MLP モデル、または数百万のクラスのトレーニング モデルをトレーニングするときに同じ問題が発生します。 TorchShard の目標は次のとおりです。
TorchShard は、Megatron-LM の中心にあるモデル並列ユニット (MPU) を完全に書き直したものです。最も重要なのは、TorchShard は PyTorch と同じ API 設計になっていることです。つまり、すべてのサブクラスとサブ関数は PyTorch と同じままです。たとえば、元の線形レイヤー torch.nn.Linear を並列にしたい場合は、次のように torch を ts に変換し、サブクラス nn.ParallelLinear を dim パラメータで呼び出します。
これに加えて、TorchShard は DDP と併用すると、シャード チェックポイントの保存と読み込み、シャード パラメータの初期化、複数のマシンと GPU にわたるテンソルの処理など、さまざまな機能をサポートします。詳細は以下の通りです。
TorchShard を使い始めるにはどうすればいいですか?インストール要件: Python バージョン 3.6 以上 (含む) および PyTorch バージョン 1.9.0 以上 (含む)。 pip 経由で TorchShard ライブラリをインストールします。
ここでは、ImageNet での ResNet-50 のトレーニングを例として、わずか数行のコードでプロジェクトで TorchShard を使用する方法を示します。通常、ResNet-50 の設計パラダイムは、下の図 1 に示すように、畳み込みブロックと完全接続層の 2 つの部分で構成されます。一般に、データセットに応じてクラスの数が多いため、最終線形層には畳み込みブロックよりも多くのパラメーターがあります。そこで、最後の線形レイヤーをスライスして、その最大サイズを確認します。 図 1: DDP および DDP + TorchShard フォワード トレーニング フロー。 上の図 1 では、従来の DDP トレーニング パラダイムが左側に示されています。 2 つのクラスがあると仮定すると、DDP は各クラスに重複したモデル パラメータを強制的に設定させます。ただし、TorchShard はレイヤー パラメータをさまざまなレベルに分割するため、全体的な GPU メモリが削減されます。ここで、ImageNet の公式トレーニング スクリプトにいくつかのコードを追加すると、修正されたバージョンが TorchShard プロジェクトの一部になります。 まず、torchshard をインポートします。
次に、DDP プロセス グループを初期化するのと同じ方法で、モデル並列プロセス グループを初期化する必要があります。ターゲット レイヤーからスライスするシャードの数を torchshard に指示する関数パラメータを設定するだけで済みます。
次に、モデルは並列バージョンに変換され、特別な処理なしでモデル全体を変換ヘルパー関数に直接入力できるようになります。
また、入力テンソルに応じて元の PyTorch バージョンと並列バージョンを切り替えることができる損失関数 torchshard.nn.ParallelCrossEntropy も忘れないでください。たとえば、入力テンソルが torchshard 並列レイヤーによって生成される場合、torchshard.nn.ParallelCrossEntropy は損失値を並列で計算します。
モデル並列モード (TorchShard) とデータ並列モード (DDP) が連携して動作する場合、並列レイヤーの入力を処理する必要があります。パラメータとトレーニングデータはレベルごとに異なります。したがって、ResNet の並列線形レイヤーの前に入力テンソルを収集します。
同様に、損失を計算する前にターゲット テンソルを収集します。
最後に、TorchShard 関数を使用すると、チェックポイントの保存と読み込みが非常に簡単になります。 TorchShard は、チェックポイントを保存するための torchshard.collect_state_dict という基本関数と、チェックポイントを読み込むための torchshard.relocate_state_dict という基本関数を提供します。 チェックポイントを保存します:
チェックポイントをロードします:
ImageNet でのシャード トレーニング用のコードの追加が完了したので、クラス数、つまり最後の線形レイヤーの出力特徴次元を増やすことでスケールアップできます。トレーニング スクリプトは torchshard/project/imagenet にあります。次の図は、クラス数が 1,000,000 以下の 8 個の NVIDIA TITAN-XP (12196 MiB) GPU と、クラス数が 2,000,000 の 16 個の GPU での ResNet-50 トレーニングのスケーラビリティを示しています。 図 2: さまざまな並列化戦略で標準の ResNet トレーニング設定 (入力サイズ 224、バッチ サイズ 256) を使用した場合の GPU メモリ コスト。 ZeROでAMPを使用するTorchShard は、Automatic Mixed Precision AMP や ZeRO などの他の技術と、シンプルで自然な PyTorch の方法で組み合わせることができます。
図 3: 標準の ResNet トレーニング設定 (入力サイズ 224、バッチ サイズ 256) を使用したさまざまな並列戦略と AMP での GPU メモリの使用コスト。 ZeRO は DeepSpeed のコアであり、PyTorch >= 1.9.0 で使用されます。関数をテストする場合は、スクリプトの最新バージョンをインストールして実行してください。コードは次のとおりです。
図 4: さまざまな並列化戦略と ZeRO オプティマイザーを使用した標準 ResNet トレーニング セットアップ (入力サイズ 224、バッチ サイズ 256) の GPU メモリ コスト。 さらに、TorchShard は、カスタム並列レイヤーの実装を簡素化するための基本的な Python API と対応するテンプレート ファイルも提供します。 研究者たちは TorchShard の開発を継続します。たとえば、TorchShard の次の機能は、torch.utils.data.DistributedSampler の命名に続く新しいデータ サンプラー torchshard.utils.data.DistributedGroupSampler です。このサンプラーは、ユーザーが M 方向のデータ並列処理と N 方向のモデル並列処理を構築できるように設計されており、DDP の DistributedSampler と同じくらいシンプルです。ユーザーが行う必要があるのは、モデル並列グループ番号を設定することだけです。そうすると、DistributedGroupSampler によって、同じモデル並列グループ内のモジュールに同じトレーニング データが含まれるようになります。 |
<<: ニッチから人気へ: 世界的な AI イノベーションが「ソフト」になった理由
>>: 二重あごをなくすコツがある。浙江大学の2000年代生まれの大学生が、ACM SIGGRAPHで発表した新しい美容アルゴリズムを開発
統計学と機械学習は密接に関連した2つの分野です。実際のところ、この 2 つの境界線は非常に曖昧になる...
アナリスト会社ガートナーは10月13日、2026年までに企業の80%以上が生成型AIアプリケーション...
今年初めの Red Hat Summit で、Red Hat は OpenShift AI によるプ...
イスラエルとパレスチナの紛争が深刻化するにつれ、ソーシャルメディアのプラットフォーム上には現地の情景...
機械学習を学びたいですか? まずはこの 10 冊の本から始めましょう。 [[374789]] >...
大型モデルはロボット工学の分野でその地位を確立しました。 「飲み物をこぼしてしまいました。助けてくれ...
この記事は、WeChat OCR 技術紹介シリーズの一部であり、ディープ シーケンス ラーニング手法...
[[392763]]コンセプト簡単に言うと、再帰とは、毎回異なる変数を渡しながら、自身を呼び出すメ...
歴史的に、これらの国や地域は旧植民地帝国によって貧困化しており、ヨーロッパの植民地主義は土地の暴力的...
[[387145]]基本的な紹介1. スタックはFILO(先入れ後出し)順序付きリストです2. ス...