共用方式為


DML_DIAGONAL_MATRIX_OPERATOR_DESC結構 (directml.h)

在主要對角線上產生具有 (或其他明確值) 的類似身分識別矩陣,並在其他地方產生零。 對角線值 (可能會透過 Offset) 移OutputTensor[i, i + Offset] = 位) 移位,這表示 Offset 大於零的自變數會將所有值移位到右邊,而小於零則會將它們向左移位。 此產生器運算子適用於模型,以避免儲存大型常數張量。 最後兩個之前的任何前置維度都會視為批次計數,這表示張量會被視為 2D 矩陣的堆疊。

這個運算子會執行下列虛擬程序代碼。

for each coordinate in OutputTensor
    OutputTensor[coordinate] = if (coordinate.y + Offset == coordinate.x) then Value else 0
endfor

語法

struct DML_DIAGONAL_MATRIX_OPERATOR_DESC {
  const DML_TENSOR_DESC *OutputTensor;
  INT                   Offset;
  FLOAT                 Value;
};

成員

OutputTensor

類型: const DML_TENSOR_DESC*

要寫入結果的張量。 維度為 { Batch1, Batch2, OutputHeight, OutputWidth }。 高度和寬度不需要是正方形。

Offset

類型: INT

Value 對角線移位移的位移,正位移會將寫入的值向右/上移, (將輸出檢視為矩陣,並將左上方為 0,0) ,負位移至左/下。

Value

類型: FLOAT

要沿著 2D 對角線填滿的值。 標準值為 1.0。 請注意,如果張量的 DataType 不是 DML_TENSOR_DATA_TYPE_FLOAT16DML_TENSOR_DATA_TYPE_FLOAT32,則值可能會截斷 (例如,10.6 會變成 10) 。

範例

預設身分識別矩陣:

Offset: 0
Value: 1.0
OutputTensor: (Sizes:{1,1,3,3}, DataType:FLOAT32)
    [[[[1, 0, 0],
       [0, 1, 0],
       [0, 0, 1]]]]

往右/ 向上移位:

Offset: 1
Value: 1.0
OutputTensor: (Sizes:{1,1,3,3}, DataType:FLOAT32)
    [[[[ 0, 1, 0],
       [ 0, 0, 1],
       [ 0, 0, 0]]]]

往左/ 向下移位:

Offset: -1
Value: 1.0
OutputTensor: (Sizes:{1,1,3,2}, DataType:FLOAT32)
    [[[[0, 0],
       [1, 0],
       [0, 1]]]]

將到目前為止的對角線移位,讓所有專案變成零:

Offset: -3
Value: 1.0
OutputTensor: (Sizes:{1,1,3,2}, DataType:FLOAT32)
    [[[[0, 0],
       [0, 0],
       [0, 0]]]]

備註

可用性

這個運算子是在 中 DML_FEATURE_LEVEL_2_0引進。

Tensor 支援

DML_FEATURE_LEVEL_5_1和更新版本

種類 支援的維度計數 支援的資料類型
OutputTensor 輸出 2 到 4 FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_4_0和更新版本

種類 支援的維度計數 支援的資料類型
OutputTensor 輸出 2 到 4 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_2_1和更新版本

種類 支援的維度計數 支援的資料類型
OutputTensor 輸出 4 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_2_0和更新版本

種類 支援的維度計數 支援的資料類型
OutputTensor 輸出 4 FLOAT32,FLOAT16

規格需求

需求
最低支援的用戶端 Windows 10 版本 2004 (10.0;組建 19041)
最低支援的伺服器 Windows Server 版本 2004 (10.0;組建 19041)
標頭 directml.h