Pytorch の核心であるモデルの定義と構築を突破しましょう! ! !

Pytorch の核心であるモデルの定義と構築を突破しましょう! ! !

こんにちは、Xiaozhuangです!

今日はモデルの定義と構築についてお話ししましょう。初心者に最適です!

ディープラーニングに PyTorch を使用する場合、まずモデルを定義して構築する方法を理解する必要があります。この内容は非常に重要です。

PyTorch では、モデル定義は通常、torch.nn.Module から継承するクラスを作成することによって行われます。

以下は、完全に接続された 1 つの層を持つ単純なニューラル ネットワークを定義する方法の簡単な例です。

 import torch import torch.nn as nn class SimpleNN(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(SimpleNN, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, output_size) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x

次に、この例を段階的に説明しましょう。

1. 必要なライブラリをインポートする

import torch import torch.nn as nn

ここで、PyTorch ライブラリとニューラル ネットワーク モジュールがインポートされます。

2. モデルクラスを定義する

class SimpleNN(nn.Module):

nn.Module から継承するクラスを作成します。このクラスは、ニューラル ネットワーク モデルの青写真として機能します。

3. 初期化関数

def __init__(self, input_size, hidden_size, output_size): super(SimpleNN, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, output_size)

__init__ 関数では、モデルのさまざまなレイヤーを定義します。

この単純なニューラル ネットワークには、入力層 (input_size ディメンション)、隠し層 (hidden_​​size ディメンション)、および出力層 (output_size ディメンション) が含まれています。

nn.Linear は完全接続層を表し、nn.ReLU は活性化関数 ReLU を表します。

4. フォワードプロパゲーション機能

def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x

forward 関数では、モデル内でデータがどのように伝播されるかを定義します。

ここでの伝播順序は、入力データが最初の完全接続層を通過し、次に ReLU 活性化関数を通過し、最後に 2 番目の完全接続層を通過してモデルの出力が得られるというものです。

この簡単な例を使用して、モデルを作成し、データを入力し、次の手順でフォワード パスを実行できます。

 # 定义输入、隐藏和输出层的维度input_size = 10 hidden_size = 20 output_size = 5 # 创建模型实例model = SimpleNN(input_size, hidden_size, output_size) # 随机生成输入数据input_data = torch.randn(32, input_size) # 32是批处理大小# 进行前向传播output = model(input_data) print(output)

これは単純なケースです。同様に、PyTorch は畳み込みニューラル ネットワーク (CNN)、再帰型ニューラル ネットワーク (RNN) など、より複雑なモデルを構築できます。

<<:  7つの変革的技術トレンド:第4次産業革命をリードする

>>:  2024 年の 6 つの主要なテクノロジー トレンドを見据えて、最もホットなテクノロジーをご紹介します。

ブログ    
ブログ    

推薦する

人工知能業界の最新の開発動向を1つの記事で理解する

[[418444]]現在、新世代の人工知能に代表される科学・産業革命が起こりつつあります。デジタル化...

人工知能のコスト問題をどう解決するか?顔認識によって情報セキュリティはどのように確保されるのでしょうか?

[[422539]] 9月7日午後、第19回「海南省科学技術会議」に新たに追加されたホットトピック...

...

2019 年に TensorFlow は王座から退いたのでしょうか?

この記事では、著者は GitHub、Medium の記事、arXiv の論文、LinkedIn など...

データのラベル付けは不要、「3D理解」によるマルチモーダル事前トレーニングの時代へ! ULIPシリーズは完全にオープンソースで、SOTAをリフレッシュします

3D 形状、2D 画像、および対応する言語記述を整合させることにより、マルチモーダル事前トレーニング...

センシング、AI、想像力:視覚がモノのインターネットをどう形作るか

ビジョンは、私たちの世界を大きく変えつつあるモノのインターネットの成長において、急速に主要なセンシン...

2022年の企業向け人工知能技術の開発動向

調査によると、企業が人工知能を導入する方法が増え、開発者がユーザーに AI サービスを提供する新しい...

機械学習の落とし穴を避ける: データはアルゴリズムよりも重要

ユーザー行動分析とネットワーク脅威検出、新たな波が起こり続けています。セキュリティ データ分析は、状...

ユニサウンドがマルチモーダルAIチップ戦略を発表、同時に開発中の3つのチップを公開

昨年5月に業界初となるモノのインターネット(IoT)向けAIチップ「Swift」とそのシステムソリュ...

OpenAI が GPT-3.5 Turbo の値下げを発表、GPT-4 Turbo の「怠惰」を解消

米国時間1月26日木曜日、OpenAIは一連のメジャーアップデートを発表した。これらのアップデートは...

Google はデータセンター向けの次世代地熱エネルギーを開発するために AI を応用している

[[401455]]地熱発電は地球の地下の自然の熱を利用して電気を生み出すので、魅力的な点がたくさん...

AIは人間の感情を理解できるのか?

温かく思いやりのある、一緒にいてくれる「ダバイ」が欲しいと願う人は多いだろうが、ダバイのように人間の...

Java における equals() と == の違いと使い方

Java 開発において、一見単純な質問ですが、インターネット上には多くのトピックや質問があります。...