XGBoost機械学習モデルの意思決定プロセス

XGBoost機械学習モデルの意思決定プロセス

XGBoost アルゴリズムは、Kaggle やその他のデータ サイエンス コンテストで優れた結果を達成することが多いため、人気があります。この記事では、特定のデータセットを使用して、XGBoost 機械学習モデルの予測プロセスを分析します。視覚化を使用して結果を表示することで、モデルの予測プロセスをよりよく理解できます。

機械学習の産業応用が発展し続けるにつれて、機械学習モデルの動作原理を理解し、説明し、定義することがますます明らかな傾向になってきているようです。非ディープラーニングタイプの機械学習分類問題の場合、XGBoost が最も人気のあるライブラリです。 XGBoost は大規模なデータセットに適切に拡張でき、複数の言語をサポートしているため、特に商用環境で役立ちます。たとえば、XGBoost を使用すると、Python でモデルをトレーニングし、実稼働 Java 環境にデプロイすることが簡単になります。

XGBoost は非常に高い精度を達成できますが、そのような高い精度を達成するために XGBoost が決定を下すプロセスはまだ十分に透明ではありません。この透明性の欠如は、結果をクライアントに直接渡すときに重大な欠点となる可能性があります。物事が起こる理由を理解することは役に立ちます。データを理解するために機械学習を適用する企業は、モデルから得られる予測も理解する必要があります。これはますます重要になってきています。たとえば、信用調査機関が機械学習モデルを使用してユーザーの信用力を予測する際に、その予測がどのように行われたかを説明できないことを望む人はいないでしょう。

別の例として、機械学習モデルが結婚記録と出生記録が同じ人物に関連していると言っている場合(記録の関連付けタスク)、記録の日付から結婚した 2 人の当事者がそれぞれ非常に高齢の人物と非常に若い人物であったことが示唆される場合、モデルがなぜそれらを関連付けたのか疑問に思うかもしれません。このような例では、モデルがなぜその予測を行ったのかを理解することが非常に重要です。その結果、モデルは名前と場所の一意性を考慮して正しい予測を行う可能性があります。しかし、モデルの機能がプロフィール上の年齢差を適切に考慮していない可能性もあります。この場合、モデルの予測を理解することで、パフォーマンスを向上させる方法を見つけるのに役立ちます。

この記事では、XGBoost の予測プロセスをよりよく理解するためのいくつかのテクニックを紹介します。これにより、モデルの意思決定プロセスを理解しながら、勾配ブースティングのパワーを活用することができます。

これらの手法を説明するために、Titanic データセットを使用します。このデータセットには、タイタニック号に乗っていたすべての乗客に関する情報(乗客が生存したかどうかを含む)が含まれています。私たちの目標は、乗客が生き残るかどうかを予測し、その予測を行うプロセスを理解することです。このデータを使用しても、モデルの決定を理解することの重要性がわかります。最近の難破船の乗客のデータセットがあると想像してください。このような予測モデルを構築する目的は、実際には結果自体を予測することではなく、予測プロセスを理解することで、事故における生存者の数を最大化する方法を学ぶことができます。

 pandaspd としてインポートする
xgboost からXGBClassifier をインポートします
sklearn.model_selection からtrain_test_split インポートします
sklearn.metrics からaccuracy_score ​をインポートします
インポート演算子
matplotlib.pyplot plt としてインポートします
Seaborn をSNS としてインポートする
lime をインポートします。lime_tabular
sklearn.pipeline からPipeline インポートする
sklearn から. preprocessing import Imputer
numpyをnp としてインポートする
sklearn . grid_search からGridSearchCV をインポートします
% matplotlib インライン

最初にやるべきことは、Kaggle で見つけられるデータを確認することです。データ セットを取得したら、データの簡単なクリーンアップを実行します。今すぐ:

  • 名前と乗客IDを明確に記入
  • カテゴリ変数をダミー変数に変換する
  • 中央値を使用したデータの入力と削除

これらのクリーニング手法は非常に単純であり、この記事の目的はデータのクリーニングについて説明することではなく、XGBoost を説明することであるため、これらはモデルをトレーニングするための迅速かつ合理的なクリーニングです。

 データ= pd.read_csv ( "./data/titantic/train.csv" )
y = データ.Survived
X = data .drop ( [ "生存" , "名前" , "乗客ID" ] , 1 )
X = pd.get_dummies ( X ) 関数は

ここで、データセットをトレーニング セットとテスト セットに分割します。

 X_trainX_testy_trainy_test = train_test_split (
Xyテストサイズ= 0.33ランダム状態= 42 )

そして、少量のハイパーパラメータ テストを含むトレーニング パイプラインを構築します。

 パイプライン= パイプライン(
[( 'imputer' , Imputer ( 戦略= 'median' )),
( 'モデル'XGBClassifier ())])
パラメータ= dict ( model__max_depth = [ 3 , 5 , 7 ],
モデル学習率= [ .01 , .1 ],
モデルn推定値= [ 100 , 500 ])
cv = GridSearchCV ( パイプラインparam_grid = パラメータ)
cv.fit ( X_train , y_train ) の

次にテスト結果を表示します。簡単にするために、Kaggle と同じ指標である精度を使用します。

 test_predictions = cv.predict ( X_test )
print ( "テスト精度: {}" . format (
精度スコア( y_testテスト予測)))

テスト精度: 0.8101694915254237

この時点で、私たちはかなりの精度を達成し、Kaggle の約 9,000 人の競合相手の中で上位 500 位にランクインしました。したがって、さらに改善する余地はありますが、これは読者の課題として残しておきます。

モデルが何を学習したかを理解するための議論を続けます。一般的なアプローチは、XGBoost によって提供される特徴量の重要度を使用することです。特徴の重要度が高いほど、その特徴がモデルの予測の改善に大きく貢献することを示します。次に、重要度パラメータを使用して特徴をランク付けし、相対的な重要度を比較します。

 fi = list ( zip ( X . columns , cv . best_estimator_ . named_steps [ 'model' ] . feature_importances_ ))
fi .sort ( キー= 演算子.itemgetter ( 1 )、 逆順= True )
トップ10 = fi [: 10 ]
x = [ x [ 0 ] ( xtop_10 ある場合) ]
y = [ x [ 1 ] ( xtop_10 含まれる場合
top_10_chart = sns.barplot ( x , y )
plt . setp ( top_10_chart . get_xticklabels (), 回転= 90 )

上の図からわかるように、チケットの価格と年齢は非常に重要な要素です。生存/死亡とチケット価格の分布をさらに調べることができます。

 sns . barplot ( y_train , X_train [ '運賃' ] )

生き残った人の平均チケット価格は、亡くなった人の平均チケット価格よりもはるかに高かったことがはっきりとわかるので、チケット価格を重要な特徴として考慮するのは合理的かもしれません。

機能の重要度は、一般的な機能の重要度を理解するための良い方法かもしれません。このような特殊なケースが発生した場合、つまり、モデルが高額なチケットを購入した乗客は生存できないと予測した場合、高額なチケットが必ずしも生存につながるわけではないと結論付けることができます。次に、モデルが乗客が生存できないと結論付ける原因となる可能性のある他の特徴を分析します。

この種の個別レベルの分析は、実稼働の機械学習システムに非常に役立ちます。別の例として、モデルを使用して誰かがローンを取得できるかどうかを予測する例を考えてみましょう。クレジット スコアはモデルの非常に重要な機能になることはわかっていますが、クレジット スコアが高い顧客はモデルによって拒否されます。これを顧客にどのように説明すればよいでしょうか。これを管理者にどう説明すればいいでしょうか?

幸いなことに、ワシントン大学から、任意の分類器の予測プロセスを説明する最近の研究が出ています。彼らの方法は LIME と呼ばれ、GitHub (https://github.com/marcotcr/lime) でオープンソース化されています。この記事ではこれについて議論するつもりはありません。論文(https://arxiv.org/pdf/1602.04938.pdf)を参照してください。

次に、LIME をモデルに適用してみます。基本的に、まずトレーニング データを処理するインタープリターを定義する必要があります (インタープリターに渡す推定トレーニング データセットがトレーニングに使用するデータセットであることを確認する必要があります)。

 X_train_imputed = cv.best_estimator_.named_steps [ 'imputer' ] . transform ( X_train )
説明者= lime.lime_tabular.LimeTabularExplainer ( X_train_imputed ,
feature_names = X_train.columns.tolist ( ) ,
class_names = [ "生存していない" , "生存している" ],
discretize_continuous = True )

次に、特徴の配列を受け取り、各クラスの確率を含む配列を返す関数を定義する必要があります。

 モデル= cv.best_estimator_.named_steps [ ' モデル' ]
xgb_prediction ( X_array_in ) を定義します
len ( X_array_in.shape ) < 2の場合:
X_array_in = np . expand_dims ( X_array_in , 0 )
モデル返す。predict_proba( X_array_in )

最後に、インタープリターが関数を使用して特徴とラベルの数を出力する例を渡します。

 X_test_imputed = cv.best_estimator_.named_steps [ 'imputer' ] . transform ( X_test )
exp = 説明者. explain_instance (
X_test_imputed [ 1 ]、
xgb_予測
num_features = 5
トップラベル= 1 )
exp.show_in_notebook ( show_table = True ,
show_all = False )

ここでは、生存不可能な可能性が 76% の例を示します。また、どの機能がどのクラスに最も貢献しているか、またそれがどの程度重要であるかを確認したいと考えています。たとえば、性別が女性の場合、生存の可能性が高くなります。棒グラフを見てみましょう:

 sns . barplot ( X_train [ 'Sex_female' ] , y_train )

だからこれは理にかなっているように思えます。あなたが女性の場合、これによりトレーニング データで生き残る可能性が大幅に高まります。では、なぜ予測は「生き残れない」なのでしょうか? Pclass = 2.0 では生存率が大幅に低下するようです。見てみましょう:

 sns . barplot ( X_train [ 'Pclass' ], y_train )

Pclass が 2 の場合の生存率はまだ比較的低いようですので、予測結果に対する理解が深まります。 LIME に表示される上位 5 つの特徴を見ると、この人はまだ生き残れそうです。ラベルを見てみましょう。

 y_test . [ 0 ] >>> 1

その人は生き残ったので、私たちのモデルは間違っていました! LIME のおかげで、問題の原因をある程度把握できるようになりました。どうやら Pclass を放棄する必要があるようです。このアプローチは、モデルを改善する方法を見つけるのに役立つと期待されます。

この記事では、読者に XGBoost を理解するためのシンプルで効果的な方法を紹介します。これらの方法が XGBoost を活用し、モデルがより良い推論を行えるようになることを願っています。

<<:  速報です!李菲菲の一番弟子カルパシーが辞任、テスラの自動運転は危機に瀕しているのか?

>>:  機械学習がデータセンター管理をどう変えるか

ブログ    
ブログ    

推薦する

新しいプログラミングパラダイム: Spring Boot と OpenAI の出会い

2023年にはAI技術が話題となり、プログラミングを中心に多くの分野に影響を及ぼします。 Sprin...

AIの開発パターンは「データ」から「知識」へと進化している

半世紀以上前に誕生して以来、人工知能(AI)革命は全世界に大きな影響を与えてきました。特に過去10年...

デジタル変革と人工知能

[[415031]]今日のビジネスにおける変化の最大の原因は、デジタル変革と呼ばれる取り組みです。つ...

注意してください、これらの6つのアルゴリズムには落とし穴があります:中国消費者協会はビッグデータが古い顧客をターゲットにしていると指摘しています

ビッグデータの登場以来、「古い顧客を搾取する」問題はますます深刻になっています。テイクアウトでも旅行...

...

...

上海交通大学卒業生によるソロ作品! 50年間のゼロ進歩アルゴリズム問題が解決された

この記事はAI新メディアQuantum Bit(公開アカウントID:QbitAI)より許可を得て転載...

AI に役立つ 7 つの優れたオープンソース ツール

ビジネスニーズを予測するには、AI を活用し、研究開発を新たなレベルに引き上げる必要があります。この...

...

...

これまでで最も詳細なAIサイバー攻撃の分類ガイド

最近、NIST は、人工知能システムに対するサイバー攻撃に関する、おそらくこれまでで最も詳細な分類ガ...

コンピュータービジョンが小売業の在庫管理をどう変えるか

小売業の経営者は、長期的な顧客関係の構築を妨げる在庫管理の問題に直面することがよくあります。小売在庫...

イメージフリーの認識がさらに一歩前進! ScalableMap: 大規模高精度地図に向けた新しいソリューション!

この記事は、Heart of Autonomous Driving の公開アカウントから許可を得て転...

インターネットで話題! 23歳の中国人医師が22歳の歴史的弱点を治す、ネットユーザー「この話はいいね」

最近、別の若い中国人男性が、22年間存在していたバグを修正したことでインターネット上で人気を博した。...

...