DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC structure (directml.h)
Calcule les gradients de rétropropagation pour la normalisation par lots. DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC effectue plusieurs calculs, qui sont détaillés dans les descriptions de sortie distinctes.
OutputScaleGradientTensor et OutputBiasGradientTensor sont calculés à l’aide de sommes sur l’ensemble de dimensions pour lesquelles les tailles MeanTensor, ScaleTensor et VarianceTensor sont égales à une.
Syntaxe
struct DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC {
const DML_TENSOR_DESC *InputTensor;
const DML_TENSOR_DESC *InputGradientTensor;
const DML_TENSOR_DESC *MeanTensor;
const DML_TENSOR_DESC *VarianceTensor;
const DML_TENSOR_DESC *ScaleTensor;
const DML_TENSOR_DESC *OutputGradientTensor;
const DML_TENSOR_DESC *OutputScaleGradientTensor;
const DML_TENSOR_DESC *OutputBiasGradientTensor;
FLOAT Epsilon;
};
Membres
InputTensor
Type : const DML_TENSOR_DESC*
Tenseur contenant les données d’entrée. Il s’agit généralement du même tenseur que celui fourni en tant que InputTensor pour DML_BATCH_NORMALIZATION_OPERATOR_DESC dans la passe avant.
InputGradientTensor
Type : const DML_TENSOR_DESC*
Tenseur de gradient entrant. Cela est généralement obtenu à partir de la sortie de la rétropropagation d’une couche précédente.
MeanTensor
Type : const DML_TENSOR_DESC*
Tenseur contenant les données moyennes. Il s’agit généralement du même tenseur que celui fourni comme meanTensor pour DML_BATCH_NORMALIZATION_OPERATOR_DESC dans la passe avant.
VarianceTensor
Type : const DML_TENSOR_DESC*
Tenseur contenant les données de variance. Il s’agit généralement du même tenseur que celui fourni comme VarianceTensor pour DML_OPERATOR_BATCH_NORMALIZATION dans la passe avant.
ScaleTensor
Type : const DML_TENSOR_DESC*
Tenseur contenant les données d’échelle. Il s’agit généralement du même tenseur que celui fourni en tant que ScaleTensor pour DML_BATCH_NORMALIZATION_OPERATOR_DESC dans la passe avant.
OutputGradientTensor
Type : const DML_TENSOR_DESC*
Pour chaque valeur correspondante dans les entrées, OutputGradient = InputGradient * (Scale / sqrt(Variance + Epsilon))
.
OutputScaleGradientTensor
Type : const DML_TENSOR_DESC*
Le calcul suivant est effectué ou chaque valeur correspondante dans les entrées.
OutputScaleGradient = sum(InputGradient * (Input - Mean) / sqrt(Variance + Epsilon))
OutputBiasGradientTensor
Type : const DML_TENSOR_DESC*
Le calcul suivant est effectué ou chaque valeur correspondante dans les entrées.
OutputBiasGradient = sum(InputGradient)
Epsilon
Type : FLOAT
Petite valeur ajoutée à la variance pour éviter zéro.
Remarques
Disponibilité
Cet opérateur a été introduit dans DML_FEATURE_LEVEL_3_1
.
Contraintes tensoriels
- InputGradientTensor, InputTensor, MeanTensor, OutputBiasGradientTensor, OutputGradientTensor, OutputScaleGradientTensor, ScaleTensor et VarianceTensor doivent avoir les mêmes DataType et DimensionCount.
- MeanTensor, OutputBiasGradientTensor, OutputScaleGradientTensor, ScaleTensor et VarianceTensor doivent avoir les mêmes tailles.
- InputGradientTensor, InputTensor et OutputGradientTensor doivent avoir les mêmes tailles.
Prise en charge des tenseurs
Tenseur | Genre | Dimensions | Nombre de dimensions pris en charge | Types de données pris en charge |
---|---|---|---|---|
InputTensor | Entrée | { InputDimensions[] } | 1 à 8 | FLOAT32, FLOAT16 |
InputGradientTensor | Entrée | { InputDimensions[] } | 1 à 8 | FLOAT32, FLOAT16 |
MeanTensor | Entrée | { MeanDimensions[] } | 1 à 8 | FLOAT32, FLOAT16 |
VarianceTensor | Entrée | { MeanDimensions[] } | 1 à 8 | FLOAT32, FLOAT16 |
ScaleTensor | Entrée | { MeanDimensions[] } | 1 à 8 | FLOAT32, FLOAT16 |
OutputGradientTensor | Sortie | { InputDimensions[] } | 1 à 8 | FLOAT32, FLOAT16 |
OutputScaleGradientTensor | Sortie | { MeanDimensions[] } | 1 à 8 | FLOAT32, FLOAT16 |
OutputBiasGradientTensor | Sortie | { MeanDimensions[] } | 1 à 8 | FLOAT32, FLOAT16 |
Configuration requise
Condition requise | Valeur |
---|---|
Client minimal pris en charge | Windows Build 22000 |
Serveur minimal pris en charge | Windows Build 22000 |
En-tête | directml.h |