Skip to content

Commit

Permalink
update jblas. add comp_dtype=fp32
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Oct 31, 2023
1 parent 53411e8 commit daac4dd
Show file tree
Hide file tree
Showing 29 changed files with 7,531 additions and 8,325 deletions.
2 changes: 1 addition & 1 deletion cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas)
function(add_jblas)
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas)
target_compile_definitions(onnxruntime_mlas PRIVATE MLAS_JBLAS)
target_compile_definitions(onnxruntime_mlas PUBLIC MLAS_JBLAS)
set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF)
endfunction()

Expand Down
95 changes: 75 additions & 20 deletions onnxruntime/core/mlas/inc/mlas_q4.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ typedef enum {
* @brief Define compute types of block quantization
*/
typedef enum {
CompFp32 = 0, /*!< int4 Symmetric Block Quantization, zero_point = 0 */
CompInt8 = 1, /*!< int4 Block Quantization, zero_point is int8 type */
} BLK_QUANT_COMPUTE_TYPE;
CompFp32 = 0, /*!< input fp32, accumulator fp32 */
CompInt8 = 1, /*!< input int8, accumulator int32 */
CompBf16 = 2, /*!< input bf16, accumulator fp32 */
CompFp16 = 3, /*!< input fp16, accumulator fp16 */
} MLAS_COMPUTE_TYPE;


/**
Expand Down Expand Up @@ -147,23 +149,6 @@ MlasQ4GemmBatch(MLAS_BLK_QUANT_TYPE QType,
const size_t BatchN,
const MLAS_Q4_GEMM_DATA_PARAMS* DataParams,
MLAS_THREADPOOL* ThreadPool = nullptr);
/**
* @brief Calculate the buffer size needed for int8 block quantize
* @param[in] QType Type of block quantization used
* @param[in] M Number of rows of the input matrix
* @param[in] K Number of columns of the input matrix
* @return buffer size (in bytes) needed, 0 if not yet supported on current hardware
*/

void MLASCALL
JblasQ4GemmBatch(BLK_QUANT_COMPUTE_TYPE CType,
MLAS_BLK_QUANT_TYPE QType,
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_Q4_GEMM_DATA_PARAMS* DataParams,
MLAS_THREADPOOL* ThreadPool = nullptr);

/**
* @brief Calculate the buffer size needed for int8 block quantize
Expand Down Expand Up @@ -235,3 +220,73 @@ MlasQ8Q4GemmBatch(MLAS_BLK_QUANT_TYPE QType,
const size_t BatchN,
const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams,
MLAS_THREADPOOL* ThreadPool);


#ifdef MLAS_JBLAS
/**
* @brief Computes the number of bytes required to pack and int4-quantize
* a weight matrix
* @param QType type of block quantization
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @return size of the packing buffer, 0 if the operation is not yet supported.
*/
size_t MLASCALL
MlasJblasQ4GemmPackBSize(
size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_COMPUTE_TYPE CompType);

/**
* @brief Prepack and Quantize fp32 weight tensor to int4 blocks
*
* @param QType type of block quantization
* @param PackedBuf destination buffer
* @param FpData the pointer to fp32 matrix
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param ldb leading dimension of B
*/
void MLASCALL
MlasJblasQ4GemmPackB(void* PackedBuf,
const float* FpData,
size_t N,
size_t K,
size_t ldb,
size_t BlkSize,
bool isAsym,
MLAS_COMPUTE_TYPE CompType,
MLAS_THREADPOOL* ThreadPool);

/**
* @brief Unpack and dequantize from int4 to fp32, reverse operation of
* MlasQ4GemmPackB
* @param QType type of block quantization
* @param FpData destination buffer, the fp32 matrix
* @param PackedBuf int4 quantized and packed data
* @param N the number of columns of matrix B.
* @param K the number of rows of matrix B.
* @param ldb leading dimension of B
*/
void MLASCALL
MlasJblasQ4GemmUnPackB(float* FpData,
const void* PackedBuf,
size_t N,
size_t K,
size_t ldb,
MLAS_THREADPOOL* ThreadPool);

/**
* @brief Calculate the buffer size needed for int8 block quantize
* @param[in] QType Type of block quantization used
* @param[in] M Number of rows of the input matrix
* @param[in] K Number of columns of the input matrix
* @return buffer size (in bytes) needed, 0 if not yet supported on current hardware
*/

void MLASCALL
MlasJblasQ4GemmBatch(const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_Q4_GEMM_DATA_PARAMS* DataParams,
MLAS_THREADPOOL* ThreadPool = nullptr);
#endif
168 changes: 107 additions & 61 deletions onnxruntime/core/mlas/lib/q4_dq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ Module Name:
language models.
--*/
#if !defined(__APPLE__)
#include "jblas/jit_blas_weight_compression.h"
#ifdef MLAS_JBLAS
#include "mlas_jblas_defs.h"
using namespace jblas;
#endif
#include "q4common.h"

Expand All @@ -30,30 +31,115 @@ BlkQ4BufSize(size_t N, size_t K)
return N * KBlocks * T::BlobSize;
}

#if !defined(__APPLE__)
template <class T, JBLAS_ISA ISA>
using WeiS4ClipFp32PerN =
jblas::prologue::weight_comp::gemm_kblcok::WeightS4ClipScaleFp32PerN<T, ISA>;
#ifdef MLAS_JBLAS
static size_t
JblasCompFp32Q4BuSize(int block_size, size_t N, size_t K, bool isAsym)
{
auto stor = JblasAvx512fS4Fp32Fp32.mProB.createStorage(
N, K, block_size, JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, JBLAS_DTYPE::F32, isAsym);
// TODO(Yu) support more S4 quant type, scale dtype
return stor.mSize;
}

template <template <class GC, JBLAS_ISA ISA> class ProB>
using AVX512VNNIPerNFp32Fp32 = jblas::wrapper::gemm_pack_weight::GemmInterfaceParallelAB<
jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight<
JblasAVX512_VNNI,
jblas::gemm::GemmCore_Row_NN_8x48_AVX512_VNNI,
jblas::prologue::gemm::ActivationFp32AsymU8Quantize,
ProB,
jblas::epilogue::gemm::ZpDequantInt32ToFp32>,
jblas::utils::parallel::Parallel2DGemm>;
static size_t
JblasCompInt8Q4BuSize(int block_size, size_t N, size_t K, bool isAsym)
{
auto stor = JblasAvx512VnniS4Fp32Fp32.mProB.createStorage(
N, K, block_size, JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, JBLAS_DTYPE::F32, isAsym);
// TODO(Yu) support more S4 quant type, scale dtype
return stor.mSize;
}
#endif

static AVX512VNNIPerNFp32Fp32<WeiS4ClipFp32PerN> avx512vnni_s4pernkernl;
#ifdef MLAS_JBLAS
size_t MLASCALL
MlasJblasQ4GemmPackBSize(
size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_COMPUTE_TYPE CompType)
{
switch (CompType) {
case CompInt8:
return JblasCompInt8Q4BuSize(int(BlkSize), N, K, isAsym);
case CompFp32:
return JblasCompFp32Q4BuSize(int(BlkSize), N, K, isAsym);
case CompBf16:
case CompFp16:
default:
break;
}
}

static size_t
JblasQ4BuSize(int block_size, size_t N, size_t K)
template <typename T>
void
JblaQ4GemmPackB(T& JblasKernel,
int BlkSize,
void* PackedBuf,
const float* FpData,
int N,
int K,
bool IsAsym,
int ldb,
MLAS_THREADPOOL* ThreadPool)
{
auto stor = JblasKernel.mProB.createStorage(N, K, BlkSize, JBLAS_DTYPE::S4_CLIP,
JBLAS_DTYPE::F32, JBLAS_DTYPE::F32, IsAsym);
stor.assign((int8_t*)PackedBuf);
ORTThreading orth(ThreadPool);
JblasKernel.mProB.packWeight(N, K, FpData, ldb, &stor, &orth);
}


void MLASCALL
MlasJblasQ4GemmPackB(void* PackedBuf,
const float* FpData,
size_t N,
size_t K,
size_t ldb,
size_t BlkSize,
bool isAsym,
MLAS_COMPUTE_TYPE CompType,
MLAS_THREADPOOL* ThreadPool)
{
switch (CompType) {
case CompInt8:
return JblaQ4GemmPackB(JblasAvxVnniS4Fp32Fp32, int(BlkSize), PackedBuf, FpData, int(N),
int(K), isAsym, int(ldb), ThreadPool);
case CompFp32:
return JblaQ4GemmPackB(JblasAvx512fS4Fp32Fp32, int(BlkSize), PackedBuf, FpData, int(N),
int(K), isAsym, int(ldb), ThreadPool);
case CompBf16:
case CompFp16:
default:
break;
}
}

void MLASCALL
MlasJblasQ4GemmUnPackB(float* FpData,
const void* PackedBuf,
size_t N,
size_t K,
size_t ldb,
MLAS_THREADPOOL* ThreadPool)
{
if (block_size == -1) {
return avx512vnni_s4pernkernl.getWeightPtr()->createStorage(N, K, false).mSize;
auto ptr =
jblas::storage::gemm::PackedWeightParser::deserialBuffer(const_cast<void*>(PackedBuf));
ORTThreading orth(ThreadPool);
if (ptr) {
if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) {
auto coretype = ptr->mCoreType;
auto NTile = uint32_t(coretype) & uint32_t(JBLAS_GEMM_CORE::NTILE_MASK);
auto CType = uint32_t(coretype) & uint32_t(JBLAS_GEMM_CORE::COMP_MASK);
if (NTile == 48 && CType == uint32_t(JBLAS_GEMM_CORE::COMP_FP32)) {
JblasAvx512fS4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb),
&orth);
}
if (NTile == 48 && CType == uint32_t(JBLAS_GEMM_CORE::COMP_INT8_US)) {
JblasAvx512VnniS4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb),
&orth);
}
}
delete ptr;
}
return 0;
}
#endif

Expand All @@ -72,9 +158,6 @@ MlasQ4GemmPackBSize(MLAS_BLK_QUANT_TYPE QType, size_t N, size_t K)
case BlkQ4Sym128:
return BlkQ4BufSize<MLAS_Q4TYPE_BLK4>(N, K);
case BlkQ4SymPerN:
#if !defined(__APPLE__)
return JblasQ4BuSize(-1, N, K);
#endif
default:
return BlkQ4BufSize<MLAS_Q4TYPE_BLK1>(N, K);
}
Expand Down Expand Up @@ -197,19 +280,6 @@ MlasQ4GemmPackBImpl<MLAS_Q4TYPE_BLK1>(
}
}

#if !defined(__APPLE__)
void
JblasQ4GemmPackB(
int block_size, void* PackedBuf, const float* FpData, size_t N, size_t K, size_t ldb)
{
if (block_size == -1) {
auto tmpstor = avx512vnni_s4pernkernl.getWeightPtr()->createStorage(N, K, false);
tmpstor.assign((int8_t*)PackedBuf);
avx512vnni_s4pernkernl.getWeightPtr()->packWeight(N, K, FpData, ldb, &tmpstor);
}
}
#endif

void MLASCALL
MlasQ4GemmPackB(
MLAS_BLK_QUANT_TYPE QType, void* PackedBuf, const float* FpData, size_t N, size_t K, size_t ldb)
Expand All @@ -222,9 +292,6 @@ MlasQ4GemmPackB(
case BlkQ4Sym128:
return MlasQ4GemmPackBImpl<MLAS_Q4TYPE_BLK4>(PackedBuf, FpData, N, K, ldb);
case BlkQ4SymPerN:
#if !defined(__APPLE__)
return JblasQ4GemmPackB(-1, PackedBuf, FpData, N, K, ldb);
#endif
default:
return MlasQ4GemmPackBImpl<MLAS_Q4TYPE_BLK1>(PackedBuf, FpData, N, K, ldb);
}
Expand Down Expand Up @@ -308,24 +375,6 @@ MlasQ4GemmUnPackBImpl<MLAS_Q4TYPE_BLK1>(
}
}

#if !defined(__APPLE__)
void
JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb)
{
auto ptr = jblas::prologue::weight_comp::gemm_kblcok::PackedWeightParser::deserialBuffer(
const_cast<void*>(PackedBuf));
if (ptr) {
if (ptr->mPrologueID == int(jblas::prologue::weight_comp::gemm_kblcok::PrologueBIDs::
WeightS4ClipScaleFp32PerChannelN)) {
if (ptr->mCoreType == jblas::gemm::GemmCoreType::AVX512_VNNI_8x48) {
avx512vnni_s4pernkernl.getWeightPtr()->unpackWeight(N, K, ptr, FpData, ldb);
}
}
delete ptr;
}
}
#endif

void MLASCALL
MlasQ4GemmUnPackB(
MLAS_BLK_QUANT_TYPE QType, float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb)
Expand All @@ -338,9 +387,6 @@ MlasQ4GemmUnPackB(
case BlkQ4Sym128:
return MlasQ4GemmUnPackBImpl<MLAS_Q4TYPE_BLK4>(FpData, PackedBuf, N, K, ldb);
case BlkQ4SymPerN:
#if !defined(__APPLE__)
return JblasQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb);
#endif
default:
return MlasQ4GemmUnPackBImpl<MLAS_Q4TYPE_BLK1>(FpData, PackedBuf, N, K, ldb);
}
Expand Down
Loading

0 comments on commit daac4dd

Please sign in to comment.