Skip to content

Commit

Permalink
[DML EP] Fix Clip clamping
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola committed Sep 27, 2024
1 parent 6e3163f commit 4677d4b
Showing 1 changed file with 13 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,19 @@ class DmlOperatorElementwiseClip11 : public DmlOperator
// logic for some corner test case
// Same applies to min and max value.
opDesc.MinMaxDataType = this->m_inputTensorDescs[0].GetDmlDataType();
CastToClampedScalarUnion<double>(opDesc.MinMaxDataType, -DBL_MAX, /*out*/&opDesc.Min);
CastToClampedScalarUnion<double>(opDesc.MinMaxDataType, DBL_MAX, /*out*/&opDesc.Max);

if (opDesc.MinMaxDataType == DML_TENSOR_DATA_TYPE_FLOAT16 || opDesc.MinMaxDataType == DML_TENSOR_DATA_TYPE_FLOAT32 || opDesc.MinMaxDataType == DML_TENSOR_DATA_TYPE_FLOAT64)
{
CastToClampedScalarUnion<double>(opDesc.MinMaxDataType, -DBL_MAX, /*out*/&opDesc.Min);
CastToClampedScalarUnion<double>(opDesc.MinMaxDataType, DBL_MAX, /*out*/&opDesc.Max);
}
else
{
// It's not safe to use DBL_MAX for non-float datatypes because not all integer can be represented in the range.
// For example, static_cast<int64_t>(static_cast<double>(INT64_MAX)) will yield a negative number.
CastToClampedScalarUnion<int64_t>(opDesc.MinMaxDataType, -INT64_MAX, /*out*/&opDesc.Min);
CastToClampedScalarUnion<uint64_t>(opDesc.MinMaxDataType, UINT64_MAX, /*out*/&opDesc.Max);
}

if (kernelInfo.IsInputValid(1))
{
Expand Down

0 comments on commit 4677d4b

Please sign in to comment.