From 9cdd9ef0165085521f6819f3ec6985e939b6e2d1 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Tue, 10 Oct 2023 20:11:08 -0700 Subject: [PATCH] [DML EP] Enable more MHA masks --- .../DmlOperatorMultiHeadAttention.cpp | 30 ++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp index 9c1a7baeaa8df..03500d0ee86a9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp @@ -205,12 +205,34 @@ class DmlOperatorMultiHeadAttention : public DmlOperator else { const auto keyPaddingMaskTensorShape = m_inputTensorDescs[dmlMaskIndex].GetSizes(); - ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape.size() == 2); + size_t maskDimCount = keyPaddingMaskTensorShape.size(); + ML_CHECK_VALID_ARGUMENT(maskDimCount >= 2 || maskDimCount <= 4); ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[0] == batchSize); - ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength); - const uint32_t actualShape[4] = {batchSize, 1, 1, kvSequenceLength}; - const uint32_t desiredShape[4] = {batchSize, numHeads, sequenceLength, kvSequenceLength}; + std::array actualShape{}; + std::array desiredShape{}; + + if (maskDimCount == 2) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength); + actualShape = {batchSize, 1, 1, kvSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, kvSequenceLength}; + } + else if (maskDimCount == 3) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == totalSequenceLength); + actualShape = {batchSize, 1, sequenceLength, totalSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + } + else if (maskDimCount == 4) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == numHeads); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[3] == totalSequenceLength); + actualShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + } m_inputTensorDescs[dmlMaskIndex] = TensorDesc::ConstructBroadcastedTensorDesc( m_inputTensorDescs[dmlMaskIndex].GetMlOperatorDataType(),