diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index a630c098d0f77..c4fe1bb56c70c 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -20,21 +20,20 @@ Module Name: #pragma once -#include - -#include - #include "mlas.h" #include "mlas_gemm_postprocessor.h" +#include +#include + /** * @brief Define types of block quantization */ typedef enum { - BlkQ4Sym = 0, /*!< int4 Symmetric Block Quantization, zero_point = 0 */ - BlkQ4Zp8 = 1, /*!< int4 Block Quantization, zero_point is int8 type */ - BlkQ4Sym64 = 2, /*!< int4 Symmetric Block Quantization, 64 values per block*/ - BlkQ4Sym128 = 4 /*!< int4 Symmetric Block Quantization, 128 values per block*/ + BlkQ4Sym = 0, /*!< int4 Symmetric Block Quantization, zero_point = 0 */ + BlkQ4Zp8 = 1, /*!< int4 Block Quantization, zero_point is int8 type */ + BlkQ4Sym64 = 2, /*!< int4 Symmetric Block Quantization, 64 values per block*/ + BlkQ4Sym128 = 4 /*!< int4 Symmetric Block Quantization, 128 values per block*/ } MLAS_BLK_QUANT_TYPE; /** @@ -44,9 +43,14 @@ typedef enum { * @param N the number of columns of matrix B. * @param K the number of rows of matrix B. * @return size of the packing buffer, 0 if the operation is not yet supported. - */ -size_t MLASCALL -MlasQ4GemmPackBSize(MLAS_BLK_QUANT_TYPE QType, size_t N, size_t K); +*/ +size_t +MLASCALL +MlasQ4GemmPackBSize( + MLAS_BLK_QUANT_TYPE QType, + size_t N, + size_t K + ); /** * @brief Prepack and Quantize fp32 weight tensor to int4 blocks @@ -57,9 +61,18 @@ MlasQ4GemmPackBSize(MLAS_BLK_QUANT_TYPE QType, size_t N, size_t K); * @param N the number of columns of matrix B. * @param K the number of rows of matrix B. * @param ldb leading dimension of B - */ -void MLASCALL -MlasQ4GemmPackB(MLAS_BLK_QUANT_TYPE QType, void* PackedBuf, const float* FpData, size_t N, size_t K, size_t ldb); +*/ +void +MLASCALL +MlasQ4GemmPackB( + MLAS_BLK_QUANT_TYPE QType, + void* PackedBuf, + const float* FpData, + size_t N, + size_t K, + size_t ldb + ); + /** * @brief Unpack and dequantize from int4 to fp32, reverse operation of @@ -71,8 +84,17 @@ MlasQ4GemmPackB(MLAS_BLK_QUANT_TYPE QType, void* PackedBuf, const float* FpData, * @param K the number of rows of matrix B. * @param ldb leading dimension of B */ -void MLASCALL -MlasQ4GemmUnPackB(MLAS_BLK_QUANT_TYPE QType, float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb); +void +MLASCALL +MlasQ4GemmUnPackB( + MLAS_BLK_QUANT_TYPE QType, + float* FpData, + const void* PackedBuf, + size_t N, + size_t K, + size_t ldb + ); + /** * @brief Data parameters for Q4 GEMM routine @@ -82,12 +104,12 @@ MlasQ4GemmUnPackB(MLAS_BLK_QUANT_TYPE QType, float* FpData, const void* PackedBu * All except C are [in] parameters */ struct MLAS_Q4_GEMM_DATA_PARAMS { - const float* A = nullptr; /**< address of A (float32 matrix)*/ - const void* B = nullptr; /**< address of B (quantized and packed int4 blob)*/ - const float* Bias = nullptr; /**< address of Bias, vector size N */ - float* C = nullptr; /**< address of result matrix */ - size_t lda = 0; /**< leading dimension of A */ - size_t ldc = 0; /**< leading dimension of C*/ + const float* A = nullptr; /**< address of A (float32 matrix)*/ + const void* B = nullptr; /**< address of B (quantized and packed int4 blob)*/ + const float* Bias = nullptr; /**< address of Bias, vector size N */ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldc = 0; /**< leading dimension of C*/ const MLAS_GEMM_POSTPROCESSOR* OutputProcessor = nullptr; }; @@ -114,7 +136,8 @@ MlasQ4GemmBatch( const size_t BatchN, const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr -); + ); + /** * @brief Calculate the buffer size needed for int8 block quantize @@ -122,9 +145,9 @@ MlasQ4GemmBatch( * @param[in] M Number of rows of the input matrix * @param[in] K Number of columns of the input matrix * @return buffer size (in bytes) needed, 0 if not yet supported on current hardware - */ - -size_t MLASCALL +*/ +size_t +MLASCALL MlasQ80BlkQuantSize(MLAS_BLK_QUANT_TYPE QType, size_t M, size_t K); /** @@ -137,11 +160,19 @@ MlasQ80BlkQuantSize(MLAS_BLK_QUANT_TYPE QType, size_t M, size_t K); * @param K Number of columns of the input matrix * @param lda leading dimension of the input matrix * @param ThreadPool - */ -void MLASCALL +*/ +void +MLASCALL MlasQ80BlkQuant( - MLAS_BLK_QUANT_TYPE QType, void* Qblob, const float* A, size_t M, size_t K, size_t lda, MLAS_THREADPOOL* ThreadPool -); + MLAS_BLK_QUANT_TYPE QType, + void* Qblob, + const float* A, + size_t M, + size_t K, + size_t lda, + MLAS_THREADPOOL* ThreadPool + ); + /** * @brief Data parameters for Q8Q4 GEMM routine @@ -182,7 +213,8 @@ MlasQ8Q4GemmBatch( const size_t BatchN, const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool -); + ); + //////////////////////////////////////////////////////////// // Blockwise quantization and dequantization where quantization @@ -193,10 +225,17 @@ MlasQ8Q4GemmBatch( * @brief For quantization type , and * matrix shape [rows, columns], compute the shape of the * quantization parameter matrix [meta_rows, meta_cols] - */ +*/ template void -MlasBlockwiseQuantMetaShape(int block_size, bool columnwise, int rows, int columns, int& meta_rows, int& meta_cols); +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); /** * @brief For quantization type , and @@ -212,10 +251,17 @@ MlasBlockwiseQuantMetaShape(int block_size, bool columnwise, int rows, int colum * @param columns * @param q_rows * @param q_cols - */ +*/ template void -MlasBlockwiseQuantizedShape(int block_size, bool columnwise, int rows, int columns, int& q_rows, int& q_cols); +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); /** * @brief Compute the sizes of the quantized data and quantization parameter buffers. @@ -243,6 +289,7 @@ MlasBlockwiseQuantizedBufferSizes( size_t* q_zero_point_size_in_bytes ); + /** * @brief Blockwise 4 bits quantization, resulting elements and quantization * parameters (scales, zero points) are packed into separate matrices @@ -255,17 +302,14 @@ MlasBlockwiseQuantizedBufferSizes( * @param dst points to the quantized matrix, shape [rows, columns] column major * @param scales points to the scales matrix, column major * @param zero_points points to the zero_points matrix, column major - * @param src points to the floating point matrix, to be quantized, row major - * shape [rows, columns] - * @param block_size size of the block to quantize, elements from the same block share - * the same scale and zero point - * @param columnwise true when elements in a block are from the same column, false when - * elements in a block are from the same row + * @param src points to the floating point matrix, to be quantized, row major shape [rows, columns] + * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point + * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row * @param rows * @param columns * @param leading_dimension * @param thread_pool - */ +*/ template void MlasQuantizeBlockwise( @@ -279,7 +323,8 @@ MlasQuantizeBlockwise( int columns, int leading_dimension, MLAS_THREADPOOL* thread_pool -); + ); + /** * @brief Blockwise 4 bits dequantization, quantized elements and quantization @@ -294,14 +339,12 @@ MlasQuantizeBlockwise( * @param src points to quantized matrix, column major * @param scales points to quantization scales, column major * @param zero_points points to quantization zero points, column major - * @param block_size size of the block to quantize, elements from the same block share the same - * scale and zero point - * @param columnwise true when elements in a block are from the same column, false when elements - * in a block are from the same row + * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point + * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row * @param rows * @param columns * @param thread_pool - */ +*/ template void MlasDequantizeBlockwise( @@ -314,4 +357,4 @@ MlasDequantizeBlockwise( int rows, int columns, MLAS_THREADPOOL* thread_pool -); + ); \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index f10a258929518..aa560638decee 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -17,17 +17,20 @@ Module Name: language models. --*/ + #include "q4common.h" -template -constexpr size_t +template +constexpr +size_t BlkQ4BufSize(size_t N, size_t K) { const size_t KBlocks = MlasDivRoundup(K, T::BlkLen); return N * KBlocks * T::BlobSize; } -size_t MLASCALL +size_t +MLASCALL MlasQ4GemmPackBSize(MLAS_BLK_QUANT_TYPE QType, size_t N, size_t K) { if (GetMlasPlatform().FpQ4GemmDispatch == nullptr) { @@ -46,18 +49,20 @@ MlasQ4GemmPackBSize(MLAS_BLK_QUANT_TYPE QType, size_t N, size_t K) } } -template -MLAS_FORCEINLINE void + +template +MLAS_FORCEINLINE +void MlasQ4GemmPackBImpl(void* PackedBuf, const float* FpData, size_t N, size_t K, size_t ldb) { auto* dst_ptr = reinterpret_cast(PackedBuf); - for (size_t n = 0; n < N; n++) { - const float* src = FpData; // starting from top of the column + for (size_t n = 0; n < N; n ++) { + const float* src = FpData; // starting from top of the column for (size_t k = 0; k < K; k += T::BlkLen) { size_t klen = std::min(size_t(T::BlkLen), K - k); - float amax = 0.0f; // abs(max) + float amax = 0.0f; // abs(max) float max = 0.0f; for (size_t l = 0; l < klen; l++) { @@ -93,18 +98,20 @@ MlasQ4GemmPackBImpl(void* PackedBuf, const float* FpData, size_t N, size_t K, si src += ldb * klen; } - FpData++; // move to next column + FpData++; // move to next column } } -template <> -MLAS_FORCEINLINE void -MlasQ4GemmPackBImpl(void* PackedBuf, const float* FpData, size_t N, size_t K, size_t ldb) +template<> +MLAS_FORCEINLINE +void +MlasQ4GemmPackBImpl( + void* PackedBuf, const float* FpData, size_t N, size_t K, size_t ldb) { auto* dst_ptr = reinterpret_cast(PackedBuf); for (size_t n = 0; n < N; n++) { - const float* src = FpData; // starting from top of the column + const float* src = FpData; // starting from top of the column for (size_t k = 0; k < K; k += MLAS_Q4TYPE_BLK1::BlkLen) { size_t klen = std::min(MLAS_Q4TYPE_BLK1::BlkLen, K - k); @@ -142,11 +149,13 @@ MlasQ4GemmPackBImpl(void* PackedBuf, const float* FpData, size size_t kklen = std::min((size_t)32, klen - kk); for (size_t l = 0; l < 16; l++) { const float v0 = l < kklen ? src[ldb * (kk + l)] : 0; - const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 * reciprocal_scale + zp))); + const uint8_t vi0 = (uint8_t)std::min( + 15.0f, std::max(0.0f, roundf(v0 * reciprocal_scale + zp))); const size_t l1 = l + 16; const float v1 = (l1 < kklen) ? src[ldb * (kk + l1)] : 0; - const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 * reciprocal_scale + zp))); + const uint8_t vi1 = (uint8_t)std::min( + 15.0f, std::max(0.0f, roundf(v1 * reciprocal_scale + zp))); data[l] = vi0 | (vi1 << 4); } @@ -156,12 +165,20 @@ MlasQ4GemmPackBImpl(void* PackedBuf, const float* FpData, size dst_ptr += MLAS_Q4TYPE_BLK1::BlobSize; src += ldb * klen; } - FpData++; // move to next column + FpData++; // move to next column } } -void MLASCALL -MlasQ4GemmPackB(MLAS_BLK_QUANT_TYPE QType, void* PackedBuf, const float* FpData, size_t N, size_t K, size_t ldb) +void +MLASCALL +MlasQ4GemmPackB( + MLAS_BLK_QUANT_TYPE QType, + void* PackedBuf, + const float* FpData, + size_t N, + size_t K, + size_t ldb + ) { switch (QType) { case BlkQ4Sym: @@ -175,8 +192,9 @@ MlasQ4GemmPackB(MLAS_BLK_QUANT_TYPE QType, void* PackedBuf, const float* FpData, } } -template -MLAS_FORCEINLINE void +template +MLAS_FORCEINLINE +void MlasQ4GemmUnPackBImpl(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb) { const auto* src = reinterpret_cast(PackedBuf); @@ -213,9 +231,11 @@ MlasQ4GemmUnPackBImpl(float* FpData, const void* PackedBuf, size_t N, size_t K, } } -template <> -MLAS_FORCEINLINE void -MlasQ4GemmUnPackBImpl(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb) +template<> +MLAS_FORCEINLINE +void +MlasQ4GemmUnPackBImpl( + float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb) { const auto* src = reinterpret_cast(PackedBuf); for (size_t n = 0; n < N; n++) { @@ -252,8 +272,16 @@ MlasQ4GemmUnPackBImpl(float* FpData, const void* PackedBuf, si } } -void MLASCALL -MlasQ4GemmUnPackB(MLAS_BLK_QUANT_TYPE QType, float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb) +void +MLASCALL +MlasQ4GemmUnPackB( + MLAS_BLK_QUANT_TYPE QType, + float* FpData, + const void* PackedBuf, + size_t N, + size_t K, + size_t ldb + ) { switch (QType) { case BlkQ4Sym: @@ -267,11 +295,14 @@ MlasQ4GemmUnPackB(MLAS_BLK_QUANT_TYPE QType, float* FpData, const void* PackedBu } } + + /*************************************************************** * The quantization format that pack data and quantization * parameters into separate buffers. */ + template < int Row_, ///< rows of a matrix int Column_ ///< columns of a matrix @@ -282,6 +313,7 @@ struct Shape2D { static int const kCount = Row_ * Column_; ///< total number of elements in a matrix }; + template struct BitsTraits { static_assert(qbits <= 8, "Only BitsTraits are for small number of bits!"); @@ -296,6 +328,7 @@ struct BitsTraits { static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); }; + /** * @brief Rectify min/max from a set of weights, and convert to scale and zero point * for quantization @@ -307,7 +340,8 @@ struct BitsTraits { * @param[out] zp */ template -MLAS_FORCEINLINE void +MLAS_FORCEINLINE +void range2scalezp(float min, float max, ScaleT& scale, uint8_t& zp) { constexpr int zp_max = BitsTraits::kMax; @@ -334,7 +368,8 @@ range2scalezp(float min, float max, ScaleT& scale, uint8_t& zp) } template -MLAS_FORCEINLINE void +MLAS_FORCEINLINE +void range2scale(float min, float max, ScaleT& scale) { constexpr int mid_v = BitsTraits::kMid; @@ -345,6 +380,7 @@ range2scale(float min, float max, ScaleT& scale) scale = ScaleT(max / mid_fp); }; + /** * @brief Blockwise quantization methods * @tparam ElementT source data type, e.g. fp32/fp16 @@ -353,7 +389,11 @@ range2scale(float min, float max, ScaleT& scale) * @tparam Columnwise true: elements in a block come from one single column * false: elements in a block come from one single row */ -template +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> struct BlockwiseQuantizer { // To support other qbits, need to add bit packing code for // storing to dst and zero points @@ -362,14 +402,17 @@ struct BlockwiseQuantizer { using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; using ThreadBlk = Shape2D::kPackSize, QuantBlk::kColumn>; - static MLAS_FORCEINLINE void quantizeMetaShape(int rows, int columns, int& meta_rows, int& meta_cols) + static + MLAS_FORCEINLINE + void quantizeMetaShape(int rows, int columns, int& meta_rows, int& meta_cols) { meta_rows = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; meta_cols = (columns + QuantBlk::kColumn - 1) / QuantBlk::kColumn; } - static MLAS_FORCEINLINE void quantizedShape(int rows, int columns, int& q_rows, int& q_cols) - { + static + MLAS_FORCEINLINE + void quantizedShape(int rows, int columns, int& q_rows, int& q_cols) { int meta_rows; int meta_cols; quantizeMetaShape(rows, columns, meta_rows, meta_cols); @@ -401,8 +444,7 @@ struct BlockwiseQuantizer { * @brief Quantized a Matrix shape [rows, columns], resulting quantized * and packed data are stored in column major (transposed) * @param[out] dst pointer to the quantized weights, column major: [columns, rows] - * @param[out] scale pointer to the scales, column major: [columns/QuantBlk::kColumn, - * rows/QuantBlk::kRow] + * @param[out] scale pointer to the scales, column major: [columns/QuantBlk::kColumn, rows/QuantBlk::kRow] * @param[out] zero_points pointer to the zero points, same shape as scale * @param[in] src pointer to the source matrix, row major: [rows, columns] * @param rows @@ -417,8 +459,7 @@ struct BlockwiseQuantizer { int32_t rows, int32_t columns, int32_t leadingDimension, - MLAS_THREADPOOL* thread_pool - ) + MLAS_THREADPOOL* thread_pool) { // Thread partitioning const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; @@ -430,84 +471,88 @@ struct BlockwiseQuantizer { int q_rows, q_cols; quantizedShape(rows, columns, q_rows, q_cols); - MlasTryBatchParallel(thread_pool, total_thrd_blks, [&](ptrdiff_t block_idx) { - uint8_t zp_bytes[BitsTraits::kPackSize]; - std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); - - const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); - const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); - - const int32_t r = r_blk_idx * ThreadBlk::kRow; - const int32_t c = c_blk_idx * ThreadBlk::kColumn; - - const int32_t r_end = std::min(r + ThreadBlk::kRow, rows); - const int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); - - const int meta_row = r / QuantBlk::kRow; - const int meta_col = c / QuantBlk::kColumn; - - // compute scale and zero point - for (int kpack = 0; kpack < BitsTraits::kPackSize; kpack++) { - // scan a single block to extract range [min, max] - float min = std::numeric_limits::max(); - float max = -min; - const int row_start = r + kpack * QuantBlk::kRow; - const int row_end = std::min(row_start + QuantBlk::kRow, r_end); - for (int i = row_start; i < row_end; ++i) { - for (int j = c; j < c_end; ++j) { - const float v = static_cast(src[i * leadingDimension + j]); - if (v < min) min = v; - if (v > max) max = v; + MlasTryBatchParallel( + thread_pool, total_thrd_blks, + [&](ptrdiff_t block_idx) { + uint8_t zp_bytes[BitsTraits::kPackSize]; + std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); + + const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); + const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); + + const int32_t r = r_blk_idx * ThreadBlk::kRow; + const int32_t c = c_blk_idx * ThreadBlk::kColumn; + + const int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + const int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + const int meta_row = r / QuantBlk::kRow; + const int meta_col = c / QuantBlk::kColumn; + + // compute scale and zero point + for (int kpack = 0; kpack < BitsTraits::kPackSize; kpack++) { + + // scan a single block to extract range [min, max] + float min = std::numeric_limits::max(); + float max = -min; + const int row_start = r + kpack * QuantBlk::kRow; + const int row_end = std::min(row_start + QuantBlk::kRow, r_end); + for (int i = row_start; i < row_end; ++i) { + for (int j = c; j < c_end; ++j) { + const float v = static_cast(src[i * leadingDimension + j]); + if (v < min) min = v; + if (v > max) max = v; + } } - } - // store scale and zero point at quant parameter matrix position - if (row_start < row_end) { - const int32_t meta_idx = meta_col * row_blks + meta_row + kpack; - if (zero_points == nullptr) { - range2scale(min, max, scales[meta_idx]); - } else { - range2scalezp(min, max, scales[meta_idx], zp_bytes[kpack]); + // store scale and zero point at quant parameter matrix position + if (row_start < row_end) { + const int32_t meta_idx = meta_col * row_blks + meta_row + kpack; + if (zero_points == nullptr) { + range2scale(min, max, scales[meta_idx]); + } else { + range2scalezp(min, max, scales[meta_idx], zp_bytes[kpack]); + } } } - } - // !! 4b specific code as we need to pack 2 4b numbers into one byte - if (zero_points != nullptr) { - const int32_t meta_idx = meta_col * ((row_blks + 1) / 2) + meta_row / 2; - zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); - } + // !! 4b specific code as we need to pack 2 4b numbers into one byte + if (zero_points != nullptr) { + const int32_t meta_idx = meta_col * ((row_blks + 1) / 2) + meta_row / 2; + zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); + } - for (int32_t j = c; j < c_end; ++j) { - const int32_t meta_c = j / QuantBlk::kColumn; - for (int32_t i = r; i < r_end; i += 2) { - const int32_t meta_r = i / QuantBlk::kRow; - const float scale = static_cast(scales[meta_c * row_blks + meta_r]); - const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; - const int8_t zp = zp_bytes[meta_r & 1]; - const int8_t zp1 = zp_bytes[((i + 1) / QuantBlk::kRow) & 1]; - - const float v0 = static_cast(src[i * leadingDimension + j]); - const uint8_t vi0 = - (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp), 0.0f, BitsTraits::kMaxFp); - - uint8_t vi1 = (uint8_t)zp; - if (i + 1 < r_end) { - float reciprocal_scale1 = reciprocal_scale; - if constexpr (QuantBlk::kRow == 1) { - const float scale1 = static_cast(scales[meta_c * row_blks + meta_r + 1]); - reciprocal_scale1 = scale1 ? 1.0f / scale1 : 0.0f; + for (int32_t j = c; j < c_end; ++j) { + const int32_t meta_c = j / QuantBlk::kColumn; + for (int32_t i = r; i < r_end; i += 2) { + const int32_t meta_r = i / QuantBlk::kRow; + const float scale = static_cast(scales[meta_c * row_blks + meta_r]); + const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; + const int8_t zp = zp_bytes[meta_r & 1]; + const int8_t zp1 = zp_bytes[((i + 1) / QuantBlk::kRow) & 1]; + + const float v0 = static_cast(src[i * leadingDimension + j]); + const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp), + 0.0f, BitsTraits::kMaxFp); + + uint8_t vi1 = (uint8_t)zp; + if (i + 1 < r_end) { + float reciprocal_scale1 = reciprocal_scale; + if constexpr (QuantBlk::kRow == 1) { + const float scale1 = + static_cast(scales[meta_c * row_blks + meta_r + 1]); + reciprocal_scale1 = scale1 ? 1.0f / scale1 : 0.0f; + } + const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); + vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale1 + zp1), 0.0f, + BitsTraits::kMaxFp); } - const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); - vi1 = - (uint8_t)std::clamp(roundf(v1 * reciprocal_scale1 + zp1), 0.0f, BitsTraits::kMaxFp); - } - // !! 4b specific code - dst[j * q_rows + i / 2] = (vi0 & 0xf) | (vi1 << 4); + // !! 4b specific code + dst[j * q_rows + i / 2] = (vi0 & 0xf) | (vi1 << 4); + } } - } - }); + }); } /** @@ -528,8 +573,7 @@ struct BlockwiseQuantizer { const uint8_t* zero_points, int32_t rows, int32_t columns, - MLAS_THREADPOOL* thread_pool - ) + MLAS_THREADPOOL* thread_pool) { // Thread partitioning const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; @@ -541,56 +585,70 @@ struct BlockwiseQuantizer { int q_rows, q_cols; quantizedShape(rows, columns, q_rows, q_cols); - MlasTryBatchParallel(thread_pool, total_thrd_blks, [&](ptrdiff_t block_idx) { - int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); - int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); - - int32_t r = r_blk_idx * ThreadBlk::kRow; - int32_t c = c_blk_idx * ThreadBlk::kColumn; - - int32_t r_end = std::min(r + ThreadBlk::kRow, rows); - int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); - - for (int32_t j = c; j < c_end; ++j) { - const int32_t meta_col = j / QuantBlk::kColumn; + MlasTryBatchParallel( + thread_pool, total_thrd_blks, + [&](ptrdiff_t block_idx) { + int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); + int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); - // !! 4b specific code - // the whole loop is 4b specific due to sub 8 bit packing - // and unpacking. We can potentially make this qbits generic - // by wraping the packing/unpacking code like cutlass::Array - for (int32_t i = r; i < r_end; i += 2) { - const int32_t meta_row = i / QuantBlk::kRow; + int32_t r = r_blk_idx * ThreadBlk::kRow; + int32_t c = c_blk_idx * ThreadBlk::kColumn; - const float scale0 = static_cast(scales[meta_col * row_blks + meta_row]); + int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); - const int zp_pair = - (zero_points == nullptr) ? 0x88 : zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2]; - const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); + for (int32_t j = c; j < c_end; ++j) { + const int32_t meta_col = j / QuantBlk::kColumn; - const uint8_t vi0 = weights[j * q_rows + i / 2] & 0xf; - const float v0 = (static_cast(vi0) - zp0) * scale0; - - dst[j * rows + i] = static_cast(v0); - if ((i + 1) < r_end) { - float scale1 = scale0; - int zp1 = zp0; - if constexpr (QuantBlk::kRow == 1) { - scale1 = static_cast(scales[meta_col * row_blks + meta_row + 1]); - zp1 = (zp_pair >> 4) & 0xf; + // !! 4b specific code + // the whole loop is 4b specific due to sub 8 bit packing + // and unpacking. We can potentially make this qbits generic + // by wraping the packing/unpacking code like cutlass::Array + for (int32_t i = r; i < r_end; i += 2) { + const int32_t meta_row = i / QuantBlk::kRow; + + const float scale0 = + static_cast(scales[meta_col * row_blks + meta_row]); + + const int zp_pair = + (zero_points == nullptr) + ? 0x88 + : zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2]; + const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); + + const uint8_t vi0 = weights[j * q_rows + i / 2] & 0xf; + const float v0 = (static_cast(vi0) - zp0) * scale0; + + dst[j * rows + i] = static_cast(v0); + if ((i + 1) < r_end) { + float scale1 = scale0; + int zp1 = zp0; + if constexpr (QuantBlk::kRow == 1) { + scale1 = + static_cast(scales[meta_col * row_blks + meta_row + 1]); + zp1 = (zp_pair >> 4) & 0xf; + } + const uint8_t vi1 = weights[j * q_rows + i / 2] >> 4; + const float v1 = (static_cast(vi1) - zp1) * scale1; + dst[j * rows + (i + 1)] = static_cast(v1); } - const uint8_t vi1 = weights[j * q_rows + i / 2] >> 4; - const float v1 = (static_cast(vi1) - zp1) * scale1; - dst[j * rows + (i + 1)] = static_cast(v1); } } - } - }); + }); } }; + template void -MlasBlockwiseQuantMetaShape(int block_size, bool columnwise, int rows, int columns, int& meta_rows, int& meta_cols) +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ) { switch (block_size) { case 16: { @@ -605,31 +663,38 @@ MlasBlockwiseQuantMetaShape(int block_size, bool columnwise, int rows, int colum if (columnwise) { BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + BlockwiseQuantizer::quantizeMetaShape( + rows, columns, meta_rows, meta_cols); } break; } case 64: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); } break; } case 128: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); } break; } case 256: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); } break; } @@ -640,9 +705,18 @@ MlasBlockwiseQuantMetaShape(int block_size, bool columnwise, int rows, int colum } } + + template void -MlasBlockwiseQuantizedShape(int block_size, bool columnwise, int rows, int columns, int& q_rows, int& q_cols) +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ) { switch (block_size) { case 16: { @@ -657,7 +731,8 @@ MlasBlockwiseQuantizedShape(int block_size, bool columnwise, int rows, int colum if (columnwise) { BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape( + rows, columns, q_rows, q_cols); } break; } @@ -692,13 +767,29 @@ MlasBlockwiseQuantizedShape(int block_size, bool columnwise, int rows, int colum } } -template void + +template +void MlasBlockwiseQuantMetaShape( - int block_size, bool columnwise, int rows, int columns, int& meta_rows, int& meta_cols -); + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); -template void -MlasBlockwiseQuantizedShape(int block_size, bool columnwise, int rows, int columns, int& q_rows, int& q_cols); void MLASCALL MlasBlockwiseQuantizedBufferSizes( @@ -786,6 +877,7 @@ MlasBlockwiseQuantizedBufferSizes( } } + template void MlasQuantizeBlockwise( @@ -799,66 +891,56 @@ MlasQuantizeBlockwise( int columns, int leading_dimension, MLAS_THREADPOOL* thread_pool -) + ) { switch (block_size) { case 16: if (columnwise) { BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool - ); + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } else { BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool - ); + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } break; case 32: if (columnwise) { BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool - ); + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } else { BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool - ); + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } break; case 64: if (columnwise) { BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool - ); + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } else { BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool - ); + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } break; case 128: if (columnwise) { BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool - ); + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } else { BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool - ); + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } break; case 256: if (columnwise) { BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool - ); + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } else { BlockwiseQuantizer::quantizeAndTranspose( - dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool - ); + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); } break; @@ -868,7 +950,8 @@ MlasQuantizeBlockwise( } } -template void +template +void MlasQuantizeBlockwise( uint8_t* dst, float* scales, @@ -880,9 +963,10 @@ MlasQuantizeBlockwise( int columns, int leading_dimension, MLAS_THREADPOOL* thread_pool -); + ); -template void +template +void MlasQuantizeBlockwise( uint8_t* dst, MLAS_FP16* scales, @@ -894,7 +978,8 @@ MlasQuantizeBlockwise( int columns, int leading_dimension, MLAS_THREADPOOL* thread_pool -); + ); + template void @@ -908,62 +993,52 @@ MlasDequantizeBlockwise( int rows, int columns, MLAS_THREADPOOL* thread_pool -) + ) { switch (block_size) { case 16: if (columnwise) { - BlockwiseQuantizer::dequantize( - dst, src, scales, zero_points, rows, columns, thread_pool - ); + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); } else { - BlockwiseQuantizer::dequantize( - dst, src, scales, zero_points, rows, columns, thread_pool - ); + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); } break; case 32: if (columnwise) { - BlockwiseQuantizer::dequantize( - dst, src, scales, zero_points, rows, columns, thread_pool - ); + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); } else { - BlockwiseQuantizer::dequantize( - dst, src, scales, zero_points, rows, columns, thread_pool - ); + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); } break; case 64: if (columnwise) { - BlockwiseQuantizer::dequantize( - dst, src, scales, zero_points, rows, columns, thread_pool - ); + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); } else { - BlockwiseQuantizer::dequantize( - dst, src, scales, zero_points, rows, columns, thread_pool - ); + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); } break; case 128: if (columnwise) { - BlockwiseQuantizer::dequantize( - dst, src, scales, zero_points, rows, columns, thread_pool - ); + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); } else { - BlockwiseQuantizer::dequantize( - dst, src, scales, zero_points, rows, columns, thread_pool - ); + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, + rows, columns, thread_pool); } break; case 256: if (columnwise) { - BlockwiseQuantizer::dequantize( - dst, src, scales, zero_points, rows, columns, thread_pool - ); + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); } else { - BlockwiseQuantizer::dequantize( - dst, src, scales, zero_points, rows, columns, thread_pool - ); + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, + rows, columns, thread_pool); } break; default: @@ -972,7 +1047,8 @@ MlasDequantizeBlockwise( } } -template void +template +void MlasDequantizeBlockwise( float* dst, const uint8_t* src, @@ -983,4 +1059,4 @@ MlasDequantizeBlockwise( int rows, int columns, MLAS_THREADPOOL* thread_pool -); + ); \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/q4gemm.cpp b/onnxruntime/core/mlas/lib/q4gemm.cpp index fe76c2ed85fd7..289c8d0f3d985 100644 --- a/onnxruntime/core/mlas/lib/q4gemm.cpp +++ b/onnxruntime/core/mlas/lib/q4gemm.cpp @@ -18,7 +18,9 @@ Module Name: #include "q4gemm.h" -size_t MLASCALL + +size_t +MLASCALL MlasQ80BlkQuantSize(MLAS_BLK_QUANT_TYPE QType, size_t M, size_t K) { if (GetMlasPlatform().Q8Q4GemmDispatch == nullptr) { @@ -36,17 +38,27 @@ MlasQ80BlkQuantSize(MLAS_BLK_QUANT_TYPE QType, size_t M, size_t K) } } -void MLASCALL + +void +MLASCALL MlasQ80BlkQuant( - MLAS_BLK_QUANT_TYPE QType, void* Qblob, const float* A, size_t M, size_t K, size_t lda, MLAS_THREADPOOL* ThreadPool -) + MLAS_BLK_QUANT_TYPE QType, + void* Qblob, + const float* A, + size_t M, + size_t K, + size_t lda, + MLAS_THREADPOOL* ThreadPool + ) { auto* dispatch = GetMlasPlatform().Q8Q4GemmDispatch; dispatch->Quants[QType](Qblob, A, M, K, lda, ThreadPool); } -template -MLAS_FORCEINLINE void + +template +MLAS_FORCEINLINE +void MlasQ4GemmBatchDriver( MLAS_BLK_QUANT_TYPE QType, const size_t M, @@ -55,16 +67,18 @@ MlasQ4GemmBatchDriver( const size_t BatchN, const ParamBlockType* DataParams, MLAS_THREADPOOL* ThreadPool -) + ) { - // const MLAS_Q4GEMM_DISPATCH* dispatch = MlasQ4GemmGetDispatch(); - // MLAS_Q4GEMM_OPERATION* operation = dispatch->Operation; - void (*operation)(const size_t, const ParamBlockType*, const size_t, const size_t, const size_t, const size_t) = - nullptr; + //const MLAS_Q4GEMM_DISPATCH* dispatch = MlasQ4GemmGetDispatch(); + //MLAS_Q4GEMM_OPERATION* operation = dispatch->Operation; + void (*operation)(const size_t, const ParamBlockType*, const size_t, const size_t, const size_t, + const size_t) = nullptr; - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) + { operation = GetMlasPlatform().FpQ4GemmDispatch->Operations[QType]; - } else { + } + else { operation = GetMlasPlatform().Q8Q4GemmDispatch->Operations[QType]; } @@ -105,8 +119,8 @@ MlasQ4GemmBatchDriver( const size_t BlockedM = MlasDivRoundup(M, StrideM); const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); if (max_nc < nc) { - nc = - std::min(nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * MLAS_QGEMM_STRIDEN_THREAD_ALIGN); + nc = std::min(nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * + MLAS_QGEMM_STRIDEN_THREAD_ALIGN); } } const size_t StrideN = nc; @@ -133,7 +147,9 @@ MlasQ4GemmBatchDriver( }); } -void MLASCALL + +void +MLASCALL MlasQ4GemmBatch( MLAS_BLK_QUANT_TYPE QType, const size_t M, @@ -142,12 +158,13 @@ MlasQ4GemmBatch( const size_t BatchN, const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool -) + ) { MlasQ4GemmBatchDriver(QType, M, N, K, BatchN, DataParams, ThreadPool); } -void MLASCALL +void +MLASCALL MlasQ8Q4GemmBatch( MLAS_BLK_QUANT_TYPE QType, const size_t M, @@ -156,7 +173,7 @@ MlasQ8Q4GemmBatch( const size_t BatchN, const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool -) + ) { MlasQ4GemmBatchDriver(QType, M, N, K, BatchN, DataParams, ThreadPool); -} +} \ No newline at end of file