Skip to content

Commit

Permalink
[ARM] MatMulNBits FP16 support - kernels only (microsoft#22806)
Browse files Browse the repository at this point in the history
### Description
A break down PR of microsoft#22651
Add fp16 kernels.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
fajin-corp authored and ankitm3k committed Dec 11, 2024
1 parent de64e53 commit 5bf2acb
Show file tree
Hide file tree
Showing 3 changed files with 918 additions and 0 deletions.
3 changes: 3 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/fp16_neon_common.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
)

set(mlas_platform_preprocess_srcs
Expand Down Expand Up @@ -384,6 +385,7 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/fp16_neon_common.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
Expand All @@ -394,6 +396,7 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/core/mlas/lib/fp16_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); }

template <int lane>
MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasLoadLaneFloat16x4(const _mlas_fp16_* Buffer, MLAS_FLOAT16X4 vec) {
return vreinterpret_f16_u16(
vld1_lane_u16(Buffer, vreinterpret_u16_f16(vec), lane)
);
}

MLAS_FORCEINLINE
MLAS_FLOAT16X4
MlasLoadPartialFloat16x4(const _mlas_fp16_* Buffer, size_t len)
Expand Down Expand Up @@ -95,6 +104,14 @@ MlasStoreFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector)
vst1_u16(Buffer, vreinterpret_u16_f16(Vector));
}

template <int lane>
MLAS_FORCEINLINE
void
MlasStoreLaneFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector)
{
vst1_lane_u16(Buffer, vreinterpret_u16_f16(Vector), lane);
}

MLAS_FORCEINLINE
void
MlasStorePartialFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector, size_t len)
Expand Down
Loading

0 comments on commit 5bf2acb

Please sign in to comment.