DML_DIAGONAL_MATRIX_OPERATOR_DESC結構 (directml.h)
在主要對角線上產生具有 (或其他明確值) 的類似身分識別矩陣,並在其他地方產生零。 對角線值 (可能會透過 Offset) 移OutputTensor[i, i + Offset]
= 位) 移位,這表示 Offset 大於零的自變數會將所有值移位到右邊,而小於零則會將它們向左移位。 此產生器運算子適用於模型,以避免儲存大型常數張量。 最後兩個之前的任何前置維度都會視為批次計數,這表示張量會被視為 2D 矩陣的堆疊。
這個運算子會執行下列虛擬程序代碼。
for each coordinate in OutputTensor
OutputTensor[coordinate] = if (coordinate.y + Offset == coordinate.x) then Value else 0
endfor
語法
struct DML_DIAGONAL_MATRIX_OPERATOR_DESC {
const DML_TENSOR_DESC *OutputTensor;
INT Offset;
FLOAT Value;
};
成員
OutputTensor
類型: const DML_TENSOR_DESC*
要寫入結果的張量。 維度為 { Batch1, Batch2, OutputHeight, OutputWidth }
。 高度和寬度不需要是正方形。
Offset
類型: INT
將 Value 對角線移位移的位移,正位移會將寫入的值向右/上移, (將輸出檢視為矩陣,並將左上方為 0,0) ,負位移至左/下。
Value
類型: FLOAT
要沿著 2D 對角線填滿的值。 標準值為 1.0。 請注意,如果張量的 DataType 不是 DML_TENSOR_DATA_TYPE_FLOAT16 或 DML_TENSOR_DATA_TYPE_FLOAT32,則值可能會截斷 (例如,10.6 會變成 10) 。
範例
預設身分識別矩陣:
Offset: 0
Value: 1.0
OutputTensor: (Sizes:{1,1,3,3}, DataType:FLOAT32)
[[[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]]]
往右/ 向上移位:
Offset: 1
Value: 1.0
OutputTensor: (Sizes:{1,1,3,3}, DataType:FLOAT32)
[[[[ 0, 1, 0],
[ 0, 0, 1],
[ 0, 0, 0]]]]
往左/ 向下移位:
Offset: -1
Value: 1.0
OutputTensor: (Sizes:{1,1,3,2}, DataType:FLOAT32)
[[[[0, 0],
[1, 0],
[0, 1]]]]
將到目前為止的對角線移位,讓所有專案變成零:
Offset: -3
Value: 1.0
OutputTensor: (Sizes:{1,1,3,2}, DataType:FLOAT32)
[[[[0, 0],
[0, 0],
[0, 0]]]]
備註
可用性
這個運算子是在 中 DML_FEATURE_LEVEL_2_0
引進。
Tensor 支援
DML_FEATURE_LEVEL_5_1和更新版本
張 | 種類 | 支援的維度計數 | 支援的資料類型 |
---|---|---|---|
OutputTensor | 輸出 | 2 到 4 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_4_0和更新版本
張 | 種類 | 支援的維度計數 | 支援的資料類型 |
---|---|---|---|
OutputTensor | 輸出 | 2 到 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_2_1和更新版本
張 | 種類 | 支援的維度計數 | 支援的資料類型 |
---|---|---|---|
OutputTensor | 輸出 | 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_2_0和更新版本
張 | 種類 | 支援的維度計數 | 支援的資料類型 |
---|---|---|---|
OutputTensor | 輸出 | 4 | FLOAT32,FLOAT16 |
規格需求
需求 | 值 |
---|---|
最低支援的用戶端 | Windows 10 版本 2004 (10.0;組建 19041) |
最低支援的伺服器 | Windows Server 版本 2004 (10.0;組建 19041) |
標頭 | directml.h |