TENSORFLOW に基づく中国語テキスト分類のための CNN と RNN

TENSORFLOW に基づく中国語テキスト分類のための CNN と RNN

[[211015]]

現在、TensorFlow のメジャーバージョンは 1.3 にアップグレードされ、多くのネットワーク層のより高度なカプセル化と実装が実現され、さらに Keras などの優れた高レベルフレームワークも統合され、使いやすさが大幅に向上しました。初期の基盤コードと比較すると、今日の実装はより簡潔でエレガントになっています。

この記事は、中国語データセットでの TensorFlow の簡略化された実装です。文字レベルの CNN と RNN を使用して中国語のテキストを分類し、良好な結果を達成しています。

データセット

この記事では、Tsinghua NLP Group が提供する THUCNews ニューステキスト分類データセットのサブセットを使用します (元のデータセットには約 740,000 件のドキュメントが含まれており、トレーニングには長い時間がかかります)。 THUCTC: 効率的な中国語テキスト分類ツールキットからデータセットをダウンロードしてください。データ プロバイダーのオープン ソース契約に従ってください。

このトレーニングでは 10 個のカテゴリが使用され、各カテゴリには 6,500 個のデータ ポイントがありました。

カテゴリーは次のとおりです。

スポーツ、金融、不動産、住宅、教育、テクノロジー、ファッション、時事問題、ゲーム、エンターテイメント

このサブセットはここからダウンロードできます: リンク: http://pan.baidu.com/s/1bpq9Eub パスワード: ycyw

データセットは次のように分割されます。

  • トレーニングセット: 5000*10
  • 検証セット: 500*10
  • テストセット: 1000*10

元のデータセットからサブセットを生成するプロセスについては、ヘルパーの下の 2 つのスクリプトを参照してください。このうち、copy_data.sh は各カテゴリから 6500 個のファイルをコピーするために使用され、cnews_group.py は複数のファイルを 1 つのファイルに統合するために使用されています。ファイルを実行すると、次の 3 つのデータ ファイルが取得されます。

  • cnews.train.txt: トレーニング セット (50,000 項目)
  • cnews.val.txt: 検証セット (5000 項目)
  • cnews.test.txt: テストセット (10000 項目)

前処理

data/cnews_loader.py はデータ前処理ファイルです。

  • read_file(): ファイルデータを読み取ります。
  • build_vocab(): 文字レベルの表現を使用して語彙を構築します。この関数は語彙を保存し、毎回繰り返し処理されないようにしています。
  • read_vocab(): 前のステップで保存された語彙を読み取り、それを {word:id} 表現に変換します。
  • read_category(): カテゴリディレクトリを修正し、{category: id} 表現に変換します。
  • to_words(): id で表されるデータをテキストに変換します。
  • preocess_file(): データセットをテキストから固定長の ID シーケンス表現に変換します。
  • batch_iter(): ニューラル ネットワークのトレーニング用にシャッフルされたデータ バッチを準備します。

データの前処理後のデータ形式は次のようになります。

CNN 畳み込みニューラルネットワーク

構成項目

CNN の設定可能なパラメータは、以下の cnn_model.py に示されています。

  1. クラス TCNNConfig(オブジェクト):
  2. CNN 構成パラメータ  
  3.  
  4. embedding_dim = 64 # 単語ベクトルの次元
  5. seq_length = 600 # シーケンスの長さ
  6. num_classes = 10 # カテゴリの数
  7. num_filters = 128 # 畳み込みカーネルの数
  8. kernel_size = 5 # 畳み込みカーネルのサイズ
  9. vocab_size = 5000 # 小さな語彙
  10.  
  11. hidden_​​dim = 128 #完全結合層ニューロン
  12.  
  13. dropout_keep_prob = 0.5 # ドロップアウト保持率
  14. learning_rate = 1e-3 # 学習率
  15.  
  16. batch_size = 64 # バッチあたりのトレーニングサイズ
  17. num_epochs = 10 # 反復回数の合計
  18.  
  19. print_per_batch = 100 # 数ラウンドごとに結果を出力します
  20. save_per_batch = 10 # テンソルボードに保存するラウンド数

CNNモデル

詳細については、cnn_model.py の実装を参照してください。

一般的な構造は次のとおりです。

トレーニングと検証

トレーニングを開始するには、python run_cnn.py train を実行します。

以前にトレーニングしたことがある場合は、TensorBoard で複数のトレーニング結果が重複しないように、tensorboard/textcnn を削除してください。

  1. CNN モデルを構成しています...
  2. TensorBoardSaver を設定しています...
  3. トレーニングおよび検証データを読み込んでいます...
  4. 使用時間: 0:00:14
  5. トレーニング評価...
  6. エポック: 1
  7. 反復: 0、列車損失: 2.3、列車精度: 10.94%、値損失: 2.3、値精度: 8.92%、時間: 0:00:01 *
  8. 反復: 100、列車損失: 0.88、列車精度: 73.44%、値損失: 1.2、値精度: 68.46%、時間: 0:00:04 *
  9. 反復: 200、列車損失: 0.38、列車精度: 92.19%、値損失: 0.75、値精度: 77.32%、時間: 0:00:07 *
  10. 反復: 300、列車損失: 0.22、列車精度: 92.19%、値損失: 0.46、値精度: 87.08%、時間: 0:00:09 *
  11. 反復: 400、列車損失: 0.24、列車精度: 90.62%、値損失: 0.4、値精度: 88.62%、時間: 0:00:12 *
  12. 反復: 500、列車損失: 0.16、列車精度: 96.88%、値損失: 0.36、値精度: 90.38%、時間: 0:00:15 *
  13. 反復: 600、列車損失: 0.084、列車精度: 96.88%、値損失: 0.35、値精度: 91.36%、時間: 0:00:17 *
  14. 反復: 700、列車損失: 0.21、列車精度: 93.75%、値損失: 0.26、値精度: 92.58%、時間: 0:00:20 *
  15. エポック: 2
  16. 反復: 800、列車損失: 0.07、列車精度: 98.44%、値損失: 0.24、値精度: 94.12%、時間: 0:00:23 *
  17. 反復: 900、列車損失: 0.092、列車精度: 96.88%、値損失: 0.27、値精度: 92.86%、時間: 0:00:25
  18. 反復: 1000、列車損失: 0.17、列車精度: 95.31%、値損失: 0.28、値精度: 92.82%、時間: 0:00:28
  19. 反復: 1100、列車損失: 0.2、列車精度: 93.75%、値損失: 0.23、値精度: 93.26%、時間: 0:00:31
  20. 反復: 1200、列車損失: 0.081、列車精度: 98.44%、値損失: 0.25、値精度: 92.96%、時間: 0:00:33
  21. 反復: 1300、列車損失: 0.052、列車精度: 100.00%、値損失: 0.24、値精度: 93.58%、時間: 0:00:36
  22. 反復: 1400、列車損失: 0.1、列車精度: 95.31%、値損失: 0.22、値精度: 94.12%、時間: 0:00:39
  23. 反復: 1500、列車損失: 0.12、列車精度: 98.44%、値損失: 0.23、値精度: 93.58%、時間: 0:00:41
  24. エポック: 3
  25. 反復: 1600、列車損失: 0.1、列車精度: 96.88%、値損失: 0.26、値精度: 92.34%、時間: 0:00:44
  26. 反復: 1700、列車損失: 0.018、列車精度: 100.00%、値損失: 0.22、値精度: 93.46%、時間: 0:00:47
  27. 反復: 1800、列車損失: 0.036、列車精度: 100.00%、値損失: 0.28、値精度: 92.72%、時間: 0:00:50
  28. 長時間最適化されず自動停止します...

検証セットでの最良の結果は 94.12% で、アルゴリズムはわずか 3 回の反復後に停止しました。

精度と誤差は図に示されています。

テスト

テスト セットをテストするには、python run_cnn.py test を実行します。

  1. CNN モデルを構成しています...
  2. テストデータを読み込んでいます...
  3. テスト中...
  4. テスト損失: 0.14、テスト精度: 96.04%
  5. 精度、再現率 F1 スコア...
  6. 精度再現率 F1スコア サポート
  7.  
  8. スポーツ 0.99 0.99 0.99 1000
  9. 金融 0.96 0.99 0.97 1000
  10. 不動産 1.00 1.00 1.00 1000
  11. ホーム 0.95 0.91 0.93 1000
  12. 教育 0.95 0.89 0.92 1000
  13. テクノロジー 0.94 0.97 0.95 1000
  14. ファッション 0.95 0.97 0.96 1000
  15. 時事 0.94 0.94 0.94 1000
  16. ゲーム 0.97 0.96 0.97 1000
  17. エンターテイメント 0.95 0.98 0.97 1000
  18.  
  19. 平均/ 合計 0.96 0.96 0.96 10000
  20.  
  21. 混同マトリックス...
  22. [[991 0 0 0 2 1 0 4 1 1]
  23. [ 0 992 0 0 2 1 0 5 0 0 ]
  24. [ 0 1 996 0 1 1 0 0 0 1 ]
  25. [ 0 14 0 912 7 15 9 29 3 11 ]
  26. [ 2 9 0 12 892 22 18 21 10 14 ]
  27. [ 0 0 0 10 1 968 4 3 12 2 ]
  28. [ 1 0 0 9 4 4 971 0 2 9]
  29. [ 1 16 0 4 18 12 1 941 1 6 ]
  30. [ 2 4 1 5 4 5 10 1 962 6 ]
  31. [ 1 0 1 6 4 3 5 0 1 979]]
  32. 使用時間: 0:00:05

テスト セットの精度は 96.04% に達し、各カテゴリの精度、再現率、f1 スコアは 0.9 を超えました。

混同行列からも分類効果が非常に優れていることがわかります。

RNN リカレント ニューラル ネットワーク

構成項目

RNN の設定可能なパラメータは、rnn_model.py に以下のように示されています。

  1. クラスTRNNConfig(オブジェクト):
  2. "" "RNN 構成パラメータ" ""  
  3.  
  4. # モデルパラメータ
  5. embedding_dim = 64 # 単語ベクトルの次元
  6. seq_length = 600 # シーケンスの長さ
  7. num_classes = 10 # カテゴリの数
  8. vocab_size = 5000 # 小さな語彙
  9.  
  10. num_layers = 2 # 隠し層の数
  11. hidden_​​dim = 128 # 隠れ層ニューロン
  12. rnn = 'gru' # lstm または gru
  13.  
  14. dropout_keep_prob = 0.8 # ドロップアウト保持率
  15. learning_rate = 1e-3 # 学習率
  16.  
  17. batch_size = 128 # バッチあたりのトレーニングサイズ
  18. num_epochs = 10 # 反復回数の合計
  19.  
  20. print_per_batch = 100 # 数ラウンドごとに結果を出力します
  21. save_per_batch = 10 # テンソルボードに保存するラウンド数

RNN モデル

詳細については、rnn_model.py の実装を参照してください。

一般的な構造は次のとおりです。

トレーニングと検証

この部分のコードは run_cnn.py と非常に似ていますが、モデルといくつかのディレクトリのみを少し変更する必要があります。

トレーニングを開始するには、python run_rnn.py train を実行します。

以前にトレーニングしたことがある場合は、TensorBoard で複数のトレーニング結果が重複しないように、tensorboard/textrnn を削除してください。

  1. RNN モデルを構成しています...
  2. TensorBoardSaver を設定しています...
  3. トレーニングおよび検証データを読み込んでいます...
  4. 使用時間: 0:00:14
  5. トレーニング評価...
  6. エポック: 1
  7. 反復: 0、列車損失: 2.3、列車精度: 8.59%、値損失: 2.3、値精度: 11.96%、時間: 0:00:08 *
  8. 反復: 100、列車損失: 0.95、列車精度: 64.06%、値損失: 1.3、値精度: 53.06%、時間: 0:01:15 *
  9. 反復: 200、列車損失: 0.61、列車精度: 79.69%、値損失: 0.94、値精度: 69.88%、時間: 0:02:22 *
  10. 反復: 300、列車損失: 0.49、列車精度: 85.16%、値損失: 0.63、値精度: 81.44%、時間: 0:03:29 *
  11. エポック: 2
  12. 反復: 400、列車損失: 0.23、列車精度: 92.97%、値損失: 0.6、値精度: 82.86%、時間: 0:04:36 *
  13. 反復: 500、列車損失: 0.27、列車精度: 92.97%、値損失: 0.47、値精度: 86.72%、時間: 0:05:43 *
  14. 反復: 600、列車損失: 0.13、列車精度: 98.44%、値損失: 0.43、値精度: 87.46%、時間: 0:06:50 *
  15. 反復: 700、列車損失: 0.24、列車精度: 91.41%、値損失: 0.46、値精度: 87.12%、時間: 0:07:57
  16. エポック: 3
  17. 反復: 800、列車損失: 0.11、列車精度: 96.09%、値損失: 0.49、値精度: 87.02%、時間: 0:09:03
  18. 反復: 900、列車損失: 0.15、列車精度: 96.09%、値損失: 0.55、値精度: 85.86%、時間: 0:10:10
  19. 反復: 1000、列車損失: 0.17、列車精度: 96.09%、値損失: 0.43、値精度: 89.44%、時間: 0:11:18 *
  20. 反復: 1100、列車損失: 0.25、列車精度: 93.75%、値損失: 0.42、値精度: 88.98%、時間: 0:12:25
  21. エポック: 4
  22. 反復: 1200、列車損失: 0.14、列車精度: 96.09%、値損失: 0.39、値精度: 89.82%、時間: 0:13:32 *
  23. 反復: 1300、列車損失: 0.2、列車精度: 96.09%、値損失: 0.43、値精度: 88.68%、時間: 0:14:38
  24. 反復: 1400、列車損失: 0.012、列車精度: 100.00%、値損失: 0.37、値精度: 90.58%、時間: 0:15:45 *
  25. 反復: 1500、列車損失: 0.15、列車精度: 96.88%、値損失: 0.39、値精度: 90.58%、時間: 0:16:52
  26. エポック: 5
  27. 反復: 1600、列車損失: 0.075、列車精度: 97.66%、値損失: 0.41、値精度: 89.90%、時間: 0:17:59
  28. 反復: 1700、列車損失: 0.042、列車精度: 98.44%、値損失: 0.41、値精度: 90.08%、時間: 0:19:06
  29. 反復: 1800、列車損失: 0.08、列車精度: 97.66%、値損失: 0.38、値精度: 91.36%、時間: 0:20:13 *
  30. 反復: 1900、列車損失: 0.089、列車精度: 98.44%、値損失: 0.39、値精度: 90.18%、時間: 0:21:20
  31. エポック: 6
  32. 反復: 2000、列車損失: 0.092、列車精度: 96.88%、値損失: 0.36、値精度: 91.42%、時間: 0:22:27 *
  33. 反復: 2100、列車損失: 0.062、列車精度: 98.44%、値損失: 0.39、値精度: 90.56%、時間: 0:23:34
  34. 反復: 2200、列車損失: 0.053、列車精度: 98.44%、値損失: 0.39、値精度: 90.02%、時間: 0:24:41
  35. 反復: 2300、列車損失: 0.12、列車精度: 96.09%、値損失: 0.37、値精度: 90.84%、時間: 0:25:48
  36. エポック: 7
  37. 反復: 2400、列車損失: 0.014、列車精度: 100.00%、値損失: 0.41、値精度: 90.38%、時間: 0:26:55
  38. 反復: 2500、列車損失: 0.14、列車精度: 96.88%、値損失: 0.37、値精度: 91.22%、時間: 0:28:01
  39. 反復: 2600、列車損失: 0.11、列車精度: 96.88%、値損失: 0.43、値精度: 89.76%、時間: 0:29:08
  40. 反復: 2700、列車損失: 0.089、列車精度: 97.66%、値損失: 0.37、値精度: 91.18%、時間: 0:30:15
  41. エポック: 8
  42. 反復: 2800、列車損失: 0.0081、列車精度: 100.00%、値損失: 0.44、値精度: 90.66%、時間: 0:31:22
  43. 反復: 2900、列車損失: 0.017、列車精度: 100.00%、値損失: 0.44、値精度: 89.62%、時間: 0:32:29
  44. 反復: 3000、列車損失: 0.061、列車精度: 96.88%、値損失: 0.43、値精度: 90.04%、時間: 0:33:36
  45. 長時間最適化されず自動停止します...

検証セットでの最高の結果は 91.42% で、8 ラウンドの反復後に停止しました。速度は CNN よりもはるかに遅いです。

精度と誤差は図に示されています。

テスト

python run_rnn.py test を実行して、テスト セットでテストを実行します。

  1. テスト中...
  2. テスト損失: 0.21、テスト精度: 94.22%
  3. 精度、再現率 F1 スコア...
  4. 精度再現率 F1スコア サポート
  5.  
  6. スポーツ 0.99 0.99 0.99 1000
  7. 金融 0.91 0.99 0.95 1000
  8. 不動産 1.00 1.00 1.00 1000
  9. ホーム 0.97 0.73 0.83 1000
  10. 教育 0.91 0.92 0.91 1000
  11. テクノロジー 0.93 0.96 0.94 1000
  12. ファッション 0.89 0.97 0.93 1000
  13. 時事 0.93 0.93 0.93 1000
  14. ゲーム 0.95 0.97 0.96 1000
  15. エンターテイメント 0.97 0.96 0.97 1000
  16.  
  17. 平均/ 合計 0.94 0.94 0.94 10000
  18.  
  19. 混同マトリックス...
  20. [[988 0 0 0 4 0 2 0 5 1]
  21. [ 0 9 9 0 1 1 1 1 0 6 0 0 ]
  22. [ 0 2 996 1 1 0 0 0 0 0 ]
  23. [ 2 71 1 731 51 20 88 28 3 5 ]
  24. [ 1 3 0 7 918 23 4 31 9 4 ]
  25. [ 1 3 0 3 0 964 3 5 21 0 ]
  26. [ 1 0 1 7 1 3 972 0 6 9]
  27. [ 0 16 0 0 22 26 0 931 2 3]
  28. [ 2 3 0 0 2 2 12 0 972 7 ]
  29. [ 0 3 1 1 7 3 11 5 9 960]]
  30. 使用時間: 0:00:33

テスト セットの精度は 94.22% に達し、ホーム カテゴリを除く各カテゴリの精度、再現率、f1 スコアは 0.9 を超えました。

混同行列から、分類効果が非常に優れていることがわかります。

2 つのモデルを比較すると、家庭用家具の分類のパフォーマンスを除いて、他のカテゴリでの RNN のパフォーマンスは CNN とそれほど変わらないことがわかります。

パラメータをさらに調整することで、より良い結果を得ることもできます。

<<:  ディープラーニングを使用してNBAの試合結果を予測する

>>:  ウェブデザインに人工知能を活用する10の方法

ブログ    
ブログ    
ブログ    
ブログ    

推薦する

DGX-2 および SXM3 カードが GTC 2018 で発表されました

最近、GTC 2018 で、Vicor チームは NVIDIA DGX-2 の発表を目撃しました。 ...

欧州の新しいAI法は倫理監査を強化する

EU があらゆる業界での AI および機械学習技術の使用を効果的に規制する AI 法の施行に向けて...

機械学習向けのテキスト注釈ツールとサービスのトップ 10: どれを選びますか?

[[347945]] [51CTO.com クイック翻訳] 現在、検索エンジンや感情分析から仮想ア...

Tensorflowを使用して畳み込みニューラルネットワークを構築する

1. 畳み込みニューラルネットワーク畳み込みニューラル ネットワーク (CNN) は、人工ニューロン...

銀行におけるクラウドコンピューティングと人工知能の利点

クラウド コンピューティング プロバイダーは、データを分析し、スキルの低いユーザー (または予算が限...

...

...

AIが悪になる危険性を排除する方法

AI テクノロジーを悪とみなす個人、政府、企業が増えるにつれ、AI が善良な存在であることを保証する...

Google、AIアシスタント「Gemini」の修正を加速、拒否率を半減

2月18日、Googleは人工知能プロジェクトを大幅にアップデートし、BardをGeminiに改名し...

...

世界的なサプライチェーンの混乱はロボットの導入をどのように促進するのでしょうか?

企業がより強力な管理を維持し、コストのかかる混乱を回避しようとする中、製造拠点の国内移転とサプライチ...

Microsoft XiaoIceが第7世代にアップグレードされ、ユーザーの権限を強化するアバターフレームワークがリリースされました

[51CTO.comよりオリジナル記事] 8月15日、マイクロソフト(アジア)インターネットエンジニ...

大規模言語モデル評価における信頼性の低いデータに注意: Flan-T5 に基づくプロンプト選択のケーススタディ

翻訳者|朱 仙中レビュー | Chonglou導入信頼性の高いモデル評価はMLOP と LLMop ...

世界図書デー: スマートテクノロジーがいかにして優れた読書環境を作り出すか

4月23日は第25回「世界本の日」です!今日は本を読みましたか?ゴーリキーはかつてこう言った。「本は...

...