Partager via


structure DML_ARGMAX_OPERATOR_DESC (directml.h)

Génère les index des éléments à valeur maximale dans une ou plusieurs dimensions du tenseur d’entrée.

Chaque élément de sortie est le résultat de l’application d’une réduction argmax sur un sous-ensemble du tenseur d’entrée. La fonction argmax génère l’index de l’élément à valeur maximale dans un ensemble d’éléments d’entrée. Les éléments d’entrée impliqués dans chaque réduction sont déterminés par les axes d’entrée fournis. De même, chaque index de sortie correspond aux axes d’entrée fournis. Si tous les axes d’entrée sont spécifiés, l’opérateur applique une réduction argmax unique et produit un seul élément de sortie.

Syntaxe

struct DML_ARGMAX_OPERATOR_DESC {
  const DML_TENSOR_DESC *InputTensor;
  const DML_TENSOR_DESC *OutputTensor;
  UINT                  AxisCount;
  const UINT            *Axes;
  DML_AXIS_DIRECTION    AxisDirection;
};

Membres

InputTensor

Type : const DML_TENSOR_DESC*

Tenseur à partir duquel lire.

OutputTensor

Type : const DML_TENSOR_DESC*

Tenseur dans lequel écrire les résultats. Chaque élément de sortie est le résultat d’une réduction argmax sur un sous-ensemble d’éléments du inputTensor.

  • DimensionCount doit correspondre à InputTensor.DimensionCount (le rang du tenseur d’entrée est conservé).
  • Les tailles doivent correspondre à InputTensor.Sizes, à l’exception des dimensions incluses dans les axes réduits, qui doivent être de taille 1.

AxisCount

Type : UINT

Nombre d’axes à réduire. Ce champ détermine la taille du tableau Axes .

Axes

Type : _Field_size_(AxisCount) const UINT*

Axes sur lesquels réduire. Les valeurs doivent se trouver dans la plage [0, InputTensor.DimensionCount - 1].

AxisDirection

Type : DML_AXIS_DIRECTION

Détermine l’index à sélectionner lorsque plusieurs éléments d’entrée ont la même valeur.

  • DML_AXIS_DIRECTION_INCREASING retourne l’index du premier élément à valeur maximale (par exemple, argmax({3,2,1,2,3}) = 0)
  • DML_AXIS_DIRECTION_DECREASING retourne l’index du dernier élément à valeur maximale (par exemple, argmax({3,2,1,2,3}) = 4)

Exemples

Les exemples de cette section utilisent tous ce même tenseur d’entrée à deux dimensions.

InputTensor: (Sizes:{3, 3}, DataType:FLOAT32)
[[1, 2, 3],
 [3, 0, 4],
 [2, 5, 2]]

Exemple 1. Application d’argmax aux colonnes

AxisCount: 1
Axes: {0}
AxisDirection: DML_AXIS_DIRECTION_INCREASING
OutputTensor: (Sizes:{1, 3}, DataType:UINT32)
[[1,  // argmax({1, 3, 2})
  2,  // argmax({2, 0, 5})
  1]] // argmax({3, 4, 2})

Exemple 2. Application d’argmax à des lignes

AxisCount: 1
Axes: {1}
AxisDirection: DML_AXIS_DIRECTION_INCREASING
OutputTensor: (Sizes:{3, 1}, DataType:UINT32)
[[2], // argmax({1, 2, 3})
 [2], // argmax({3, 0, 4})
 [1]] // argmax({2, 5, 2})

Exemple 3. Application d’argmax à tous les axes (le tenseur entier)

AxisCount: 2
Axes: {0, 1}
AxisDirection: DML_AXIS_DIRECTION_INCREASING
OutputTensor: (Sizes:{1, 1}, DataType:UINT32)
[[7]]  // argmax({1, 2, 3, 3, 0, 4, 2, 5, 2})

Notes

Les tailles de tenseurs de sortie doivent être identiques aux tailles de tenseurs d’entrée, à l’exception des axes réduits, qui doivent être 1.

Quand AxisDirection est DML_AXIS_DIRECTION_INCREASING, cette API équivaut à DML_REDUCE_OPERATOR_DESC avec DML_REDUCE_FUNCTION_ARGMAX.

Un sous-ensemble de cette fonctionnalité est exposé via l’opérateur DML_REDUCE_OPERATOR_DESC et est pris en charge sur les niveaux de fonctionnalités DirectML antérieurs.

Disponibilité

Cet opérateur a été introduit dans DML_FEATURE_LEVEL_3_0.

Contraintes tensoriels

InputTensor et OutputTensor doivent avoir le même DimensionCount.

Prise en charge de Tensor

DML_FEATURE_LEVEL_4_1 et versions ultérieures

Tenseur Type Nombre de dimensions pris en charge Types de données pris en charge
InputTensor Entrée 1 à 8 FLOAT32, FLOAT16, INT64, INT32, INT16, INT8, UINT64, UINT32, UINT16, UINT8
OutputTensor Output 1 à 8 INT64, INT32, UINT64, UINT32

DML_FEATURE_LEVEL_3_0 et versions ultérieures

Tenseur Type Nombre de dimensions pris en charge Types de données pris en charge
InputTensor Entrée 1 à 8 FLOAT32, FLOAT16, INT32, INT16, INT8, UINT32, UINT16, UINT8
OutputTensor Output 1 à 8 INT64, INT32, UINT64, UINT32

Spécifications

   
Client minimal pris en charge Windows 10 Build 20348
Serveur minimal pris en charge Windows 10 Build 20348
En-tête directml.h