diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 734506681ab60..28228d530ba4a 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -175,8 +175,8 @@ Do not modify directly.*
|||[1, 12]|**T** = tensor(float)|
|LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)|
|||[7, 13]|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)|
-|LayerNormalization|*in* X:**T**
*in* Scale:**T**
*in* B:**T**
*out* Y:**T**
*out* Mean:**U**
*out* InvStdDev:**U**
or
*in* X:**T**
*in* Scale:**V**
*in* B:**V**
*out* Y:**V**
*out* Mean:**U**
*out* InvStdDev:**U**|17+|**T** = tensor(double), tensor(float)
**U** = tensor(float)|
-|||[1, 16]|**T** = tensor(double), tensor(float)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float)|
+|LayerNormalization|*in* X:**T**
*in* Scale:**T**
*in* B:**T**
*out* Y:**T**
*out* Mean:**U**
*out* InvStdDev:**U**
or
*in* X:**T**
*in* Scale:**V**
*in* B:**V**
*out* Y:**V**
*out* Mean:**U**
*out* InvStdDev:**U**|17+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(float)|
+|||[1, 16]|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)
**V** = tensor(double), tensor(float), tensor(float16)|
|LeakyRelu|*in* X:**T**
*out* Y:**T**|16+|**T** = tensor(float)|
|||[6, 15]|**T** = tensor(float)|
|Less|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)|
@@ -369,7 +369,7 @@ Do not modify directly.*
|||[6, 12]|**T** = tensor(double), tensor(float)|
|Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[9, 12]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float)|
+|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)
**V** = tensor(double), tensor(float), tensor(float16)|
|Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float)|
|Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float)|
|Size|*in* data:**T**
*out* size:**T1**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
@@ -511,8 +511,8 @@ Do not modify directly.*
|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)|
|SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)|
-|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
-|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
+|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|SparseAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* block_row_indices:**M**
*in* block_col_indices:**M**
*in* total_sequence_length:**M**
*in* key_total_sequence_lengths:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float)|
|SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)|
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index e75d485830ca5..6ffe861d19931 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -136,12 +136,16 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomai
// LayerNormalization is now in the ONNX spec. As the contrib op (incorrectly) used kOnnxDomain we need to version it
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, float, LayerNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, double, LayerNormalization);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, MLFloat16, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, SimplifiedLayerNormalization);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipSimplifiedLayerNormalization);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu);
@@ -342,12 +346,16 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cpu/layer_norm.cc b/onnxruntime/contrib_ops/cpu/layer_norm.cc
index 94f32360bd2f4..c949fcddad093 100644
--- a/onnxruntime/contrib_ops/cpu/layer_norm.cc
+++ b/onnxruntime/contrib_ops/cpu/layer_norm.cc
@@ -25,6 +25,7 @@ namespace contrib {
REGISTER_CONTRIB_KERNELS(float)
REGISTER_CONTRIB_KERNELS(double)
+REGISTER_CONTRIB_KERNELS(MLFloat16)
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
index 4e103c2556a7a..faf78cae80ee1 100644
--- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
+++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
@@ -5,6 +5,7 @@
#include "core/util/math_cpuonly.h"
#include "core/providers/common.h"
#include "core/platform/threadpool.h"
+#include "core/util/force_inline.h"
#include "skip_layer_norm.h"
#include "skip_layer_norm_helper.h"
@@ -33,6 +34,50 @@ namespace contrib {
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(double)
+REGISTER_KERNEL_TYPED(MLFloat16)
+
+// Utility to convert from MLFloat16 to float only when the input type is MLFloat16.
+template
+ORT_FORCEINLINE Ret ConvertMLFloat16ToDoubleOrFloatIfNeeded(T val);
+
+template <>
+ORT_FORCEINLINE float ConvertMLFloat16ToDoubleOrFloatIfNeeded(MLFloat16 val) {
+ return val.ToFloat();
+}
+
+template <>
+ORT_FORCEINLINE double ConvertMLFloat16ToDoubleOrFloatIfNeeded(MLFloat16 val) {
+ return static_cast(ConvertMLFloat16ToDoubleOrFloatIfNeeded(val));
+}
+
+template <>
+ORT_FORCEINLINE constexpr float ConvertMLFloat16ToDoubleOrFloatIfNeeded(float val) {
+ return val;
+}
+
+template <>
+ORT_FORCEINLINE constexpr double ConvertMLFloat16ToDoubleOrFloatIfNeeded(double val) {
+ return val;
+}
+
+// Function template that only converts the input value to MLFloat16 if T is MLFloat16.
+template
+ORT_FORCEINLINE constexpr typename std::enable_if_t || std::is_same_v, T>
+ConvertDoubleOrFloatToMLFloat16IfNeeded(T val) {
+ return val;
+}
+
+template
+ORT_FORCEINLINE constexpr typename std::enable_if_t, T>
+ConvertDoubleOrFloatToMLFloat16IfNeeded(float val) {
+ return MLFloat16(val);
+}
+
+template
+ORT_FORCEINLINE constexpr typename std::enable_if_t, T>
+ConvertDoubleOrFloatToMLFloat16IfNeeded(double val) {
+ return MLFloat16(static_cast(val));
+}
template
SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
@@ -91,21 +136,32 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const {
T* p_output = output_data + offset;
T* p_skip_input_bias_add_output_data = skip_input_bias_add_output_data != nullptr ? skip_input_bias_add_output_data + offset : nullptr;
- T mean = 0;
- T mean_square = 0;
+ using DoubleOrFloat = typename std::conditional<
+ std::is_same::value, // If T is double
+ double, // Use double
+ float // Otherwise, use float (covers float and MLFloat16)
+ >::type;
+
+ DoubleOrFloat mean(0.0f);
+ DoubleOrFloat mean_square(0.0f);
+
+ std::unique_ptr output_buffer = std::make_unique(hidden_size);
+ for (size_t h = 0; h < static_cast(hidden_size); h++) {
+ DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(p_input[h]);
+ DoubleOrFloat skip_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(p_skip[h]);
- for (int64_t h = 0; h < hidden_size; h++) {
- T value = p_input[h] + p_skip[h];
+ DoubleOrFloat value = input_value + skip_value;
if (nullptr != bias_data) {
- value += bias_data[h];
+ value += ConvertMLFloat16ToDoubleOrFloatIfNeeded(bias_data[h]);
}
+ output_buffer[h] = value;
+ T converted_value = ConvertDoubleOrFloatToMLFloat16IfNeeded(value);
if (nullptr != p_skip_input_bias_add_output_data) {
- p_skip_input_bias_add_output_data[h] = value;
+ p_skip_input_bias_add_output_data[h] = converted_value;
}
- p_output[h] = value;
mean += value;
mean_square += value * value;
}
@@ -117,13 +173,15 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const {
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_);
}
- for (int64_t h = 0; h < hidden_size; h++) {
+ for (size_t h = 0; h < static_cast(hidden_size); h++) {
+ DoubleOrFloat gamma_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(gamma_data[h]);
if (simplified) {
- p_output[h] = p_output[h] / mean_square * gamma_data[h];
+ p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded(output_buffer[h] / mean_square * gamma_value);
} else if (nullptr == beta_data) {
- p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h];
+ p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded((output_buffer[h] - mean) / mean_square * gamma_value);
} else {
- p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h];
+ DoubleOrFloat beta_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(beta_data[h]);
+ p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded((output_buffer[h] - mean) / mean_square * gamma_value + beta_value);
}
}
},
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 7b1b136eb091e..424bee63511ad 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -903,6 +903,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, Me
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, STFT);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, double, LayerNormalization);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization);
// Opset 18
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 18, float, Resize);
@@ -2465,6 +2466,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
LayerNormalization)>,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// Opset 18
BuildKernelCreateInfo
+ORT_FORCEINLINE Ret ConvertMLFloat16ToDoubleOrFloatIfNeeded(T val);
+
+template <>
+ORT_FORCEINLINE float ConvertMLFloat16ToDoubleOrFloatIfNeeded(MLFloat16 val) {
+ return val.ToFloat();
+}
+
+template <>
+ORT_FORCEINLINE double ConvertMLFloat16ToDoubleOrFloatIfNeeded(MLFloat16 val) {
+ return double(ConvertMLFloat16ToDoubleOrFloatIfNeeded(val));
+}
+
+template <>
+ORT_FORCEINLINE constexpr float ConvertMLFloat16ToDoubleOrFloatIfNeeded(float val) {
+ return val;
+}
+
+template <>
+ORT_FORCEINLINE constexpr double ConvertMLFloat16ToDoubleOrFloatIfNeeded(double val) {
+ return val;
+}
+
+ORT_FORCEINLINE constexpr float ConvertToFloatIfNeeded(float val) {
+ return val;
+}
+
+ORT_FORCEINLINE constexpr float ConvertToFloatIfNeeded(double val) {
+ // ONNX spec doesn't support 'double' for 'Ret' so when 'T' == double, 'Ret' == float and we need to narrow
+ return gsl::narrow_cast(val);
+}
+
+// Function template that only converts the input value to MLFloat16 if T is MLFloat16.
+template
+ORT_FORCEINLINE constexpr typename std::enable_if_t || std::is_same_v, float>
+ConvertToMLFloat16IfNeeded(float val) {
+ return val;
+}
+
+template
+ORT_FORCEINLINE constexpr typename std::enable_if_t, MLFloat16>
+ConvertToMLFloat16IfNeeded(float val) {
+ return MLFloat16(val);
+}
+
+template
+ORT_FORCEINLINE constexpr double ConvertToMLFloat16IfNeeded(double val) {
+ return val;
+}
+
LayerNormImpl::LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified, bool contrib_op)
: OpKernel(op_kernel_info), simplified_{simplified}, contrib_op_{contrib_op} {
ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK());
@@ -24,14 +76,14 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo
const Tensor* X = p_ctx->Input(0);
const Tensor* scale = p_ctx->Input(1);
const Tensor* bias = p_ctx->Input(2);
- auto X_data = X->Data();
- auto scale_data = scale->Data();
- auto bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data();
+ const T* X_data = X->Data();
+ const T* scale_data = scale->Data();
+ const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data();
const TensorShape& x_shape = X->Shape();
const int64_t axis = HandleNegativeAxis(orig_axis, x_shape.NumDimensions());
- auto norm_count = x_shape.SizeToDimension(onnxruntime::narrow(axis));
- auto norm_size = x_shape.SizeFromDimension(onnxruntime::narrow(axis));
+ int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow(axis));
+ int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow(axis));
const auto scale_size = scale->Shape().Size();
const auto bias_size = (bias_data) ? bias->Shape().Size() : 0;
@@ -80,12 +132,19 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo
const T* p_input = X_data + task_idx * norm_size;
T* p_output = Y_data + task_idx * norm_size;
- T mean = 0;
- T mean_square = 0;
+ using DoubleOrFloat = typename std::conditional<
+ std::is_same::value, // If T is double
+ double, // Use double
+ float // Otherwise, use float (covers float and MLFloat16)
+ >::type;
+
+ DoubleOrFloat mean(0.0f);
+ DoubleOrFloat mean_square(0.0f);
for (int64_t h = 0; h < norm_size; h++) {
- mean += p_input[h];
- mean_square += p_input[h] * p_input[h];
+ DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(p_input[h]);
+ mean += input_value;
+ mean_square += input_value * input_value;
}
mean = mean / norm_size;
@@ -96,22 +155,25 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo
}
for (int64_t h = 0; h < norm_size; h++) {
+ DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(p_input[h]);
+ DoubleOrFloat scale_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(scale_data[h]);
if (simplified) {
- p_output[h] = p_input[h] / mean_square * scale_data[h];
+ p_output[h] = ConvertToMLFloat16IfNeeded(input_value / mean_square * scale_value);
} else if (nullptr == bias) {
- p_output[h] = (p_input[h] - mean) / mean_square * scale_data[h];
+ p_output[h] = ConvertToMLFloat16IfNeeded((input_value - mean) / mean_square * scale_value);
} else {
- p_output[h] = (p_input[h] - mean) / mean_square * scale_data[h] + bias_data[h];
+ DoubleOrFloat bias_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded(bias_data[h]);
+ p_output[h] = ConvertToMLFloat16IfNeeded((input_value - mean) / mean_square * scale_value + bias_value);
}
}
if (mean_data != nullptr) {
// ONNX spec doesn't support 'double' for 'U' so when 'T' == double, 'U' == float and we need to narrow
- mean_data[task_idx] = gsl::narrow_cast(mean);
+ mean_data[task_idx] = ConvertToMLFloat16IfNeeded(ConvertToFloatIfNeeded(mean));
}
if (inv_std_dev_data != nullptr) {
- inv_std_dev_data[task_idx] = gsl::narrow_cast(1 / mean_square);
+ inv_std_dev_data[task_idx] = ConvertToMLFloat16IfNeeded(ConvertToFloatIfNeeded(1 / mean_square));
}
},
0);
@@ -141,7 +203,7 @@ struct SrcDispatcher {
Status LayerNormImpl::Compute(OpKernelContext* p_ctx) const {
const auto elem_type = p_ctx->Input(0)->GetElementType();
- using SupportedTypeList = boost::mp11::mp_list;
+ using SupportedTypeList = boost::mp11::mp_list;
utils::MLTypeCallDispatcherFromTypeList t_disp(elem_type);
return t_disp.InvokeRet(p_ctx, axis_, epsilon_, simplified_, contrib_op_);