Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Sep 24, 2024
1 parent 66323eb commit a1b86ac
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
6 changes: 3 additions & 3 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/fp16_neon_common.cpp
)

set(mlas_platform_preprocess_srcs
Expand Down Expand Up @@ -383,7 +383,7 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/fp16_neon_common.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
Expand All @@ -393,7 +393,7 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@ Licensed under the MIT License.
Module Name:
sqnbitgemm_kernel_neon_fp16.cpp
fp16_neon_common.cpp
Abstract:
This module implements the float/quantized n-bit integer matrix
multiplication kernels for ARM NEON specific to
input type T1 as float16.
This module implements the common kernels for ARM NEON specific to float16.
--*/

Expand Down Expand Up @@ -64,11 +62,23 @@ MlasCastF16ToF32KernelNeon(const unsigned short* src, float* dest, size_t count)
}

// aligned src
for (; i + 3 < count; i += 4)
for (; i + 7 < count; i += 8)
{
float16x4_t fp16v4_0 = vreinterpret_f16_u16(vld1_u16(src + i));
float32x4_t fp32v4_0 = vcvt_f32_f16(fp16v4_0);
vst1q_f32(dest + i, fp32v4_0);

float16x4_t fp16v4_1 = vreinterpret_f16_u16(vld1_u16(src + i + 4));
float32x4_t fp32v4_1 = vcvt_f32_f16(fp16v4_1);
vst1q_f32(dest + i + 4, fp32v4_1);
}

if (i + 3 < count)
{
float16x4_t fp16v4_0 = vreinterpret_f16_u16(vld1_u16(src + i));
float32x4_t fp32v4_0 = vcvt_f32_f16(fp16v4_0);
vst1q_f32(dest + i, fp32v4_0);
i += 4;
}

// Handle trailing unaligned src
Expand Down Expand Up @@ -124,11 +134,23 @@ MlasCastF32ToF16KernelNeon(const float* src, unsigned short* dest, size_t count)
}

// aligned src
for (; i + 3 < count; i += 4)
for (; i + 7 < count; i += 8)
{
float32x4_t fp32v4_0 = vld1q_f32(src + i);
float16x4_t fp16v4_0 = vcvt_f16_f32(fp32v4_0);
vst1_u16(dest + i, vreinterpret_u16_f16(fp16v4_0));

float32x4_t fp32v4_1 = vld1q_f32(src + i + 4);
float16x4_t fp16v4_1 = vcvt_f16_f32(fp32v4_1);
vst1_u16(dest + i + 4, vreinterpret_u16_f16(fp16v4_1));
}

if (i + 3 < count)
{
float32x4_t fp32v4_0 = vld1q_f32(src + i);
float16x4_t fp16v4_0 = vcvt_f16_f32(fp32v4_0);
vst1_u16(dest + i, vreinterpret_u16_f16(fp16v4_0));
i += 4;
}

// Handle trailing unaligned src
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ BENCHMARK(BM_ConvertF16ToF32)
->UseRealTime()
->Apply([](benchmark::internal::Benchmark* b) {
b->ArgNames({"aligned"});
b->Argsproduct({{0, 1}});
b->ArgsProduct({{0, 1}});
});

BENCHMARK(BM_ConvertF32ToF16)
Expand Down

0 comments on commit a1b86ac

Please sign in to comment.