From 6df167648c023d0b4a3ecb1508d2480cfc345331 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Thu, 26 Sep 2024 20:15:21 +0800 Subject: [PATCH] optimize qlinearsoftmax --- cmake/onnxruntime_mlas.cmake | 11 + .../cpu/quantization/qlinear_softmax.cc | 150 +-------- onnxruntime/core/mlas/inc/mlas.h | 15 + onnxruntime/core/mlas/lib/mlasi.h | 30 ++ onnxruntime/core/mlas/lib/platform.cpp | 7 +- onnxruntime/core/mlas/lib/qsoftmax.cpp | 210 +++++++++++++ .../core/mlas/lib/qsoftmax_kernel_avx2.cpp | 267 ++++++++++++++++ .../core/mlas/lib/qsoftmax_kernel_avx512.cpp | 293 ++++++++++++++++++ .../core/mlas/lib/qsoftmax_kernel_naive.cpp | 106 +++++++ .../test/mlas/bench/bench_qsoftmax.cpp | 38 +++ 10 files changed, 982 insertions(+), 145 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/qsoftmax.cpp create mode 100644 onnxruntime/core/mlas/lib/qsoftmax_kernel_avx2.cpp create mode 100644 onnxruntime/core/mlas/lib/qsoftmax_kernel_avx512.cpp create mode 100644 onnxruntime/core/mlas/lib/qsoftmax_kernel_naive.cpp create mode 100644 onnxruntime/test/mlas/bench/bench_qsoftmax.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 20bb1fb772189..17ba6b7e35fbd 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -41,6 +41,8 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp ${MLAS_SRC_DIR}/cast.cpp + ${MLAS_SRC_DIR}/qsoftmax.cpp + ${MLAS_SRC_DIR}/qsoftmax_kernel_naive.cpp ) target_sources(onnxruntime_mlas PRIVATE @@ -163,6 +165,10 @@ function(setup_mlas_source_for_windows) file(GLOB_RECURSE mlas_platform_srcs_avx2 CONFIGURE_DEPENDS "${MLAS_SRC_DIR}/intrinsics/avx2/*.cpp" ) + set(mlas_platform_srcs_avx2 + ${mlas_platform_srcs_avx2} + "${MLAS_SRC_DIR}/qsoftmax_kernel_avx2.cpp" + ) set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2") target_sources(onnxruntime_mlas PRIVATE @@ -171,6 +177,7 @@ function(setup_mlas_source_for_windows) ${mlas_platform_srcs_avx2} ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/qsoftmax_kernel_avx512.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp @@ -570,6 +577,7 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/qsoftmax_kernel_avx2.cpp ) if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) set(mlas_platform_srcs_avx2 @@ -610,6 +618,7 @@ endif() set(mlas_platform_srcs_avx512vnni ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp + ${MLAS_SRC_DIR}/qsoftmax_kernel_avx512.cpp ) set_source_files_properties(${mlas_platform_srcs_avx512vnni} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") @@ -619,6 +628,7 @@ endif() ${MLAS_SRC_DIR}/dgemm.cpp ${MLAS_SRC_DIR}/pooling_fp16.cpp ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp + ${mlas_platform_srcs_sse2} ${mlas_platform_srcs_avx} ${mlas_platform_srcs_avx2} @@ -643,6 +653,7 @@ endif() ) set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") + set_source_files_properties(${MLAS_SRC_DIR}/qsoftmax_kernel_avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc index de1798e54874f..bc381839859b3 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc @@ -36,7 +36,7 @@ void QlinearBuildLookupTableUint32(gsl::span tabl for (int32_t i = 0; i < 256; i++) { double scaled_exp_xi = exp((static_cast(i) - 255 + bit_shift) * static_cast(x_scale)); // we can't get the real max value of input tensor here, so we just assume 255-bit_shift. - // in the function of `QlinearSoftmaxCPU`, + // in the function of `MlasComputeQSoftmax`, // all numbers will have a shift (255-bit_shift-max_value) if its max value is not 255 // // if is_signed index = [1 2 3 ......126 127 -128 -127 ..... -3 -2 -1] @@ -123,136 +123,6 @@ Status QLinearSoftmax::Compute(OpKernelContext* ctx) const { } } -template -common::Status QlinearSoftmaxCPU(size_t N, - size_t D, - const T* x_data, - T* y_data, - const QLinearSoftmax::EXP_OUT_DTYPE* lookup_table, - QLinearSoftmax::EXP_OUT_DTYPE y_scale, - T yzp, - onnxruntime::concurrency::ThreadPool* thread_pool); - -template <> -common::Status QlinearSoftmaxCPU(size_t N, - size_t D, - const uint8_t* x_data, - uint8_t* y_data, - const QLinearSoftmax::EXP_OUT_DTYPE* lookup_table, - QLinearSoftmax::EXP_OUT_DTYPE y_scale, - uint8_t yzp, - onnxruntime::concurrency::ThreadPool* thread_pool) { - using onnxruntime::TensorOpCost; - using onnxruntime::concurrency::ThreadPool; - ThreadPool::TryParallelFor( - thread_pool, N, - // Read 3*N (max,sum,div) write N (div), computation=Read - TensorOpCost{static_cast(D) * 3.0, - static_cast(D), - static_cast(D) * 3.0}, - [x_data, y_data, D, y_scale, yzp, &lookup_table](std::ptrdiff_t first, std::ptrdiff_t last) { - const auto c_y_scale = y_scale; - const auto c_y_zp = yzp; - const uint8_t* x_t = x_data + first * D; - uint8_t* y_t = y_data + first * D; - for (; first < last; first++) { - // reduceMaxUint8 - uint8_t xmax = *std::max_element(x_t, x_t + D); - // we want the xmas to align with 255 for higher precision. - // as we build a lookup table with X-255. So we could use the adjustment here - // to let all numbers have a shift in the lookup table. - // 1 2 3 4 5 ...........................254 255 - // 1 3 5 ... 10 - // after the shift ---> - // 235 237 239 .. 255 - const QLinearSoftmax::EXP_OUT_DTYPE* shifted_lookuptable = lookup_table + 255 - xmax; - size_t elements_n = D; - // reduceSumUin8ToUint32: need speedup - // vsum = \sum_i{e^x_i} - QLinearSoftmax::EXP_OUT_DTYPE vsum = 0; - const uint8_t* x_t_cur = x_t; - do { - const size_t vx = *x_t_cur++; - vsum += shifted_lookuptable[vx]; - } while (--elements_n != 0); - if (vsum == 0) { - return; - } - elements_n = D; - x_t_cur = x_t; - // elementwise div, y_i=\frac{x_i}{vsum} - do { - const size_t vx = *x_t_cur++; - const QLinearSoftmax::EXP_OUT_DTYPE vt = shifted_lookuptable[vx]; - // simulate round function, and re-quant to uint8 - const uint32_t vq = static_cast(std::nearbyintf(((vt * c_y_scale)) / vsum)) + c_y_zp; - const uint8_t vy = vq > 255 ? static_cast(255) : static_cast(vq); - *y_t++ = vy; - } while (--elements_n != 0); - x_t = x_t_cur; - } - }); - - return Status::OK(); -} - -template <> -common::Status QlinearSoftmaxCPU(size_t N, - size_t D, - const int8_t* x_data, - int8_t* y_data, - const QLinearSoftmax::EXP_OUT_DTYPE* lookup_table, - QLinearSoftmax::EXP_OUT_DTYPE y_scale, - int8_t yzp, - onnxruntime::concurrency::ThreadPool* thread_pool) { - using onnxruntime::TensorOpCost; - using onnxruntime::concurrency::ThreadPool; - ThreadPool::TryParallelFor( - thread_pool, N, - // Read 3*N (max,sum,div) write N (div), computation=Read - TensorOpCost{static_cast(D) * 3.0, - static_cast(D), - static_cast(D) * 3.0}, - [x_data, y_data, D, y_scale, yzp, &lookup_table](std::ptrdiff_t first, std::ptrdiff_t last) { - const auto c_y_scale = y_scale; - const auto c_y_zp = yzp; - - const int8_t* x_t = x_data + first * D; - int8_t* y_t = y_data + first * D; - for (; first < last; first++) { - // reduceMaxInt8 - int8_t xmax = *std::max_element(x_t, x_t + D); - const int32_t adjustment = int32_t(127) - xmax; - const QLinearSoftmax::EXP_OUT_DTYPE* shifted_lookuptable = lookup_table; - size_t elements_n = D; - // reduceSumUin8ToUint32: need speedup - QLinearSoftmax::EXP_OUT_DTYPE vsum = 0; - const int8_t* x_t_cur = x_t; - do { - const uint8_t vx = uint8_t(adjustment + (*x_t_cur++)); - vsum += shifted_lookuptable[vx]; - } while (--elements_n != 0); - if (vsum == 0) { - return; - } - elements_n = D; - x_t_cur = x_t; - // elementwise div - do { - const uint8_t vx = uint8_t(adjustment + (*x_t_cur++)); - const QLinearSoftmax::EXP_OUT_DTYPE vt = shifted_lookuptable[vx]; - // simulate round function, and re-quant to Int8 - const int32_t vq = static_cast(std::nearbyintf(((vt * c_y_scale)) / vsum)) + c_y_zp; - const int8_t vy = static_cast(vq) > 255 ? static_cast(255) : static_cast(vq); - *y_t++ = vy; - } while (--elements_n != 0); - x_t = x_t_cur; - } - }); - - return Status::OK(); -} - gsl::span QLinearSoftmax::GetLookupTable( OpKernelContext* context, gsl::span lookup_table_span, @@ -270,25 +140,17 @@ gsl::span QLinearSoftmax::GetLookupTable( Status QLinearSoftmax::ComputeInternal(OpKernelContext* context, const Tensor& input, Tensor& output, gsl::span lookup_table, int axis, concurrency::ThreadPool* thread_pool) const { + const auto* X_scale_tensor = context->Input(1); const auto* Y_scale_tensor = context->Input(3); const auto* Y_zp_tensor = context->Input(4); const QLinearSoftmax::EXP_OUT_DTYPE Y_scale = std::floor(1.0F / (*(Y_scale_tensor->Data()))); const auto& X_shape = input.Shape(); const size_t N = onnxruntime::narrow(X_shape.SizeToDimension(onnxruntime::narrow(axis))); const size_t D = onnxruntime::narrow(X_shape.SizeFromDimension(onnxruntime::narrow(axis))); - common::Status status; - if (is_signed_) { - using T = int8_t; - const T Y_zp = Y_zp_tensor ? *(Y_zp_tensor->Data()) : 0; - status = QlinearSoftmaxCPU(N, D, input.Data(), output.MutableData(), - lookup_table.data(), Y_scale, Y_zp, thread_pool); - } else { - using T = uint8_t; - const T Y_zp = Y_zp_tensor ? *(Y_zp_tensor->Data()) : 0; - status = QlinearSoftmaxCPU(N, D, input.Data(), output.MutableData(), - lookup_table.data(), Y_scale, Y_zp, thread_pool); - } - return status; + const int Y_zp = Y_zp_tensor ? (is_signed_ ? *(Y_zp_tensor->Data()) : *(Y_zp_tensor->Data())) : 0; + MlasComputeQSoftmax(input.DataRaw(), output.MutableDataRaw(), N, D, lookup_table.data(), + *X_scale_tensor->Data(), Y_scale, Y_zp, is_signed_, thread_pool); + return Status::OK(); } // opset-13 and above diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 28ae64c4d5b3e..1eb9d87a9546a 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1018,6 +1018,21 @@ MlasComputeSoftmax( MLAS_THREADPOOL* ThreadPool ); +void +MLASCALL +MlasComputeQSoftmax( + const void* Input, + void* Output, + size_t N, + size_t D, + const float* LoopupTable, + float X_Scale, + float Scale, + int ZeroPoint, + bool is_signed, + MLAS_THREADPOOL* ThreadPool +); + void MLASCALL MlasComputeTanh( diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 13ea8d96c20e4..aa0454f3aed19 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -56,6 +56,7 @@ Module Name: #if defined(__GNUC__) && __GNUC__ >= 12 #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // GCC 12 warns about uninitialized variables in immintrin.h. +#pragma GCC diagnostic ignored "-Wuninitialized" // GCC 12 warns about uninitialized variables in immintrin.h. #include #pragma GCC diagnostic pop #else @@ -711,6 +712,26 @@ void float Scale, int8_t ZeroPoint); +typedef +void (MLASCALL MLAS_QUANTIZE_SOFTMAX_I8_KERNEL)( + size_t D, + const int8_t* Xdata, + int8_t* Ydata, + const float* LookupTable, + float Yscale, + int8_t YZeroPoint, + float* Buff); + +typedef +void (MLASCALL MLAS_QUANTIZE_SOFTMAX_U8_KERNEL)( + size_t D, + const uint8_t* Xdata, + uint8_t* Ydata, + const float* LookupTable, + float Yscale, + uint8_t YZeroPoint, + float* Buff); + template struct MLAS_QUANT_KERNEL { @@ -876,7 +897,13 @@ extern "C" { MLAS_QLINEAR_BINARY_OP_U8_KERNEL MlasQLinearAddU8KernelAvx2; MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelAvx512F; MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelAvx512F; + MLAS_QUANTIZE_SOFTMAX_I8_KERNEL MlasQuantizeSoftmaxI8KernelAvx2; + MLAS_QUANTIZE_SOFTMAX_I8_KERNEL MlasQuantizeSoftmaxI8KernelAvx512; + MLAS_QUANTIZE_SOFTMAX_U8_KERNEL MlasQuantizeSoftmaxU8KernelAvx2; + MLAS_QUANTIZE_SOFTMAX_U8_KERNEL MlasQuantizeSoftmaxU8KernelAvx512; #endif + MLAS_QUANTIZE_SOFTMAX_I8_KERNEL MlasQuantizeSoftmaxI8KernelNaive; + MLAS_QUANTIZE_SOFTMAX_U8_KERNEL MlasQuantizeSoftmaxU8KernelNaive; MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32Kernel; MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32Kernel; @@ -1188,6 +1215,9 @@ struct MLAS_PLATFORM { MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; + + MLAS_QUANTIZE_SOFTMAX_I8_KERNEL *QuantizeSoftmaxI8Kernel; + MLAS_QUANTIZE_SOFTMAX_U8_KERNEL *QuantizeSoftmaxU8Kernel; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 23d29fd02fa5a..c0ae041701646 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -246,6 +246,8 @@ Return Value: this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernel; this->CastF16ToF32Kernel = nullptr; this->CastF32ToF16Kernel = nullptr; + this->QuantizeSoftmaxI8Kernel = MlasQuantizeSoftmaxI8KernelNaive; + this->QuantizeSoftmaxU8Kernel = MlasQuantizeSoftmaxU8KernelNaive; #if defined(MLAS_TARGET_AMD64_IX86) @@ -258,7 +260,6 @@ Return Value: this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchSse; #if defined(MLAS_TARGET_AMD64) - this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Sse; this->GemmDoubleKernel = MlasGemmDoubleKernelSse; this->ConvNchwFloatKernel = MlasConvNchwFloatKernelSse; @@ -391,6 +392,8 @@ Return Value: this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2; this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; + this->QuantizeSoftmaxI8Kernel = MlasQuantizeSoftmaxI8KernelAvx2; + this->QuantizeSoftmaxU8Kernel = MlasQuantizeSoftmaxU8KernelAvx2; // // Check if the processor supports Hybrid core architecture. @@ -460,6 +463,8 @@ Return Value: this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchAvx512; this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512; + this->QuantizeSoftmaxI8Kernel = MlasQuantizeSoftmaxI8KernelAvx512; + this->QuantizeSoftmaxU8Kernel = MlasQuantizeSoftmaxU8KernelAvx512; // // Check if the processor supports AVX512VNNI. // diff --git a/onnxruntime/core/mlas/lib/qsoftmax.cpp b/onnxruntime/core/mlas/lib/qsoftmax.cpp new file mode 100644 index 0000000000000..ba81fc6b3d0b9 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qsoftmax.cpp @@ -0,0 +1,210 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qsoftmax.cpp + +Abstract: + + This module implements miscellaneous computation routines. + + Our usage requires building platform specific versions of the algorithm to + target different instruction sets. The implementation below targets the + base instruction set (typically SSE2) while assembly implementations target + newer instruction sets (such as FMA3). + +--*/ + +#include "mlasi.h" + +struct MLAS_QSOFTMAX_WORK_BLOCK { + const void* Input; + void* Output; + size_t N; + size_t D; + const float* LoopupTable; + float Scale; + int ZeroPoint; + size_t ThreadCountN; + bool is_signed; +}; + +static void BuildLookupTable(gsl::span table, + const float x_scale, + size_t reduce_len, bool is_signed) { + // make sure sum(exp(x)) < max() + double bit_shift = log(std::numeric_limits::max() / reduce_len); + double reserve_bit = std::is_same_v ? 5 : 3; + bit_shift = std::max(0.0, bit_shift - reserve_bit) / x_scale; + + for (int32_t i = 0; i < 256; i++) { + double scaled_exp_xi = exp((static_cast(i) - 255 + bit_shift) * static_cast(x_scale)); + // we can't get the real max value of input tensor here, so we just assume 255-bit_shift. + // in the function of `MlasComputeQSoftmax`, + // all numbers will have a shift (255-bit_shift-max_value) if its max value is not 255 + // + // if is_signed index = [1 2 3 ......126 127 -128 -127 ..... -3 -2 -1] + // else [0 1 2 3 4 ..... 256] + uint8_t index = static_cast(is_signed ? i - 128 : i); + table[index] = static_cast((scaled_exp_xi)); + } +} + +void MlasComputeQSoftmaxThreaded(void* Context, ptrdiff_t Index) +/*++ + +Routine Description: + + This routine is invoked from a worker thread to execute a segment of a + softmax or log softmax operation. + +Arguments: + + Context - Supplies the pointer to the context for the threaded operation. + + ThreadId - Supplies the current index of the threaded operation. + +Return Value: + + None. + +--*/ +{ + static MLAS_QUANTIZE_SOFTMAX_I8_KERNEL* Ikernel = GetMlasPlatform().QuantizeSoftmaxI8Kernel; + static MLAS_QUANTIZE_SOFTMAX_U8_KERNEL* Ukernel = GetMlasPlatform().QuantizeSoftmaxU8Kernel; + + const auto* WorkBlock = (MLAS_QSOFTMAX_WORK_BLOCK*)Context; + + // + // Partition the operation along the N dimension. + // + + size_t n; + size_t CountN; + + MlasPartitionWork(Index, WorkBlock->ThreadCountN, WorkBlock->N, &n, &CountN); + size_t packBSize = (WorkBlock->D * sizeof(float) + ThreadedBufAlignment - 1) / ThreadedBufAlignment; + packBSize *= ThreadedBufAlignment; + + MlasThreadedBufAlloc(packBSize); + + float* temp_buff = reinterpret_cast(ThreadedBufHolder.get()); + + // + // Compute the softmax or log softmax function. + // + + const size_t D = WorkBlock->D; + const float Scale = WorkBlock->Scale; + const int ZeroPoint = WorkBlock->ZeroPoint; + const float* LoopupTable = WorkBlock->LoopupTable; + + const int8_t* Input = reinterpret_cast(WorkBlock->Input) + n * D; + int8_t* Output = reinterpret_cast(WorkBlock->Output) + n * D; + +#if defined(MLAS_SSE2_INTRINSICS) + // TODO: Use std::hardware_constructive_interference_size + constexpr size_t CacheLineSize = 64; + constexpr size_t ElementsPerCacheLine = CacheLineSize / sizeof(float); +#endif + + while (CountN > 0) { +#if defined(MLAS_SSE2_INTRINSICS) + // + // Prefetch the next row of the input buffer. + // + + for (size_t i = 0; i * ElementsPerCacheLine < D; i++) { + _mm_prefetch((char*)(Input + D) + i * CacheLineSize, _MM_HINT_T0); + } +#endif + if (WorkBlock->is_signed) { + Ikernel(D, (Input), Output, LoopupTable, Scale, static_cast(ZeroPoint), temp_buff); + } else { + Ukernel(D, reinterpret_cast(Input), reinterpret_cast(Output), LoopupTable, Scale, + static_cast(ZeroPoint), temp_buff); + } + + Input += D; + Output += D; + CountN--; + } +} + +void MLASCALL MlasComputeQSoftmax(const void* Input, void* Output, size_t N, size_t D, const float* LoopupTable, + float X_Scale, float Scale, int ZeroPoint, bool is_signed, MLAS_THREADPOOL* ThreadPool) +/*++ + +Routine Description: + + This routine computes the quantized softmax function. + + N.B. This implementation supports in place updates of the output buffer. + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. + + N - Supplies the number of rows to process. + + D - Supplies the number of columns per row to process. + + LoopupTable - Supplies lookup exp values. + + Scale - quantization params. + ZeroPoint - quantization params. + is_signed - int8 or uint8. + + ThreadPool - Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + +Return Value: + + None. + +--*/ +{ + MLAS_QSOFTMAX_WORK_BLOCK WorkBlock; + + // + // Capture the softmax parameters to the work block. + // + + WorkBlock.LoopupTable = LoopupTable; + WorkBlock.Scale = Scale; + WorkBlock.ZeroPoint = ZeroPoint; + WorkBlock.Input = Input; + WorkBlock.Output = Output; + WorkBlock.N = N; + WorkBlock.D = D; + WorkBlock.is_signed = is_signed; + + // shared by all threads + std::vector lookup_table(256); + if (WorkBlock.LoopupTable == nullptr) { + BuildLookupTable(lookup_table, X_Scale, D, is_signed); + WorkBlock.LoopupTable = lookup_table.data(); + } + // + // Compute the number of target threads given the complexity of the softmax + // operation. Limit the number of threads to the number of rows and try to + // keep each thread processing a minimum number of elements before using + // another thread. + // + + ptrdiff_t ThreadCountN = MlasGetMaximumThreadCount(ThreadPool); + + if (size_t(ThreadCountN) > N) { + ThreadCountN = ptrdiff_t(N); + } + + WorkBlock.ThreadCountN = ThreadCountN; + + MlasExecuteThreaded(MlasComputeQSoftmaxThreaded, &WorkBlock, ThreadCountN, ThreadPool); +} diff --git a/onnxruntime/core/mlas/lib/qsoftmax_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/qsoftmax_kernel_avx2.cpp new file mode 100644 index 0000000000000..cea4ca396c6f7 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qsoftmax_kernel_avx2.cpp @@ -0,0 +1,267 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qsoftmax_kernel_avx2.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for x64 avx2. + +--*/ + +#include "mlasi.h" + +static uint8_t reduce_max_u8_avx2(const uint8_t* data, size_t size) { + __m256i max_vec = _mm256_set1_epi8(0); + + size_t i; + for (i = 0; i + 32 <= size; i += 32) { + __m256i vec = _mm256_loadu_si256((const __m256i*)(data + i)); + max_vec = _mm256_max_epu8(max_vec, vec); + } + + // Now reduce the 256-bit max_vec into a single max value + // First, split the 256-bit vector into two 128-bit halves and compute the max + // between them + __m128i max_128 = _mm_max_epu8(_mm256_castsi256_si128(max_vec), _mm256_extracti128_si256(max_vec, 1)); + + // Further reduce the 128-bit vector to a scalar + // Extract the upper 64-bit part and compute the max + max_128 = _mm_max_epu8(max_128, _mm_srli_si128(max_128, 8)); + // Extract the upper 32-bit part and compute the max + max_128 = _mm_max_epu8(max_128, _mm_srli_si128(max_128, 4)); + // Extract the upper 16-bit part and compute the max + max_128 = _mm_max_epu8(max_128, _mm_srli_si128(max_128, 2)); + // Extract the upper 8-bit part and compute the max + max_128 = _mm_max_epu8(max_128, _mm_srli_si128(max_128, 1)); + + // Extract the final max value + uint8_t max_value = static_cast(_mm_extract_epi8(max_128, 0)); + + for (; i < size; ++i) { + if (data[i] > max_value) { + max_value = data[i]; + } + } + + return max_value; +} + +int8_t reduce_max_i8_avx2(const int8_t* data, size_t size) { + __m256i max_vec = _mm256_set1_epi8(INT8_MIN); + + size_t i; + for (i = 0; i + 32 <= size; i += 32) { + __m256i vec = _mm256_loadu_si256((const __m256i*)(data + i)); + max_vec = _mm256_max_epi8(max_vec, vec); + } + + int8_t remaining_max = INT8_MIN; + for (; i < size; ++i) { + if (data[i] > remaining_max) { + remaining_max = data[i]; + } + } + + alignas(32) int8_t max_arr[32]; + _mm256_storeu_si256((__m256i*)max_arr, max_vec); + + int8_t max_value = max_arr[0]; + for (size_t j = 1; j < 32; ++j) { + if (max_arr[j] > max_value) { + max_value = max_arr[j]; + } + } + + if (remaining_max > max_value) { + max_value = remaining_max; + } + + return max_value; +} + +MLAS_FORCEINLINE +__m128i MlasFloatToI8Avx2(const __m256 float_val1, const __m256 float_val2) { + __m256 rounded_val1 = _mm256_round_ps(float_val1, _MM_FROUND_TO_NEAREST_INT); + __m256 rounded_val2 = _mm256_round_ps(float_val2, _MM_FROUND_TO_NEAREST_INT); + __m256i int_vec1 = _mm256_cvtps_epi32(rounded_val1); + __m256i int_vec2 = _mm256_cvtps_epi32(rounded_val2); + + __m256i packed16_1 = _mm256_packs_epi32(int_vec1, int_vec2); + __m128i packed8 = _mm_packs_epi16(_mm256_castsi256_si128(packed16_1), _mm256_extracti128_si256(packed16_1, 1)); + __m128i lanefix = _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(packed8), _mm_setr_epi32(0, 2, 1, 3))); + return lanefix; +} + +MLAS_FORCEINLINE +__m128i MlasFloatToU8Avx2(const __m256 float_val1, const __m256 float_val2) { + __m256 rounded_val1 = _mm256_round_ps(float_val1, _MM_FROUND_TO_NEAREST_INT); + __m256 rounded_val2 = _mm256_round_ps(float_val2, _MM_FROUND_TO_NEAREST_INT); + __m256i int_vec1 = _mm256_cvtps_epi32(rounded_val1); + __m256i int_vec2 = _mm256_cvtps_epi32(rounded_val2); + + __m256i packed16_1 = _mm256_packus_epi32(int_vec1, int_vec2); + __m128i packed8 = _mm_packus_epi16(_mm256_castsi256_si128(packed16_1), _mm256_extracti128_si256(packed16_1, 1)); + __m128i lanefix = _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(packed8), _mm_setr_epi32(0, 2, 1, 3))); + return lanefix; +} + +float exp_and_sum_i8_avx2(const float* base_addr, const int8_t* indice, size_t size, int32_t adjustment, + float* temp_out) { + __m256 sum = _mm256_setzero_ps(); + __m128i broadcast_adjustment = _mm_set1_epi8(static_cast(adjustment)); + //======================reduce sum start========================= + size_t i; + for (i = 0; i + 16 <= size; i += 16) { + __m128i index_ori = _mm_loadu_si128((const __m128i*)(indice + i)); + __m128i index = _mm_add_epi8(index_ori, broadcast_adjustment); + + __m256i vec32_low = _mm256_cvtepu8_epi32(index); + __m256 gathered = _mm256_i32gather_ps(base_addr, vec32_low, 4); + sum = _mm256_add_ps(sum, gathered); + _mm256_storeu_ps(&temp_out[i], gathered); + + __m128i vec8_high = _mm_srli_si128(index, 8); + __m256i vec32_high = _mm256_cvtepu8_epi32(vec8_high); + gathered = _mm256_i32gather_ps(base_addr, vec32_high, 4); + sum = _mm256_add_ps(sum, gathered); + _mm256_storeu_ps(&temp_out[i + 8], gathered); + } + float partial_sum = 0; + for (; i < size; ++i) { + float data = base_addr[uint8_t(indice[i] + adjustment)]; + partial_sum += data; + temp_out[i] = data; + } + alignas(32) float results[8]; + _mm256_store_ps(results, sum); + float total_sum = partial_sum; + for (size_t j = 0; j < 8; ++j) { + total_sum += results[j]; + } + return total_sum; +} + +float exp_and_sum_u8_avx2(const float* base_addr, const uint8_t* indice, size_t size, int32_t, float* temp_out) { + __m256 sum = _mm256_setzero_ps(); + //======================reduce sum start========================= + size_t i; + for (i = 0; i + 16 <= size; i += 16) { + __m128i index = _mm_loadu_si128((const __m128i*)(indice + i)); + __m256i vec32_low = _mm256_cvtepu8_epi32(index); + __m256 gathered = _mm256_i32gather_ps(base_addr, vec32_low, 4); + sum = _mm256_add_ps(sum, gathered); + _mm256_storeu_ps(&temp_out[i], gathered); + + __m128i vec8_high = _mm_srli_si128(index, 8); + __m256i vec32_high = _mm256_cvtepu8_epi32(vec8_high); + gathered = _mm256_i32gather_ps(base_addr, vec32_high, 4); + sum = _mm256_add_ps(sum, gathered); + _mm256_storeu_ps(&temp_out[i + 8], gathered); + } + float partial_sum = 0; + for (; i < size; ++i) { + float data = base_addr[indice[i]]; + partial_sum += data; + temp_out[i] = data; + } + alignas(32) float results[8]; + _mm256_store_ps(results, sum); + float total_sum = partial_sum; + for (size_t j = 0; j < 8; ++j) { + total_sum += results[j]; + } + return total_sum; +} + +int32_t normalize_sum_avx2(float total_sum, size_t size, float x_scale, float* temp_out, float yzp, int8_t* output) { + size_t i; + + //======================m scale d sum p zero start========================= + float inverse_sum = 1.0f / total_sum; + float scale = inverse_sum * x_scale; + __m256 broadcast_scale = _mm256_broadcast_ss(&scale); + __m256 broadcast_zp = _mm256_broadcast_ss(&yzp); + + // div sum + for (i = 0; i + 16 <= size; i += 16) { + __m256 vec1 = _mm256_loadu_ps(&temp_out[i]); + __m256 product1 = _mm256_mul_ps(vec1, broadcast_scale); + __m256 fma_result1 = _mm256_add_ps(product1, broadcast_zp); + + __m256 vec2 = _mm256_loadu_ps(&temp_out[i + 8]); + __m256 product2 = _mm256_mul_ps(vec2, broadcast_scale); + __m256 fma_result2 = _mm256_add_ps(product2, broadcast_zp); + + __m128i packed8 = MlasFloatToI8Avx2(fma_result1, fma_result2); + + _mm_storeu_si128((__m128i*)&output[i], packed8); + } + + constexpr uint8_t max_u8 = 255; + for (; i < size; ++i) { + int v = int32_t(std::nearbyintf(temp_out[i] * scale + yzp)); + output[i] = v > max_u8 ? static_cast(max_u8) : static_cast(v); + } + + return 0; +} + +int32_t normalize_sum_avx2(float total_sum, size_t size, float x_scale, float* temp_out, float yzp, uint8_t* output) { + size_t i; + + //======================m scale d sum p zero start========================= + float inverse_sum = 1.0f / total_sum; + float scale = inverse_sum * x_scale; + __m256 broadcast_scale = _mm256_broadcast_ss(&scale); + __m256 broadcast_zp = _mm256_broadcast_ss(&yzp); + + // div sum + for (i = 0; i + 16 <= size; i += 16) { + __m256 vec1 = _mm256_loadu_ps(&temp_out[i]); + __m256 product1 = _mm256_mul_ps(vec1, broadcast_scale); + __m256 fma_result1 = _mm256_add_ps(product1, broadcast_zp); + + __m256 vec2 = _mm256_loadu_ps(&temp_out[i + 8]); + __m256 product2 = _mm256_mul_ps(vec2, broadcast_scale); + __m256 fma_result2 = _mm256_add_ps(product2, broadcast_zp); + + __m128i packed8 = MlasFloatToU8Avx2(fma_result1, fma_result2); + + _mm_storeu_si128((__m128i*)&output[i], packed8); + } + constexpr uint8_t max_u8 = 255; + for (; i < size; ++i) { + int v = int32_t(std::nearbyintf(temp_out[i] * scale + yzp)); + output[i] = v > max_u8 ? max_u8 : static_cast(v); + } + + return 0; +} + +// compute softmax for a row with D elements +void MlasQuantizeSoftmaxI8KernelAvx2(size_t D, const int8_t* x_data, int8_t* y_data, const float* lookup_table, + float y_scale, int8_t yzp, float* tempaddr) { + constexpr int i = 0; + int32_t xmax = reduce_max_i8_avx2(x_data + i * D, D); + const int32_t adjustment = int32_t(127) - xmax; + const float* shifted_lookuptable = lookup_table; + float total_sum = exp_and_sum_i8_avx2(shifted_lookuptable, x_data + i * D, D, adjustment, (float*)tempaddr); + normalize_sum_avx2(total_sum, D, y_scale, (float*)tempaddr, yzp, y_data + i * D); +} + +void MlasQuantizeSoftmaxU8KernelAvx2(size_t D, const uint8_t* x_data, uint8_t* y_data, const float* lookup_table, + float y_scale, uint8_t yzp, float* tempaddr) { + constexpr int i = 0; + int32_t xmax = reduce_max_u8_avx2(x_data + i * D, D); + const int32_t adjustment = int32_t(255) - xmax; + const float* shifted_lookuptable = lookup_table + adjustment; + float total_sum = exp_and_sum_u8_avx2(shifted_lookuptable, x_data + i * D, D, adjustment, (float*)tempaddr); + normalize_sum_avx2(total_sum, D, y_scale, (float*)tempaddr, yzp, y_data + i * D); +} diff --git a/onnxruntime/core/mlas/lib/qsoftmax_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/qsoftmax_kernel_avx512.cpp new file mode 100644 index 0000000000000..be83a147b844c --- /dev/null +++ b/onnxruntime/core/mlas/lib/qsoftmax_kernel_avx512.cpp @@ -0,0 +1,293 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qsoftmax_kernel_avx512.cpp.h + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for x64 avx2. + +--*/ +#include "mlas.h" +#include "mlasi.h" + +uint8_t reduce_max_u8_avx512(const uint8_t* data, size_t size) { + // Initialize max value to the smallest possible uint8_t (0) + __m512i max_val = _mm512_set1_epi8(0); // Set the initial max value to 0 for unsigned + size_t i; + // Process data in chunks of 64 bytes (512 bits, which is 64 * 8-bit integers) + for (i = 0; i + 64 < size; i += 64) { + // Load 64 bytes into a 512-bit register + __m512i vec = _mm512_loadu_si512((__m512i*)&data[i]); + + // Compute the maximum values + max_val = _mm512_max_epu8(max_val, vec); // Use unsigned comparison + } + + // Reduce the final max_val to find the maximum element + // Extract the upper 256 bits and compare with the lower 256 bits + __m256i max256 = _mm512_extracti64x4_epi64(max_val, 1); // Extract upper 256 bits + max256 = _mm256_max_epu8(max256, + _mm512_castsi512_si256(max_val)); // Compare upper 256 with lower 256 + + // Further reduce 256-bit value + __m128i max128 = _mm256_extracti128_si256(max256, 1); // Extract upper 128 bits + max128 = _mm_max_epu8(max128, + _mm256_castsi256_si128(max256)); // Compare upper 128 with lower 128 + + // Further reduce 128-bit value + max128 = _mm_max_epu8(max128, _mm_srli_si128(max128, 8)); // Compare first 8 bytes with second 8 bytes + max128 = _mm_max_epu8(max128, _mm_srli_si128(max128, 4)); // Further reduce + max128 = _mm_max_epu8(max128, _mm_srli_si128(max128, 2)); // Further reduce + max128 = _mm_max_epu8(max128, _mm_srli_si128(max128, 1)); // Final reduce + + // The maximum value is now in the first byte of max128 + uint8_t max_value = static_cast(_mm_extract_epi8(max128, 0)); // Extract the maximum value + + for (; i < size; ++i) { + if (data[i] > max_value) { + max_value = data[i]; + } + } + + return max_value; +} + +int8_t reduce_max_i8_avx512(const int8_t* data, size_t size) { + size_t i; + __m512i max_val = _mm512_set1_epi8(INT8_MIN); // Start with the minimum signed value + + // Process data in chunks of 64 bytes (512 bits, which is 64 * 8-bit integers) + for (i = 0; i + 64 < size; i += 64) { + // Load 64 bytes into a 512-bit register + __m512i vec = _mm512_loadu_si512((__m512i*)&data[i]); + + // Compute the maximum values + max_val = _mm512_max_epi8(max_val, vec); + } + + // Reduce the final max_val to find the maximum element + // Extract the upper 256 bits and compare with the lower 256 bits + __m256i max256 = _mm512_extracti64x4_epi64(max_val, 1); // Extract upper 256 bits + max256 = _mm256_max_epi8(max256, + _mm512_castsi512_si256(max_val)); // Compare upper 256 with lower 256 + + // Further reduce 256-bit value + __m128i max128 = _mm256_extracti128_si256(max256, 1); // Extract upper 128 bits + max128 = _mm_max_epi8(max128, + _mm256_castsi256_si128(max256)); // Compare upper 128 with lower 128 + + // Further reduce 128-bit value + max128 = _mm_max_epi8(max128, _mm_srli_si128(max128, 8)); // Compare first 8 bytes with second 8 bytes + max128 = _mm_max_epi8(max128, _mm_srli_si128(max128, 4)); // Further reduce + max128 = _mm_max_epi8(max128, _mm_srli_si128(max128, 2)); // Further reduce + max128 = _mm_max_epi8(max128, _mm_srli_si128(max128, 1)); // Final reduce + + int8_t sc_max_value = static_cast(_mm_extract_epi8(max128, 0)); + for (; i < size; ++i) { + if (data[i] > sc_max_value) { + sc_max_value = data[i]; + } + } + + return sc_max_value; +} + +__m128i convert_float_to_u8_avx512bw(__m512 float_vals) { + // Apply rounding + __m512 rounded_vals = _mm512_roundscale_ps(float_vals, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + + // Convert float to int + __m512i int_vals = _mm512_cvttps_epi32(rounded_vals); + + __m256i f256 = _mm512_extracti64x4_epi64(int_vals, 0); + __m256i s256 = _mm512_extracti64x4_epi64(int_vals, 1); + + __m256i packed16_1 = _mm256_packus_epi32(f256, s256); + __m128i packed8 = _mm_packus_epi16(_mm256_castsi256_si128(packed16_1), _mm256_extracti128_si256(packed16_1, 1)); + __m128i lanefix = _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(packed8), _mm_setr_epi32(0, 2, 1, 3))); + return lanefix; + // _mm_storeu_si128((__m128i*)&uint8_result[i], lanefix); +} + +__m128i convert_float_to_i8_avx512bw(__m512 float_vals) { + // Apply rounding + __m512 rounded_vals = _mm512_roundscale_ps(float_vals, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + + // Convert float to int + __m512i int_vals = _mm512_cvttps_epi32(rounded_vals); + + __m256i f256 = _mm512_extracti64x4_epi64(int_vals, 0); + __m256i s256 = _mm512_extracti64x4_epi64(int_vals, 1); + + __m256i packed16_1 = _mm256_packs_epi32(f256, s256); + __m128i packed8 = _mm_packs_epi16(_mm256_castsi256_si128(packed16_1), _mm256_extracti128_si256(packed16_1, 1)); + __m128i lanefix = _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(packed8), _mm_setr_epi32(0, 2, 1, 3))); + return lanefix; + // _mm_storeu_si128((__m128i*)&uint8_result[i], lanefix); +} + +float exp_and_sum_i8_avx512(const float* base_addr, const int8_t* indice, size_t size, int32_t adjustment, + float* temp_out) { + __m512 sum = _mm512_setzero_ps(); + __m256i broadcast_adjustment = _mm256_set1_epi8(static_cast(adjustment)); + + size_t i = 0; + for (; i + 32 <= size; i += 32) { + __m256i index_ori = _mm256_loadu_si256((__m256i*)&indice[i]); + __m256i index = _mm256_add_epi8(index_ori, broadcast_adjustment); + // Extract the lower 128 bits (first half) + __m128i index_low = _mm256_extracti128_si256(index, 0); + __m512i vec32_low = _mm512_cvtepu8_epi32(index_low); + __m512 gathered_data = _mm512_i32gather_ps(vec32_low, base_addr, 4); + sum = _mm512_add_ps(sum, gathered_data); + _mm512_storeu_ps(&temp_out[i], gathered_data); + + // Extract the upper 128 bits (second half) + __m128i index_high = _mm256_extracti128_si256(index, 1); + __m512i vec32_high = _mm512_cvtepu8_epi32(index_high); + gathered_data = _mm512_i32gather_ps(vec32_high, base_addr, 4); + sum = _mm512_add_ps(sum, gathered_data); + _mm512_storeu_ps(&temp_out[i + 16], gathered_data); + } + // Reduce sum to a scalar value + // Use shuffle and add to accumulate the result within the 512-bit register + __m512 shuf = _mm512_shuffle_f32x4(sum, sum, 0b11110101); // Swap 128-bit halves + sum = _mm512_add_ps(sum, shuf); // Add swapped halves + + shuf = _mm512_shuffle_f32x4(sum, sum, 0b01001110); // Further shuffle within 128-bit lanes + sum = _mm512_add_ps(sum, shuf); // Add + + // Now reduce within the 128-bit lanes + shuf = _mm512_shuffle_ps(sum, sum, 0b10110001); // Swap pairs of elements + sum = _mm512_add_ps(sum, shuf); // Add + + shuf = _mm512_shuffle_ps(sum, sum, 0b01001110); // Further shuffle pairs + sum = _mm512_add_ps(sum, shuf); // Add + + float total = _mm_cvtss_f32(_mm512_castps512_ps128(sum)); + for (; i < size; ++i) { + float v = base_addr[uint8_t(indice[i] + adjustment)]; + temp_out[i] = v; + total += v; + } + + return total; +} + +float exp_and_sum_u8_avx512(const float* base_addr, const uint8_t* indice, size_t size, int32_t, float* temp_out) { + __m512 sum = _mm512_setzero_ps(); + + size_t i = 0; + for (; i + 32 <= size; i += 32) { + __m256i index = _mm256_loadu_si256((__m256i*)&indice[i]); + // Extract the lower 128 bits (first half) + __m128i index_low = _mm256_extracti128_si256(index, 0); + __m512i vec32_low = _mm512_cvtepu8_epi32(index_low); + __m512 gathered_data = _mm512_i32gather_ps(vec32_low, base_addr, 4); + sum = _mm512_add_ps(sum, gathered_data); + _mm512_storeu_ps(&temp_out[i], gathered_data); + + // Extract the upper 128 bits (second half) + __m128i index_high = _mm256_extracti128_si256(index, 1); + __m512i vec32_high = _mm512_cvtepu8_epi32(index_high); + gathered_data = _mm512_i32gather_ps(vec32_high, base_addr, 4); + sum = _mm512_add_ps(sum, gathered_data); + _mm512_storeu_ps(&temp_out[i + 16], gathered_data); + } + // Reduce sum to a scalar value + // Use shuffle and add to accumulate the result within the 512-bit register + __m512 shuf = _mm512_shuffle_f32x4(sum, sum, 0b11110101); // Swap 128-bit halves + sum = _mm512_add_ps(sum, shuf); // Add swapped halves + + shuf = _mm512_shuffle_f32x4(sum, sum, 0b01001110); // Further shuffle within 128-bit lanes + sum = _mm512_add_ps(sum, shuf); // Add + + // Now reduce within the 128-bit lanes + shuf = _mm512_shuffle_ps(sum, sum, 0b10110001); // Swap pairs of elements + sum = _mm512_add_ps(sum, shuf); // Add + + shuf = _mm512_shuffle_ps(sum, sum, 0b01001110); // Further shuffle pairs + sum = _mm512_add_ps(sum, shuf); // Add + + float total = _mm_cvtss_f32(_mm512_castps512_ps128(sum)); + for (; i < size; ++i) { + float v = base_addr[indice[i]]; + temp_out[i] = v; + total += v; + } + + return total; +} + +int32_t normalize_sum_avx512(float total_sum, size_t size, float x_scale, float* temp_out, float yzp, uint8_t* output) { + size_t i = 0; + float inverse_sum = 1.0f / total_sum; + float scale = inverse_sum * x_scale; + __m512 broadcast_scale = _mm512_set1_ps(scale); + __m512 broadcast_zp = _mm512_set1_ps(yzp); + + for (i = 0; i + 16 <= size; i += 16) { + __m512 a_vec = _mm512_loadu_ps(&temp_out[i]); + __m512 result_vec = _mm512_fmadd_ps(a_vec, broadcast_scale, broadcast_zp); + // __m512i fma_result1 = _mm512_cvtt_roundps_epi32(result_vec, + // _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + __m128i packed8 = convert_float_to_u8_avx512bw(result_vec); + _mm_storeu_si128((__m128i*)&output[i], packed8); + } + constexpr uint8_t max_u8 = 255; + for (; i < size; ++i) { + int v = int32_t(std::nearbyintf(temp_out[i] * scale + yzp)); + output[i] = v > max_u8 ? max_u8 : static_cast(v); + } + return 0; +} + +int32_t normalize_sum_avx512(float total_sum, size_t size, float x_scale, float* temp_out, float yzp, int8_t* output) { + size_t i = 0; + float inverse_sum = 1.0f / total_sum; + float scale = inverse_sum * x_scale; + __m512 broadcast_scale = _mm512_set1_ps(scale); + __m512 broadcast_zp = _mm512_set1_ps(yzp); + + for (i = 0; i + 16 <= size; i += 16) { + __m512 a_vec = _mm512_loadu_ps(&temp_out[i]); + __m512 result_vec = _mm512_fmadd_ps(a_vec, broadcast_scale, broadcast_zp); + // __m512i fma_result1 = _mm512_cvtt_roundps_epi32(result_vec, + // _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + __m128i packed8 = convert_float_to_i8_avx512bw(result_vec); + _mm_storeu_si128((__m128i*)&output[i], packed8); + } + constexpr uint8_t max_u8 = 255; + for (; i < size; ++i) { + int v = int32_t(std::nearbyintf(temp_out[i] * scale + yzp)); + output[i] = v > max_u8 ? static_cast(max_u8) : static_cast(v); + } + return 0; +} + +void MlasQuantizeSoftmaxI8KernelAvx512(size_t D, const int8_t* x_data, int8_t* y_data, + const float* lookup_table, float y_scale, int8_t yzp, float* tempaddr) { + constexpr int i = 0; + int32_t xmax = reduce_max_i8_avx512(x_data + i * D, D); + const int32_t adjustment = int32_t(127) - xmax; + const float* shifted_lookuptable = lookup_table; + float total_sum = exp_and_sum_i8_avx512(shifted_lookuptable, x_data + i * D, D, adjustment, (float*)tempaddr); + normalize_sum_avx512(total_sum, D, y_scale, (float*)tempaddr, yzp, y_data + i * D); +} + +void MlasQuantizeSoftmaxU8KernelAvx512(size_t D, const uint8_t* x_data, uint8_t* y_data, + const float* lookup_table, float y_scale, uint8_t yzp, float* tempaddr) { + constexpr int i = 0; + int32_t xmax = reduce_max_u8_avx512(x_data + i * D, D); + const int32_t adjustment = int32_t(255) - xmax; + const float* shifted_lookuptable = lookup_table + adjustment; + float total_sum = exp_and_sum_u8_avx512(shifted_lookuptable, x_data + i * D, D, adjustment, (float*)tempaddr); + normalize_sum_avx512(total_sum, D, y_scale, (float*)tempaddr, yzp, y_data + i * D); +} diff --git a/onnxruntime/core/mlas/lib/qsoftmax_kernel_naive.cpp b/onnxruntime/core/mlas/lib/qsoftmax_kernel_naive.cpp new file mode 100644 index 0000000000000..379f82edcbe04 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qsoftmax_kernel_naive.cpp @@ -0,0 +1,106 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qsoftmax_kernel_avx2.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for x64 avx2. + +--*/ + +#include +#include +#include + +#include "mlasi.h" + +void MlasQuantizeSoftmaxU8KernelNaive(size_t D, const uint8_t* x_data, uint8_t* y_data, const float* lookup_table, + float y_scale, uint8_t yzp, float*) { + constexpr size_t N = 1; + const auto c_y_scale = y_scale; + const auto c_y_zp = yzp; + const uint8_t* x_t = x_data + 0 * D; + uint8_t* y_t = y_data + 0 * D; + for (size_t first = 0; first < N; first++) { + // reduceMaxUint8 + uint8_t xmax = *std::max_element(x_t, x_t + D); + // we want the xmas to align with 255 for higher precision. + // as we build a lookup table with X-255. So we could use the adjustment here + // to let all numbers have a shift in the lookup table. + // 1 2 3 4 5 ...........................254 255 + // 1 3 5 ... 10 + // after the shift ---> + // 235 237 239 .. 255 + const float* shifted_lookuptable = lookup_table + 255 - xmax; + size_t elements_n = D; + // reduceSumUin8ToUint32: need speedup + // vsum = \sum_i{e^x_i} + float vsum = 0; + const uint8_t* x_t_cur = x_t; + do { + const size_t vx = *x_t_cur++; + vsum += shifted_lookuptable[vx]; + } while (--elements_n != 0); + if (vsum == 0) { + return; + } + elements_n = D; + x_t_cur = x_t; + // elementwise div, y_i=\frac{x_i}{vsum} + do { + const size_t vx = *x_t_cur++; + const float vt = shifted_lookuptable[vx]; + // simulate round function, and re-quant to uint8 + const uint32_t vq = static_cast(std::nearbyintf((vt * c_y_scale) / vsum)) + c_y_zp; + const uint8_t vy = vq > 255 ? static_cast(255) : static_cast(vq); + *y_t++ = vy; + } while (--elements_n != 0); + x_t = x_t_cur; + } +} + +void MlasQuantizeSoftmaxI8KernelNaive(size_t D, const int8_t* x_data, int8_t* y_data, const float* lookup_table, + float y_scale, int8_t yzp, float*) { + constexpr size_t N = 1; + const auto c_y_scale = y_scale; + const auto c_y_zp = yzp; + size_t first = 0; + const int8_t* x_t = x_data + first * D; + int8_t* y_t = y_data + first * D; + for (; first < N; first++) { + // reduceMaxInt8 + int8_t xmax = *std::max_element(x_t, x_t + D); + const int32_t adjustment = int32_t(127) - xmax; + const float* shifted_lookuptable = lookup_table; + size_t elements_n = D; + // reduceSumUin8ToUint32: need speedup + float vsum = 0; + const int8_t* x_t_cur = x_t; + do { + const uint8_t vx = uint8_t(adjustment + (*x_t_cur++)); + vsum += shifted_lookuptable[vx]; + } while (--elements_n != 0); + if (vsum == 0) { + return; + } + elements_n = D; + x_t_cur = x_t; + // elementwise div + do { + const uint8_t vx = uint8_t(adjustment + (*x_t_cur++)); + const float vt = shifted_lookuptable[vx]; + // simulate round function, and re-quant to Int8 + const int32_t vq = static_cast(std::nearbyintf(((vt * c_y_scale)) / vsum)) + c_y_zp; + const int8_t vy = static_cast(vq) > 255 ? static_cast(255) : static_cast(vq); + *y_t++ = vy; + } while (--elements_n != 0); + x_t = x_t_cur; + } +} diff --git a/onnxruntime/test/mlas/bench/bench_qsoftmax.cpp b/onnxruntime/test/mlas/bench/bench_qsoftmax.cpp new file mode 100644 index 0000000000000..0d6717457b6bf --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_qsoftmax.cpp @@ -0,0 +1,38 @@ +#include "bench_util.h" +#include "core/mlas/lib/mlasi.h" + +static const std::vector qsoftmax_bench_arg_names = {"N", "D", "is_signed"}; +//(const void* Input, void* Output, size_t N, size_t D, const float* LoopupTable,float Scale, int ZeroPoint, bool is_signed, MLAS_THREADPOOL* ThreadPool); +void BM_Qsoftmax(benchmark::State& state) { + size_t N = static_cast(state.range(0)); + size_t D = static_cast(state.range(1)); + bool is_signed = static_cast(state.range(0)); + const size_t count = N * D; + auto src = RandomVectorUniform(count, 0, 127); + auto LoopupTable = RandomVectorUniform(count, 0.f, 1e10); + auto dst = std::vector(count + 16); + auto aligned_dst = (reinterpret_cast(dst.data()) + 15) & (~15); + int8_t* dst_start = aligned ? reinterpret_cast(aligned_dst) + : reinterpret_cast(aligned_dst + 1); + + OrtThreadPoolParams tpo; + tpo.thread_pool_size = 8; + tpo.auto_set_affinity = true; + std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + + // Warm up + MlasComputeQSoftmax(src.data(), dst_start, N, D, LoopupTable.data(), 0.1f, 1.0f, 0, is_signed, tp.get()); + + for (auto _ : state) { + MlasComputeQSoftmax(src.data(), dst_start, N, D, LoopupTable.data(), 0.1f, 1.0f, 0, is_signed, tp.get()); + } +} + +BENCHMARK(BM_Qsoftmax) + ->UseRealTime() + ->Apply([](benchmark::internal::Benchmark* b) { + b->ArgNames(qsoftmax_bench_arg_names); + b->ArgsProduct({{1971, 20}, {81, 1000}, {0, 1}}); + });