Skip to content

Commit

Permalink
add AMX_INT8 kernels. move jblas option to the top.
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Nov 15, 2023
1 parent e517583 commit ca89422
Show file tree
Hide file tree
Showing 11 changed files with 448 additions and 363 deletions.
1 change: 1 addition & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF)
option(onnxruntime_USE_SNPE "Build with SNPE support" OFF)
option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF)
option(onnxruntime_USE_DNNL "Build with DNNL support" OFF)
option(onnxruntime_USE_JLAS "Build MLAS with JBLAS support" OFF)
option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF)
option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON)
option(onnxruntime_BUILD_CSHARP "Build C# library" OFF)
Expand Down
7 changes: 3 additions & 4 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Licensed under the MIT License.

set(MLAS_SRC_DIR ${ONNXRUNTIME_ROOT}/core/mlas/lib)
set(MLAS_WITH_JBLAS ON)
#
# All hardware agnostic source files here
# hardware specific files would cause trouble in
Expand Down Expand Up @@ -47,7 +46,7 @@ set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas)
function(add_jblas)
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas)
target_compile_definitions(onnxruntime_mlas PUBLIC MLAS_JBLAS)
target_compile_definitions(onnxruntime_mlas PRIVATE MLAS_JBLAS)
set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF)
endfunction()

Expand Down Expand Up @@ -204,7 +203,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/q4gemm_avx512.cpp
)
endif()
if(MLAS_WITH_JBLAS)
if(onnxruntime_USE_JLAS)
add_jblas()
endif()
else()
Expand Down Expand Up @@ -569,7 +568,7 @@ else()
)
set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
if(MLAS_WITH_JBLAS)
if(onnxruntime_USE_JLAS)
add_jblas()
endif()
endif()
Expand Down
16 changes: 8 additions & 8 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,29 +51,29 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;

if (accuracy_level_ > 0 && nbits_ == 4) {
auto compt_type = static_cast<MLAS_COMPUTE_TYPE>(accuracy_level_);
if (MlasNBitsGemmPackBSupport(N_, K_, block_size_, nbits_, is_asym_, compt_type)) {
// TODO use threadpool here
MLAS_THREADPOOL* pool = NULL;
if (input_idx == 1) {
auto qptr = tensor.Data<uint8_t>();
packed_b_size_ = MlasJblasQ4GemmPackBSize(N_, K_, block_size_, is_asym_, static_cast<MLAS_COMPUTE_TYPE>(accuracy_level_));
packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, nbits_, is_asym_, compt_type);
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasJblasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, is_asym_, false, static_cast<MLAS_COMPUTE_TYPE>(accuracy_level_), pool);
MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, nbits_, is_asym_, false, compt_type, pool);
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
is_packed = true;
}
if (input_idx == 2) {
auto sptr = tensor.Data<float>();
MlasJblasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, is_asym_, !is_asym_, static_cast<MLAS_COMPUTE_TYPE>(accuracy_level_), pool);
MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, nbits_, is_asym_, !is_asym_, compt_type, pool);
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
is_packed = true;
}
if (input_idx == 3) {
auto zptr = tensor.Data<uint8_t>();
MlasJblasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, is_asym_, is_asym_, static_cast<MLAS_COMPUTE_TYPE>(accuracy_level_), pool);
MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, nbits_, is_asym_, is_asym_, compt_type, pool);
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
is_packed = true;
Expand All @@ -87,7 +87,7 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prep
int input_idx,
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;

// Pack three tensors into one buffer
if (input_idx == 1) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
Expand Down Expand Up @@ -140,7 +140,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
gemm_params[i].C = y_data + helper.OutputOffsets()[i];
gemm_params[i].ldc = N;
}
MlasJblasQ4GemmBatch(M, N, K, max_len, gemm_params.data(), (int8_t*)ws_ptr.get(), thread_pool);
MlasNBitsGemmBatch(M, N, K, max_len, gemm_params.data(), (int8_t*)ws_ptr.get(), thread_pool);
return Status::OK();
}

Expand Down
92 changes: 36 additions & 56 deletions onnxruntime/core/mlas/inc/mlas_q4.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ Module Name:
* @brief Define types of block quantization
*/
typedef enum {
BlkQ4Sym = 0, /*!< int4 Symmetric Block Quantization, zero_point = 0 */
BlkQ4Zp8 = 1, /*!< int4 Block Quantization, zero_point is int8 type */
BlkQ4Sym64 = 2, /*!< int4 Symmetric Block Quantization, 64 values per block*/
BlkQ4Sym128 = 4, /*!< int4 Symmetric Block Quantization, 128 values per block*/
BlkQ4Sym = 0, /*!< int4 Symmetric Block Quantization, zero_point = 0 */
BlkQ4Zp8 = 1, /*!< int4 Block Quantization, zero_point is int8 type */
BlkQ4Sym64 = 2, /*!< int4 Symmetric Block Quantization, 64 values per block*/
BlkQ4Sym128 = 4, /*!< int4 Symmetric Block Quantization, 128 values per block*/
} MLAS_BLK_QUANT_TYPE;

/**
Expand Down Expand Up @@ -323,7 +323,10 @@ MlasDequantizeBlockwise(ElementT* dst,
int columns,
MLAS_THREADPOOL* thread_pool);

#ifdef MLAS_JBLAS
bool MLASCALL
MlasNBitsGemmPackBSupport(
size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_COMPUTE_TYPE CompType);

/**
* @brief Computes the number of bytes required to pack and int4-quantize
* a weight matrix
Expand All @@ -333,8 +336,8 @@ MlasDequantizeBlockwise(ElementT* dst,
* @return size of the packing buffer, 0 if the operation is not yet supported.
*/
size_t MLASCALL
MlasJblasQ4GemmPackBSize(
size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_COMPUTE_TYPE CompType);
MlasNBitsGemmPackBSize(
size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_COMPUTE_TYPE CompType);

/**
* @brief Prepack and Quantize fp32 weight tensor to int4 blocks
Expand All @@ -347,40 +350,19 @@ MlasJblasQ4GemmPackBSize(
* @param ldb leading dimension of B
*/
void MLASCALL
MlasJblasQ4GemmPackB(void* PackedBuf,
const float* FpData,
size_t N,
size_t K,
size_t ldb,
size_t BlkSize,
bool isAsym,
MLAS_COMPUTE_TYPE CompType,
MLAS_THREADPOOL* ThreadPool);

/**
* @brief Prepack and Quantize fp32 weight tensor to int4 blocks
*
* @param QType type of block quantization
* @param PackedBuf destination buffer
* @param FpData the pointer to fp32 matrix
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param ldb leading dimension of B
*/
void MLASCALL
MlasJblasNBitsGemmPackB(void* PackedBuf,
const uint8_t* QData,
const float* Scale,
const uint8_t* Zp,
size_t N,
size_t K,
size_t ldb,
size_t BlkSize,
bool isAsym,
bool lastCall,
MLAS_COMPUTE_TYPE CompType,
MLAS_THREADPOOL* ThreadPool);

MlasNBitsGemmPackB(void* PackedBuf,
const uint8_t* QData,
const float* Scale,
const uint8_t* Zp,
size_t N,
size_t K,
size_t ldb,
size_t BlkSize,
int nbits,
bool isAsym,
bool lastCall,
MLAS_COMPUTE_TYPE CompType,
MLAS_THREADPOOL* ThreadPool);
/**
* @brief Unpack and dequantize from int4 to fp32, reverse operation of
* MlasQ4GemmPackB
Expand All @@ -392,12 +374,12 @@ MlasJblasNBitsGemmPackB(void* PackedBuf,
* @param ldb leading dimension of B
*/
void MLASCALL
MlasJblasQ4GemmUnPackB(float* FpData,
const void* PackedBuf,
size_t N,
size_t K,
size_t ldb,
MLAS_THREADPOOL* ThreadPool);
MlasNBitsGemmUnPackB(float* FpData,
const void* PackedBuf,
size_t N,
size_t K,
size_t ldb,
MLAS_THREADPOOL* ThreadPool);

/**
* @brief Calculate the buffer size needed for int8 block quantize
Expand All @@ -407,13 +389,11 @@ MlasJblasQ4GemmUnPackB(float* FpData,
* @return buffer size (in bytes) needed, 0 if not yet supported on current
* hardware
*/

void MLASCALL
MlasJblasQ4GemmBatch(const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_Q4_GEMM_DATA_PARAMS* DataParams,
int8_t* WorkSpace,
MLAS_THREADPOOL* ThreadPool = nullptr);
#endif
MlasNBitsGemmBatch(const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_Q4_GEMM_DATA_PARAMS* DataParams,
int8_t* WorkSpace,
MLAS_THREADPOOL* ThreadPool = nullptr);
Loading

0 comments on commit ca89422

Please sign in to comment.