From 296435264182e09cc37cfd981b012854226ddd2c Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 7 Mar 2024 15:46:11 -0800 Subject: [PATCH] Implement IsNaN-9,13,20 for CUDA along with tests (#19807) ### Description ### Motivation and Context Some models require IsNan CUDA along with training --- docs/OperatorKernels.md | 5 +- .../providers/cpu/cpu_execution_provider.cc | 4 ++ .../core/providers/cpu/tensor/isnan.cc | 19 +++++- .../core/providers/cuda/cu_inc/common.cuh | 59 ++++++++++++++++++- .../providers/cuda/cuda_execution_provider.cc | 7 ++- .../cuda/math/unary_elementwise_ops.cc | 44 ++++++++++++++ .../cuda/math/unary_elementwise_ops.h | 6 ++ .../cuda/math/unary_elementwise_ops_impl.cu | 24 +++++++- .../cuda/math/unary_elementwise_ops_impl.h | 14 +++++ .../core/providers/rocm/cu_inc/common.cuh | 57 ++++++++++++++++++ .../providers/rocm/rocm_execution_provider.cc | 6 ++ .../test/providers/cpu/tensor/isnan_test.cc | 16 ++++- 12 files changed, 252 insertions(+), 9 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 4514a85531d6b..9f5cd4cc842dc 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -162,7 +162,7 @@ Do not modify directly.* |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(float)| |IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| -|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float)| @@ -633,6 +633,9 @@ Do not modify directly.* |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| |IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| +|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[13, 19]|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| +|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 7e0f919deb0a7..c3d5a51b636ef 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -714,6 +714,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16, IsNaN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, BFloat16, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, NonZero); @@ -1023,6 +1024,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, BFloat16, IsNaN); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu); #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN); @@ -2553,6 +2555,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo::Compute(OpKernelContext* context) const { template <> Status IsNaN::Compute(OpKernelContext* context) const { const auto* X_ptr = context->Input(0); - if (!X_ptr) { - return Status(common::ONNXRUNTIME, common::FAIL, "Null input ptr"); - } + auto X_data = X_ptr->Data(); auto& dims = X_ptr->Shape(); auto shape_size = dims.Size(); @@ -91,6 +91,19 @@ Status IsNaN::Compute(OpKernelContext* context) const { return Status::OK(); } +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X_ptr = context->Input(0); + + auto X_data = X_ptr->DataAsSpan(); + auto& Y = *context->Output(0, X_ptr->Shape()); + + std::transform(X_data.begin(), X_data.end(), Y.MutableData(), + [](BFloat16 x) { return x.IsNaN(); }); + + return Status::OK(); +} + #if !defined(DISABLE_FLOAT8_TYPES) template <> Status IsNaN::Compute(OpKernelContext* context) const { diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index bba9178348132..bed2f677166d6 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -485,7 +485,7 @@ struct IsInfTyped { #if !defined(DISABLE_FLOAT8_TYPES) -template +template struct ReturnFalse { constexpr static bool __device__ __inline__ IsInf(T) { return false; } constexpr static bool __device__ __inline__ IsInfPos(T) { return false; } @@ -532,6 +532,63 @@ struct _IsInf { } }; +// float and double +template +struct _IsNan { + __device__ __inline__ bool operator()(T a) const { + return isnan(a); + } +}; + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(half a) const { + return static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask) + > MLFloat16::kPositiveInfinityBits; + } +}; + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(BFloat16 a) const { + return static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask) + > BFloat16::kPositiveInfinityBits; + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template<> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E4M3FN a) const { + return (*reinterpret_cast(&a) & 0x7f) == 0x7f; + } +}; + +template<> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E4M3FNUZ a) const { + return *reinterpret_cast(&a) == 0x80; + } +}; + +template<> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E5M2 a) const { + uint8_t c = *reinterpret_cast(&a); + return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00); + } +}; + +template<> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E5M2FNUZ a) const { + return *reinterpret_cast(&a) == 0x80; + } +}; + +#endif + // We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer // For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. #ifndef CUDA_LONG diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index bade2faf8f2e2..18c7334af6611 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -746,6 +746,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, uint32_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, uint64_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, bool, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, float, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, double, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, MLFloat16, Pad); @@ -938,7 +939,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom // OpSet 12 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Clip); - class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, float, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, double, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, MLFloat16, MaxPool); @@ -1087,6 +1087,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, U class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Concat); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Gather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, GatherElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 19, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul); @@ -1368,6 +1369,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -1553,6 +1555,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1979,6 +1982,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2279,6 +2283,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 00de1b37f3302..24593b255371c 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -109,6 +109,50 @@ Status IsInf::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } +// IsNan +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + IsNaN, + kOnnxDomain, + 9, + 12, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsNaN); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + IsNaN, + kOnnxDomain, + 13, + 19, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsNaN); + +ONNX_OPERATOR_KERNEL_EX( + IsNaN, + kOnnxDomain, + 20, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsNaN); + +Status IsNaN::ComputeInternal(OpKernelContext* context) const { + UnaryElementwisePreparation p; + ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); + + Explicit_Impl_IsNan(Stream(context), p.input_tensor->GetElementType(), p.input_tensor->DataRaw(), + p.output_tensor->MutableData(), + p.input_tensor->Shape().Size()); + + return Status::OK(); +} + #define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \ UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h index 3b7d6df7221b7..95d68b5e1d534 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h @@ -131,5 +131,11 @@ class IsInf final : public UnaryElementwise { int opset_; }; +class IsNaN : public UnaryElementwise { + public: + explicit IsNaN(const OpKernelInfo& info) : UnaryElementwise(info) {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 554d5908cf854..2cdfcda5be26a 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -315,13 +315,33 @@ void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, if (op_set < 20) { utils::MLTypeCallDispatcher dispatcher{input_data_type}; dispatcher.Invoke(stream, input_raw, output_data, - detect_positive, detect_negative, count); + detect_positive, detect_negative, count); } else { utils::MLTypeCallDispatcher dispatcher{input_data_type}; dispatcher.Invoke(stream, input_raw, output_data, - detect_positive, detect_negative, count); + detect_positive, detect_negative, count); } } +// IsNan + +namespace isnan_details { +template +struct IsNan_Disp { + void operator()(cudaStream_t stream, const void* input_raw, bool* output_data, size_t count) const { + using CudaType = typename ToCudaType::MappedType; + const auto* input_data = reinterpret_cast(input_raw); + UnaryElementWiseImpl(stream, input_data, output_data, _IsNan{}, count); + } +}; +} // namespace isnan_details + +void Explicit_Impl_IsNan(cudaStream_t stream, int32_t input_data_type, + const void* input_raw, bool* output_data, size_t count) { + // KernelDef constraints would ensure only subset of datatypes is used. + utils::MLTypeCallDispatcher dispatcher{input_data_type}; + dispatcher.Invoke(stream, input_raw, output_data, count); +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index a606d479bc79b..2588f56e32c12 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -151,6 +151,20 @@ void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, int32_t input_data_type, const void* input_raw, bool* output_data, size_t count); + +// IsNan +#define ISNAN_OPSET9_FLOATS float, double, MLFloat16 +#define ISNAN_OPSET13_FLOATS float, double, MLFloat16, BFloat16 +#if !defined(DISABLE_FLOAT8_TYPES) +#define ISNAN_OPSET20_FLOATS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \ + Float8E5M2FNUZ +#else +#define ISNAN_OPSET20_FLOATS ISNAN_OPSET13_FLOATS +#endif + +void Explicit_Impl_IsNan(cudaStream_t stream, int32_t input_data_type, + const void* input_raw, bool* output_data, size_t count); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index f3685606c17f5..1698e5ca8478c 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -429,6 +429,63 @@ struct _IsInf { } }; +// float and double +template +struct _IsNan { + __device__ __inline__ bool operator()(T a) const { + return isnan(a); + } +}; + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(half a) const { + return static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask) + > MLFloat16::kPositiveInfinityBits; + } +}; + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(BFloat16 a) const { + return static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask) + > BFloat16::kPositiveInfinityBits; + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E4M3FN a) const { + return (*reinterpret_cast(&a) & 0x7f) == 0x7f; + } +}; + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E4M3FNUZ a) const { + return *reinterpret_cast(&a) == 0x80; + } +}; + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E5M2 a) const { + uint8_t c = *reinterpret_cast(&a); + return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00); + } +}; + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E5M2FNUZ a) const { + return *reinterpret_cast(&a) == 0x80; + } +}; + +#endif + // We would like to use 64-bit integer to support large matrices. However, ROCM seems to support only 32-bit integer // For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. #ifndef HIP_LONG diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 32be74550951e..87daaeea969ac 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -734,6 +734,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, Shrink); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, double, Shrink); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, Shrink); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Less); @@ -1067,6 +1068,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint32_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint64_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, bool, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 19, IsNaN); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size); @@ -1346,6 +1348,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, S // Opset 20 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsInf); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsNaN); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -1531,6 +1534,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, @@ -1941,6 +1945,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2304,6 +2309,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // opset 20 BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/test/providers/cpu/tensor/isnan_test.cc b/onnxruntime/test/providers/cpu/tensor/isnan_test.cc index 0f1e5c07cdd9b..3cf99fde2cce7 100644 --- a/onnxruntime/test/providers/cpu/tensor/isnan_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/isnan_test.cc @@ -38,9 +38,23 @@ TEST(IsNaNOpTest, IsNaNFloat16_9) { run_is_nan_test(9, dims, input, output); } +TEST(IsNaNOpTest, IsNaNFloat16_13) { + std::vector dims{2, 2}; + std::initializer_list input = {MLFloat16::One, MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(13, dims, input, output); +} + TEST(IsNaNOpTest, IsNaNFloat16_20) { std::vector dims{2, 2}; - std::initializer_list input = {MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN}; + std::initializer_list input = {MLFloat16::One, MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(20, dims, input, output); +} + +TEST(IsNaNOpTest, IsNaNBFloat16_20) { + std::vector dims{2, 2}; + std::initializer_list input = {BFloat16::One, BFloat16::NaN, BFloat16(2.0f), BFloat16::NaN}; std::initializer_list output = {false, true, false, true}; run_is_nan_test(20, dims, input, output); }