Skip to content

Commit

Permalink
sync compute_type with attribute order
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Nov 16, 2023
1 parent ac5e863 commit 2d0668c
Showing 1 changed file with 24 additions and 92 deletions.
116 changes: 24 additions & 92 deletions onnxruntime/core/mlas/inc/mlas_q4.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@ Module Name:

#pragma once

#include "mlas.h"
#include "mlas_gemm_postprocessor.h"

#include <math.h>

#include <algorithm>

#include "mlas.h"
#include "mlas_gemm_postprocessor.h"

/**
* @brief Define types of block quantization
Expand All @@ -45,9 +43,9 @@ typedef enum {
typedef enum {
CompUndef = 0, /*!< undef */
CompFp32 = 1, /*!< input fp32, accumulator fp32 */
CompInt8 = 2, /*!< input int8, accumulator int32 */
CompFp16 = 2, /*!< input fp16, accumulator fp16 */
CompBf16 = 3, /*!< input bf16, accumulator fp32 */
CompFp16 = 4 /*!< input fp16, accumulator fp16 */
CompInt8 = 4 /*!< input int8, accumulator int32 */
} MLAS_COMPUTE_TYPE;

/**
Expand All @@ -72,12 +70,7 @@ MlasQ4GemmPackBSize(MLAS_BLK_QUANT_TYPE QType, size_t N, size_t K);
* @param ldb leading dimension of B
*/
void MLASCALL
MlasQ4GemmPackB(MLAS_BLK_QUANT_TYPE QType,
void* PackedBuf,
const float* FpData,
size_t N,
size_t K,
size_t ldb);
MlasQ4GemmPackB(MLAS_BLK_QUANT_TYPE QType, void* PackedBuf, const float* FpData, size_t N, size_t K, size_t ldb);

/**
* @brief Unpack and dequantize from int4 to fp32, reverse operation of
Expand All @@ -90,13 +83,7 @@ MlasQ4GemmPackB(MLAS_BLK_QUANT_TYPE QType,
* @param ldb leading dimension of B
*/
void MLASCALL
MlasQ4GemmUnPackB(MLAS_BLK_QUANT_TYPE QType,
float* FpData,
const void* PackedBuf,
size_t N,
size_t K,
size_t ldb);

MlasQ4GemmUnPackB(MLAS_BLK_QUANT_TYPE QType, float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb);

/**
* @brief Data parameters for Q4 GEMM routine
Expand Down Expand Up @@ -130,13 +117,7 @@ struct MLAS_Q4_GEMM_DATA_PARAMS {
* @return
*/
void MLASCALL
MlasQ4GemmBatch(MLAS_BLK_QUANT_TYPE QType,
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_Q4_GEMM_DATA_PARAMS* DataParams,
MLAS_THREADPOOL* ThreadPool = nullptr);
MlasQ4GemmBatch(MLAS_BLK_QUANT_TYPE QType, const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr);

/**
* @brief Calculate the buffer size needed for int8 block quantize
Expand All @@ -161,13 +142,7 @@ MlasQ80BlkQuantSize(MLAS_BLK_QUANT_TYPE QType, size_t M, size_t K);
* @param ThreadPool
*/
void MLASCALL
MlasQ80BlkQuant(MLAS_BLK_QUANT_TYPE QType,
void* Qblob,
const float* A,
size_t M,
size_t K,
size_t lda,
MLAS_THREADPOOL* ThreadPool);
MlasQ80BlkQuant(MLAS_BLK_QUANT_TYPE QType, void* Qblob, const float* A, size_t M, size_t K, size_t lda, MLAS_THREADPOOL* ThreadPool);

/**
* @brief Data parameters for Q8Q4 GEMM routine
Expand Down Expand Up @@ -200,13 +175,7 @@ struct MLAS_Q8Q4_GEMM_DATA_PARAMS {
* @return
*/
void MLASCALL
MlasQ8Q4GemmBatch(MLAS_BLK_QUANT_TYPE QType,
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams,
MLAS_THREADPOOL* ThreadPool);
MlasQ8Q4GemmBatch(MLAS_BLK_QUANT_TYPE QType, const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool);

////////////////////////////////////////////////////////////
// Blockwise quantization and dequantization where quantization
Expand All @@ -217,11 +186,12 @@ MlasQ8Q4GemmBatch(MLAS_BLK_QUANT_TYPE QType,
* @brief For quantization type <T, block_size, columnwise>, and
* matrix shape [rows, columns], compute the shape of the
* quantization parameter matrix [meta_rows, meta_cols]
*/
*/
template <typename T, int qbits>
void
MlasBlockwiseQuantMetaShape(
int block_size, bool columnwise, int rows, int columns, int& meta_rows, int& meta_cols);
int block_size, bool columnwise, int rows, int columns, int& meta_rows, int& meta_cols
);

/**
* @brief For quantization type <T, block_size, columnwise>, and
Expand All @@ -237,11 +207,12 @@ MlasBlockwiseQuantMetaShape(
* @param columns
* @param q_rows
* @param q_cols
*/
*/
template <typename T, int qbits>
void
MlasBlockwiseQuantizedShape(
int block_size, bool columnwise, int rows, int columns, int& q_rows, int& q_cols);
int block_size, bool columnwise, int rows, int columns, int& q_rows, int& q_cols
);

/**
* @brief Compute the sizes of the quantized data and quantization parameter buffers.
Expand Down Expand Up @@ -269,7 +240,6 @@ MlasBlockwiseQuantizedBufferSizes(
size_t* q_zero_point_size_in_bytes
);


/**
* @brief Blockwise 4 bits quantization, resulting elements and quantization
* parameters (scales, zero points) are packed into separate matrices
Expand All @@ -295,16 +265,7 @@ MlasBlockwiseQuantizedBufferSizes(
*/
template <typename ElementT, int qbits>
void
MlasQuantizeBlockwise(uint8_t* dst,
ElementT* scales,
uint8_t* zero_points,
const ElementT* src,
int block_size,
bool columnwise,
int rows,
int columns,
int leading_dimension,
MLAS_THREADPOOL* thread_pool);
MlasQuantizeBlockwise(uint8_t* dst, ElementT* scales, uint8_t* zero_points, const ElementT* src, int block_size, bool columnwise, int rows, int columns, int leading_dimension, MLAS_THREADPOOL* thread_pool);

/**
* @brief Blockwise 4 bits dequantization, quantized elements and quantization
Expand All @@ -329,15 +290,7 @@ MlasQuantizeBlockwise(uint8_t* dst,
*/
template <typename ElementT, int qbits>
void
MlasDequantizeBlockwise(ElementT* dst,
const uint8_t* src,
const ElementT* scales,
const uint8_t* zero_points,
int block_size,
bool columnwise,
int rows,
int columns,
MLAS_THREADPOOL* thread_pool);
MlasDequantizeBlockwise(ElementT* dst, const uint8_t* src, const ElementT* scales, const uint8_t* zero_points, int block_size, bool columnwise, int rows, int columns, MLAS_THREADPOOL* thread_pool);

/**
* @brief Check if the parameter combination is supported
Expand All @@ -353,7 +306,8 @@ MlasDequantizeBlockwise(ElementT* dst,
*/
bool MLASCALL
MlasNBitsGemmPackBSupport(
size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_COMPUTE_TYPE comp_type);
size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_COMPUTE_TYPE comp_type
);

/**
* @brief Compute the byte size of the parameter combination
Expand All @@ -369,7 +323,8 @@ MlasNBitsGemmPackBSupport(
*/
size_t MLASCALL
MlasNBitsGemmPackBSize(
size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_COMPUTE_TYPE comp_type);
size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_COMPUTE_TYPE comp_type
);

/**
* @brief Prepack tensor data from MatMulNBits operator
Expand All @@ -390,22 +345,10 @@ MlasNBitsGemmPackBSize(
* @param thread_pool
*/
void MLASCALL
MlasNBitsGemmPackB(void* PackedBuf,
const uint8_t* QData,
const float* Scale,
const uint8_t* Zp,
size_t N,
size_t K,
size_t ldb,
size_t block_size,
int nbits,
bool is_asym,
bool last_call,
MLAS_COMPUTE_TYPE comp_type,
MLAS_THREADPOOL* thread_pool);
MlasNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, size_t ldb, size_t block_size, int nbits, bool is_asym, bool last_call, MLAS_COMPUTE_TYPE comp_type, MLAS_THREADPOOL* thread_pool);
/**
* @brief Unpack and dequantize to fp32
*
*
* @param FpData unpakced float32 data
* @param PackedBuf int4 quantized and packed data
* @param N the number of columns of matrix B.
Expand All @@ -414,12 +357,7 @@ MlasNBitsGemmPackB(void* PackedBuf,
* @param thread_pool
*/
void MLASCALL
MlasNBitsGemmUnPackB(float* FpData,
const void* PackedBuf,
size_t N,
size_t K,
size_t ldb,
MLAS_THREADPOOL* thread_pool);
MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* thread_pool);

/**
* @brief Batched GEMM: C = A * B
Expand All @@ -436,10 +374,4 @@ MlasNBitsGemmUnPackB(float* FpData,
* @return
*/
void MLASCALL
MlasNBitsGemmBatch(const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_Q4_GEMM_DATA_PARAMS* DataParams,
int8_t* WorkSpace,
MLAS_THREADPOOL* ThreadPool = nullptr);
MlasNBitsGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_Q4_GEMM_DATA_PARAMS* DataParams, int8_t* WorkSpace, MLAS_THREADPOOL* ThreadPool = nullptr);

0 comments on commit 2d0668c

Please sign in to comment.