Skip to content

Commit

Permalink
adding kernels
Browse files Browse the repository at this point in the history
fajin-corp committed Dec 14, 2024
1 parent aeef8c1 commit 9557efc
Showing 7 changed files with 476 additions and 54 deletions.
100 changes: 100 additions & 0 deletions onnxruntime/core/mlas/lib/fp16_common.h
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@ Module Name:

#pragma once

#include <arm_neon.h>
#include "mlas_float16.h"
#include "mlasi.h"

@@ -349,4 +350,103 @@ MlasBitwiseSelectFloat16x4(MLAS_UINT16X4 select, MLAS_FLOAT16X4 ones, MLAS_FLOAT
return vbsl_f16(select, ones, zeros);
}

MLAS_FORCEINLINE
void
Transpose8x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3,
MLAS_FLOAT16X8& v4, MLAS_FLOAT16X8& v5, MLAS_FLOAT16X8& v6, MLAS_FLOAT16X8& v7)
{
// |v00|v01|v02|v03|v04|v05|v06|v07|
// |v10|v11|v12|v13|v14|v15|v16|v17|
// |v20|v21|v22|v23|v24|v25|v26|v27|
// |v30|v31|v32|v33|v34|v35|v36|v37|
// |v40|v41|v42|v43|v44|v45|v46|v47|
// |v50|v51|v52|v53|v54|v55|v56|v57|
// |v60|v61|v62|v63|v64|v65|v66|v67|
// |v70|v71|v72|v73|v74|v75|v76|v77|
float16x8x2_t t01 = vtrnq_f16(v0, v1);
float16x8x2_t t23 = vtrnq_f16(v2, v3);
float16x8x2_t t45 = vtrnq_f16(v4, v5);
float16x8x2_t t67 = vtrnq_f16(v6, v7);
// |v00|v10|v02|v12|v04|v14|v06|v16|
// |v01|v11|v03|v13|v05|v15|v07|v17|
// |v20|v30|v22|v32|v24|v34|v26|v36|
// |v21|v31|v23|v33|v25|v35|v27|v37|
// |v40|v50|v42|v52|v44|v54|v46|v56|
// |v41|v51|v43|v53|v45|v55|v47|v57|
// |v60|v70|v62|v72|v64|v74|v66|v76|
// |v61|v71|v63|v73|v65|v75|v67|v77|
float32x4x2_t t02 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]));
float32x4x2_t t13 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]));
float32x4x2_t t46 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[0]), vreinterpretq_f32_f16(t67.val[0]));
float32x4x2_t t57 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[1]), vreinterpretq_f32_f16(t67.val[1]));
// |v00|v10|v20|v30|v04|v14|v24|v34|
// |v01|v11|v21|v31|v05|v15|v25|v35|
// |v02|v12|v22|v32|v06|v16|v26|v36|
// |v03|v13|v23|v33|v07|v17|v27|v37|
// |v40|v50|v60|v70|v44|v54|v64|v74|
// |v41|v51|v61|v71|v45|v55|v65|v75|
// |v42|v52|v62|v72|v46|v56|v66|v76|
// |v43|v53|v63|v73|v47|v57|v67|v77|
v0 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0])));
v4 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0])));
v2 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1])));
v6 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1])));
v1 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0])));
v5 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0])));
v3 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1])));
v7 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1])));
// |v00|v10|v20|v30|v40|v50|v60|v70|
// |v01|v11|v21|v31|v41|v51|v61|v71|
// |v02|v12|v22|v32|v42|v52|v62|v72|
// |v03|v13|v23|v33|v43|v53|v63|v73|
// |v04|v14|v24|v34|v44|v54|v64|v74|
// |v05|v15|v25|v35|v45|v55|v65|v75|
// |v06|v16|v26|v36|v46|v56|v66|v76|
// |v07|v17|v27|v37|v47|v57|v67|v77|
}

MLAS_FORCEINLINE
void
Transpose4x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3)
{
// |v00|v01|v02|v03|v04|v05|v06|v07|
// |v10|v11|v12|v13|v14|v15|v16|v17|
// |v20|v21|v22|v23|v24|v25|v26|v27|
// |v30|v31|v32|v33|v34|v35|v36|v37|
// =>
// |v00|v10|v20|v30|v04|v14|v24|v34|
// |v01|v11|v21|v31|v05|v15|v25|v35|
// |v02|v12|v22|v32|v06|v16|v26|v36|
// |v03|v13|v23|v33|v07|v17|v27|v37|
float16x8x2_t t01 = vtrnq_f16(v0, v1);
float16x8x2_t t23 = vtrnq_f16(v2, v3);

v0 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0])));
v2 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0])));
v1 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1])));
v3 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1])));
}

MLAS_FORCEINLINE
void
Transpose4x4(MLAS_FLOAT16X4& v0, MLAS_FLOAT16X4& v1, MLAS_FLOAT16X4& v2, MLAS_FLOAT16X4& v3)
{
// |v00|v01|v02|v03|
// |v10|v11|v12|v13|
// |v20|v21|v22|v23|
// |v30|v31|v32|v33|
// =>
// |v00|v10|v20|v30|
// |v01|v11|v21|v31|
// |v02|v12|v22|v32|
// |v03|v13|v23|v33|
float16x4x2_t t01 = vtrn_f16(v0, v1);
float16x4x2_t t23 = vtrn_f16(v2, v3);

v0 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0])));
v1 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1])));
v2 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0])));
v3 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1])));
}

#endif // fp16 vector intrinsic supported
70 changes: 51 additions & 19 deletions onnxruntime/core/mlas/lib/halfgemm.cpp
Original file line number Diff line number Diff line change
@@ -332,7 +332,10 @@ MlasHGemmSupported(
) {
auto* dispatch = GetMlasPlatform().HGemmDispatch;
if (TransA == CblasNoTrans && TransB == CblasTrans) {
return dispatch && dispatch->HGemmKernel_TransposeB;
return dispatch &&
dispatch->HGemmKernel_TransposeB &&
dispatch->HTransposePackB &&
dispatch->HGemmKernel_TransposePackB;
}

return false;
@@ -342,7 +345,7 @@ void
HGemmOperation(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t K,
size_t K, // full K slice
const MLAS_HGEMM_DATA_PARAMS* DataParams,
const size_t RangeStartM,
const size_t RangeCountM,
@@ -354,31 +357,60 @@ HGemmOperation(
const size_t ldc = DataParams->ldc;
const MLAS_FP16 alpha = DataParams->alpha;
const MLAS_FP16 beta = DataParams->beta;
constexpr size_t StrideM = 2;
constexpr size_t StrideN = 32;
auto* dispatch = GetMlasPlatform().HGemmDispatch;
constexpr size_t StrideM = 2;
const auto beta_add = MLAS_FP16(1.0f);
constexpr size_t buffer_size = MLAS_HGEMM_STRIDEN * MLAS_HGEMM_STRIDEK;
MLAS_DECLSPEC_ALIGN(MLAS_FP16 PackedB[buffer_size], 16 * sizeof(_mlas_fp16_));

if (TransA == CblasNoTrans && TransB == CblasTrans) {
if (!dispatch || !dispatch->HGemmKernel_TransposeB) {
MLAS_THROW_EX(std::runtime_error, "hgemm does not have A x Transpoe(B) kernel");
}

const auto* A = DataParams->A + RangeStartM * lda;
const auto* B = DataParams->B + RangeStartN * ldb;
auto* C = DataParams->C + RangeStartM * ldc + RangeStartN;

for (size_t n = 0, countN; n < RangeCountN; n += countN) {
countN = std::min(StrideN, RangeCountN - n);
const MLAS_FP16* a = A;
MLAS_FP16* c = C;
for (size_t m = 0, countM; m < RangeCountM; m += countM) {
countM = std::min(StrideM, RangeCountM - m);
dispatch->HGemmKernel_TransposeB(a, B, c, countM, countN, K, lda, ldb, ldc, alpha, beta);
a += countM * lda;
c += countM * ldc;
if (RangeCountM <= StrideM) {
if (!dispatch || !dispatch->HGemmKernel_TransposeB) {
MLAS_THROW_EX(std::runtime_error, "hgemm does not have A x Transpoe(B) kernels");
}
// Without PackB, to utilize memory locality, iterate full K.
const size_t StrideN = 16;
for (size_t n = 0, countN; n < RangeCountN; n += countN) {
countN = std::min(StrideN, RangeCountN - n);
dispatch->HGemmKernel_TransposeB(A, B, C, RangeCountM, countN, K, lda, ldb, ldc, alpha, beta);
B += countN * ldb;
C += countN;
}
} else {
if (!dispatch || !dispatch->HTransposePackB || !dispatch->HGemmKernel_TransposePackB) {
MLAS_THROW_EX(std::runtime_error, "hgemm does not have A x Transpoe(B) kernels");
}
// 16N is the smallest pack unit.
const size_t StrideK = std::min(K, size_t(MLAS_HGEMM_STRIDEK));
const size_t StrideN = buffer_size/StrideK & (~15); // >= MLAS_HGEMM_STRIDEN
for (size_t n = 0, countN; n < RangeCountN; n += countN) {
countN = std::min(StrideN, RangeCountN - n);
const MLAS_FP16* a = A;
const MLAS_FP16* b = B;
MLAS_FP16* c = C;
for (size_t k = 0, countK; k < K; k += countK) {
countK = std::min(StrideK, K - k);
dispatch->HTransposePackB(b, PackedB, countN, countK, ldb);
const MLAS_FP16* aa = a;
MLAS_FP16* cc = c;
for (size_t m = 0, countM; m < RangeCountM; m += countM) {
countM = std::min(StrideM, RangeCountM - m);
// First K iteration, beta is applied to the whole C. In rest K iterations, use add mode.
dispatch->HGemmKernel_TransposePackB(
aa, PackedB, cc, countM, countN, countK, lda, ldc, alpha, k == 0 ? beta : beta_add);
aa += countM * lda;
cc += countM * ldc;
}
a += countK;
b += countK;
}
B += countN * ldb;
C += countN;
}
B += countN * ldb;
C += countN;
}
} else {
MLAS_THROW_EX(std::runtime_error, "hgemm currently only support A x Transpoe(B)");
74 changes: 72 additions & 2 deletions onnxruntime/core/mlas/lib/halfgemm.h
Original file line number Diff line number Diff line change
@@ -516,6 +516,14 @@ MlasHalfGemmGetDispatch()

namespace hgemm_neon {

void HTransposePackB_Kernel(
const MLAS_FP16* B,
MLAS_FP16* PackedB,
size_t CountN,
size_t CountK,
size_t ldb
);

void HGemm_TransposeB_Kernel(
const MLAS_FP16* A,
const MLAS_FP16* B,
@@ -530,11 +538,43 @@ void HGemm_TransposeB_Kernel(
MLAS_FP16 beta
);

void HGemm_TransposePackB_Kernel(
const MLAS_FP16* A,
const MLAS_FP16* PackedB,
MLAS_FP16* C,
size_t CountM,
size_t CountN,
size_t CountK,
size_t lda,
size_t ldc,
MLAS_FP16 alpha,
MLAS_FP16 beta
);

} // namespace hgemm_neon

struct MLAS_HGEMM_DISPATCH {
/**
* @brief C = alpha * A * Transpose(B) + beta * C
/**
* @brief Transpose and pack the B matrix segment. Elements from the same row are packed continuously.
* First pack CountK rows x 16 columns, then pack CountK rows x 8 columns.
* If there are < 8 columns left, pad the columns with 0.
* @param B the first element of the B matrix segment. Column major.
* @param[out] PackedB the first element of the packed B matrix segment.
* @param CountN the number of columns of B chunk.
* @param CountK the number of rows of B chunk.
*/
typedef void(HTransposePackBKernel_Fn) (
const MLAS_FP16* B,
MLAS_FP16* PackedB,
size_t CountN,
size_t CountK,
size_t ldb
);

HTransposePackBKernel_Fn* HTransposePackB = nullptr;

/**
* @brief C = alpha * A * Transpose(B) + beta * C. CountM <= 2. B is not packed. Used when M is small.
*
* @param A first row of the A matrix segment. Row major.
* @param B first column of the B matrix segment. Column major.
@@ -563,4 +603,34 @@ struct MLAS_HGEMM_DISPATCH {
);

HGemmKernel_TransposeB_Fn* HGemmKernel_TransposeB = nullptr;

/**
* @brief C = alpha * A * Transpose(B) + beta * C. CountM <= 2. B is packed using HTransposePackBKernel_Fn.
* Used when M is large.
*
* @param A first row of the A matrix segment. Row major.
* @param PackedB first element of the packed B buffer.
* @param[out] C first element of the output matrix segment. Row major.
* @param CountM the number of rows of A chunk.
* @param CountN the number of columns of B chunk.
* @param CountK the number of columns of A chunk and the number of rows of B chunk.
* @param lda the leading dimension of A.
* @param ldc the leading dimension of C.
* @param alpha the alpha scalar value.
* @param beta the beta scalar value.
*/
typedef void(HGemmKernel_TransposePackB_Fn)(
const MLAS_FP16* A,
const MLAS_FP16* PackedB,
MLAS_FP16* C,
size_t CountM,
size_t CountN,
size_t CountK,
size_t lda,
size_t ldc,
MLAS_FP16 alpha,
MLAS_FP16 beta
);

HGemmKernel_TransposePackB_Fn* HGemmKernel_TransposePackB = nullptr;
};
2 changes: 2 additions & 0 deletions onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp
Original file line number Diff line number Diff line change
@@ -189,7 +189,9 @@ const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon = {
const MLAS_HGEMM_DISPATCH MlasHGemmDispatchNeon = [](){
MLAS_HGEMM_DISPATCH d;
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
d.HTransposePackB = hgemm_neon::HTransposePackB_Kernel;
d.HGemmKernel_TransposeB = hgemm_neon::HGemm_TransposeB_Kernel;
d.HGemmKernel_TransposePackB = hgemm_neon::HGemm_TransposePackB_Kernel;
#endif
return d;
}();
Loading

0 comments on commit 9557efc

Please sign in to comment.