PyTorchBigGraph を使用して超大規模グラフ モデルをトレーニングする方法は?

PyTorchBigGraph を使用して超大規模グラフ モデルをトレーニングする方法は?

Facebook は、数十億のノードと数兆のエッジを持つグラフ モデルを効率的にトレーニングできる BigGraph というフレームワークを提案し、その PyTorch 実装をオープンソース化しました。この記事では、その革新性を説明し、大規模なグラフ ネットワークから効率的に知識を抽出できる理由を分析します。

グラフは、機械学習アプリケーションにおける最も基本的なデータ構造の 1 つです。具体的には、グラフ埋め込み法は、ローカル グラフ構造を使用してノードの表現を学習する教師なし学習法です。ソーシャル メディア予測、IoT パターン検出、薬物シーケンス モデリングなどの主流のシナリオにおけるトレーニング データは、グラフ構造として自然に表現できます。これらのシナリオのそれぞれでは、数十億個の接続されたノードを持つグラフが簡単に作成されます。グラフは構造が非常に豊富で、本質的にナビゲート可能であるため、機械学習モデルに適しています。それにもかかわらず、グラフ構造は非常に複雑であり、アプリケーションに合わせて拡張するのが困難です。したがって、最新のディープラーニング フレームワークでは、大規模なグラフ データ構造のサポートが依然として非常に限られています。

Facebook は PyTorch BigGraph: https://github.com/facebookresearch/PyTorch-BigGraph というフレームワークを立ち上げました。これにより、PyTorch モデル内の大規模なグラフ構造のグラフ埋め込みをより迅速かつ簡単に生成できます。

ある意味では、グラフ構造は、ノード間の接続を使用して特定の関係を推測できるため、ラベル付けされたトレーニング データセットの代替として考えることができます。このアプローチは、教師なしグラフ埋め込み法のパターンに従います。この方法では、エッジで接続されたノード ペアの埋め込みがエッジのないノード ペアの埋め込みよりも近くなるようにノード ペアの埋め込みを最適化することで、グラフ内の各ノードのベクトル表現を学習できます。これは、テキストでトレーニングされた word2vec からの単語埋め込みが機能する方法に似ています。

ほとんどのグラフ埋め込み方法は、大規模なグラフ構造に適用すると、かなり限られた結果しか示しません。たとえば、モデルに 20 億のノードがあり、各ノードに 100 個の埋め込みパラメータ (浮動小数点数として表される) がある場合、これらのパラメータを格納するためだけに 800 GB のメモリが必要になるため、多くの標準的なアプローチでは一般的なコモディティ サーバーのメモリ容量を超えてしまいます。これはディープラーニング モデルが直面している大きな課題であり、Facebook が BigGraph フレームワークを開発した理由です。

PyTorch ビッググラフ

PyTorch BigGraph (PBG) の目標は、グラフ埋め込みモデルを拡張して、数十億のノードと数兆のエッジを持つグラフを処理することです。 PBG はなぜこれができるのでしょうか? 4 つの基本的な構成要素を使用するためです。

  1. グラフのパーティション分割により、モデルをメモリに完全にロードする必要がなくなります。
  2. 各マシンでのマルチスレッドコンピューティング
  3. 複数のマシン上での分散実行(オプション)。すべての操作はグラフの切断された部分で実行されます。
  4. バッチネガティブサンプリングでは、エッジごとに 100 個のネガティブサンプルがある場合、マシンごとに 1 秒あたり 100 万を超えるエッジを処理できます。

PBG は、グラフ構造を P 個のランダムに分割されたパーティションに分割し、2 つのパーティションがメモリに収まるようにすることで、従来のグラフ埋め込み方法の欠点の一部を解決します。たとえば、エッジがパーティション p1 で始まり、パーティション p2 で終わる場合、そのエッジはバケット (p1、p2) に配置されます。次に、同じモデル内で、これらのグラフ ノードはソース ノードとターゲット ノードに応じて P2 バケットに分割されます。ノードとエッジの分割が完了したら、一度に 1 つのバケットでトレーニングを実行できます。バケット (p1、p2) のトレーニングでは、パーティション p1 と p2 の埋め込みをメモリに保存するだけで済みます。 PBG 構造により、バケットには少なくとも 1 つの以前にトレーニングされた埋め込みパーティションが含まれるようになります。

PBG のもう一つの大きな革新は、トレーニング メカニズムの並列化と分散です。 PBG は PyTorch 独自の並列化メカニズムを使用して、上記のモジュール分割構造を使用する分散トレーニング モデルを実装します。このモデルでは、各マシンが分離したバケットでのトレーニングを調整します。これは、バケットをワーカーにディスパッチする役割を果たすロック サーバーを使用し、異なるマシン間の通信を最小限に抑えます。各マシンは異なるバケットを使用してモデルを並列にトレーニングできます。

上の図では、マシン 2 の Trainer モジュールがマシン 1 のロック サーバーにバケットを要求し、バケットのパーティションをロックします。次に、トレーナーは使用しなくなったパーティションを保存し、共有パーティション サーバーから必要な新しいパーティションをロードします。この時点で、古いパーティションをロック サーバーに戻すことができます。次に、エッジは共有ファイル システムからロードされ、スレッド内同期なしで複数のスレッドでトレーニングされます。別のスレッドでは、いくつかの共有パラメータのみが共有パラメータ サーバーと継続的に同期されます。モデル チェックポイントは、トレーナーから共有ファイル システムに時々書き込まれます。このモデルでは、最大 P/2 台のマシンを使用して P 個のバケットのセットを並列化できます。

PBG のそれほど直接的ではない革新は、バッチネガティブサンプリングの使用です。従来のグラフ埋め込みモデルは、負のトレーニング例として、真の正のエッジとともにランダムな「偽の」エッジを構築します。これにより、新しい例ごとに重みのごく一部だけを更新すればよいため、トレーニングの速度が大幅に向上します。ただし、負の例はグラフ処理にパフォーマンスのオーバーヘッドをもたらし、ランダムなソース ノードまたはターゲット ノードを通じて実際のエッジを「破損」させる可能性があります。 PBG は、N 個のランダム ノードの単一バッチを再利用して、N 個のトレーニング エッジの破損した負のサンプルを取得する方法を導入します。他の埋め込み方法と比較して、この手法では、計算コストを非常に低く抑えながら、エッジごとに多数の負の例をトレーニングできます。

大規模なグラフでのメモリ効率と計算リソースを向上させるために、PBG は Bn 個のサンプリングされたソース ノードまたはターゲット ノードの単一バッチを使用して、複数の負の例を構築します。通常の設定では、PBG はトレーニング セットから B = 1000 個の正の例のバッチを取得し、それらを 50 個のエッジのブロックに分割します。各ブロックからのターゲット(ソースに相当)埋め込みは、末尾のエンティティ タイプから均一にサンプリングされた 50 個の埋め込みと連結されます。 50 個の正例と 200 個のサンプリング ノードの外積は、9900 個の負例に等しくなります。

バッチネガティブサンプリング方式は、モデルのトレーニング速度に直接影響を与える可能性があります。バッチ処理を行わない場合、トレーニングの速度は負の例の数に反比例します。バッチトレーニングにより方程式を改善し、安定したトレーニング速度を得ることができます。

Facebook は、LiveJournal、Twitter データ、YouTube ユーザー インタラクション データなどのさまざまなデータセットを使用して PBG を評価しました。さらに、PBG は、1 億 2,000 万を超えるノードと 27 億のエッジを含む Freebase ナレッジ グラフを使用してベンチマークされました。また、Freebase の小さなサブセットである FB15k でもテストしました。FB15k には 15,000 個のノードと 600,000 個のエッジが含まれており、マルチリレーション埋め込み方法のベンチマークとしてよく使用されます。 FB15k 実験では、PBG が現在の最良のグラフ埋め込みモデルと同様のパフォーマンスを発揮することが示されています。ただし、完全な Freebase データセットで評価すると、PBG はメモリ消費において 88% の改善を達成します。

PBG は、数十億のノードと数兆のエッジを含むグラフをトレーニングおよび処理できる最初のスケーラブルな方法です。 PBG の最初の実装はオープンソース化されており、将来的にはさらに興味深い貢献が出てくるでしょう。

<<:  Google、少ないパラメータでテキスト分類を行う新モデル「pQRNN」を発表、BERTに匹敵する性能

>>:  AI起業家にとって、これら4つの新たな方向性は注目に値するかもしれない

ブログ    
ブログ    
ブログ    

推薦する

小売業界におけるAIインテリジェントビデオ分析の応用

人工知能 (AI) は、情報の集合からビジネス価値のある洞察を抽出することを目的とするデータ サイエ...

NTTとシスコがAR技術を活用して生産性を向上

[[400946]]距離がチームワークを制限するべきではないメンテナンスの問題をより早く解決世界中の...

...

工業生産は変化している:機械は人間よりも製造に優れている

最近、ロボットが人気になってきました。家庭生活、ホテル経営、学校教育、医療などさまざまな場面でロボッ...

マスク氏は人気検索に頻繁に登場、テスラは「過大評価されている」

この記事はLeiphone.comから転載したものです。転載する場合は、Leiphone.com公式...

...

ディープフェイクに取って代わると期待されていますか?今年最も注目されているNeRFテクノロジーの秘密を解き明かす

え、まだNeRFを知らないの? NeRF は、今年コンピューター ビジョン分野で最も注目されている ...

AIが地震の前兆信号を識別?機械学習がデータ内の不思議な相関関係を発見、人類に地震予測の希望を与える

最近、世界中で地震が頻繁に発生しています。 1月1日、突然、マグニチュード7.6の地震が日本を襲い、...

機械学習ツールボックスには6つの重要なアルゴリズムが隠されています

1. 線形回帰フランスの数学者アドリアン・マリー・ルジャンドルは、彗星の将来の位置を予測することに常...

マイクロソフト、人間の編集者をAIに置き換え、ジャーナリスト数名を解雇

[[328414]]マイクロソフトは、マイクロソフトニュースとMSNチームから数十人のジャーナリスト...

...

MetaとMicrosoft、Nvidia GPUの代替として新しいAMD AIチップを購入することを約束

12月7日、Meta、OpenAI、Microsoftは、現地時間水曜日のAMD投資家向けイベントで...

...

...

...