Skip to content

Commit

Permalink
optimize qlinearsoftmax
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Nov 1, 2024
1 parent 7ad7873 commit f7d97e1
Show file tree
Hide file tree
Showing 10 changed files with 981 additions and 145 deletions.
11 changes: 11 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
${MLAS_SRC_DIR}/cast.cpp
${MLAS_SRC_DIR}/qsoftmax.cpp
${MLAS_SRC_DIR}/qsoftmax_kernel_naive.cpp
)

target_sources(onnxruntime_mlas PRIVATE
Expand Down Expand Up @@ -163,6 +165,10 @@ function(setup_mlas_source_for_windows)
file(GLOB_RECURSE mlas_platform_srcs_avx2 CONFIGURE_DEPENDS
"${MLAS_SRC_DIR}/intrinsics/avx2/*.cpp"
)
set(mlas_platform_srcs_avx2
${mlas_platform_srcs_avx2}
"${MLAS_SRC_DIR}/qsoftmax_kernel_avx2.cpp"
)
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2")

target_sources(onnxruntime_mlas PRIVATE
Expand All @@ -171,6 +177,7 @@ function(setup_mlas_source_for_windows)
${mlas_platform_srcs_avx2}
${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp
${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/qsoftmax_kernel_avx512.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
Expand Down Expand Up @@ -570,6 +577,7 @@ else()
${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/qsoftmax_kernel_avx2.cpp
)
if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE))
set(mlas_platform_srcs_avx2
Expand Down Expand Up @@ -610,6 +618,7 @@ endif()

set(mlas_platform_srcs_avx512vnni
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp
${MLAS_SRC_DIR}/qsoftmax_kernel_avx512.cpp
)
set_source_files_properties(${mlas_platform_srcs_avx512vnni} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f")

Expand All @@ -619,6 +628,7 @@ endif()
${MLAS_SRC_DIR}/dgemm.cpp
${MLAS_SRC_DIR}/pooling_fp16.cpp
${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp

${mlas_platform_srcs_sse2}
${mlas_platform_srcs_avx}
${mlas_platform_srcs_avx2}
Expand All @@ -643,6 +653,7 @@ endif()
)
set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
set_source_files_properties(${MLAS_SRC_DIR}/qsoftmax_kernel_avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
150 changes: 6 additions & 144 deletions onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void QlinearBuildLookupTableUint32(gsl::span<QLinearSoftmax::EXP_OUT_DTYPE> tabl
for (int32_t i = 0; i < 256; i++) {
double scaled_exp_xi = exp((static_cast<double>(i) - 255 + bit_shift) * static_cast<double>(x_scale));
// we can't get the real max value of input tensor here, so we just assume 255-bit_shift.
// in the function of `QlinearSoftmaxCPU`,
// in the function of `MlasComputeQSoftmax`,
// all numbers will have a shift (255-bit_shift-max_value) if its max value is not 255
//
// if is_signed index = [1 2 3 ......126 127 -128 -127 ..... -3 -2 -1]
Expand Down Expand Up @@ -123,136 +123,6 @@ Status QLinearSoftmax::Compute(OpKernelContext* ctx) const {
}
}

template <typename T>
common::Status QlinearSoftmaxCPU(size_t N,
size_t D,
const T* x_data,
T* y_data,
const QLinearSoftmax::EXP_OUT_DTYPE* lookup_table,
QLinearSoftmax::EXP_OUT_DTYPE y_scale,
T yzp,
onnxruntime::concurrency::ThreadPool* thread_pool);

template <>
common::Status QlinearSoftmaxCPU<uint8_t>(size_t N,
size_t D,
const uint8_t* x_data,
uint8_t* y_data,
const QLinearSoftmax::EXP_OUT_DTYPE* lookup_table,
QLinearSoftmax::EXP_OUT_DTYPE y_scale,
uint8_t yzp,
onnxruntime::concurrency::ThreadPool* thread_pool) {
using onnxruntime::TensorOpCost;
using onnxruntime::concurrency::ThreadPool;
ThreadPool::TryParallelFor(
thread_pool, N,
// Read 3*N (max,sum,div) write N (div), computation=Read
TensorOpCost{static_cast<double>(D) * 3.0,
static_cast<double>(D),
static_cast<double>(D) * 3.0},
[x_data, y_data, D, y_scale, yzp, &lookup_table](std::ptrdiff_t first, std::ptrdiff_t last) {
const auto c_y_scale = y_scale;
const auto c_y_zp = yzp;
const uint8_t* x_t = x_data + first * D;
uint8_t* y_t = y_data + first * D;
for (; first < last; first++) {
// reduceMaxUint8
uint8_t xmax = *std::max_element(x_t, x_t + D);
// we want the xmas to align with 255 for higher precision.
// as we build a lookup table with X-255. So we could use the adjustment here
// to let all numbers have a shift in the lookup table.
// 1 2 3 4 5 ...........................254 255
// 1 3 5 ... 10
// after the shift --->
// 235 237 239 .. 255
const QLinearSoftmax::EXP_OUT_DTYPE* shifted_lookuptable = lookup_table + 255 - xmax;
size_t elements_n = D;
// reduceSumUin8ToUint32: need speedup
// vsum = \sum_i{e^x_i}
QLinearSoftmax::EXP_OUT_DTYPE vsum = 0;
const uint8_t* x_t_cur = x_t;
do {
const size_t vx = *x_t_cur++;
vsum += shifted_lookuptable[vx];
} while (--elements_n != 0);
if (vsum == 0) {
return;
}
elements_n = D;
x_t_cur = x_t;
// elementwise div, y_i=\frac{x_i}{vsum}
do {
const size_t vx = *x_t_cur++;
const QLinearSoftmax::EXP_OUT_DTYPE vt = shifted_lookuptable[vx];
// simulate round function, and re-quant to uint8
const uint32_t vq = static_cast<uint32_t>(std::nearbyintf(((vt * c_y_scale)) / vsum)) + c_y_zp;
const uint8_t vy = vq > 255 ? static_cast<uint8_t>(255) : static_cast<uint8_t>(vq);
*y_t++ = vy;
} while (--elements_n != 0);
x_t = x_t_cur;
}
});

return Status::OK();
}

template <>
common::Status QlinearSoftmaxCPU<int8_t>(size_t N,
size_t D,
const int8_t* x_data,
int8_t* y_data,
const QLinearSoftmax::EXP_OUT_DTYPE* lookup_table,
QLinearSoftmax::EXP_OUT_DTYPE y_scale,
int8_t yzp,
onnxruntime::concurrency::ThreadPool* thread_pool) {
using onnxruntime::TensorOpCost;
using onnxruntime::concurrency::ThreadPool;
ThreadPool::TryParallelFor(
thread_pool, N,
// Read 3*N (max,sum,div) write N (div), computation=Read
TensorOpCost{static_cast<double>(D) * 3.0,
static_cast<double>(D),
static_cast<double>(D) * 3.0},
[x_data, y_data, D, y_scale, yzp, &lookup_table](std::ptrdiff_t first, std::ptrdiff_t last) {
const auto c_y_scale = y_scale;
const auto c_y_zp = yzp;

const int8_t* x_t = x_data + first * D;
int8_t* y_t = y_data + first * D;
for (; first < last; first++) {
// reduceMaxInt8
int8_t xmax = *std::max_element(x_t, x_t + D);
const int32_t adjustment = int32_t(127) - xmax;
const QLinearSoftmax::EXP_OUT_DTYPE* shifted_lookuptable = lookup_table;
size_t elements_n = D;
// reduceSumUin8ToUint32: need speedup
QLinearSoftmax::EXP_OUT_DTYPE vsum = 0;
const int8_t* x_t_cur = x_t;
do {
const uint8_t vx = uint8_t(adjustment + (*x_t_cur++));
vsum += shifted_lookuptable[vx];
} while (--elements_n != 0);
if (vsum == 0) {
return;
}
elements_n = D;
x_t_cur = x_t;
// elementwise div
do {
const uint8_t vx = uint8_t(adjustment + (*x_t_cur++));
const QLinearSoftmax::EXP_OUT_DTYPE vt = shifted_lookuptable[vx];
// simulate round function, and re-quant to Int8
const int32_t vq = static_cast<int32_t>(std::nearbyintf(((vt * c_y_scale)) / vsum)) + c_y_zp;
const int8_t vy = static_cast<int32_t>(vq) > 255 ? static_cast<int8_t>(255) : static_cast<int8_t>(vq);
*y_t++ = vy;
} while (--elements_n != 0);
x_t = x_t_cur;
}
});

return Status::OK();
}

gsl::span<const QLinearSoftmax::EXP_OUT_DTYPE> QLinearSoftmax::GetLookupTable(
OpKernelContext* context,
gsl::span<EXP_OUT_DTYPE> lookup_table_span,
Expand All @@ -270,25 +140,17 @@ gsl::span<const QLinearSoftmax::EXP_OUT_DTYPE> QLinearSoftmax::GetLookupTable(
Status QLinearSoftmax::ComputeInternal(OpKernelContext* context, const Tensor& input, Tensor& output,
gsl::span<const EXP_OUT_DTYPE> lookup_table, int axis,
concurrency::ThreadPool* thread_pool) const {
const auto* X_scale_tensor = context->Input<Tensor>(1);
const auto* Y_scale_tensor = context->Input<Tensor>(3);
const auto* Y_zp_tensor = context->Input<Tensor>(4);
const QLinearSoftmax::EXP_OUT_DTYPE Y_scale = std::floor(1.0F / (*(Y_scale_tensor->Data<float>())));
const auto& X_shape = input.Shape();
const size_t N = onnxruntime::narrow<size_t>(X_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis)));
const size_t D = onnxruntime::narrow<size_t>(X_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis)));
common::Status status;
if (is_signed_) {
using T = int8_t;
const T Y_zp = Y_zp_tensor ? *(Y_zp_tensor->Data<T>()) : 0;
status = QlinearSoftmaxCPU<T>(N, D, input.Data<T>(), output.MutableData<T>(),
lookup_table.data(), Y_scale, Y_zp, thread_pool);
} else {
using T = uint8_t;
const T Y_zp = Y_zp_tensor ? *(Y_zp_tensor->Data<T>()) : 0;
status = QlinearSoftmaxCPU<T>(N, D, input.Data<T>(), output.MutableData<T>(),
lookup_table.data(), Y_scale, Y_zp, thread_pool);
}
return status;
const int Y_zp = Y_zp_tensor ? (is_signed_ ? *(Y_zp_tensor->Data<int8_t>()) : *(Y_zp_tensor->Data<uint8_t>())) : 0;
MlasComputeQSoftmax(input.DataRaw(), output.MutableDataRaw(), N, D, lookup_table.data(),
*X_scale_tensor->Data<float>(), Y_scale, Y_zp, is_signed_, thread_pool);
return Status::OK();
}

// opset-13 and above
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,21 @@ MlasComputeSoftmax(
MLAS_THREADPOOL* ThreadPool
);

void
MLASCALL
MlasComputeQSoftmax(
const void* Input,
void* Output,
size_t N,
size_t D,
const float* LoopupTable,
float X_Scale,
float Scale,
int ZeroPoint,
bool is_signed,
MLAS_THREADPOOL* ThreadPool
);

void
MLASCALL
MlasComputeTanh(
Expand Down
32 changes: 32 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Module Name:
#if defined(__GNUC__) && __GNUC__ >= 12
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // GCC 12 warns about uninitialized variables in immintrin.h.
#pragma GCC diagnostic ignored "-Wuninitialized" // GCC 12 warns about uninitialized variables in immintrin.h.
#include <immintrin.h>
#pragma GCC diagnostic pop
#else
Expand Down Expand Up @@ -711,6 +712,28 @@ void
float Scale,
int8_t ZeroPoint);

typedef
void
(MLASCALL MLAS_QUANTIZE_SOFTMAX_I8_KERNEL)(
size_t D,
const int8_t* Xdata,
int8_t* Ydata,
const float* LookupTable,
float Yscale,
int8_t YZeroPoint,
float* Buff);

typedef
void
(MLASCALL MLAS_QUANTIZE_SOFTMAX_U8_KERNEL)(
size_t D,
const uint8_t* Xdata,
uint8_t* Ydata,
const float* LookupTable,
float Yscale,
uint8_t YZeroPoint,
float* Buff);

template<typename InputType, typename FilterType>
struct MLAS_QUANT_KERNEL
{
Expand Down Expand Up @@ -876,7 +899,13 @@ extern "C" {
MLAS_QLINEAR_BINARY_OP_U8_KERNEL MlasQLinearAddU8KernelAvx2;
MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelAvx512F;
MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelAvx512F;
MLAS_QUANTIZE_SOFTMAX_I8_KERNEL MlasQuantizeSoftmaxI8KernelAvx2;
MLAS_QUANTIZE_SOFTMAX_I8_KERNEL MlasQuantizeSoftmaxI8KernelAvx512;
MLAS_QUANTIZE_SOFTMAX_U8_KERNEL MlasQuantizeSoftmaxU8KernelAvx2;
MLAS_QUANTIZE_SOFTMAX_U8_KERNEL MlasQuantizeSoftmaxU8KernelAvx512;
#endif
MLAS_QUANTIZE_SOFTMAX_I8_KERNEL MlasQuantizeSoftmaxI8KernelNaive;
MLAS_QUANTIZE_SOFTMAX_U8_KERNEL MlasQuantizeSoftmaxU8KernelNaive;

MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32Kernel;
MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32Kernel;
Expand Down Expand Up @@ -1188,6 +1217,9 @@ struct MLAS_PLATFORM {

MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel;
MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel;

MLAS_QUANTIZE_SOFTMAX_I8_KERNEL *QuantizeSoftmaxI8Kernel;
MLAS_QUANTIZE_SOFTMAX_U8_KERNEL *QuantizeSoftmaxU8Kernel;
};

inline
Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ Return Value:
this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernel<int8_t, uint8_t>;
this->CastF16ToF32Kernel = nullptr;
this->CastF32ToF16Kernel = nullptr;
this->QuantizeSoftmaxI8Kernel = MlasQuantizeSoftmaxI8KernelNaive;
this->QuantizeSoftmaxU8Kernel = MlasQuantizeSoftmaxU8KernelNaive;

#if defined(MLAS_TARGET_AMD64_IX86)

Expand All @@ -258,7 +260,6 @@ Return Value:
this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchSse;

#if defined(MLAS_TARGET_AMD64)

this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Sse;
this->GemmDoubleKernel = MlasGemmDoubleKernelSse;
this->ConvNchwFloatKernel = MlasConvNchwFloatKernelSse;
Expand Down Expand Up @@ -391,6 +392,8 @@ Return Value:
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2;
this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2;

this->QuantizeSoftmaxI8Kernel = MlasQuantizeSoftmaxI8KernelAvx2;
this->QuantizeSoftmaxU8Kernel = MlasQuantizeSoftmaxU8KernelAvx2;

//
// Check if the processor supports Hybrid core architecture.
Expand Down Expand Up @@ -460,6 +463,8 @@ Return Value:
this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchAvx512;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512;

this->QuantizeSoftmaxI8Kernel = MlasQuantizeSoftmaxI8KernelAvx512;
this->QuantizeSoftmaxU8Kernel = MlasQuantizeSoftmaxU8KernelAvx512;
//
// Check if the processor supports AVX512VNNI.
//
Expand Down
Loading

0 comments on commit f7d97e1

Please sign in to comment.