Skip to content

Commit

Permalink
update descriptions for new functions
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Nov 16, 2023
1 parent 66d2532 commit ac5e863
Showing 1 changed file with 58 additions and 28 deletions.
86 changes: 58 additions & 28 deletions onnxruntime/core/mlas/inc/mlas_q4.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,31 +339,55 @@ MlasDequantizeBlockwise(ElementT* dst,
int columns,
MLAS_THREADPOOL* thread_pool);

/**
* @brief Check if the parameter combination is supported
*
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param block_size size of the block to quantize, elements from the same block share the same
* scale and zero point
* @param nbits number of bits used for weight quantization (default 4)
* @param is_asym flag for asymmetric quantization
* @param comp_type specify input data type and accumulator data type
* @return support flag, true if the combination is supported.
*/
bool MLASCALL
MlasNBitsGemmPackBSupport(
size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_COMPUTE_TYPE CompType);
size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_COMPUTE_TYPE comp_type);

/**
* @brief Computes the number of bytes required to pack and int4-quantize
* a weight matrix
* @param QType type of block quantization
* @brief Compute the byte size of the parameter combination
*
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param block_size size of the block to quantize, elements from the same block share the same
* scale and zero point
* @param nbits number of bits used for weight quantization (default 4)
* @param is_asym flag for asymmetric quantization
* @param comp_type specify input data type and accumulator data type
* @return size of the packing buffer, 0 if the operation is not yet supported.
*/
size_t MLASCALL
MlasNBitsGemmPackBSize(
size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_COMPUTE_TYPE CompType);
size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_COMPUTE_TYPE comp_type);

/**
* @brief Prepack and Quantize fp32 weight tensor to int4 blocks
* @brief Prepack tensor data from MatMulNBits operator
*
* @param QType type of block quantization
* @param PackedBuf destination buffer
* @param FpData the pointer to fp32 matrix
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param ldb leading dimension of B
* @param PackedBuf pakced data buffer
* @param QData quantized data buffer
* @param Scale scale pointer
* @param Zp zero point pointer
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param ldb leading dimension of B
* @param block_size size of the block to quantize, elements from the same block share the same
* scale and zero point
* @param nbits number of bits used for weight quantization (default 4)
* @param is_asym flag for asymmetric quantization
* @param comp_type specify input data type and accumulator data type
* @param last_call flag to activate the epilogue process of packB
* @param thread_pool
*/
void MLASCALL
MlasNBitsGemmPackB(void* PackedBuf,
Expand All @@ -373,37 +397,43 @@ MlasNBitsGemmPackB(void* PackedBuf,
size_t N,
size_t K,
size_t ldb,
size_t BlkSize,
size_t block_size,
int nbits,
bool isAsym,
bool lastCall,
MLAS_COMPUTE_TYPE CompType,
MLAS_THREADPOOL* ThreadPool);
bool is_asym,
bool last_call,
MLAS_COMPUTE_TYPE comp_type,
MLAS_THREADPOOL* thread_pool);
/**
* @brief Unpack and dequantize from int4 to fp32, reverse operation of
* MlasQ4GemmPackB
* @param QType type of block quantization
* @param FpData destination buffer, the fp32 matrix
* @brief Unpack and dequantize to fp32
*
* @param FpData unpakced float32 data
* @param PackedBuf int4 quantized and packed data
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param ldb leading dimension of B
* @param thread_pool
*/
void MLASCALL
MlasNBitsGemmUnPackB(float* FpData,
const void* PackedBuf,
size_t N,
size_t K,
size_t ldb,
MLAS_THREADPOOL* ThreadPool);
MLAS_THREADPOOL* thread_pool);

/**
* @brief Calculate the buffer size needed for int8 block quantize
* @param[in] QType Type of block quantization used
* @param[in] M Number of rows of the input matrix
* @param[in] K Number of columns of the input matrix
* @return buffer size (in bytes) needed, 0 if not yet supported on current
* hardware
* @brief Batched GEMM: C = A * B
* A, C must be a float32 matrix
* B must be a packed nbits blob
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] WorkSpace temporary buffer
* @param[in] ThreadPool
* @return
*/
void MLASCALL
MlasNBitsGemmBatch(const size_t M,
Expand Down

0 comments on commit ac5e863

Please sign in to comment.