Skip to content

Commit

Permalink
Add MLFloat16 support for LayerNormalization, SkipLayerNormalization (#…
Browse files Browse the repository at this point in the history
…22063)

Add `MLFloat16` support for:
- `LayerNormalization`
- `SimplifiedLayerNormalization`
- `SkipLayerNormalization`
- `SkipSimplifiedLayerNormalization`

There are existing `LayerNormTest` unit tests that cover the `MLFloat16`
functionality for `LayerNormalization` once `MLFloat16` is registered
(for example
[`LayerNormTest.LayerNorm_Scale_Float16Input`](https://github.com/microsoft/onnxruntime/blob/91c916f9c6263d1c437ec7eda222d7f6b8817175/onnxruntime/test/contrib_ops/layer_norm_op_test.cc#L112)).

Similarly, there are unit tests such as
[`SkipLayerNormTest.SkipLayerNormBatch1_Float16`](https://github.com/microsoft/onnxruntime/blob/91c916f9c6263d1c437ec7eda222d7f6b8817175/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc#L255)
that cover MLFloat16 inputs for `SkipLayerNormalization`.
  • Loading branch information
amarin16 authored Sep 24, 2024
1 parent 6199633 commit eb2506d
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 31 deletions.
10 changes: 5 additions & 5 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ Do not modify directly.*
|||[1, 12]|**T** = tensor(float)|
|LSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|14+|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(int32)|
|||[7, 13]|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(int32)|
|LayerNormalization|*in* X:**T**<br> *in* Scale:**T**<br> *in* B:**T**<br> *out* Y:**T**<br> *out* Mean:**U**<br> *out* InvStdDev:**U**<br><br>or<br><br>*in* X:**T**<br> *in* Scale:**V**<br> *in* B:**V**<br> *out* Y:**V**<br> *out* Mean:**U**<br> *out* InvStdDev:**U**|17+|**T** = tensor(double), tensor(float)<br/> **U** = tensor(float)|
|||[1, 16]|**T** = tensor(double), tensor(float)<br/> **U** = tensor(double), tensor(float)<br/> **V** = tensor(double), tensor(float)|
|LayerNormalization|*in* X:**T**<br> *in* Scale:**T**<br> *in* B:**T**<br> *out* Y:**T**<br> *out* Mean:**U**<br> *out* InvStdDev:**U**<br><br>or<br><br>*in* X:**T**<br> *in* Scale:**V**<br> *in* B:**V**<br> *out* Y:**V**<br> *out* Mean:**U**<br> *out* InvStdDev:**U**|17+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(float)|
|||[1, 16]|**T** = tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float), tensor(float16)<br/> **V** = tensor(double), tensor(float), tensor(float16)|
|LeakyRelu|*in* X:**T**<br> *out* Y:**T**|16+|**T** = tensor(float)|
|||[6, 15]|**T** = tensor(float)|
|Less|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T1** = tensor(bool)|
Expand Down Expand Up @@ -369,7 +369,7 @@ Do not modify directly.*
|||[6, 12]|**T** = tensor(double), tensor(float)|
|Sign|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[9, 12]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|SimplifiedLayerNormalization|*in* X:**T**<br> *in* scale:**V**<br> *out* Y:**V**<br> *out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float)<br/> **U** = tensor(double), tensor(float)<br/> **V** = tensor(double), tensor(float)|
|SimplifiedLayerNormalization|*in* X:**T**<br> *in* scale:**V**<br> *out* Y:**V**<br> *out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float), tensor(float16)<br/> **V** = tensor(double), tensor(float), tensor(float16)|
|Sin|*in* input:**T**<br> *out* output:**T**|7+|**T** = tensor(double), tensor(float)|
|Sinh|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(float)|
|Size|*in* data:**T**<br> *out* size:**T1**|21+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
Expand Down Expand Up @@ -511,8 +511,8 @@ Do not modify directly.*
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|SparseAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* block_row_indices:**M**<br> *in* block_col_indices:**M**<br> *in* total_sequence_length:**M**<br> *in* key_total_sequence_lengths:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float)|
|SparseToDenseMatMul|*in* A:**T**<br> *in* B:**T1**<br> *out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(string)|
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,16 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomai
// LayerNormalization is now in the ONNX spec. As the contrib op (incorrectly) used kOnnxDomain we need to version it
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, float, LayerNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, double, LayerNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, MLFloat16, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu);

Expand Down Expand Up @@ -342,12 +346,16 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, float, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, double, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 16, MLFloat16, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MLFloat16, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipSimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu)>,

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace contrib {

REGISTER_CONTRIB_KERNELS(float)
REGISTER_CONTRIB_KERNELS(double)
REGISTER_CONTRIB_KERNELS(MLFloat16)

} // namespace contrib
} // namespace onnxruntime
80 changes: 69 additions & 11 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "core/util/math_cpuonly.h"
#include "core/providers/common.h"
#include "core/platform/threadpool.h"
#include "core/util/force_inline.h"
#include "skip_layer_norm.h"
#include "skip_layer_norm_helper.h"

Expand Down Expand Up @@ -33,6 +34,50 @@ namespace contrib {

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(double)
REGISTER_KERNEL_TYPED(MLFloat16)

// Utility to convert from MLFloat16 to float only when the input type is MLFloat16.
template <typename T, typename Ret>
ORT_FORCEINLINE Ret ConvertMLFloat16ToDoubleOrFloatIfNeeded(T val);

template <>
ORT_FORCEINLINE float ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, float>(MLFloat16 val) {
return val.ToFloat();
}

template <>
ORT_FORCEINLINE double ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, double>(MLFloat16 val) {
return static_cast<double>(ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, float>(val));
}

template <>
ORT_FORCEINLINE constexpr float ConvertMLFloat16ToDoubleOrFloatIfNeeded<float, float>(float val) {
return val;
}

template <>
ORT_FORCEINLINE constexpr double ConvertMLFloat16ToDoubleOrFloatIfNeeded<double, double>(double val) {
return val;
}

// Function template that only converts the input value to MLFloat16 if T is MLFloat16.
template <typename T>
ORT_FORCEINLINE constexpr typename std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>, T>
ConvertDoubleOrFloatToMLFloat16IfNeeded(T val) {
return val;
}

template <typename T>
ORT_FORCEINLINE constexpr typename std::enable_if_t<std::is_same_v<T, MLFloat16>, T>
ConvertDoubleOrFloatToMLFloat16IfNeeded(float val) {
return MLFloat16(val);
}

template <typename T>
ORT_FORCEINLINE constexpr typename std::enable_if_t<std::is_same_v<T, MLFloat16>, T>
ConvertDoubleOrFloatToMLFloat16IfNeeded(double val) {
return MLFloat16(static_cast<float>(val));
}

template <typename T, bool simplified>
SkipLayerNorm<T, simplified>::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
Expand Down Expand Up @@ -91,21 +136,32 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
T* p_output = output_data + offset;
T* p_skip_input_bias_add_output_data = skip_input_bias_add_output_data != nullptr ? skip_input_bias_add_output_data + offset : nullptr;

T mean = 0;
T mean_square = 0;
using DoubleOrFloat = typename std::conditional<
std::is_same<T, double>::value, // If T is double
double, // Use double
float // Otherwise, use float (covers float and MLFloat16)
>::type;

DoubleOrFloat mean(0.0f);
DoubleOrFloat mean_square(0.0f);

std::unique_ptr<DoubleOrFloat[]> output_buffer = std::make_unique<DoubleOrFloat[]>(hidden_size);
for (size_t h = 0; h < static_cast<size_t>(hidden_size); h++) {
DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(p_input[h]);
DoubleOrFloat skip_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(p_skip[h]);

for (int64_t h = 0; h < hidden_size; h++) {
T value = p_input[h] + p_skip[h];
DoubleOrFloat value = input_value + skip_value;

if (nullptr != bias_data) {
value += bias_data[h];
value += ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(bias_data[h]);
}

output_buffer[h] = value;
T converted_value = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>(value);
if (nullptr != p_skip_input_bias_add_output_data) {
p_skip_input_bias_add_output_data[h] = value;
p_skip_input_bias_add_output_data[h] = converted_value;
}

p_output[h] = value;
mean += value;
mean_square += value * value;
}
Expand All @@ -117,13 +173,15 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_);
}

for (int64_t h = 0; h < hidden_size; h++) {
for (size_t h = 0; h < static_cast<size_t>(hidden_size); h++) {
DoubleOrFloat gamma_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(gamma_data[h]);
if (simplified) {
p_output[h] = p_output[h] / mean_square * gamma_data[h];
p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>(output_buffer[h] / mean_square * gamma_value);
} else if (nullptr == beta_data) {
p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h];
p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>((output_buffer[h] - mean) / mean_square * gamma_value);
} else {
p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h];
DoubleOrFloat beta_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(beta_data[h]);
p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>((output_buffer[h] - mean) / mean_square * gamma_value + beta_value);
}
}
},
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, Me
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, STFT);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, double, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization);

// Opset 18
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 18, float, Resize);
Expand Down Expand Up @@ -2465,6 +2466,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, double,
LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, MLFloat16,
LayerNormalization)>,

// Opset 18
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 18,
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cpu/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ namespace onnxruntime {

REGISTER_ONNX_KERNEL_TYPED(float)
REGISTER_ONNX_KERNEL_TYPED(double)
REGISTER_ONNX_KERNEL_TYPED(MLFloat16)

} // namespace onnxruntime
Loading

0 comments on commit eb2506d

Please sign in to comment.