Skip to content

Commit

Permalink
Implement IsNaN-9,13,20 for CUDA along with tests (#19807)
Browse files Browse the repository at this point in the history
### Description


### Motivation and Context
Some models require IsNan CUDA along with training
  • Loading branch information
yuslepukhin authored Mar 7, 2024
1 parent 33578cc commit 2964352
Show file tree
Hide file tree
Showing 12 changed files with 252 additions and 9 deletions.
5 changes: 4 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ Do not modify directly.*
|InstanceNormalization|*in* input:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *out* output:**T**|6+|**T** = tensor(float)|
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|||[10, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|LRN|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float)|
Expand Down Expand Up @@ -633,6 +633,9 @@ Do not modify directly.*
|InstanceNormalization|*in* input:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *out* output:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|||[10, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|||[13, 19]|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|LRN|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|LSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(int32)|
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float, IsNaN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double, IsNaN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16, IsNaN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, BFloat16, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, NonZero);
Expand Down Expand Up @@ -1023,6 +1024,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, BFloat16, IsNaN);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu);
#if !defined(DISABLE_FLOAT8_TYPES)
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN);
Expand Down Expand Up @@ -2553,6 +2555,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16,
IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, BFloat16,
IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu)>,
#if !defined(DISABLE_FLOAT8_TYPES)
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN,
Expand Down
19 changes: 16 additions & 3 deletions onnxruntime/core/providers/cpu/tensor/isnan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ ADD_TYPED_ISNAN_OP_9(MLFloat16);
ADD_TYPED_ISNAN_OP_13(float);
ADD_TYPED_ISNAN_OP_13(double);
ADD_TYPED_ISNAN_OP_13(MLFloat16);
ADD_TYPED_ISNAN_OP_13(BFloat16);
ADD_TYPED_ISNAN_OP(float);
ADD_TYPED_ISNAN_OP(double);
ADD_TYPED_ISNAN_OP(MLFloat16);
ADD_TYPED_ISNAN_OP(BFloat16);

#if !defined(DISABLE_FLOAT8_TYPES)
ADD_TYPED_ISNAN_OP(Float8E4M3FN);
Expand All @@ -75,9 +77,7 @@ Status IsNaN<T>::Compute(OpKernelContext* context) const {
template <>
Status IsNaN<MLFloat16>::Compute(OpKernelContext* context) const {
const auto* X_ptr = context->Input<Tensor>(0);
if (!X_ptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "Null input ptr");
}

auto X_data = X_ptr->Data<MLFloat16>();
auto& dims = X_ptr->Shape();
auto shape_size = dims.Size();
Expand All @@ -91,6 +91,19 @@ Status IsNaN<MLFloat16>::Compute(OpKernelContext* context) const {
return Status::OK();
}

template <>
Status IsNaN<BFloat16>::Compute(OpKernelContext* context) const {
const auto* X_ptr = context->Input<Tensor>(0);

auto X_data = X_ptr->DataAsSpan<BFloat16>();
auto& Y = *context->Output(0, X_ptr->Shape());

std::transform(X_data.begin(), X_data.end(), Y.MutableData<bool>(),
[](BFloat16 x) { return x.IsNaN(); });

return Status::OK();
}

#if !defined(DISABLE_FLOAT8_TYPES)
template <>
Status IsNaN<Float8E4M3FN>::Compute(OpKernelContext* context) const {
Expand Down
59 changes: 58 additions & 1 deletion onnxruntime/core/providers/cuda/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ struct IsInfTyped<BFloat16> {

#if !defined(DISABLE_FLOAT8_TYPES)

template<typename T>
template <typename T>
struct ReturnFalse {
constexpr static bool __device__ __inline__ IsInf(T) { return false; }
constexpr static bool __device__ __inline__ IsInfPos(T) { return false; }
Expand Down Expand Up @@ -532,6 +532,63 @@ struct _IsInf {
}
};

// float and double
template <typename T>
struct _IsNan {
__device__ __inline__ bool operator()(T a) const {
return isnan(a);
}
};

template <>
struct _IsNan<half> {
__device__ __inline__ bool operator()(half a) const {
return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~MLFloat16::kSignMask)
> MLFloat16::kPositiveInfinityBits;
}
};

template <>
struct _IsNan<BFloat16> {
__device__ __inline__ bool operator()(BFloat16 a) const {
return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~BFloat16::kSignMask)
> BFloat16::kPositiveInfinityBits;
}
};

#if !defined(DISABLE_FLOAT8_TYPES)

template<>
struct _IsNan<Float8E4M3FN> {
__device__ __inline__ bool operator()(Float8E4M3FN a) const {
return (*reinterpret_cast<const uint8_t*>(&a) & 0x7f) == 0x7f;
}
};

template<>
struct _IsNan<Float8E4M3FNUZ> {
__device__ __inline__ bool operator()(Float8E4M3FNUZ a) const {
return *reinterpret_cast<const uint8_t*>(&a) == 0x80;
}
};

template<>
struct _IsNan<Float8E5M2> {
__device__ __inline__ bool operator()(Float8E5M2 a) const {
uint8_t c = *reinterpret_cast<const uint8_t*>(&a);
return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00);
}
};

template<>
struct _IsNan<Float8E5M2FNUZ> {
__device__ __inline__ bool operator()(Float8E5M2FNUZ a) const {
return *reinterpret_cast<const uint8_t*>(&a) == 0x80;
}
};

#endif

// We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer
// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type.
#ifndef CUDA_LONG
Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, uint32_t, Cast);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, uint64_t, Cast);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, bool, Cast);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, IsNaN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, float, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, double, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, MLFloat16, Pad);
Expand Down Expand Up @@ -938,7 +939,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom

// OpSet 12
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Clip);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, float, MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, double, MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, MLFloat16, MaxPool);
Expand Down Expand Up @@ -1087,6 +1087,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, U
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Concat);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Gather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, GatherElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 19, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul);
Expand Down Expand Up @@ -1368,6 +1369,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN);

template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
Expand Down Expand Up @@ -1553,6 +1555,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, float, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, double, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, bool, Not)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization)>,
Expand Down Expand Up @@ -1979,6 +1982,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint32_t, Cast)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint64_t, Cast)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, bool, Cast)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 19, IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size)>,
Expand Down Expand Up @@ -2279,6 +2283,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN)>,
#endif
};

Expand Down
44 changes: 44 additions & 0 deletions onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,50 @@ Status IsInf::ComputeInternal(OpKernelContext* context) const {
return Status::OK();
}

// IsNan
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
IsNaN,
kOnnxDomain,
9,
12,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", BuildKernelDefConstraints<ISNAN_OPSET9_FLOATS>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
IsNaN);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
IsNaN,
kOnnxDomain,
13,
19,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", BuildKernelDefConstraints<ISNAN_OPSET13_FLOATS>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
IsNaN);

ONNX_OPERATOR_KERNEL_EX(
IsNaN,
kOnnxDomain,
20,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", BuildKernelDefConstraints<ISNAN_OPSET20_FLOATS>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
IsNaN);

Status IsNaN::ComputeInternal(OpKernelContext* context) const {
UnaryElementwisePreparation p;
ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p));

Explicit_Impl_IsNan(Stream(context), p.input_tensor->GetElementType(), p.input_tensor->DataRaw(),
p.output_tensor->MutableData<bool>(),
p.input_tensor->Shape().Size());

return Status::OK();
}

#define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \
UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T)

Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,11 @@ class IsInf final : public UnaryElementwise {
int opset_;
};

class IsNaN : public UnaryElementwise {
public:
explicit IsNaN(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};

} // namespace cuda
} // namespace onnxruntime
24 changes: 22 additions & 2 deletions onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,33 @@ void Explicit_Impl_IsInf(cudaStream_t stream, int op_set,
if (op_set < 20) {
utils::MLTypeCallDispatcher<float, double> dispatcher{input_data_type};
dispatcher.Invoke<isinf_details::IsInf_DispFunc>(stream, input_raw, output_data,
detect_positive, detect_negative, count);
detect_positive, detect_negative, count);
} else {
utils::MLTypeCallDispatcher<ISINF_OPSET20_ALL_FLOATS> dispatcher{input_data_type};
dispatcher.Invoke<isinf_details::IsInf_DispFunc>(stream, input_raw, output_data,
detect_positive, detect_negative, count);
detect_positive, detect_negative, count);
}
}

// IsNan

namespace isnan_details {
template <typename T>
struct IsNan_Disp {
void operator()(cudaStream_t stream, const void* input_raw, bool* output_data, size_t count) const {
using CudaType = typename ToCudaType<T>::MappedType;
const auto* input_data = reinterpret_cast<const CudaType*>(input_raw);
UnaryElementWiseImpl(stream, input_data, output_data, _IsNan<CudaType>{}, count);
}
};
} // namespace isnan_details

void Explicit_Impl_IsNan(cudaStream_t stream, int32_t input_data_type,
const void* input_raw, bool* output_data, size_t count) {
// KernelDef constraints would ensure only subset of datatypes is used.
utils::MLTypeCallDispatcher<ISNAN_OPSET20_FLOATS> dispatcher{input_data_type};
dispatcher.Invoke<isnan_details::IsNan_Disp>(stream, input_raw, output_data, count);
}

} // namespace cuda
} // namespace onnxruntime
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,20 @@ void Explicit_Impl_IsInf(cudaStream_t stream, int op_set,
int32_t input_data_type,
const void* input_raw, bool* output_data,
size_t count);

// IsNan
#define ISNAN_OPSET9_FLOATS float, double, MLFloat16
#define ISNAN_OPSET13_FLOATS float, double, MLFloat16, BFloat16
#if !defined(DISABLE_FLOAT8_TYPES)
#define ISNAN_OPSET20_FLOATS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \
Float8E5M2FNUZ
#else
#define ISNAN_OPSET20_FLOATS ISNAN_OPSET13_FLOATS
#endif

void Explicit_Impl_IsNan(cudaStream_t stream, int32_t input_data_type,
const void* input_raw, bool* output_data, size_t count);

} // namespace cuda

} // namespace onnxruntime
57 changes: 57 additions & 0 deletions onnxruntime/core/providers/rocm/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,63 @@ struct _IsInf {
}
};

// float and double
template <typename T>
struct _IsNan {
__device__ __inline__ bool operator()(T a) const {
return isnan(a);
}
};

template <>
struct _IsNan<half> {
__device__ __inline__ bool operator()(half a) const {
return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~MLFloat16::kSignMask)
> MLFloat16::kPositiveInfinityBits;
}
};

template <>
struct _IsNan<BFloat16> {
__device__ __inline__ bool operator()(BFloat16 a) const {
return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~BFloat16::kSignMask)
> BFloat16::kPositiveInfinityBits;
}
};

#if !defined(DISABLE_FLOAT8_TYPES)

template <>
struct _IsNan<Float8E4M3FN> {
__device__ __inline__ bool operator()(Float8E4M3FN a) const {
return (*reinterpret_cast<const uint8_t*>(&a) & 0x7f) == 0x7f;
}
};

template <>
struct _IsNan<Float8E4M3FNUZ> {
__device__ __inline__ bool operator()(Float8E4M3FNUZ a) const {
return *reinterpret_cast<const uint8_t*>(&a) == 0x80;
}
};

template <>
struct _IsNan<Float8E5M2> {
__device__ __inline__ bool operator()(Float8E5M2 a) const {
uint8_t c = *reinterpret_cast<const uint8_t*>(&a);
return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00);
}
};

template <>
struct _IsNan<Float8E5M2FNUZ> {
__device__ __inline__ bool operator()(Float8E5M2FNUZ a) const {
return *reinterpret_cast<const uint8_t*>(&a) == 0x80;
}
};

#endif

// We would like to use 64-bit integer to support large matrices. However, ROCM seems to support only 32-bit integer
// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type.
#ifndef HIP_LONG
Expand Down
Loading

0 comments on commit 2964352

Please sign in to comment.