Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARM CPU] Add Fp16 kernels for MatMulNBits #22651

Closed
wants to merge 56 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
3fef98c
removed accuracy level field
fajin-corp Oct 4, 2024
8209edd
added template support
fajin-corp Oct 5, 2024
c3233b6
added template support for compfp16 to support different input types
fajin-corp Oct 7, 2024
917b6da
finished adding interfaces
fajin-corp Oct 8, 2024
d8094eb
passed build on intel
fajin-corp Oct 9, 2024
26466f0
added accuracy 2 to fp16 ut
fajin-corp Oct 9, 2024
97036bb
fix arm build
fajin-corp Oct 9, 2024
8b662a9
fixing arm build
fajin-corp Oct 9, 2024
068fd35
fixing arm build
fajin-corp Oct 9, 2024
ee7d79f
fixing arm build
fajin-corp Oct 9, 2024
e15304e
finished SQ4BitGemmPackQuantBData_CompFp16
fajin-corp Oct 10, 2024
b65ad4f
finished dequant b
fajin-corp Oct 11, 2024
bf9db68
finished matmul kernels
fajin-corp Oct 16, 2024
1dffc21
passed intel build
fajin-corp Oct 16, 2024
672cb97
fixing arm build
fajin-corp Oct 16, 2024
27ef55b
passed arm build
fajin-corp Oct 16, 2024
8d216cc
passed arm build
fajin-corp Oct 17, 2024
e461155
added prepack ut
fajin-corp Oct 17, 2024
b0f3223
passed prepack ut
fajin-corp Oct 18, 2024
35e5216
added ut for dequant b
fajin-corp Oct 18, 2024
6e9876c
passed dequant B
fajin-corp Oct 19, 2024
aac7c04
added matmul ut
fajin-corp Oct 23, 2024
f39fde8
fixed linux build
fajin-corp Oct 23, 2024
49bd472
fixing out buffer bug
fajin-corp Oct 24, 2024
01a641d
fixed B stepping bug
fajin-corp Oct 24, 2024
0391fff
modified test cases
fajin-corp Oct 24, 2024
2cd4939
fixed B loop step and ref matmul
fajin-corp Oct 24, 2024
b7612d0
finished mlas ut
fajin-corp Oct 25, 2024
aaf1b68
fixed fp16 alloc bug
fajin-corp Oct 25, 2024
b9403fc
passed UT!
fajin-corp Oct 26, 2024
b670915
refactored fp16 kernels
fajin-corp Oct 29, 2024
ccdadf4
fixed b step
fajin-corp Oct 29, 2024
adab09a
fixing seg fault
fajin-corp Oct 29, 2024
55dcbae
finished benchmarking
fajin-corp Oct 30, 2024
b52a24a
fix linting
fajin-corp Oct 30, 2024
e1ddf1d
add post processing
fajin-corp Oct 30, 2024
8f51538
fix merge
fajin-corp Oct 30, 2024
8e2000a
further speed up
fajin-corp Oct 31, 2024
aeb6189
changed signatures of functions used in matmul_nbits.cc
fajin-corp Oct 31, 2024
64d1c49
changed dispatch name
fajin-corp Oct 31, 2024
cef21cf
renamed sqnbits and clean up code branches
fajin-corp Nov 1, 2024
fd4e18f
renamed files
fajin-corp Nov 1, 2024
139149d
fixing x86 template error
fajin-corp Nov 1, 2024
d12377f
fixing linting
fajin-corp Nov 1, 2024
e4f8fbf
fixing unused specialization
fajin-corp Nov 1, 2024
6a3d3c1
fixing CI
fajin-corp Nov 1, 2024
b6c6bbb
fixing ci
fajin-corp Nov 1, 2024
c7f4dfb
fixing ci
fajin-corp Nov 1, 2024
90e9f83
pass qnn ut
fajin-corp Nov 2, 2024
3f007e2
renaming sqnb to hqnb
fajin-corp Nov 4, 2024
a2067bb
fix linting
fajin-corp Nov 5, 2024
e074b25
try to fix illegal instructions
fajin-corp Nov 7, 2024
23a83c8
limit EP types for fp16 ut
fajin-corp Nov 8, 2024
adce284
fixing fp32
fajin-corp Nov 8, 2024
8050f0a
platform change
fajin-corp Nov 8, 2024
617ddfa
add runtime check for mlas ut
fajin-corp Nov 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/qpostprocessor.cpp
${MLAS_SRC_DIR}/qlgavgpool.cpp
${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp
${MLAS_SRC_DIR}/sqnbitgemm.h
${MLAS_SRC_DIR}/sqnbitgemm.cpp
${MLAS_SRC_DIR}/qnbitgemm.h
${MLAS_SRC_DIR}/qnbitgemm.cpp
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
${MLAS_SRC_DIR}/cast.cpp
Expand Down Expand Up @@ -84,11 +84,12 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/fp16_neon_common.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
)

set(mlas_platform_preprocess_srcs
Expand Down Expand Up @@ -362,8 +363,8 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
)
Expand All @@ -384,6 +385,7 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/fp16_neon_common.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
Expand All @@ -394,6 +396,7 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
232 changes: 132 additions & 100 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

Large diffs are not rendered by default.

95 changes: 49 additions & 46 deletions onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,51 +27,53 @@ Module Name:
* @brief Define compute types of block quantization, in order of decreasing accuracy.
*/
typedef enum {
CompUndef = 0, /*!< undef */
CompFp32, /*!< input fp32, accumulator fp32 */
CompFp16, /*!< input fp16, accumulator fp16 */
CompBf16, /*!< input bf16, accumulator fp32 */
CompInt8, /*!< input int8, accumulator int32 */
SQNBIT_CompFp32, /*!< input fp32, accumulator fp32 */
HQNBIT_CompFp16, /*!< input fp16, accumulator fp16 */
BHQNBIT_CompBf16, /*!< input bf16, accumulator fp32 */

// special values that should be the first and last actual values
// This compute type only makes sense if there is a performance gain.
SQNBIT_CompInt8, /*!< input fp32 + int8, accumulator int32. */
HQNBIT_CompInt8 /*!< input fp16 + int8, accumulator int32. This compute type may not make sense. */

CompMostAccurate = CompUndef,
CompLeastAccurate = CompInt8,
} MLAS_SQNBIT_GEMM_COMPUTE_TYPE;
} MLAS_QNBIT_GEMM_COMPUTE_TYPE;

/**
* @brief Data parameters for float/n-bit quantized int GEMM routine.
*
* @tparam T data type of input A
fajin-corp marked this conversation as resolved.
Show resolved Hide resolved
*/
struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
const float* A = nullptr; ///< address of A (float32 matrix)
size_t lda = 0; ///< leading dimension of A
const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values)
const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data
const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block
const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block
const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block
const float* Bias = nullptr; ///< optional address of Bias, vector size N
float* C = nullptr; ///< address of result matrix
size_t ldc = 0; ///< leading dimension of C
template <typename T>
struct MLAS_QNBIT_GEMM_DATA_PARAMS {
const T* A = nullptr; // address of A (float32/16 matrix)
size_t lda = 0; // leading dimension of A
const void* QuantBDataWorkspace; // address of quantized B (quantized n-bit int values)
const std::byte* PackedQuantBData = nullptr; // address of packed quantized B data
const T* QuantBScale = nullptr; // address of scale values of quantized B, one per block
const void* QuantBZeroPoint = nullptr; // optional address of zero point values of quantized B, one per block
const T* QuantBBlkSum = nullptr; // optional address of scale * zp, one per block
const T* Bias = nullptr; // optional address of Bias, vector size N
T* C = nullptr; // address of result matrix
size_t ldc = 0; // leading dimension of C

///< optional post processing to apply to result matrix
MLAS_GEMM_POSTPROCESSOR<float>* PostProcessor = nullptr;
MLAS_GEMM_POSTPROCESSOR<T>* PostProcessor = nullptr;
};

/**
* @brief Batched GEMM: C = A * B + Bias
* A must be a float32 matrix
* A must be a float32/16 matrix
* B must be a quantized and packed n-bit int matrix
*
* Call MlasIsSQNBitGemmAvailable() with the same parameters to determine whether this function may be called.
* Call MlasIsQNBitGemmAvailable() with the same parameters to determine whether this function may be called.
*
* Call MlasSQNBitGemmPackQuantBDataSize() with the same parameters to determine whether
* MLAS_SQNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with
* MlasSQNBitGemmPackQuantBData().
* Call MlasQNBitGemmPackQuantBDataSize() with the same parameters to determine whether
* MLAS_QNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with
* MlasQNBitGemmPackQuantBData().
*
* Call MlasSQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should
* Call MlasQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should
* point to an intermediate workspace buffer.
*
* @tparam T data type of input A in data params
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
Expand All @@ -81,41 +83,42 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] Workspace Address of intermediate workspace buffer.
If MlasSQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a
If MlasQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a
buffer with at least that many bytes. Otherwise, it may be nullptr.
* @param[in] ThreadPool optional thread pool to use
*/
template <typename T>
void MLASCALL
MlasSQNBitGemmBatch(
MlasQNBitGemmBatch(
size_t M,
size_t N,
size_t K,
size_t BatchN,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType,
const MLAS_QNBIT_GEMM_DATA_PARAMS<T>* DataParams,
void* Workspace,
MLAS_THREADPOOL* ThreadPool = nullptr
);

/**
* @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform.
* @brief Determines whether a float32/16 quantized n-bit int GEMM implementation is available on the current platform.
*
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
fajin-corp marked this conversation as resolved.
Show resolved Hide resolved
*/
bool MLASCALL
MlasIsSQNBitGemmAvailable(
MlasIsQNBitGemmAvailable(
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

/**
* @brief Gets the size in bytes of the intermediate workspace buffer required by the float32/quantized n-bit int GEMM
* implementation. If zero, no intermediate workspace is required.
* @brief Gets the size in bytes of the intermediate workspace buffer required by the float32/16 quantized n-bit int
* GEMM implementation. If zero, no intermediate workspace is required.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
Expand All @@ -126,22 +129,22 @@ MlasIsSQNBitGemmAvailable(
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
size_t MLASCALL
MlasSQNBitGemmBatchWorkspaceSize(
MlasQNBitGemmBatchWorkspaceSize(
size_t M,
size_t N,
size_t K,
size_t BatchN,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

/**
* @brief Gets the size in bytes of the packed quantized B data.
* If non-zero, the quantized B data must first be packed by calling MlasSQNBitGemmPackQuantBData() with a buffer of
* this size, and then that packed quantized B data buffer must be passed to MlasSQNBitGemmBatch().
* If zero, MlasSQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to
* MlasSQNBitGemmBatch().
* If non-zero, the quantized B data must first be packed by calling MlasQNBitGemmPackQuantBData() with a buffer of
* this size, and then that packed quantized B data buffer must be passed to MlasQNBitGemmBatch().
* If zero, MlasQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to
* MlasQNBitGemmBatch().
*
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
Expand All @@ -150,12 +153,12 @@ MlasSQNBitGemmBatchWorkspaceSize(
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
size_t MLASCALL
MlasSQNBitGemmPackQuantBDataSize(
MlasQNBitGemmPackQuantBDataSize(
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

/**
Expand Down Expand Up @@ -186,12 +189,12 @@ MlasSQNBitGemmPackQuantBDataSize(
* @param[in] ThreadPool thread pool to use (no parallel if nullptr)
*/
void MLASCALL
MlasSQNBitGemmPackQuantBData(
MlasQNBitGemmPackQuantBData(
size_t N,
size_t K,
size_t BlkBitWidth,
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType,
const void* QuantBData,
void* PackedQuantBDataAndOrBlkSum,
const void* QuantBScale,
Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/core/mlas/lib/fp16_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); }

template <int lane>
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasLoadLaneFloat16x4(const _mlas_fp16_* Buffer, MLAS_FLOAT16X4 vec) {
return vreinterpret_f16_u16(
vld1_lane_u16(Buffer, vreinterpret_u16_f16(vec), lane)
);
}

MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasLoadPartialFloat16x4(const _mlas_fp16_* Buffer, size_t len)
Expand Down Expand Up @@ -95,6 +104,14 @@ MlasStoreFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector)
vst1_u16(Buffer, vreinterpret_u16_f16(Vector));
}

template <int lane>
MLAS_FORCEINLINE
void
MlasStoreLaneFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector)
{
vst1_lane_u16(Buffer, vreinterpret_u16_f16(Vector), lane);
}

MLAS_FORCEINLINE
void
MlasStorePartialFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector, size_t len)
Expand Down
Loading
Loading