diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc index 73aaa4ae61a6e..9c8e5717bc9bd 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc @@ -436,3 +436,49 @@ void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, c return; } } + +void NSLayerNormalization(size_t norm_count, size_t norm_size, bool isrms, float epsilon, const float* FpIn, + float* FpOut, const float* scale, const float* bias, float* mean_out, float* inv_msq_out, + void* ThreadPool) { + auto inorm_count = static_cast(norm_count); + auto inorm_size = static_cast(norm_size); + auto pth = reinterpret_cast(ThreadPool); + int threads = inorm_count <= 4 ? 1 : pth->num_threads(); + parallel::Scheduler2D sch({threads, inorm_count, inorm_size, 1, inorm_size}); + if (threads == 1) { + parallel::SingleThread st; + st.parallel_for([&](int tidx) { + parallel::ThreadProblem2D tp{tidx}; + sch.getIndex(tp); + if (tp.valid) { + for (size_t i = 0; i < tp.size[0]; i++) { + auto rid = tp.loc[0] + i; + auto srcptr = FpIn + rid * inorm_size; + auto dstptr = FpOut + rid * inorm_size; + auto ret = bestla::kernel::wrapper::LayerNormalization::forward_auto( + srcptr, scale, bias, epsilon, inorm_size, dstptr, mean_out ? &mean_out[rid] : nullptr, + inv_msq_out ? &inv_msq_out[rid] : nullptr, isrms); + (void)(ret); + assert(ret == BTLA_CODE::Success); + } + } + }); + } else { + pth->parallel_for([&](int tidx) { + parallel::ThreadProblem2D tp{tidx}; + sch.getIndex(tp); + if (tp.valid) { + for (size_t i = 0; i < tp.size[0]; i++) { + auto rid = tp.loc[0] + i; + auto srcptr = FpIn + rid * inorm_size; + auto dstptr = FpOut + rid * inorm_size; + auto ret = kernel::wrapper::LayerNormalization::forward_auto( + srcptr, scale, bias, epsilon, inorm_size, dstptr, mean_out ? &mean_out[rid] : nullptr, + inv_msq_out ? &inv_msq_out[rid] : nullptr, isrms); + (void)(ret); + assert(ret == BTLA_CODE::Success); + } + } + }); + } +} diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h index ebcb3027a209f..39c30dac6da51 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h @@ -127,3 +127,7 @@ size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const siz void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, void* ThreadPool = nullptr); + +void NSLayerNormalization(size_t norm_count, size_t norm_size, bool isrms, float epsilon, const float* FpIn, + float* FpOut, const float* scale, const float* bias, float* mean_out, float* inv_msq_out, + void* ThreadPool); \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h index d3902f9bd68c7..b85c7fc72e06e 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h @@ -31,6 +31,7 @@ #include "bestla/bestla_prologue_a.h" #include "bestla/bestla_wrapper.h" +#include "bestla/kernel_wrapper.h" #if defined(__GNUC__) #pragma GCC diagnostic pop diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc index e01f7f27c3596..b6a11ff573568 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc @@ -8,6 +8,7 @@ #include "core/platform/threadpool.h" #include "core/providers/common.h" #include "core/util/math_cpuonly.h" +#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" namespace onnxruntime { @@ -73,7 +74,13 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo if (inv_std_dev != nullptr) { inv_std_dev_data = inv_std_dev->MutableData(); } - +#if defined(ORT_NEURAL_SPEED) + if constexpr (std::is_same_v && std::is_same_v) { + NSLayerNormalization(norm_count, norm_size, simplified, epsilon, X_data, Y_data, scale_data, bias_data, mean_data, + inv_std_dev_data, p_ctx->GetOperatorThreadPool()); + return Status::OK(); + } +#endif concurrency::ThreadPool::TryBatchParallelFor( p_ctx->GetOperatorThreadPool(), static_cast(norm_count), [&](ptrdiff_t task_idx) {