次の方法で共有


TensorFlowEstimator クラス

定義

これは TensorFlowTransformer 、次の 2 つのシナリオで使用されます。

  1. 事前トレーニング済みの TensorFlow モデルを使用したスコア付け: このモードでは、変換は事前トレーニング済みの Tensorflow モデルから非表示レイヤーの値を抽出し、出力を ML.Net パイプラインの機能として使用します。
  2. TensorFlow モデルの再トレーニング: このモードでは、変換は、パイプラインを介して渡されたユーザー データを使用して TensorFlow モデル ML.Net 再トレーニングします。 モデルがトレーニングされると、その出力をスコアリングの特徴として使用できます。
public sealed class TensorFlowEstimator : Microsoft.ML.IEstimator<Microsoft.ML.Transforms.TensorFlowTransformer>
type TensorFlowEstimator = class
    interface IEstimator<TensorFlowTransformer>
Public NotInheritable Class TensorFlowEstimator
Implements IEstimator(Of TensorFlowTransformer)
継承
TensorFlowEstimator
実装

注釈

TensorFlowTransform は、事前トレーニング済みの Tensorflow モデルを使用して、指定された出力を抽出します。 必要に応じて、ユーザー データの TensorFlow モデルをさらに再トレーニングして、ユーザー データのモデル パラメーターを調整できます ("Transfer Learning" とも呼ばれます)。

スコアリングの場合、変換は、事前トレーニング済みの Tensorflow モデル、入力ノードの名前、および抽出する値を持つ出力ノードの名前を入力として受け取ります。 再トレーニングの場合、変換には、TensorFlow グラフ内の最適化操作の名前、グラフ内の学習率操作の名前とその値、計算損失とパフォーマンス メトリックに対するグラフ内の操作の名前などのトレーニング関連パラメーターも必要です。

この変換を行うには、 Microsoft.ML.TensorFlow nuget をインストールする必要があります。 TensorFlowTransform には、入力、出力、データの処理、再トレーニングに関する次の前提条件があります。

  1. 入力モデルの場合、現在、TensorFlowTransform では 、固定モデル 形式と SavedModel 形式の両方がサポートされています。 ただし、モデルの再トレーニングは SavedModel 形式でのみ可能です。 現在、チェックポイント 形式はスコアリングと再トレーニングのどちらもサポートされていません。これは、読み込みに対する TensorFlow C-API のサポートがないためです。
  2. 変換では、一度に 1 つの例のみをスコアリングできます。 ただし、再トレーニングはバッチで実行できます。
  3. TensorFlow C-API を使用したモデル内でのネットワーク/グラフ操作のサポートがないため、高度な転送学習/微調整シナリオ (ネットワークへのレイヤーの追加、入力の形状の変更、再トレーニングプロセス中に更新する必要のないレイヤーの凍結など) は現在不可能です。
  4. 入力列の名前は、TensorFlow モデルの入力の名前と一致する必要があります。
  5. 各出力列の名前は、TensorFlow グラフ内のいずれかの操作と一致する必要があります。
  6. 現在、double、float、long、int、short、sbyte、ulong、uint、ushort、byte、bool は、入力/出力に許容されるデータ型です。
  7. 成功すると、変換によって、指定された各出力列に IDataView 対応する新しい列が導入されます。

TensorFlow モデルの入力と出力は、ツールまたはsummarize_graph ツールをGetModelSchema()使用して取得できます。

メソッド

Fit(IDataView)

をトレーニングして返します TensorFlowTransformer

GetOutputSchema(SchemaShape)

SchemaShapeトランスフォーマーによって生成されるスキーマの値を返します。 パイプラインでのスキーマの伝達と検証に使用されます。

拡張メソッド

AppendCacheCheckpoint<TTrans>(IEstimator<TTrans>, IHostEnvironment)

推定チェーンに "キャッシュ チェックポイント" を追加します。 これにより、ダウンストリーム推定器がキャッシュされたデータに対してトレーニングされるようになります。 複数のデータを受け取るトレーナーの前にキャッシュ チェックポイントを設定すると便利です。

WithOnFitDelegate<TTransformer>(IEstimator<TTransformer>, Action<TTransformer>)

エスティメーターを指定すると、デリゲートが呼 Fit(IDataView) び出されると呼び出されるラップ オブジェクトを返します。 多くの場合、エスティメーターが適合した内容に関する情報を返すことが重要です。そのため Fit(IDataView) 、メソッドは一般的 ITransformerなオブジェクトではなく、具体的に型指定されたオブジェクトを返します。 ただし、同時に、 IEstimator<TTransformer> 多くのオブジェクトを含むパイプラインに形成されることが多いため、トランスフォーマーを取得する推定器がこのチェーンのどこかに埋もれている場所を介して EstimatorChain<TLastTransformer> 、推定器のチェーンを構築する必要がある場合があります。 このシナリオでは、このメソッドを使用して、fit が呼び出されると呼び出されるデリゲートをアタッチできます。

適用対象