From 4677d4bdcabe520d39123bb888bcc9df1dce9b54 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Fri, 27 Sep 2024 12:17:40 -0700 Subject: [PATCH] [DML EP] Fix Clip clamping --- .../src/Operators/DmlOperatorElementWise.cpp | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index 412207fd3cbd2..d4d7ee1311874 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -451,8 +451,19 @@ class DmlOperatorElementwiseClip11 : public DmlOperator // logic for some corner test case // Same applies to min and max value. opDesc.MinMaxDataType = this->m_inputTensorDescs[0].GetDmlDataType(); - CastToClampedScalarUnion(opDesc.MinMaxDataType, -DBL_MAX, /*out*/&opDesc.Min); - CastToClampedScalarUnion(opDesc.MinMaxDataType, DBL_MAX, /*out*/&opDesc.Max); + + if (opDesc.MinMaxDataType == DML_TENSOR_DATA_TYPE_FLOAT16 || opDesc.MinMaxDataType == DML_TENSOR_DATA_TYPE_FLOAT32 || opDesc.MinMaxDataType == DML_TENSOR_DATA_TYPE_FLOAT64) + { + CastToClampedScalarUnion(opDesc.MinMaxDataType, -DBL_MAX, /*out*/&opDesc.Min); + CastToClampedScalarUnion(opDesc.MinMaxDataType, DBL_MAX, /*out*/&opDesc.Max); + } + else + { + // It's not safe to use DBL_MAX for non-float datatypes because not all integer can be represented in the range. + // For example, static_cast(static_cast(INT64_MAX)) will yield a negative number. + CastToClampedScalarUnion(opDesc.MinMaxDataType, -INT64_MAX, /*out*/&opDesc.Min); + CastToClampedScalarUnion(opDesc.MinMaxDataType, UINT64_MAX, /*out*/&opDesc.Max); + } if (kernelInfo.IsInputValid(1)) {