PyTorch のデータセット Torchvision と Torchtext

PyTorch のデータセット Torchvision と Torchtext

[[421061]]

PyTorch がさまざまな種類のデータを読み込んで処理できるように、公式では torchvision と torchtext が提供されています。

以前は、torchDataLoader クラスを使用して画像を直接読み込み、テンソルに変換していました。ここでは、torchvisionとtorchtextと組み合わせてtorchに組み込まれているデータセットを紹介します。

Torchvisionのデータセット

MNIST

MNIST は、正規化され中央が切り取られた手書き画像で構成されるデータセットです。 60,000 枚以上のトレーニング画像と 10,000 枚以上のテスト画像があります。これは、学習や実験の目的で最もよく使用されるデータセットの 1 つです。データセットを読み込んで使用するには、次の構文を使用してインポートします: torchvision.datasets.MNIST()。

ファッションMNIST

Fashion MNIST データセットは MNIST に似ていますが、このデータセットには手書きの数字の代わりに T シャツ、パンツ、バッグなどの衣料品が含まれており、トレーニング サンプルとテスト サンプルの数はそれぞれ 60,000 と 10,000 です。データセットを読み込んで使用するには、次の構文を使用してインポートします: torchvision.datasets.FashionMNIST()

シーファー

CIFAR データセットには、CIFAR10 と CIFAR100 の 2 つのバージョンがあります。 CIFAR10 は 10 個の異なるラベルを持つ画像で構成され、CIFAR100 には 100 個の異なるクラスがあります。これらには、トラック、カエル、ボート、車、鹿などの一般的な画像が含まれます。

  1. torchvision.datasets.CIFAR10()
  2. torchvision.datasets.CIFAR100()

ココ

COCO データセットには、人、ボトル、文房具、本など、100,000 を超える日常的なオブジェクトが含まれています。この画像データセットは、オブジェクト検出や画像キャプション作成アプリケーションに広く使用されています。 COCO をロードできる場所は次のとおりです: torchvision.datasets.CocoCaptions()

エムニスト

EMNIST データセットは、MNIST データセットの高度なバージョンです。数字や文字を含む画像で構成されています。画像からテキストを認識する問題に取り組んでいる場合、EMNIST は良い選択です。 EMNIST をロードできる場所は次のとおりです: torchvision.datasets.EMNIST()

イメージネット

ImageNet は、高度なニューラル ネットワークをトレーニングするための主要なデータセットの 1 つです。 10,000 のカテゴリに分散された 120 万枚以上の画像で構成されています。通常、このデータセットは、単一の CPU ではこのような大規模なデータセットを処理できないため、ハイエンドのハードウェア システムにロードされます。 ImageNetデータセットをロードするためのクラスは次のとおりです: torchvision.datasets.ImageNet()

Torchtext のデータセット

IMDB

IMDB は感情分類用のデータセットで、トレーニング用に 25,000 件の非常に極端な映画レビューのセット、テスト用にさらに 25,000 件のレビューが含まれています。 torchtext クラスを使用してこれらのデータをロードします: torchtext.datasets.IMDB()

ウィキテキスト2

WikiText2 言語モデリング データセットは、1 億を超えるトークンのコレクションです。これは Wikipedia から抽出されたもので、句読点と実際の大文字と小文字が保持されています。長期的な依存関係を伴うアプリケーションで広く使用されています。このデータはtorchtextから読み込むことができます: torchtext.datasets.WikiText2()

上記の 2 つの人気データセットに加えて、SST、TREC、SNLI、MultiNLI、WikiText-2、WikiText103、PennTreebank、Multi30k など、torchtext ライブラリで利用できるデータセットが他にもあります。

MNISTデータセットを詳しく見る

MNIST は最も人気のあるデータセットの 1 つです。ここで、PyTorch が pytorch/vision リポジトリから MNIST データセットをロードする方法を確認します。まずデータセットをダウンロードし、data_trainという変数にロードしましょう。

  1. torchvision.datasetsからMNISTをインポートする
  2.  
  3. # MNISTをダウンロード
  4. data_train = MNIST( '~/mnist_data' 、 train= True 、 download= True )
  5.  
  6. matplotlib.pyplot をpltとしてインポートします。
  7.  
  8. ランダム画像 = データトレーニング[0][0]
  9. ランダム画像ラベル = データトレーニング[0][1]
  10.  
  11. # Matplotlib を使用して画像を印刷する
  12. plt.imshow(ランダム画像)
  13. print( "画像のラベルは:" , random_image_label)

DataLoaderはMNISTをロードします

次に、以下に示すように、DataLoader クラスを使用してデータセットを読み込みます。

  1. 輸入トーチ
  2. torchvisionから変換をインポート
  3.  
  4. data_train = torch.utils.data.DataLoader(
  5. MNIST(
  6. '~/mnist_data' 、トレーニング= True 、ダウンロード= True
  7. 変換 = transforms.Compose([
  8. 変換.ToTensor()
  9. ]))、
  10. バッチサイズ=64、
  11. シャッフル= 
  12.  
  13. batch_idxの場合enumerate(data_train)サンプル:
  14. print(batch_idx, サンプル)

CUDA 読み込み

GPU を有効にすると、モデルのトレーニングが高速化されます。ここで、CUDA (PyTorch の GPU サポート) を使用してデータをロードするときに使用できる構成を使用しましょう。

  1. デバイス = "cuda" 、torch.cuda.is_available() の場合、そうでない場合  "CPU"  
  2. kwargs = { 'num_workers' : 1, 'pin_memory' : True } デバイス == 'cuda'の場合 それ以外{}
  3.  
  4. トレーニングローダー = torch.utils.data.DataLoader(
  5. torchvision.datasets.MNIST( '/files/' 、トレーニング= True 、ダウンロード= True )、
  6. batch_size=batch_size_train、**kwargs)
  7.  
  8. test_loader = torch.utils.data.DataLoader(
  9. torchvision.datasets.MNIST( 'files/' 、トレーニング= False 、ダウンロード= True )、
  10. batch_size=batch_size、**kwargs)

画像フォルダ

ImageFolder は、独自の画像データセットを読み込むのに役立つ、一般的なデータ ローダー クラス torchvision です。分類問題を取り上げ、与えられた画像がリンゴかオレンジかを識別するニューラル ネットワークを構築します。 PyTorch でこれを行うには、まず次のように、画像をデフォルトのフォルダー構造に配置する必要があります。

  1. ├──オレンジ
  2. │ ├── オレンジ_image1.png
  3. │ └── オレンジ_image1.png
  4. ├── リンゴ
  5. │ └── apple_image1.png
  6. │ └── apple_image2.png
  7. │ └── apple_image3.png

これらの画像はすべて、ImageLoader クラスを使用して読み込むことができます。

  1. torchvision.datasets.ImageFolder(ルート、変換)

変換する

PyTorch Transforms は、データセット全体を独自の形式に変換できるシンプルな画像変換手法を定義します。

異なる解像度の異なる車の画像を含むデータセットの場合、トレーニング中は、トレーニング データセット内のすべての画像が同じ解像度サイズである必要があります。すべての画像を必要な入力サイズに手動で変換すると時間がかかるため、変換を使用できます。数行の PyTorch コードで、データセット内のすべての画像を必要な入力サイズと解像度に変換できます。

ここで、CIFAR10torchvision.datasets をロードし、次の変換を適用します。

  • すべての画像を32×32にサイズ変更します
  • 画像に中央切り抜き変換を適用します
  • 切り取った画像をテンソルに変換する
  • 標準化された画像
  1. 輸入トーチ
  2. torchvision をインポートする
  3. torchvision.transforms をtransformsとしてインポートします
  4. matplotlib.pyplot をpltとしてインポートします。
  5. numpyをnpとしてインポートする
  6.  
  7. 変換 = transforms.Compose([
  8. # サイズを 32×32 に変更
  9. 変換.サイズ変更(32)
  10. # センタークロップのクロッピング変換
  11. 変換.CenterCrop(32)、
  12. #から -テンソル
  13. 変換.ToTensor()、
  14. # 正規化
  15. 変換します。正規化します([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  16. ])
  17.  
  18. トレインセット = torchvision.datasets.CIFAR10(ルート = './data' 、トレイン = True
  19. ダウンロード= True 、変換=変換)
  20. トレインローダー = torch.utils.data.DataLoader(トレインセット、バッチサイズ=4、
  21. シャッフル = False )

PyTorch でカスタムデータセットを作成する

次に、数字とテキストで構成される単純なカスタム データセットを作成します。 Dataset クラスで __getitem__() メソッドと __len__() メソッドをカプセル化する必要があります。

  • __getitem__() メソッドは、インデックスによってデータセット内の選択された例を返します。
  • __len__() メソッドはデータセットの合計サイズを返します。

以下は、FruitImagesDataset データセットをカプセル化するコードです。これは基本的に、PyTorch でカスタム データセットを作成するための優れたテンプレートです。

  1. インポートOS
  2. numpyをnpとしてインポートする
  3. cv2をインポート
  4. 輸入トーチ
  5. matplotlib.patches をパッチとしてインポートする
  6. アルバムをAとしてインポートする
  7. albumentations.pytorch.transformsからToTensorV2 をインポートします
  8. matplotlibからpyplotをpltとしてインポートします
  9. torch.utils.dataからデータセットをインポート
  10. xml.etreeからElementTree をElementTreeとしてインポートします
  11. torchvisionからtorchtransとして変換をインポートします
  12.  
  13. クラス FruitImagesDataset(torch.utils.data.Dataset):
  14. def __init__(self, files_dir, width, height, transforms=None):
  15. self.transforms = 変換
  16. self.files_dir = ファイルディレクトリ
  17. self.height = 高さ
  18. 自己.幅 = 幅
  19.  
  20.  
  21. self.imgs = [画像for image in sorted(os.listdir(files_dir))
  22. 画像[-4:] == '.jpg'の場合]
  23.  
  24. self.classes = [ '_' 'リンゴ' 'バナナ' 'オレンジ' ]
  25.  
  26. __getitem__(self, idx)を定義します。
  27.  
  28. img_name = self.imgs[idx]
  29. image_path = os.path.join (self.files_dir、img_name) です。
  30.  
  31. # 画像を読み取り正しいサイズ変換する 
  32. img = cv2.imread(画像パス)
  33. img_rgb = cv2.cvtColor(img、cv2.COLOR_BGR2RGB).astype(np.float32) で、
  34. img_res = cv2.resize(img_rgb, (self.width, self.height), cv2.INTER_AREA)
  35. # 255によるダイビング
  36. 画像解像度 /= 255.0
  37.  
  38. # 注釈ファイル
  39. annot_filename = img_name[:-4] + '.xml'  
  40. annot_file_path = os.path.join (self.files_dir、annot_filename)
  41.  
  42. ボックス = []
  43. ラベル = []
  44. ツリー = et.parse(annot_file_path)
  45. ルート = tree.getroot()
  46.  
  47. # cv2 イメージはサイズを示します ×幅
  48. 重量 = 画像の形状[1]
  49. イメージシェイプ[0]
  50.  
  51. # xmlファイルボックス座標が抽出され指定された画像サイズに合わせ修正されます
  52. root.findall( 'object' )内のメンバーの場合:
  53. ラベルを追加します(self.classes.index ( member.find( 'name' ).text))
  54.  
  55. # 境界ボックス
  56. xmin = int (member.find( 'bndbox' ).find( 'xmin' ).text)
  57. xmax = int (member.find( 'bndbox' ).find( 'xmax' ).text)
  58.  
  59. ymin = int (member.find( 'bndbox' ).find( 'ymin' ).text)
  60. ymax = int (member.find( 'bndbox' ).find( 'ymax' ).text)
  61.  
  62. xmin_corr = (xmin / wt) * 自己幅
  63. xmax_corr = (xmax / wt) * 自己幅
  64. ymin_corr = (ymin / ht) * 自己高さ
  65. ymax_corr = (ymax / ht) * 自己高さ
  66.  
  67. ボックスを追加します([xmin_corr, ymin_corr, xmax_corr, ymax_corr])
  68.  
  69. #ボックスをtorch.Tensor変換する
  70. ボックス = torch.as_tensor(ボックス、dtype=torch.float32)
  71.  
  72. #ボックス面積を取得する
  73. 面積 = (ボックス[:, 3] - ボックス[:, 1]) * (ボックス[:, 2] - ボックス[:, 0])
  74.  
  75. #すべてのインスタンスが混雑していないと仮定
  76. iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
  77.  
  78. ラベル = torch.as_tensor(ラベル、dtype=torch.int64)
  79.  
  80. ターゲット = {}
  81. target[ "boxes" ] = ボックス
  82. target[ "labels" ] = ラベル
  83. target[ "area" ] = エリア
  84. ターゲット[ "iscrowd" ] = iscrowd
  85. #画像ID
  86. イメージID = torch.tensor([idx])
  87. ターゲット[ "image_id" ] = イメージID
  88.  
  89. self.transformsの場合:
  90. サンプル = self.transforms(画像 = img_res,
  91. bboxes=ターゲット[ 'ボックス' ],
  92. ラベル=ラベル)
  93.  
  94. img_res = サンプル[ '画像' ]
  95. ターゲット[ 'boxes' ] = torch.Tensor(サンプル[ 'bboxes' ])
  96. img_res、ターゲットを返す
  97. __len__(自分)を定義します:
  98. len(self.imgs)を返す
  99.  
  100. get_transform(train)を定義します。
  101. 電車の場合:
  102. A.Compose([を返す
  103. A.水平反転(0.5)、
  104. テンソルV2(p=1.0)
  105. ], bbox_params={ 'format' : 'pascal_voc' , 'label_fields' : [ 'labels' ]})
  106. それ以外
  107. A.Compose([を返す
  108. テンソルV2(p=1.0)
  109. ], bbox_params={ 'format' : 'pascal_voc' , 'label_fields' : [ 'labels' ]})
  110.  
  111. files_dir = '../input/fruit-images-for-object-detection/train_zip/train'  
  112. test_dir = '../input/fruit-images-for-object-detection/test_zip/test'  
  113.  
  114. データセット = FruitImagesDataset(train_dir, 480, 480)

<<:  ロボットが人間の「仲間」となり、人間と機械の関係が変化する。これは良いことなのか、悪いことなのか?

>>:  Kubernetes にディープラーニング モデルをデプロイする方法

ブログ    
ブログ    

推薦する

機械学習アルゴリズムが NDA の法的分析テストで 20 人の弁護士に勝利

ロボット工学と人工知能の発展により、多くの仕事が機械に置き換えられるでしょう。機械は、一部のタスク、...

Python アルゴリズムの時間計算量

アルゴリズムを実装する場合、アルゴリズムの複雑さは通常、時間の複雑さと空間の複雑さという 2 つの側...

転移学習の限界を突破せよ! Googleが新しいNLPモデル「T5」を提案、複数のベンチマークでSOTAに到達

[[316154]]過去数年間、転移学習は NLP 分野に実りある成果をもたらし、新たな発展の波を...

認知分析について知っておくべきことすべて

コンテキストを提供し、大量の情報に隠された答えを発見するために、コグニティブ コンピューティングはさ...

消費者がリアルなAIを信頼しない理由

Amazon Alexaのような音声アシスタントの台頭にもかかわらず、人々は本物そっくりのAIに不安...

ベイジアンアルゴリズムは「アプリチケット詐欺」を打破する良い方法となるだろう

最近、世間を騒がせた360 Appランキング操作事件とその背後にある闇産業チェーンの出現により、Ap...

ガートナーレポート: 世界のカスタマーサービスセンターが会話型 AI を導入、今年の支出は 16.2% 増加

8月1日、市場調査会社ガートナーが発表した最新のレポートによると、世界中のカスタマーサービスセンター...

業界最高品質の AI データを作成するにはどうすればよいでしょうか?クラウドデータの成功の秘密を明かす

[[344160]] AIの実装が加速する中、AIデータのラベリングは人工知能産業の実装における重要...

私の国はAIや5Gを含む多くの技術で米国を上回っており、米国が私たちを絞め殺すことはますます困難になっています。

世界大国として、中国と米国は多くの分野、特に科学技術分野で競争している。中国は科学技術分野で比較的目...

...

機械学習が近い将来教育を変える5つの方法

テクノロジーは私たちの生活、仕事、遊び方を変えており、教育も例外ではありません。機械学習は他の分野を...

...