diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 0944e12e7115e..b11d09d674c4e 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -20,14 +20,12 @@ Module Name: #pragma once -#include "mlas.h" -#include "mlas_gemm_postprocessor.h" - #include #include #include "mlas.h" +#include "mlas_gemm_postprocessor.h" /** * @brief Define types of block quantization @@ -45,9 +43,9 @@ typedef enum { typedef enum { CompUndef = 0, /*!< undef */ CompFp32 = 1, /*!< input fp32, accumulator fp32 */ - CompInt8 = 2, /*!< input int8, accumulator int32 */ + CompFp16 = 2, /*!< input fp16, accumulator fp16 */ CompBf16 = 3, /*!< input bf16, accumulator fp32 */ - CompFp16 = 4 /*!< input fp16, accumulator fp16 */ + CompInt8 = 4 /*!< input int8, accumulator int32 */ } MLAS_COMPUTE_TYPE; /** @@ -72,12 +70,7 @@ MlasQ4GemmPackBSize(MLAS_BLK_QUANT_TYPE QType, size_t N, size_t K); * @param ldb leading dimension of B */ void MLASCALL -MlasQ4GemmPackB(MLAS_BLK_QUANT_TYPE QType, - void* PackedBuf, - const float* FpData, - size_t N, - size_t K, - size_t ldb); +MlasQ4GemmPackB(MLAS_BLK_QUANT_TYPE QType, void* PackedBuf, const float* FpData, size_t N, size_t K, size_t ldb); /** * @brief Unpack and dequantize from int4 to fp32, reverse operation of @@ -90,13 +83,7 @@ MlasQ4GemmPackB(MLAS_BLK_QUANT_TYPE QType, * @param ldb leading dimension of B */ void MLASCALL -MlasQ4GemmUnPackB(MLAS_BLK_QUANT_TYPE QType, - float* FpData, - const void* PackedBuf, - size_t N, - size_t K, - size_t ldb); - +MlasQ4GemmUnPackB(MLAS_BLK_QUANT_TYPE QType, float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb); /** * @brief Data parameters for Q4 GEMM routine @@ -130,13 +117,7 @@ struct MLAS_Q4_GEMM_DATA_PARAMS { * @return */ void MLASCALL -MlasQ4GemmBatch(MLAS_BLK_QUANT_TYPE QType, - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, - MLAS_THREADPOOL* ThreadPool = nullptr); +MlasQ4GemmBatch(MLAS_BLK_QUANT_TYPE QType, const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr); /** * @brief Calculate the buffer size needed for int8 block quantize @@ -161,13 +142,7 @@ MlasQ80BlkQuantSize(MLAS_BLK_QUANT_TYPE QType, size_t M, size_t K); * @param ThreadPool */ void MLASCALL -MlasQ80BlkQuant(MLAS_BLK_QUANT_TYPE QType, - void* Qblob, - const float* A, - size_t M, - size_t K, - size_t lda, - MLAS_THREADPOOL* ThreadPool); +MlasQ80BlkQuant(MLAS_BLK_QUANT_TYPE QType, void* Qblob, const float* A, size_t M, size_t K, size_t lda, MLAS_THREADPOOL* ThreadPool); /** * @brief Data parameters for Q8Q4 GEMM routine @@ -200,13 +175,7 @@ struct MLAS_Q8Q4_GEMM_DATA_PARAMS { * @return */ void MLASCALL -MlasQ8Q4GemmBatch(MLAS_BLK_QUANT_TYPE QType, - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, - MLAS_THREADPOOL* ThreadPool); +MlasQ8Q4GemmBatch(MLAS_BLK_QUANT_TYPE QType, const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool); //////////////////////////////////////////////////////////// // Blockwise quantization and dequantization where quantization @@ -217,11 +186,12 @@ MlasQ8Q4GemmBatch(MLAS_BLK_QUANT_TYPE QType, * @brief For quantization type , and * matrix shape [rows, columns], compute the shape of the * quantization parameter matrix [meta_rows, meta_cols] -*/ + */ template void MlasBlockwiseQuantMetaShape( - int block_size, bool columnwise, int rows, int columns, int& meta_rows, int& meta_cols); + int block_size, bool columnwise, int rows, int columns, int& meta_rows, int& meta_cols +); /** * @brief For quantization type , and @@ -237,11 +207,12 @@ MlasBlockwiseQuantMetaShape( * @param columns * @param q_rows * @param q_cols -*/ + */ template void MlasBlockwiseQuantizedShape( - int block_size, bool columnwise, int rows, int columns, int& q_rows, int& q_cols); + int block_size, bool columnwise, int rows, int columns, int& q_rows, int& q_cols +); /** * @brief Compute the sizes of the quantized data and quantization parameter buffers. @@ -269,7 +240,6 @@ MlasBlockwiseQuantizedBufferSizes( size_t* q_zero_point_size_in_bytes ); - /** * @brief Blockwise 4 bits quantization, resulting elements and quantization * parameters (scales, zero points) are packed into separate matrices @@ -295,16 +265,7 @@ MlasBlockwiseQuantizedBufferSizes( */ template void -MlasQuantizeBlockwise(uint8_t* dst, - ElementT* scales, - uint8_t* zero_points, - const ElementT* src, - int block_size, - bool columnwise, - int rows, - int columns, - int leading_dimension, - MLAS_THREADPOOL* thread_pool); +MlasQuantizeBlockwise(uint8_t* dst, ElementT* scales, uint8_t* zero_points, const ElementT* src, int block_size, bool columnwise, int rows, int columns, int leading_dimension, MLAS_THREADPOOL* thread_pool); /** * @brief Blockwise 4 bits dequantization, quantized elements and quantization @@ -329,15 +290,7 @@ MlasQuantizeBlockwise(uint8_t* dst, */ template void -MlasDequantizeBlockwise(ElementT* dst, - const uint8_t* src, - const ElementT* scales, - const uint8_t* zero_points, - int block_size, - bool columnwise, - int rows, - int columns, - MLAS_THREADPOOL* thread_pool); +MlasDequantizeBlockwise(ElementT* dst, const uint8_t* src, const ElementT* scales, const uint8_t* zero_points, int block_size, bool columnwise, int rows, int columns, MLAS_THREADPOOL* thread_pool); /** * @brief Check if the parameter combination is supported @@ -353,7 +306,8 @@ MlasDequantizeBlockwise(ElementT* dst, */ bool MLASCALL MlasNBitsGemmPackBSupport( - size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_COMPUTE_TYPE comp_type); + size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_COMPUTE_TYPE comp_type +); /** * @brief Compute the byte size of the parameter combination @@ -369,7 +323,8 @@ MlasNBitsGemmPackBSupport( */ size_t MLASCALL MlasNBitsGemmPackBSize( - size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_COMPUTE_TYPE comp_type); + size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_COMPUTE_TYPE comp_type +); /** * @brief Prepack tensor data from MatMulNBits operator @@ -390,22 +345,10 @@ MlasNBitsGemmPackBSize( * @param thread_pool */ void MLASCALL -MlasNBitsGemmPackB(void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t block_size, - int nbits, - bool is_asym, - bool last_call, - MLAS_COMPUTE_TYPE comp_type, - MLAS_THREADPOOL* thread_pool); +MlasNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, size_t ldb, size_t block_size, int nbits, bool is_asym, bool last_call, MLAS_COMPUTE_TYPE comp_type, MLAS_THREADPOOL* thread_pool); /** * @brief Unpack and dequantize to fp32 - * + * * @param FpData unpakced float32 data * @param PackedBuf int4 quantized and packed data * @param N the number of columns of matrix B. @@ -414,12 +357,7 @@ MlasNBitsGemmPackB(void* PackedBuf, * @param thread_pool */ void MLASCALL -MlasNBitsGemmUnPackB(float* FpData, - const void* PackedBuf, - size_t N, - size_t K, - size_t ldb, - MLAS_THREADPOOL* thread_pool); +MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* thread_pool); /** * @brief Batched GEMM: C = A * B @@ -436,10 +374,4 @@ MlasNBitsGemmUnPackB(float* FpData, * @return */ void MLASCALL -MlasNBitsGemmBatch(const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, - int8_t* WorkSpace, - MLAS_THREADPOOL* ThreadPool = nullptr); \ No newline at end of file +MlasNBitsGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, int8_t* WorkSpace, MLAS_THREADPOOL* ThreadPool = nullptr); \ No newline at end of file