Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARM] MatMulNBits fp16 support - connect kernels #22856

Merged
merged 5 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/fp16_neon_common.cpp
Expand Down Expand Up @@ -363,8 +363,8 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Module Name:

#include "fp16_common.h"
#include "qnbitgemm.h"
#include "sqnbitgemm_kernel_neon.h"
#include "qnbitgemm_kernel_neon.h"

namespace sqnbitgemm_neon
{
Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ Return Value:
this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon;
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;

//
// Check if the processor supports ASIMD dot product instructions.
Expand Down Expand Up @@ -560,9 +561,6 @@ Return Value:
this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot;
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot;

// MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
}

#if defined(__linux__)
Expand Down
94 changes: 91 additions & 3 deletions onnxruntime/core/mlas/lib/qnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ MlasIsQNBitGemmAvailable(
switch (Variant) {
case SQNBitGemmVariant_BitWidth4_CompFp32: {
return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr &&
Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr;
Dispatch->SQ4BitBlkDequantBForSgemm_CompFp32 != nullptr;
}
case HQNBitGemmVariant_BitWidth4_CompFp16: {
return Dispatch->HQ4BitGemmPackQuantBData != nullptr &&
Dispatch->HQ4BitGemmKernel_CompFp16 != nullptr &&
Dispatch->HQ4BitBlkDequantBForHgemm_CompFp16 != nullptr;
}
case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8
return
Expand Down Expand Up @@ -253,6 +258,16 @@ MlasQNBitGemmPackQuantBData(
packed_quant_b,
ThreadPool
);
} else if (ComputeType == HQNBIT_CompFp16 && Dispatch->HQ4BitGemmPackQuantBData != nullptr) {
Dispatch->HQ4BitGemmPackQuantBData(
N,
K,
BlkLen,
ComputeType,
static_cast<const std::byte*>(QuantBData),
static_cast<std::byte*>(PackedQuantBDataAndOrBlkSumWorkspace),
ThreadPool
);
} else if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) {
// TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests.
//assert(QuantBScale == nullptr);
Expand Down Expand Up @@ -387,7 +402,7 @@ SQ4BitGemm_CompFp32(
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;

GetMlasPlatform().QNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32(
GetMlasPlatform().QNBitGemmDispatch->SQ4BitBlkDequantBForSgemm_CompFp32(
BlkLen,
dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks
);
Expand Down Expand Up @@ -419,6 +434,79 @@ SQ4BitGemm_CompFp32(
}
}

void
HQ4BitGemm_CompFp16(
const size_t BlkLen,
const size_t K,
const MLAS_QNBIT_GEMM_DATA_PARAMS<MLAS_FP16>* const DataParams,
void* const PerGemmWorkspace,
const size_t RangeStartM,
const size_t RangeCountM,
const size_t RangeStartN,
const size_t RangeCountN
)
{
constexpr size_t BlkBitWidth = 4;
MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace);

const size_t lda = DataParams->lda;
const size_t ldc = DataParams->ldc;
const size_t k_blk_num = MlasDivRoundup(K, BlkLen);
const size_t qldb = k_blk_num * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
const size_t ldb = k_blk_num * BlkLen;
const size_t k_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes<BlkBitWidth>(k_blk_num);

const MLAS_FP16* A = DataParams->A + RangeStartM * lda;
MLAS_FP16* C = DataParams->C + RangeStartM * ldc + RangeStartN;
const std::byte* QuantBData = static_cast<const std::byte*>(DataParams->PackedQuantBData) + RangeStartN * qldb;
const MLAS_FP16* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blk_num;
const std::byte* QuantBZeroPoint =
(DataParams->QuantBZeroPoint == nullptr)
? nullptr
: static_cast<const std::byte*>(DataParams->QuantBZeroPoint) + RangeStartN * k_zp_bytes;
const MLAS_FP16* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias;

// 32N is the sweet spot of cache utilization. It is machine dependent though.
constexpr size_t StrideM = 2;
constexpr size_t StrideN = 32;

// TODO(fajin): move allocation up to the op.
size_t bufsize = ldb * StrideN * sizeof(MLAS_FP16);
MlasThreadedBufAlloc(bufsize);
auto* dequant_b = reinterpret_cast<MLAS_FP16*>(ThreadedBufHolder.get());

for (size_t n = 0, countN; n < RangeCountN; n += countN) {
countN = std::min(StrideN, RangeCountN - n);
GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16(
BlkLen, dequant_b, QuantBData, QuantBScale, QuantBZeroPoint, countN, K, k_blk_num
);

const MLAS_FP16* a = A;
MLAS_FP16* c = C;
for (size_t m = 0, countM; m < RangeCountM; m += countM) {
countM = std::min(StrideM, RangeCountM - m);
GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16(
a, dequant_b, Bias, c, countM, countN, K, lda, ldb, ldc
);

if (DataParams->PostProcessor != nullptr) {
DataParams->PostProcessor->Process(
DataParams->C, RangeStartM + m, RangeStartN + n, countM, countN, ldc
);
}

a += countM * lda;
c += countM * ldc;
}

QuantBData += countN * qldb;
QuantBScale += countN * k_blk_num;
QuantBZeroPoint = QuantBZeroPoint ? QuantBZeroPoint + countN * k_zp_bytes : nullptr;
Bias = Bias ? Bias + countN : nullptr;
C += countN;
}
}

void
SQ4BitGemm_CompInt8(
const size_t BlkLen,
Expand Down Expand Up @@ -720,7 +808,7 @@ GetQNBitGemm(QNBitGemmVariant variant)
{
switch (variant) {
case HQNBitGemmVariant_BitWidth4_CompFp16:
return nullptr;
return HQ4BitGemm_CompFp16;
default:
return nullptr;
}
Expand Down
81 changes: 72 additions & 9 deletions onnxruntime/core/mlas/lib/qnbitgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,17 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
//

/** Gets size of packed quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */
typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)(
typedef size_t(Q4BitGemmPackQuantBDataSize_Fn)(
size_t N,
size_t K,
size_t BlkLen,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

SQ4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr;
Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr;

/** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */
typedef void(SQ4BitGemmPackQuantBData_Fn)(
typedef void(Q4BitGemmPackQuantBData_Fn)(
size_t N,
size_t K,
size_t BlkLen,
Expand All @@ -111,7 +111,8 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
MLAS_THREADPOOL* ThreadPool
);

SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr;
Q4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr;
Q4BitGemmPackQuantBData_Fn* HQ4BitGemmPackQuantBData = nullptr;

typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)(
size_t N,
Expand Down Expand Up @@ -142,28 +143,28 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
typedef size_t(SQ4BitGemmPerGemmWorkspaceSize_Fn)(
typedef size_t(Q4BitGemmPerGemmWorkspaceSize_Fn)(
size_t M,
size_t N,
size_t K,
size_t BlkLen,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

SQ4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr;
Q4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr;

/**
* @brief Gets the required byte alignment of the per-GEMM intermediate workspace.
*
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
typedef size_t(SQ4BitGemmPerGemmWorkspaceAlignment_Fn)(
typedef size_t(Q4BitGemmPerGemmWorkspaceAlignment_Fn)(
size_t BlkLen,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

SQ4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr;
Q4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr;

//
// SQNBIT_CompFp32 kernel function prototypes.
Expand Down Expand Up @@ -229,7 +230,38 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
size_t BlockStrideQuantB
);

Q4BitBlkDequantBForSgemm_CompFp32_Fn* Q4BitBlkDequantBForSgemm_CompFp32 = nullptr;
Q4BitBlkDequantBForSgemm_CompFp32_Fn* SQ4BitBlkDequantBForSgemm_CompFp32 = nullptr;

/**
* @brief Dequantize B into the format expected by the Sgemm kernel.
* B is a quantized 4-bit integer matrix that is block quantized and column major.
* This is equivalent to dequantizing B and then running MlasSgemmCopyPackB.
*
* @param BlkLen Number of values in a block.
* @param[out] FpData Supplies the output buffer for the dequantized B float data.
* It should have enough space for
* (CountN + 16 - 1) / 16 * 16 * (CountK + BlkLen - 1) / BlkLen * BlkLen
* elements. Only the first (CountN + 16 - 1) / 16 * 16 * CountK elements are
* useful, but the kernel implementation can be simplified with the extra space.
* @param QuantBData Supplies the quantized B matrix block data.
* @param QuantBScale Supplies the quantized B matrix block scale values.
* @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional.
* @param CountN Number of columns of B.
* @param CountK Number of rows of B.
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
*/
typedef void(Q4BitBlkDequantBForSgemm_CompFp16_Fn)(
size_t BlkLen,
MLAS_FP16* FpData,
const std::byte* QuantBData,
const MLAS_FP16* QuantBScale,
const std::byte* QuantBZeroPoint,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB
);

Q4BitBlkDequantBForSgemm_CompFp16_Fn* HQ4BitBlkDequantBForHgemm_CompFp16 = nullptr;

//
// SQNBIT_CompInt8 kernel function prototypes.
Expand Down Expand Up @@ -338,4 +370,35 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
float* AScaledGroupSum // scale_k * Sum_blklen(a_i)
);
QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8 = nullptr;

/**
* @brief Multiply fp16 matrix A rows with fp16 matrix B columns.
* Results are written to fp16 matrix C.
* If bias is provided, the bias are added to the result.
*
* @param A first row of the A matrix segment. Row major.
* @param B first column of the B matrix segment. Column major.
* @param Bias the bias at the target column. Optional.
* @param[out] C first element of the output matrix segment. Row major.
* @param CountM the number of rows of A chunk.
* @param CountN the number of columns of B chunk.
* @param K the number of columns of A matrix and rows of B matrix.
* @param lda the leading dimension of A.
* @param ldb the leading dimension of B.
* @param ldc the leading dimension of C.
*/
typedef void(HQ4BitGemmKernel_CompFp16_Fn)(
const MLAS_FP16* A,
const MLAS_FP16* B,
const MLAS_FP16* Bias,
MLAS_FP16* C,
size_t CountM,
size_t CountN,
size_t K,
size_t lda,
size_t ldb,
size_t ldc
);

HQ4BitGemmKernel_CompFp16_Fn* HQ4BitGemmKernel_CompFp16 = nullptr;
};
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Licensed under the MIT License.

Module Name:

sqnbitgemm_kernel_neon.cpp
qnbitgemm_kernel_neon.cpp

Abstract:

Expand All @@ -20,7 +20,7 @@ Module Name:
#include <cassert>

#include "qnbitgemm.h"
#include "sqnbitgemm_kernel_neon.h"
#include "qnbitgemm_kernel_neon.h"
#include "sqnbitgemm_q8_block.h"

namespace sqnbitgemm_neon
Expand Down Expand Up @@ -185,10 +185,17 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() {
d.Q4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceAlignment;

d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32;
d.Q4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::Q4BitBlkDequantBForSgemm_CompFp32;

d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8;
d.SQ4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::SQ4BitBlkDequantBForSgemm_CompFp32;
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot()) {
d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8;
}
d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8;

#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
d.HQ4BitGemmPackQuantBData = sqnbitgemm_neon::HQ4BitGemmPackQuantBData_CompFp16;
d.HQ4BitBlkDequantBForHgemm_CompFp16 = sqnbitgemm_neon::HQ4BitBlkDequantBForHgemm_CompFp16;
d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16;
#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64

return d;
}();
Loading
Loading