Skip to content

Commit

Permalink
implement isinf20 and isnan20 (#17874)
Browse files Browse the repository at this point in the history
  • Loading branch information
liqunfu authored Oct 24, 2023
1 parent abb3291 commit efa0cc2
Show file tree
Hide file tree
Showing 8 changed files with 389 additions and 101 deletions.
6 changes: 4 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,10 @@ Do not modify directly.*
|||[1, 10]|**B** = tensor(bool)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ImageScaler|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float)|
|InstanceNormalization|*in* input:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *out* output:**T**|6+|**T** = tensor(float)|
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|10+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|13+|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|||[10, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
|||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|LRN|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float)|
|||[1, 12]|**T** = tensor(float)|
Expand Down
5 changes: 4 additions & 1 deletion include/onnxruntime/core/framework/float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,10 @@ struct Float8E4M3FNUZ {
val = static_cast<uint8_t>((b & 0x80000000) >> 24); // sign
if ((b & 0x7fffffff) == 0x7f800000) { // infinity
if (saturate) {
// the highest available value
val |= 0x7F;
} else {
// infinity
// NaN
val = 0x80;
}
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
Expand Down Expand Up @@ -362,8 +363,10 @@ struct Float8E5M2 {
val = (b & 0x80000000) >> 24; // sign
if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
if (saturate) {
// the highest available value
val |= 0x7B;
} else {
// the infinity
val |= 0x7C;
}
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
Expand Down
42 changes: 31 additions & 11 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, Slice);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 11, Dropout);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 19, IsInf);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, float, RoiAlign);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, double, RoiAlign);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ReverseSequence);
Expand Down Expand Up @@ -682,9 +682,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Ga
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterND);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Identity);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, IsNaN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float, IsNaN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double, IsNaN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, NonZero);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, NonZero);
Expand Down Expand Up @@ -960,6 +960,16 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh

// Opset 20
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN);
#if !defined(DISABLE_FLOAT8_TYPES)
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN);
#endif
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf);

// !!PLEASE READ BELOW!! Following that, add new entries above this comment

Expand Down Expand Up @@ -1492,7 +1502,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
Dropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
NonMaxSuppression)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 19, IsInf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, float,
RoiAlign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, double,
Expand Down Expand Up @@ -1981,12 +1991,12 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Identity)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float,
IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double,
IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16,
IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float,
IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double,
IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16,
IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool,
NonZero)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float,
Expand Down Expand Up @@ -2389,6 +2399,16 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {

// Opset 20
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN)>,
#if !defined(DISABLE_FLOAT8_TYPES)
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2, IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN)>,
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf)>,
};

for (auto& function_table_entry : function_table) {
Expand Down
101 changes: 92 additions & 9 deletions onnxruntime/core/providers/cpu/tensor/isinf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,64 @@ namespace onnxruntime {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsInf

namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0,
float, double);
using IsInfTypesOpset10 = TypeList<float, double>;

ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
kCpuExecutionProvider, kOnnxDomain, IsInf, 10, Input, 0,
IsInfTypesOpset10);

using IsInfTypesOpset20 =
TypeList<
float,
double
#if !defined(DISABLE_FLOAT8_TYPES)
,
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
#endif
>;

ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
kCpuExecutionProvider,
kOnnxDomain,
IsInf,
20,
Input,
0,
IsInfTypesOpset20);
} // namespace op_kernel_type_control

class IsInf final : public OpKernel {
public:
using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
IsInf, Input, 0);
using EnabledDataTypes10 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain,
IsInf, 10, Input, 0);
using EnabledDataTypes20 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain,
IsInf, 20, Input, 0);

explicit IsInf(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;

private:
int64_t detect_positive_{1};
int64_t detect_negative_{1};
int opset_;
};

ONNX_CPU_OPERATOR_KERNEL(
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
IsInf,
10,
19,
KernelDefBuilder()
.TypeConstraint("T1",
BuildKernelDefConstraintsFromTypeList<IsInf::EnabledDataTypes>())
BuildKernelDefConstraintsFromTypeList<IsInf::EnabledDataTypes10>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
IsInf);

ONNX_CPU_OPERATOR_KERNEL(
IsInf,
20,
KernelDefBuilder()
.TypeConstraint("T1",
BuildKernelDefConstraintsFromTypeList<IsInf::EnabledDataTypes20>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
IsInf);

Expand All @@ -46,6 +80,7 @@ IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) {
ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive");
status = info.GetAttr("detect_negative", &detect_negative_);
ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative");
opset_ = info.node().SinceVersion();
}

namespace isinf_internal {
Expand Down Expand Up @@ -78,6 +113,49 @@ struct ComputeDispatchTarget {
}
}
};

#if !defined(DISABLE_FLOAT8_TYPES)
template <>
struct ComputeDispatchTarget<Float8E4M3FN> {
void operator()(const Tensor&, Tensor& Y, bool, bool) const {
EigenMap<bool>(Y).array() = false;
}
};

template <>
struct ComputeDispatchTarget<Float8E4M3FNUZ> {
void operator()(const Tensor&, Tensor& Y, bool, bool) const {
EigenMap<bool>(Y).array() = false;
}
};

template <>
struct ComputeDispatchTarget<Float8E5M2> {
void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
auto& dims = X.Shape();
auto input = ConstEigenVectorMap<uint8_t>(static_cast<const uint8_t*>(static_cast<const void*>(X.Data<Float8E5M2>())), onnxruntime::narrow<size_t>(dims.Size()));
auto output = EigenMap<bool>(Y);

// S.11111.00
if (detect_positive && detect_negative) {
output.array() = input.array() == 0b01111100 || input.array() == 0b11111100;
} else if (detect_positive) {
output.array() = input.array() == 0b01111100;
} else if (detect_negative) {
output.array() = input.array() == 0b11111100;
} else {
output.array() = false;
}
}
};

template <>
struct ComputeDispatchTarget<Float8E5M2FNUZ> {
void operator()(const Tensor&, Tensor& Y, bool, bool) const {
EigenMap<bool>(Y).array() = false;
}
};
#endif
} // namespace isinf_internal

Status IsInf::Compute(OpKernelContext* context) const {
Expand All @@ -88,8 +166,13 @@ Status IsInf::Compute(OpKernelContext* context) const {

using namespace isinf_internal;

utils::MLTypeCallDispatcherFromTypeList<EnabledDataTypes> dispatcher{X.GetElementType()};
dispatcher.Invoke<ComputeDispatchTarget>(X, Y, detect_positive_ != 0, detect_negative_ != 0);
if (opset_ < 20) {
utils::MLTypeCallDispatcherFromTypeList<EnabledDataTypes10> dispatcher{X.GetElementType()};
dispatcher.Invoke<ComputeDispatchTarget>(X, Y, detect_positive_ != 0, detect_negative_ != 0);
} else {
utils::MLTypeCallDispatcherFromTypeList<EnabledDataTypes20> dispatcher{X.GetElementType()};
dispatcher.Invoke<ComputeDispatchTarget>(X, Y, detect_positive_ != 0, detect_negative_ != 0);
}

return Status::OK();
}
Expand Down
81 changes: 80 additions & 1 deletion onnxruntime/core/providers/cpu/tensor/isnan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,20 @@ namespace onnxruntime {
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()), \
IsNaN<data_type>);

#define ADD_TYPED_ISNAN_OP_13(data_type) \
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
IsNaN, \
13, 19, \
data_type, \
KernelDefBuilder() \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<data_type>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()), \
IsNaN<data_type>);

#define ADD_TYPED_ISNAN_OP(data_type) \
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
IsNaN, \
13, \
20, \
data_type, \
KernelDefBuilder() \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<data_type>()) \
Expand All @@ -33,10 +43,20 @@ namespace onnxruntime {
ADD_TYPED_ISNAN_OP_9(float);
ADD_TYPED_ISNAN_OP_9(double);
ADD_TYPED_ISNAN_OP_9(MLFloat16);
ADD_TYPED_ISNAN_OP_13(float);
ADD_TYPED_ISNAN_OP_13(double);
ADD_TYPED_ISNAN_OP_13(MLFloat16);
ADD_TYPED_ISNAN_OP(float);
ADD_TYPED_ISNAN_OP(double);
ADD_TYPED_ISNAN_OP(MLFloat16);

#if !defined(DISABLE_FLOAT8_TYPES)
ADD_TYPED_ISNAN_OP(Float8E4M3FN);
ADD_TYPED_ISNAN_OP(Float8E4M3FNUZ);
ADD_TYPED_ISNAN_OP(Float8E5M2);
ADD_TYPED_ISNAN_OP(Float8E5M2FNUZ);
#endif

template <typename T>
Status IsNaN<T>::Compute(OpKernelContext* context) const {
const auto* X_ptr = context->Input<Tensor>(0);
Expand Down Expand Up @@ -70,4 +90,63 @@ Status IsNaN<MLFloat16>::Compute(OpKernelContext* context) const {

return Status::OK();
}

#if !defined(DISABLE_FLOAT8_TYPES)
template <>
Status IsNaN<Float8E4M3FN>::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
auto& dims = X->Shape();
auto& Y = *context->Output(0, dims);

auto input = ConstEigenVectorMap<uint8_t>(static_cast<const uint8_t*>(static_cast<const void*>(X->Data<Float8E4M3FN>())), onnxruntime::narrow<size_t>(dims.Size()));
auto output = EigenMap<bool>(Y);

// S.1111.111
std::transform(input.begin(), input.end(), output.begin(), [](uint8_t c) { return (c & 0x7f) == 0x7f; });
return Status::OK();
}

template <>
Status IsNaN<Float8E4M3FNUZ>::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
auto X_data = X->Data<Float8E4M3FNUZ>();
auto& dims = X->Shape();
auto shape_size = dims.Size();
auto& Y = *context->Output(0, dims);

// 1.0000.000
EigenMap<bool>(Y) =
ConstEigenVectorMap<uint8_t>(static_cast<const uint8_t*>(static_cast<const void*>(X_data)), onnxruntime::narrow<size_t>(shape_size)).array() == 0x80;

return Status::OK();
}

template <>
Status IsNaN<Float8E5M2>::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
auto& dims = X->Shape();
auto& Y = *context->Output(0, dims);

auto input = ConstEigenVectorMap<uint8_t>(static_cast<const uint8_t*>(static_cast<const void*>(X->Data<Float8E5M2>())), onnxruntime::narrow<size_t>(dims.Size()));
auto output = EigenMap<bool>(Y);

// S.11111.{01, 10, 11}
std::transform(input.begin(), input.end(), output.begin(), [](uint8_t c) { return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00); });
return Status::OK();
}

template <>
Status IsNaN<Float8E5M2FNUZ>::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
auto X_data = X->Data<Float8E5M2FNUZ>();
auto& dims = X->Shape();
auto shape_size = dims.Size();
auto& Y = *context->Output(0, dims);

// 1.0000.000
EigenMap<bool>(Y) = ConstEigenVectorMap<uint8_t>(static_cast<const uint8_t*>(static_cast<const void*>(X_data)), onnxruntime::narrow<size_t>(shape_size)).array() == 0x80;

return Status::OK();
}
#endif
} // namespace onnxruntime
Loading

0 comments on commit efa0cc2

Please sign in to comment.