diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 20c26458a27dd..0ba4694c329e3 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -88,7 +88,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/fp16_neon_common.cpp ) set(mlas_platform_preprocess_srcs @@ -383,7 +383,7 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/fp16_neon_common.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") @@ -393,7 +393,7 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/fp16_neon_common.cpp similarity index 79% rename from onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp16.cpp rename to onnxruntime/core/mlas/lib/fp16_neon_common.cpp index 89f08a4cf03c5..bd345bff187fe 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/fp16_neon_common.cpp @@ -6,13 +6,11 @@ Licensed under the MIT License. Module Name: - sqnbitgemm_kernel_neon_fp16.cpp + fp16_neon_common.cpp Abstract: - This module implements the float/quantized n-bit integer matrix - multiplication kernels for ARM NEON specific to - input type T1 as float16. + This module implements the common kernels for ARM NEON specific to float16. --*/ @@ -64,11 +62,23 @@ MlasCastF16ToF32KernelNeon(const unsigned short* src, float* dest, size_t count) } // aligned src - for (; i + 3 < count; i += 4) + for (; i + 7 < count; i += 8) { float16x4_t fp16v4_0 = vreinterpret_f16_u16(vld1_u16(src + i)); float32x4_t fp32v4_0 = vcvt_f32_f16(fp16v4_0); vst1q_f32(dest + i, fp32v4_0); + + float16x4_t fp16v4_1 = vreinterpret_f16_u16(vld1_u16(src + i + 4)); + float32x4_t fp32v4_1 = vcvt_f32_f16(fp16v4_1); + vst1q_f32(dest + i + 4, fp32v4_1); + } + + if (i + 3 < count) + { + float16x4_t fp16v4_0 = vreinterpret_f16_u16(vld1_u16(src + i)); + float32x4_t fp32v4_0 = vcvt_f32_f16(fp16v4_0); + vst1q_f32(dest + i, fp32v4_0); + i += 4; } // Handle trailing unaligned src @@ -124,11 +134,23 @@ MlasCastF32ToF16KernelNeon(const float* src, unsigned short* dest, size_t count) } // aligned src - for (; i + 3 < count; i += 4) + for (; i + 7 < count; i += 8) + { + float32x4_t fp32v4_0 = vld1q_f32(src + i); + float16x4_t fp16v4_0 = vcvt_f16_f32(fp32v4_0); + vst1_u16(dest + i, vreinterpret_u16_f16(fp16v4_0)); + + float32x4_t fp32v4_1 = vld1q_f32(src + i + 4); + float16x4_t fp16v4_1 = vcvt_f16_f32(fp32v4_1); + vst1_u16(dest + i + 4, vreinterpret_u16_f16(fp16v4_1)); + } + + if (i + 3 < count) { float32x4_t fp32v4_0 = vld1q_f32(src + i); float16x4_t fp16v4_0 = vcvt_f16_f32(fp32v4_0); vst1_u16(dest + i, vreinterpret_u16_f16(fp16v4_0)); + i += 4; } // Handle trailing unaligned src diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm_neon_fp16.cpp b/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp similarity index 98% rename from onnxruntime/test/mlas/bench/bench_sqnbitgemm_neon_fp16.cpp rename to onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp index 59dcc457c4417..1dccbe44aafaf 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm_neon_fp16.cpp +++ b/onnxruntime/test/mlas/bench/bench_fp16_neon_common.cpp @@ -41,7 +41,7 @@ BENCHMARK(BM_ConvertF16ToF32) ->UseRealTime() ->Apply([](benchmark::internal::Benchmark* b) { b->ArgNames({"aligned"}); - b->Argsproduct({{0, 1}}); + b->ArgsProduct({{0, 1}}); }); BENCHMARK(BM_ConvertF32ToF16)