From c4f1dfcfaa1d68567606d2ad919330e2daecd742 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Fri, 28 Jan 2022 09:26:52 -0800 Subject: [PATCH] Cfu s8s8 (#10413) Adding S8S8 kernels for symmetric quantized indirect conv and depthwise conv. Perf number with single thread: Nokia G10 (baseline / new) in ms Pixel 4 (baseline/new) in ms mobilenet_edgetpu 220 / 213 18.5 / 17.6 cartoongan 8537 / 8521 967 / 928 Co-authored-by: Chen Fu --- cmake/onnxruntime_mlas.cmake | 10 +- .../mlas/lib/aarch64/ConvSymS8KernelDot.S | 603 +++++++++++++++ .../mlas/lib/aarch64/ConvSymS8KernelNeon.S | 441 +++++++++++ .../mlas/lib/aarch64/ConvSymU8KernelDot.S | 2 +- .../mlas/lib/aarch64/ConvSymU8KernelNeon.S | 4 +- .../aarch64/DepthwiseQConvSymS8KernelNeon.S | 692 +++++++++++++++++ ...Neon.S => DepthwiseQConvSymU8KernelNeon.S} | 4 +- .../mlas/lib/arm64/ConvSymS8KernelDot.asm | 605 +++++++++++++++ .../mlas/lib/arm64/ConvSymS8KernelNeon.asm | 423 +++++++++++ .../mlas/lib/arm64/ConvSymU8KernelDot.asm | 4 +- .../mlas/lib/arm64/ConvSymU8KernelNeon.asm | 6 +- .../arm64/DepthwiseQConvSymS8KernelNeon.asm | 693 ++++++++++++++++++ ....asm => DepthwiseQConvSymU8KernelNeon.asm} | 6 +- onnxruntime/core/mlas/lib/convsym.cpp | 172 ++++- onnxruntime/core/mlas/lib/mlasi.h | 19 +- onnxruntime/core/mlas/lib/platform.cpp | 6 +- .../core/mlas/lib/qdwconv_kernelsize.cpp | 224 ++---- .../providers/cpu/nn/qlinearconv_op_test.cc | 9 - 18 files changed, 3701 insertions(+), 222 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelDot.S create mode 100644 onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelNeon.S create mode 100644 onnxruntime/core/mlas/lib/aarch64/DepthwiseQConvSymS8KernelNeon.S rename onnxruntime/core/mlas/lib/aarch64/{DepthwiseConvSymKernelNeon.S => DepthwiseQConvSymU8KernelNeon.S} (99%) create mode 100644 onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelDot.asm create mode 100644 onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelNeon.asm create mode 100644 onnxruntime/core/mlas/lib/arm64/DepthwiseQConvSymS8KernelNeon.asm rename onnxruntime/core/mlas/lib/arm64/{DepthwiseConvSymKernelNeon.asm => DepthwiseQConvSymU8KernelNeon.asm} (99%) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index e552ff85b5020..cfb2b6c62f61a 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -47,9 +47,12 @@ function(setup_mlas_source_for_windows) ) set(mlas_platform_preprocess_srcs + ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDot.asm ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelDot.asm + ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelNeon.asm ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/DepthwiseConvsymKernelNeon.asm + ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymS8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymU8KernelNeon.asm ${MLAS_SRC_DIR}/arm64/DepthwiseQConvKernelSize9Neon.asm ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelNeon.asm ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelNeon.asm @@ -272,9 +275,12 @@ else() if(ARM64 AND MLAS_SOURCE_IS_NOT_SET ) enable_language(ASM) set(mlas_platform_srcs + ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDot.S ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelDot.S + ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/DepthwiseConvSymKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymS8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymU8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvKernelSize9Neon.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S diff --git a/onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelDot.S b/onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelDot.S new file mode 100644 index 0000000000000..1a0bd7731dae6 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelDot.S @@ -0,0 +1,603 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + ConvSymS8KernelDot.S + +Abstract: + + This module implements the kernels for the symmetric quantized integer + convolution operation. + +--*/ + +#include "asmmacro.h" +#include "AssembleDotProduct.h" + + .equ .LMLAS_CONV_SYM_FLAG_INPUT_DIRECT, 1 + .equ .LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE, 2 + +// +// Stack frame layout for the symmetric convolution kernel. +// d8-d15, x19-x30 need to be preserved if used +// + .equ .LConvSymFrame_SavedRegisters, (6 * 8) + .equ .LConvSymFrame_PostProcessParams, 0 + .LConvSymFrame_SavedRegisters + .equ .LConvSymFrame_KernelFlags, 8 + .LConvSymFrame_SavedRegisters + + .equ .LConvSymPostProcessParams_Bias, 0 + .equ .LConvSymPostProcessParams_Scale, 8 + .equ .LConvSymPostProcessParams_Min, 16 + .equ .LConvSymPostProcessParams_Max, 20 + .equ .LConvSymPostProcessParams_ZeroPoint, 24 + + .text + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + +Arguments: + + Input (x0) - Points to the input buffer. + + If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then the input buffer points + directly at the input tensor. + + If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is clear, then the input buffer is an + indirection buffer. Every pointer in the indirection buffer points at a + InputChannels length vector (either from the input tensor or a vector of + padding values). These are grouped in batches of length KernelSize. + These batches are then repeated OutputCount times. + + Filter (x1) - Points to the filter buffer. + + Output (x2) - Points the output buffer. + + KernelSize (x3/x9) - Size of the kernel (most commonly. 3x3=9, 5x5=25). + + If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then kernel size should be 1. + + InputChannels (x4/x7) - Number of input channels. + + OutputChannels (x5) - Number of output channels. + + ChannelCount (x6) - Number of output channels this iteration produces. + + OutputCount (x7) - Number of output elements this iteration produces. + + This implementation requires the count to be no larger than 4. + + PostProcessParams (x8) - Points to the post process parameter block. + + KernelFlags - (w10) Additional flags controlling the operation. + +Return Value: + + None. + +--*/ + FUNCTION_ENTRY MlasConvSymS8KernelDot + + stp d8,d9,[sp,#-.LConvSymFrame_SavedRegisters]! + ldr x8,[sp,#.LConvSymFrame_PostProcessParams] + ldr w10,[sp,#.LConvSymFrame_KernelFlags] + stp d10,d11,[sp,#16] + stp x19,x20,[sp,#32] + + cmp x7,2 // OutputCount < 2 ? + add x16,x2,x5 // x16 -> C1 + lsl x3,x3,#3 // KernelSize * sizeof(int8_t*) + csel x16,x2,x16,lo // if OutputCount < 2 x16/C1 -> C0 + mov x20,x4 + add x4,x4,3 // InputChannels align to 4 + add x17,x16,x5 // x17 -> C2 + ldr x11,[x8,#.LConvSymPostProcessParams_Bias] + csel x17,x16,x17,ls // if OutputCount <= 2 x17/C2 -> C1 + bic x4,x4,3 + cmp x7,4 // OutputCount < 4 ? + add x5,x17,x5 // x5 -> C3 + ldr x19,[x8,#.LConvSymPostProcessParams_Scale] + csel x5,x17,x5,lo // if OutputCount < 4 x5/C3 -> C2 + + // TODO!! tiptoe around loading biases if we need to support + // output channels none divisible by 16 +OutputChannelLoop: + ldp q16,q20,[x11],32 // Init accumulators with biases + mov v17.16b,v16.16b + mov v18.16b,v16.16b + ldp q24,q28,[x11],32 + mov v19.16b,v16.16b + mov v21.16b,v20.16b + mov v22.16b,v20.16b + mov v23.16b,v20.16b + mov v25.16b,v24.16b + mov v26.16b,v24.16b + mov v27.16b,v24.16b + mov v29.16b,v28.16b + mov v30.16b,v28.16b + mov v31.16b,v28.16b + mov x9,x3 // restore KernelSize * sizeof(int8_t*) + +KernelSizeLoop: + tst w10,#.LMLAS_CONV_SYM_FLAG_INPUT_DIRECT + beq InputIndirection + +InputDirect: + cmp x16,x2 + mov x12,x0 // x12 -> A0 + add x13,x0,x20 // x13 -> A1 = A0 + input channels + csel x13,x0,x13,eq + cmp x17,x16 + add x14,x0,x20,lsl#1 // x14 -> A2 + csel x14,x13,x14,eq + cmp x5,x17 + add x15,x13,x20,lsl#1 // x15 -> A3 + csel x15,x14,x15,eq + b FinishLoadAPtr + +InputIndirection: + ldr x12,[x0] // x12 -> A0 + cmp x16,x2 + b.eq SkipLoadA1 // C1==C0 -> A0=A1=A2=A3 + cmp x17,x16 + lsl x14,x3,#1 + ldr x13,[x0,x3] // x13 -> A1 + b.eq SkipLoadA2 // C2==C1 -> A1=A2=A3 + cmp x5,x17 + add x15,x3,x3,lsl#1 + ldr x14,[x0,x14] // x14 -> A2 + b.eq SkipLoadA3 // C3==C2 -> A2=A3 + ldr x15,[x0,x15] // x15 -> A3 + b FinishLoadAPtr +SkipLoadA1: + mov x13,x12 +SkipLoadA2: + mov x14,x13 +SkipLoadA3: + mov x15,x14 + +// Register Usage +// B (x1) -> 4x16 +// ---------------------------------------------------------------------------- +// |v4.b[0]..v4.b[12] v5.b[0]..v5.b[12] v6.b[0]..v6.b[12] v7.b[0]..v7.b[12]| +// | ... ... ... ... ... ... ... ... | +// |v4.b[3]..v4.b[15] v5.b[3]..v5.b[15] v6.b[3]..v6.b[15] v7.b[3]..v7.b[15]| +// A 4x4 ---------------------------------------------------------------------------- +// ------------------ ---------------------------------------------------------------------------- +// x12 |v0.b[0]..v0.b[3]| |v16.s[0]_v16.s[3] v20.s[0]_v20.s[3] v24.s[0]_v24.s[3] v28.s[0]_v28.s[3]| x2 +// x13 |v1.b[0]..v1.b[3]| |v17.s[0]_v17.s[3] v21.s[0]_v21.s[3] v25.s[0]_v25.s[3] v29.s[0]_v29.s[3]| x16 +// x14 |v2.b[0]..v2.b[3]| |v18.s[0]_v18.s[3] v22.s[0]_v23.s[3] v26.s[0]_v26.s[3] v30.s[0]_v31.s[3]| x17 +// x15 |v3.b[0]..v3.b[3]| |v19.s[0]_v19.s[3] v23.s[0]_v23.s[3] v27.s[0]_v27.s[3] v31.s[0]_v31.s[3]| x5 +// ------------------ ---------------------------------------------------------------------------- + +FinishLoadAPtr: + subs x7,x4,16 // Need 16 input channels for loop + add x0,x0,8 // indirect A advance to next pointer, prepare for kernel size loop + b.lo InChannels8 + + ldr d0,[x12],8 + ldr q4,[x1],16 + ldr d1,[x13],8 + subs x7,x7,16 + ldr d2,[x14],8 + ldr d3,[x15],8 + ldr q5,[x1],16 + ldr q6,[x1],16 + ldr q7,[x1],16 + b.lo InChLoopEpilogue // Need 32 input channels for main loop + +InputChannelLoop: + SdotByElement 16, 4, 0,0 + SdotByElement 17, 4, 1,0 + ldr d8,[x12],8 + SdotByElement 18, 4, 2,0 + SdotByElement 19, 4, 3,0 + ldr q4,[x1],16 + SdotByElement 20, 5, 0,0 + SdotByElement 21, 5, 1,0 + ldr d9,[x13],8 + SdotByElement 22, 5, 2,0 + SdotByElement 23, 5, 3,0 + ldr q5,[x1],16 + SdotByElement 24, 6, 0,0 + SdotByElement 25, 6, 1,0 + ldr d10,[x14],8 + SdotByElement 26, 6, 2,0 + SdotByElement 27, 6, 3,0 + ldr q6,[x1],16 + SdotByElement 28, 7, 0,0 + SdotByElement 29, 7, 1,0 + ldr d11,[x15],8 + SdotByElement 30, 7, 2,0 + SdotByElement 31, 7, 3,0 + ldr q7,[x1],16 + SdotByElement 16, 4, 0,1 + SdotByElement 17, 4, 1,1 + SdotByElement 18, 4, 2,1 + SdotByElement 19, 4, 3,1 + ldr q4,[x1],16 + SdotByElement 20, 5, 0,1 + SdotByElement 21, 5, 1,1 + SdotByElement 22, 5, 2,1 + SdotByElement 23, 5, 3,1 + ldr q5,[x1],16 + SdotByElement 24, 6, 0,1 + SdotByElement 25, 6, 1,1 + SdotByElement 26, 6, 2,1 + SdotByElement 27, 6, 3,1 + ldr q6,[x1],16 + SdotByElement 28, 7, 0,1 + SdotByElement 29, 7, 1,1 + SdotByElement 30, 7, 2,1 + SdotByElement 31, 7, 3,1 + ldr q7,[x1],16 + SdotByElement 16, 4, 8,0 + SdotByElement 17, 4, 9,0 + ldr d0,[x12],8 + SdotByElement 18, 4,10,0 + SdotByElement 19, 4,11,0 + ldr q4,[x1],16 + SdotByElement 20, 5, 8,0 + SdotByElement 21, 5, 9,0 + ldr d1,[x13],8 + SdotByElement 22, 5,10,0 + SdotByElement 23, 5,11,0 + ldr q5,[x1],16 + SdotByElement 24, 6, 8,0 + SdotByElement 25, 6, 9,0 + ldr d2,[x14],8 + SdotByElement 26, 6,10,0 + SdotByElement 27, 6,11,0 + ldr q6,[x1],16 + SdotByElement 28, 7, 8,0 + SdotByElement 29, 7, 9,0 + ldr d3,[x15],8 + SdotByElement 30, 7,10,0 + SdotByElement 31, 7,11,0 + ldr q7,[x1],16 + SdotByElement 16, 4, 8,1 + SdotByElement 17, 4, 9,1 + SdotByElement 18, 4,10,1 + SdotByElement 19, 4,11,1 + ldr q4,[x1],16 + SdotByElement 20, 5, 8,1 + SdotByElement 21, 5, 9,1 + SdotByElement 22, 5,10,1 + SdotByElement 23, 5,11,1 + ldr q5,[x1],16 + SdotByElement 24, 6, 8,1 + SdotByElement 25, 6, 9,1 + SdotByElement 26, 6,10,1 + SdotByElement 27, 6,11,1 + ldr q6,[x1],16 + SdotByElement 28, 7, 8,1 + SdotByElement 29, 7, 9,1 + subs x7,x7,16 // InputChannels -= 16 + SdotByElement 30, 7,10,1 + SdotByElement 31, 7,11,1 + ldr q7,[x1],16 + b.hs InputChannelLoop + +InChLoopEpilogue: + SdotByElement 16, 4, 0,0 + SdotByElement 17, 4, 1,0 + ldr d8,[x12],8 + SdotByElement 18, 4, 2,0 + SdotByElement 19, 4, 3,0 + ldr q4,[x1],16 + SdotByElement 20, 5, 0,0 + SdotByElement 21, 5, 1,0 + ldr d9,[x13],8 + SdotByElement 22, 5, 2,0 + SdotByElement 23, 5, 3,0 + ldr q5,[x1],16 + SdotByElement 24, 6, 0,0 + SdotByElement 25, 6, 1,0 + ldr d10,[x14],8 + SdotByElement 26, 6, 2,0 + SdotByElement 27, 6, 3,0 + ldr q6,[x1],16 + SdotByElement 28, 7, 0,0 + SdotByElement 29, 7, 1,0 + ldr d11,[x15],8 + SdotByElement 30, 7, 2,0 + SdotByElement 31, 7, 3,0 + ldr q7,[x1],16 + SdotByElement 16, 4, 0,1 + SdotByElement 17, 4, 1,1 + SdotByElement 18, 4, 2,1 + SdotByElement 19, 4, 3,1 + ldr q4,[x1],16 + SdotByElement 20, 5, 0,1 + SdotByElement 21, 5, 1,1 + SdotByElement 22, 5, 2,1 + SdotByElement 23, 5, 3,1 + ldr q5,[x1],16 + SdotByElement 24, 6, 0,1 + SdotByElement 25, 6, 1,1 + SdotByElement 26, 6, 2,1 + SdotByElement 27, 6, 3,1 + ldr q6,[x1],16 + SdotByElement 28, 7, 0,1 + SdotByElement 29, 7, 1,1 + SdotByElement 30, 7, 2,1 + SdotByElement 31, 7, 3,1 + ldr q7,[x1],16 + SdotByElement 16, 4, 8,0 + SdotByElement 17, 4, 9,0 + SdotByElement 18, 4,10,0 + SdotByElement 19, 4,11,0 + ldr q4,[x1],16 + SdotByElement 20, 5, 8,0 + SdotByElement 21, 5, 9,0 + SdotByElement 22, 5,10,0 + SdotByElement 23, 5,11,0 + ldr q5,[x1],16 + SdotByElement 24, 6, 8,0 + SdotByElement 25, 6, 9,0 + SdotByElement 26, 6,10,0 + SdotByElement 27, 6,11,0 + ldr q6,[x1],16 + SdotByElement 28, 7, 8,0 + SdotByElement 29, 7, 9,0 + SdotByElement 30, 7,10,0 + SdotByElement 31, 7,11,0 + ldr q7,[x1],16 + SdotByElement 16, 4, 8,1 + SdotByElement 17, 4, 9,1 + SdotByElement 18, 4,10,1 + SdotByElement 19, 4,11,1 + SdotByElement 20, 5, 8,1 + SdotByElement 21, 5, 9,1 + SdotByElement 22, 5,10,1 + SdotByElement 23, 5,11,1 + SdotByElement 24, 6, 8,1 + SdotByElement 25, 6, 9,1 + SdotByElement 26, 6,10,1 + SdotByElement 27, 6,11,1 + SdotByElement 28, 7, 8,1 + SdotByElement 29, 7, 9,1 + SdotByElement 30, 7,10,1 + SdotByElement 31, 7,11,1 + + tst x7,15 + b.ne InChannels8 // 4 ~ 12 InputChannels + + subs x9,x9,8 // KernelSize-=1 + b.hi KernelSizeLoop + +Requantize: + tst w10,#.LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE + ldr w13,[x8,#.LConvSymPostProcessParams_ZeroPoint] + beq BroadcastScaleValue + ldp q0,q1,[x19],32 // load scale vector + ldp q2,q3,[x19],32 + b AccumulatorsToFloat + +BroadcastScaleValue: + ld1r {v0.4s},[x19] // load scale Value + mov v1.16b, v0.16b + mov v2.16b, v0.16b + mov v3.16b, v0.16b + +AccumulatorsToFloat: + scvtf v16.4s,v16.4s // convert to float + scvtf v17.4s,v17.4s + scvtf v18.4s,v18.4s + scvtf v19.4s,v19.4s + scvtf v20.4s,v20.4s + scvtf v21.4s,v21.4s + scvtf v22.4s,v22.4s + scvtf v23.4s,v23.4s + scvtf v24.4s,v24.4s + scvtf v25.4s,v25.4s + scvtf v26.4s,v26.4s + scvtf v27.4s,v27.4s + scvtf v28.4s,v28.4s + scvtf v29.4s,v29.4s + scvtf v30.4s,v30.4s + scvtf v31.4s,v31.4s + fmul v16.4s,v16.4s,v0.4s // multiply by scale + fmul v17.4s,v17.4s,v0.4s + fmul v18.4s,v18.4s,v0.4s + fmul v19.4s,v19.4s,v0.4s + fmul v20.4s,v20.4s,v1.4s + fmul v21.4s,v21.4s,v1.4s + fmul v22.4s,v22.4s,v1.4s + fmul v23.4s,v23.4s,v1.4s + fmul v24.4s,v24.4s,v2.4s + fmul v25.4s,v25.4s,v2.4s + fmul v26.4s,v26.4s,v2.4s + fmul v27.4s,v27.4s,v2.4s + fmul v28.4s,v28.4s,v3.4s + fmul v29.4s,v29.4s,v3.4s + fmul v30.4s,v30.4s,v3.4s + fmul v31.4s,v31.4s,v3.4s + fcvtns v16.4s,v16.4s // convert to int + fcvtns v17.4s,v17.4s + fcvtns v18.4s,v18.4s + fcvtns v19.4s,v19.4s + fcvtns v20.4s,v20.4s + fcvtns v21.4s,v21.4s + fcvtns v22.4s,v22.4s + fcvtns v23.4s,v23.4s + fcvtns v24.4s,v24.4s + fcvtns v25.4s,v25.4s + fcvtns v26.4s,v26.4s + fcvtns v27.4s,v27.4s + fcvtns v28.4s,v28.4s + fcvtns v29.4s,v29.4s + fcvtns v30.4s,v30.4s + fcvtns v31.4s,v31.4s + + sqxtn v16.4h,v16.4s + sqxtn v17.4h,v17.4s + sqxtn v18.4h,v18.4s + sqxtn v19.4h,v19.4s + sqxtn v24.4h,v24.4s + sqxtn v25.4h,v25.4s + sqxtn v26.4h,v26.4s + sqxtn v27.4h,v27.4s + dup v4.8h,w13 // zero point + sqxtn2 v16.8h,v20.4s + sqxtn2 v17.8h,v21.4s + sqxtn2 v18.8h,v22.4s + sqxtn2 v19.8h,v23.4s + sqxtn2 v24.8h,v28.4s + sqxtn2 v25.8h,v29.4s + sqxtn2 v26.8h,v30.4s + sqxtn2 v27.8h,v31.4s + sqadd v16.8h,v16.8h,v4.8h + sqadd v17.8h,v17.8h,v4.8h + sqadd v18.8h,v18.8h,v4.8h + sqadd v19.8h,v19.8h,v4.8h + sqadd v24.8h,v24.8h,v4.8h + sqadd v25.8h,v25.8h,v4.8h + sqadd v26.8h,v26.8h,v4.8h + sqadd v27.8h,v27.8h,v4.8h + sqxtn v0.8b,v16.8h + sqxtn v1.8b,v17.8h + sqxtn v2.8b,v18.8h + sqxtn v3.8b,v19.8h + sqxtn2 v0.16b,v24.8h + sqxtn2 v1.16b,v25.8h + subs x6,x6,16 // processed 16 output channels + sqxtn2 v2.16b,v26.8h + sqxtn2 v3.16b,v27.8h + b.lo PartialStore + + st1 {v3.16b},[x5],16 // Store full 4 x 16 + st1 {v2.16b},[x17],16 + sub x0,x0,x3 // Restore pointer to A: a -= ks + st1 {v1.16b},[x16],16 + st1 {v0.16b},[x2],16 + b.hi OutputChannelLoop + +ExitKernel: + ldp x19,x20,[sp,#32] + ldp d10,d11,[sp,#16] + ldp d8,d9,[sp],#.LConvSymFrame_SavedRegisters + ret + +InChannels8: + tbz x7,3,InChannels4 + ldr d0,[x12],8 + ldr q4,[x1],16 + ldr d1,[x13],8 + ldr d2,[x14],8 + ldr d3,[x15],8 + ldr q5,[x1],16 + SdotByElement 16, 4, 0,0 + SdotByElement 17, 4, 1,0 + ldp q6, q7, [x1], 32 + SdotByElement 18, 4, 2,0 + SdotByElement 19, 4, 3,0 + SdotByElement 20, 5, 0,0 + SdotByElement 21, 5, 1,0 + SdotByElement 22, 5, 2,0 + SdotByElement 23, 5, 3,0 + SdotByElement 24, 6, 0,0 + SdotByElement 25, 6, 1,0 + ldp q4, q5, [x1], 32 + SdotByElement 26, 6, 2,0 + SdotByElement 27, 6, 3,0 + SdotByElement 28, 7, 0,0 + SdotByElement 29, 7, 1,0 + SdotByElement 30, 7, 2,0 + SdotByElement 31, 7, 3,0 + SdotByElement 16, 4, 0,1 + SdotByElement 17, 4, 1,1 + ldp q6, q7, [x1], 32 + SdotByElement 18, 4, 2,1 + SdotByElement 19, 4, 3,1 + SdotByElement 20, 5, 0,1 + SdotByElement 21, 5, 1,1 + SdotByElement 22, 5, 2,1 + SdotByElement 23, 5, 3,1 + SdotByElement 24, 6, 0,1 + SdotByElement 25, 6, 1,1 + SdotByElement 26, 6, 2,1 + SdotByElement 27, 6, 3,1 + SdotByElement 28, 7, 0,1 + SdotByElement 29, 7, 1,1 + SdotByElement 30, 7, 2,1 + SdotByElement 31, 7, 3,1 + tbz x7,2,SkipInCh4 + +InChannels4: + ldr s0,[x12],4 + ldr q4,[x1],16 + ldr s1,[x13],4 + ldr s2,[x14],4 + ldr s3,[x15],4 + ldr q5, [x1], 16 + SdotByElement 16, 4, 0,0 + SdotByElement 17, 4, 1,0 + ldp q6, q7, [x1], 32 + SdotByElement 18, 4, 2,0 + SdotByElement 19, 4, 3,0 + SdotByElement 20, 5, 0,0 + SdotByElement 21, 5, 1,0 + SdotByElement 22, 5, 2,0 + SdotByElement 23, 5, 3,0 + SdotByElement 24, 6, 0,0 + SdotByElement 25, 6, 1,0 + SdotByElement 26, 6, 2,0 + SdotByElement 27, 6, 3,0 + SdotByElement 28, 7, 0,0 + SdotByElement 29, 7, 1,0 + SdotByElement 30, 7, 2,0 + SdotByElement 31, 7, 3,0 + +SkipInCh4: + subs x9,x9,8 // ks -= 1 + b.hi KernelSizeLoop + b Requantize + +PartialStore: + tbz x6,3,LT8Store + str d3,[x5],8 // no less than 8 channels + str d2,[x17],8 + dup d3,v3.d[1] + dup d2,v2.d[1] + str d1,[x16],8 + str d0,[x2],8 + dup d1,v1.d[1] + dup d0,v0.d[1] +LT8Store: + tbz x6,2,LT4Store + str s3,[x5],4 + str s2,[x17],4 + dup s3,v3.s[1] + dup s2,v2.s[1] + str s1,[x16],4 + str s0,[x2],4 + dup s1,v1.s[1] + dup s0,v0.s[1] +LT4Store: + tbz x6,1, LT2Store + str h3,[x5],2 + str h2,[x17],2 + dup h3,v3.h[1] + dup h2,v2.h[1] + str h1,[x16],2 + str h0,[x2],2 + dup h1,v1.h[1] + dup h0,v0.h[1] +LT2Store: + tbz x6,0,ExitKernel + str b3,[x5] + str b2,[x17] + str b1,[x16] + str b0,[x2] + b ExitKernel + + .end diff --git a/onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelNeon.S new file mode 100644 index 0000000000000..1bbe1f166bf44 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/ConvSymS8KernelNeon.S @@ -0,0 +1,441 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + ConvSymS8KernelNeon.S + +Abstract: + + This module implements the kernels for the symmetric quantized integer + convolution operation. + +--*/ + +#include "asmmacro.h" + + .equ .LMLAS_CONV_SYM_FLAG_INPUT_DIRECT, 1 + .equ .LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE, 2 + +// +// Stack frame layout for the symmetric convolution kernel. +// d8-d15, x19-x30 need to be preserved if used +// + .equ .LConvSymFrame_SavedNeonRegisters, (8 * 8) + .equ .LConvSymFrame_SavedRegisters, .LConvSymFrame_SavedNeonRegisters + .equ .LConvSymFrame_PostProcessParams, 0 + .LConvSymFrame_SavedRegisters + .equ .LConvSymFrame_KernelFlags, 8 + .LConvSymFrame_SavedRegisters + + .equ .LConvSymPostProcessParams_Bias, 0 + .equ .LConvSymPostProcessParams_Scale, 8 + .equ .LConvSymPostProcessParams_Min, 16 + .equ .LConvSymPostProcessParams_Max, 20 + .equ .LConvSymPostProcessParams_ZeroPoint, 24 + + .text + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + +Arguments: + + Input (x0) - Supplies the address of the input buffer. + + If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then the input buffer points + directly at the input tensor. + + If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is clear, then the input buffer is an + indirection buffer. Every pointer in the indirection buffer points at a + InputChannels length vector (either from the input tensor or a vector of + padding values). These are grouped in batches of length KernelSize. + These batches are then repeated OutputCount times. + + Filter (x1) - Supplies the address of the filter buffer. + + Output (x2) - Supplies the address of the output buffer. + + KernelSize (x3) - Supplies the size of the kernel. + + If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then kernel size should be 1. + + InputChannels (x4) - Supplies the number of input channels. + + This implementation requires the count to be a multiple of 8. + + OutputChannels (x5) - Supplies the number of output channels. + + ChannelCount (x6) - Supplies the number of channels this iteration produces. + + This implementation requires the count to be 8. + + OutputCount (x7) - Supplies the number of output elements this iteration produces. + + This implementation requires the count to be 1 or 2. + + PostProcessParams - Supplies the address of the post process parameter block. + + KernelFlags - Supplies additional flags controlling the operation. + +Return Value: + + None. + +--*/ + FUNCTION_ENTRY MlasConvSymS8KernelNeon + + stp d8,d9,[sp,#-64]! + ldr x8,[sp,#.LConvSymFrame_PostProcessParams] + ldrb w10,[sp,#.LConvSymFrame_KernelFlags] + stp d10,d11,[sp,#16] + stp d12,d13,[sp,#32] + stp d14,d15,[sp,#48] + mov x9,x3 // save kernel size + ldr x11,[x8,#.LConvSymPostProcessParams_Bias] + mov x16,x4 // save input channels + ldr x12,[x8,#.LConvSymPostProcessParams_Scale] + cmp x7,2 // if OutputCount < 2 + add x5,x2,x5 // c1 = c0 + ldc + add x4,x4,7 // kc = (kc + 7) & ~7 + csel x5,x2,x5,lo // if OutputCount < 2 c1 = c0 + bic x4,x4,7 + ldp s16,s18,[x11],8 // init accumulators with bias + ldp s20,s22,[x11],8 + ldp s24,s26,[x11],8 + ldp s28,s30,[x11],8 + mov v17.16b,v16.16b + mov v19.16b,v18.16b + mov v21.16b,v20.16b + mov v23.16b,v22.16b + mov v25.16b,v24.16b + mov v27.16b,v26.16b + mov v29.16b,v28.16b + mov v31.16b,v30.16b + +// Nested loops, inner loop: input channel; outter loop: kernel size +// Each inner iteration processes 8 input channels, 2 output pixels, 8 output channels. +// +// B 8x8 +// ------------------------------------------------------------------ +// |v4.b[0] v5.b[0] v4.b[0] v5.b[0] v4.b[0] v5.b[0] v4.b[0] v5.b[0] | +// | ... ... ... ... ... ... ... ... | +// |v4.b[7] v5.b[7] v4.b[7] v5.b[7] v4.b[7] v5.b[7] v4.b[7] v5.b[7] | +// A 2x8 ------------------------------------------------------------------ +// ------------------ ------------------------------------------------------------------ +// x13-> |v0.b[0]..v0.b[7]| |v16.4s v18.4s v20.4s v22.4s v24.4s v26.4s v28.4s v30.4s | +// x15-> |v1.b[0]..v1.b[7]| |v17.4s v19.4s v21.4s v23.4s v25.4s v27.4s v29.4s v31.4s | +// ------------------ ------------------------------------------------------------------ +// When Input Channels greater than 16, unroll: +// A registers v6 v7, +// B registers v8 v9 +// + +.LConvSym.KernelSizeLoop: + + # Load next 2 A pointers + tst w10,#.LMLAS_CONV_SYM_FLAG_INPUT_DIRECT + ldr d4,[x1] + ldr d5,[x1,8] + beq .LConvSym.InputIndirection + +.LConvSym.InputDirect: + mov x13,x0 // x13 -> A0 + add x15,x0,x16 // x15 -> A1 = A0 + input channels + b .LConvSym.BlockLoopPrologue + +.LConvSym.InputIndirection: + cmp x7,2 // test if OutputCount < 2 + ldr x13,[x0] // x13 -> A0 + blo .LConvSym.SkipLoadA1 + ldr x15,[x0,x3,lsl#3] // x15 -> A1 +.LConvSym.SkipLoadA1: + +.LConvSym.BlockLoopPrologue: + cmp x7,2 // test if OutputCount < 2 + add x0,x0,8 // indirect A advance to next pointer, prepare for kernel size loop + csel x15,x13,x15,lo // if OutputCount < 2 x15 -> A0 + subs x14,x4,16 // input channel - 16 + blo .LConvSym.8InputChannels // less than 16 deep, no unroll + + ldr d0,[x13],8 + ldr d1,[x15],8 + ldr d8,[x1,64] + ldr d9,[x1,72] + ldr d6,[x13],8 + subs x14,x14,16 // input channel - 16 + ldr d7,[x15],8 + blo .LConvSym.BlockLoopEpilogue // need 32 input channel for full unrolled loop + +.LConvSym.Blockloop: + smull v2.8h,v4.8b,v0.8b + smull v3.8h,v4.8b,v1.8b + ldr d4,[x1,16] + smull v10.8h,v5.8b,v0.8b + smull v11.8h,v5.8b,v1.8b + ldr d5,[x1,24] + smlal v2.8h,v8.8b,v6.8b + smlal v3.8h,v8.8b,v7.8b + ldr d8,[x1,80] + smlal v10.8h,v9.8b,v6.8b + smlal v11.8h,v9.8b,v7.8b + ldr d9,[x1,88] + smull v12.8h,v4.8b,v0.8b + sadalp v16.4s,v2.8h + smull v13.8h,v4.8b,v1.8b + ldr d4,[x1,32] + sadalp v17.4s,v3.8h + smull v14.8h,v5.8b,v0.8b + sadalp v18.4s,v10.8h + smull v15.8h,v5.8b,v1.8b + ldr d5,[x1,40] + sadalp v19.4s,v11.8h + smlal v12.8h,v8.8b,v6.8b + smlal v13.8h,v8.8b,v7.8b + ldr d8,[x1,96] + smlal v14.8h,v9.8b,v6.8b + smlal v15.8h,v9.8b,v7.8b + ldr d9,[x1,104] + smull v2.8h,v4.8b,v0.8b + sadalp v20.4s,v12.8h + smull v3.8h,v4.8b,v1.8b + ldr d4,[x1,48] + sadalp v21.4s,v13.8h + smull v10.8h,v5.8b,v0.8b + sadalp v22.4s,v14.8h + smull v11.8h,v5.8b,v1.8b + ldr d5,[x1,56] + sadalp v23.4s, v15.8h + smlal v2.8h,v8.8b,v6.8b + smlal v3.8h,v8.8b,v7.8b + ldr d8,[x1,112] + smlal v10.8h,v9.8b,v6.8b + smlal v11.8h,v9.8b,v7.8b + ldr d9,[x1,120] + smull v12.8h,v4.8b,v0.8b + add x1,x1,128 + sadalp v24.4s,v2.8h + smull v13.8h,v4.8b,v1.8b + ldr d4,[x1] // Read B + sadalp v25.4s,v3.8h + smull v14.8h,v5.8b,v0.8b + ldr d0,[x13],8 // Read A0 + sadalp v26.4s,v10.8h + smull v15.8h,v5.8b,v1.8b + ldr d1,[x15],8 // Read A1 + sadalp v27.4s,v11.8h + smlal v12.8h,v8.8b,v6.8b + ldr d5,[x1,8] // Read B + smlal v13.8h,v8.8b,v7.8b + ldr d8,[x1,64] // Read B + smlal v14.8h,v9.8b,v6.8b + ldr d6,[x13],8 // Read A0 + smlal v15.8h,v9.8b,v7.8b + ldr d7,[x15],8 // Read A1 + sadalp v28.4s,v12.8h + ldr d9,[x1,72] // Read B + sadalp v29.4s,v13.8h + subs x14,x14,16 + sadalp v30.4s,v14.8h + sadalp v31.4s,v15.8h + b.hs .LConvSym.Blockloop + +.LConvSym.BlockLoopEpilogue: // remaining 16 input channels + smull v2.8h,v4.8b,v0.8b + smull v3.8h,v4.8b,v1.8b + ldr d4,[x1,16] + smull v10.8h,v5.8b,v0.8b + smull v11.8h,v5.8b,v1.8b + ldr d5,[x1,24] + smlal v2.8h,v8.8b,v6.8b + smlal v3.8h,v8.8b,v7.8b + ldr d8,[x1,80] + smlal v10.8h,v9.8b,v6.8b + smlal v11.8h,v9.8b,v7.8b + ldr d9,[x1,88] + smull v12.8h,v4.8b,v0.8b + sadalp v16.4s,v2.8h + smull v13.8h,v4.8b,v1.8b + ldr d4,[x1,32] + sadalp v17.4s,v3.8h + smull v14.8h,v5.8b,v0.8b + sadalp v18.4s,v10.8h + smull v15.8h,v5.8b,v1.8b + sadalp v19.4s,v11.8h + ldr d5,[x1,40] + smlal v12.8h,v8.8b,v6.8b + smlal v13.8h,v8.8b,v7.8b + ldr d8,[x1,96] + smlal v14.8h,v9.8b,v6.8b + smlal v15.8h,v9.8b,v7.8b + ldr d9,[x1,104] + smull v2.8h,v4.8b,v0.8b + sadalp v20.4s,v12.8h + smull v3.8h,v4.8b,v1.8b + ldr d4,[x1,48] + sadalp v21.4s,v13.8h + smull v10.8h,v5.8b,v0.8b + sadalp v22.4s,v14.8h + smull v11.8h,v5.8b,v1.8b + sadalp v23.4s,v15.8h + ldr d5,[x1,56] + smlal v2.8h,v8.8b,v6.8b + smlal v3.8h,v8.8b,v7.8b + ldr d8,[x1,112] + smlal v10.8h,v9.8b,v6.8b + smlal v11.8h,v9.8b,v7.8b + ldr d9,[x1,120] + smull v12.8h,v4.8b,v0.8b + sadalp v24.4s,v2.8h + smull v13.8h,v4.8b,v1.8b + sadalp v25.4s,v3.8h + smull v14.8h,v5.8b,v0.8b + sadalp v26.4s,v10.8h + smull v15.8h,v5.8b,v1.8b + sadalp v27.4s,v11.8h + smlal v12.8h,v8.8b,v6.8b + smlal v13.8h,v8.8b,v7.8b + smlal v14.8h,v9.8b,v6.8b + smlal v15.8h,v9.8b,v7.8b + add x1,x1,128 + + sadalp v28.4s,v12.8h + sadalp v29.4s,v13.8h + sadalp v30.4s,v14.8h + sadalp v31.4s,v15.8h + tbnz x14,3,.LConvSym.8InputChannels + + subs x9,x9,1 + b.hi .LConvSym.KernelSizeLoop + +.LConvSym.Requantize: + ldr w11, [x8, #.LConvSymPostProcessParams_ZeroPoint] + tst w10,#.LMLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE + beq .LConvSym.BroadcastScaleValue + ld1 {v4.4s,v5.4s},[x12] // load scale vector + b .LConvSym.AccumulatorsToFloat + +.LConvSym.BroadcastScaleValue: + ld1r {v4.4s},[x12] // load scale Value + mov v5.16b, v4.16b + +.LConvSym.AccumulatorsToFloat: + addp v16.4s,v16.4s,v18.4s + addp v20.4s,v20.4s,v22.4s + addp v24.4s,v24.4s,v26.4s + addp v28.4s,v28.4s,v30.4s + addp v17.4s,v17.4s,v19.4s + addp v21.4s,v21.4s,v23.4s + addp v25.4s,v25.4s,v27.4s + addp v29.4s,v29.4s,v31.4s + addp v0.4s,v16.4s,v20.4s + addp v1.4s,v24.4s,v28.4s + addp v2.4s,v17.4s,v21.4s + addp v3.4s,v25.4s,v29.4s + scvtf v0.4s,v0.4s // convert to float + scvtf v1.4s,v1.4s + scvtf v2.4s,v2.4s + scvtf v3.4s,v3.4s + fmul v0.4s,v0.4s,v4.4s // multiply by scale + fmul v1.4s,v1.4s,v5.4s + fmul v2.4s,v2.4s,v4.4s + fmul v3.4s,v3.4s,v5.4s + fcvtns v0.4s,v0.4s // convert to int + fcvtns v1.4s,v1.4s + dup v9.8h,w11 + fcvtns v2.4s,v2.4s + fcvtns v3.4s,v3.4s + sqxtn v0.4h,v0.4s + sqxtn2 v0.8h,v1.4s + sqxtn v2.4h,v2.4s + sqxtn2 v2.8h,v3.4s + subs x6, x6, 8 + sqadd v0.8h,v0.8h,v9.8h + sqadd v2.8h,v2.8h,v9.8h + sqxtn v0.8b,v0.8h // shorten to int8 + sqxtn2 v0.16b,v2.8h + b.lo .LConvSym.PartialStore + + st1 {v0.d}[1],[x5] // full 2x8 store to c + st1 {v0.8b},[x2] + +.LConvSym.ExitKernel: + ldp d14,d15,[sp,#48] + ldp d12,d13,[sp,#32] + ldp d10,d11,[sp,#16] + ldp d8,d9,[sp],#64 + ret + +.LConvSym.8InputChannels: + ldr d0,[x13] + ldr d1,[x15] + ldr d4,[x1] + ldr d5,[x1,8] + ldr d6,[x1,16] + ldr d7,[x1,24] + smull v2.8h,v4.8b,v0.8b + smull v3.8h,v4.8b,v1.8b + ldr d4,[x1,32] + smull v10.8h,v5.8b,v0.8b + smull v11.8h,v5.8b,v1.8b + ldr d5,[x1,40] + smull v12.8h,v6.8b,v0.8b + sadalp v16.4s,v2.8h + smull v13.8h,v6.8b,v1.8b + ldr d6,[x1,48] + sadalp v17.4s,v3.8h + smull v14.8h,v7.8b,v0.8b + sadalp v18.4s,v10.8h + smull v15.8h,v7.8b,v1.8b + ldr d7,[x1,56] + sadalp v19.4s,v11.8h + smull v2.8h,v4.8b,v0.8b + sadalp v20.4s,v12.8h + smull v3.8h,v4.8b,v1.8b + sadalp v21.4s,v13.8h + smull v10.8h,v5.8b,v0.8b + sadalp v22.4s,v14.8h + smull v11.8h,v5.8b,v1.8b + sadalp v23.4s,v15.8h + smull v12.8h,v6.8b,v0.8b + sadalp v24.4s,v2.8h + smull v13.8h,v6.8b,v1.8b + sadalp v25.4s,v3.8h + smull v14.8h,v7.8b,v0.8b + sadalp v26.4s,v10.8h + smull v15.8h,v7.8b,v1.8b + sadalp v27.4s,v11.8h + add x1,x1,64 + sadalp v28.4s,v12.8h + sadalp v29.4s,v13.8h + sadalp v30.4s,v14.8h + sadalp v31.4s,v15.8h + + # ks loop + subs x9,x9,1 + b.hi .LConvSym.KernelSizeLoop + b .LConvSym.Requantize + +.LConvSym.PartialStore: + tbz x6,2,.LConvSym.Store2 + st1 {v0.s}[2],[x5],4 + str s0,[x2],4 + EXT v0.16b,v0.16b,v0.16b,4 + +.LConvSym.Store2: + tbz x6, 1, .LConvSym.Store1 + st1 {v0.h}[4], [x5], 2 + str h0, [x2], 2 + EXT v0.16b,v0.16b,v0.16b,2 +.LConvSym.Store1: + tbz x6,0,.LConvSym.ExitKernel + st1 {v0.b}[8],[x5] + str b0,[x2] + b .LConvSym.ExitKernel + + .end diff --git a/onnxruntime/core/mlas/lib/aarch64/ConvSymU8KernelDot.S b/onnxruntime/core/mlas/lib/aarch64/ConvSymU8KernelDot.S index a040427c95172..dc11ad93c6349 100644 --- a/onnxruntime/core/mlas/lib/aarch64/ConvSymU8KernelDot.S +++ b/onnxruntime/core/mlas/lib/aarch64/ConvSymU8KernelDot.S @@ -84,7 +84,7 @@ Return Value: None. --*/ - FUNCTION_ENTRY MlasConvSymKernelNeonDot + FUNCTION_ENTRY MlasConvSymU8KernelDot stp d8,d9,[sp,#-.LConvSymFrame_SavedRegisters]! ldr x8,[sp,#.LConvSymFrame_PostProcessParams] diff --git a/onnxruntime/core/mlas/lib/aarch64/ConvSymU8KernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/ConvSymU8KernelNeon.S index 4cb68827f0a10..fd16b2cbae2cd 100644 --- a/onnxruntime/core/mlas/lib/aarch64/ConvSymU8KernelNeon.S +++ b/onnxruntime/core/mlas/lib/aarch64/ConvSymU8KernelNeon.S @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - ConvSymKernelNeon.s + ConvSymU8KernelNeon.S Abstract: @@ -88,7 +88,7 @@ Return Value: None. --*/ - FUNCTION_ENTRY MlasConvSymKernelNeon + FUNCTION_ENTRY MlasConvSymU8KernelNeon stp d8,d9,[sp,#-64]! ldr x8,[sp,#.LConvSymFrame_PostProcessParams] diff --git a/onnxruntime/core/mlas/lib/aarch64/DepthwiseQConvSymS8KernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/DepthwiseQConvSymS8KernelNeon.S new file mode 100644 index 0000000000000..7a27b9c92a3e5 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/DepthwiseQConvSymS8KernelNeon.S @@ -0,0 +1,692 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + DepthwiseQConvSymS8KernelNeon.S + +Abstract: + + This module implements the kernels for the depthwise convolution + operation with symmetrically quantized integer values + +--*/ + +#include "asmmacro.h" + +// +// Stack frame layout for the depthwise conv kernel. +// d8-d15, x19-x30 need to be preserved if used +// + + .equ .LConvSymDepthwiseKernelFrame_SavedRegisters, (4 * 8) + .equ .LConvSymDepthwiseKernelFrame_PostProcessParams, 0 + .LConvSymDepthwiseKernelFrame_SavedRegisters + .equ .LConvSymDepthwiseKernelFrame_KernelFlags, 8 + .LConvSymDepthwiseKernelFrame_SavedRegisters + + .equ .LConvSymDepthwisePostProcessParams_Bias, 0 + .equ .LConvSymDepthwisePostProcessParams_Scale, 8 + .equ .LConvSymDepthwisePostProcessParams_Min, 16 + .equ .LConvSymDepthwisePostProcessParams_Max, 20 + .equ .LConvSymDepthwisePostProcessParams_ZeroPoint, 24 + + .equ MLAS_CONV_SYM_FLAG_INPUT_DIRECT, 1 + .equ MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE, 2 + + .text + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a depthwise convolution for the + elements of an output row for a set of filter rows. + +Arguments: + + Input (x0) - Supplies the address of the indirection buffer. + + Filter (x1) - Supplies the address of the filter buffer. + + Output (x2) - Supplies the address of the output buffer. + + KernelSize (x3) - Supplies the size of the kernel. + + Channels (x4) - Supplies the number of input and output channels. + + ChannelOffset (x5) - Supplies the byte offset from the indirection buffer base + address for this iteration. + + ChannelCount (x6) - Supplies the number of channels this iteration produces. + + This implementation requires the count to be 16 or 8 + + OutputCount (x7)- Supplies the number of output elements this iteration produces. + + This implementation requires the count to be in the range 1 to 2. + + PostProcessParams - Supplies the address of the post process parameter block. + + KernelFlags - Supplies additional flags controlling the operation. + +Return Value: + + None. + +--*/ + + FUNCTION_ENTRY MlasConvSymDepthwiseS8KernelNeon + + stp d12,d13,[sp,#-.LConvSymDepthwiseKernelFrame_SavedRegisters]! + ldr x8,[sp,#.LConvSymDepthwiseKernelFrame_PostProcessParams] + stp d14,d15,[sp,#16] + cmp x7,2 + add x9,x0,x3,lsl#3 // x9 -> &A1 + add x14,x0,x3,lsl#4 // x14 -> &A2 + add x15,x9,x3,lsl#4 // x15 -> &A3 + ldr x16,[x8,#.LConvSymDepthwisePostProcessParams_Bias] + csel x9,x0,x9,lo // x9 -> &A0 if OutputCount < 2 + csel x14,x0,x14,ls // x14 -> &A0 if OutputCount <= 2 + ldr x11,[x9],#8 // x11 -> A1 iter 0 + cmp x7,4 + ldp q24,q25,[x16],#32 // init accumulators with bias + csel x15,x0,x15,lo // x15 -> &A0 if OutputCount < 4 + cmp x6,16 + ldr x10,[x0],#8 // x10 -> A0 iter 0 + b.lo .LProcess8Channels + +// +// Process an input block of length Channels for each element of the kernel. +// +// Filter: v0, +// v1 // unroll +// Input: +// x0 -> x10 -> v4 +// -> x12 -> v2 // unroll +// x9 -> x11 -> v6 +// -> x13 -> v3 // unroll +// x14 -> x10 -> v4 +// -> x12 -> v2 // unroll +// x15 -> x11 -> v6 +// -> x13 -> v3 // unroll +// + +.LProcess16Channels: + cmp x3,1 + ldp q26,q27,[x16] + b.eq .LProcC16P1 + + ldr x12,[x0],#8 // x12 -> A0 iter 1 + ldr x13,[x9],#8 // x13 -> A1 iter 1 + mov v28.16b,v24.16b + mov v29.16b,v25.16b + ld1 {v0.16b},[x1],x4 // filter iter 0 + ld1 {v1.16b},[x1],x4 // filter iter 1 + mov v16.16b,v24.16b + mov v17.16b,v25.16b + ldr q4,[x10,x5] // A0 iter 0 + mov v20.16b,v24.16b + ldr x10,[x14],#8 // x10 -> A2 iter 0 + mov v21.16b,v25.16b + ldr q6,[x11,x5] // A1 iter 0 + mov v30.16b,v26.16b + ldr x11,[x15],#8 // x11 -> A3 iter 0 + mov v31.16b,v27.16b + ldr q2,[x12,x5] // A0 iter 1 + subs x3,x3,2 // decrement input blocks remaining + mov v18.16b,v26.16b + ldr x12,[x14],#8 // x12 -> A2 iter 1 + mov v19.16b,v27.16b + ldr q3,[x13,x5] // A1 iter 1 + mov v22.16b,v26.16b + ldr x13,[x15],#8 // x13 -> A3 iter 1 + mov v23.16b,v27.16b + +.LBlockLoopC16: + + // + // Process 2 pixels, and load next two pixels + // + smull v12.8h,v0.8b,v4.8b + smull2 v13.8h,v0.16b,v4.16b + ldr q4,[x10,x5] // A2 iter 0 + b.eq .LEpilogueC16P2 + smull v14.8h,v0.8b,v6.8b + ldr x10,[x0],#8 // x10 -> A0 iter 2 + smull2 v15.8h,v0.16b,v6.16b + cmp x3,1 + ldr q6,[x11,x5] // A3 iter 0 + smlal v12.8h,v1.8b,v2.8b + ldr x11,[x9],#8 // x11 -> A1 iter 2 + smlal2 v13.8h,v1.16b,v2.16b + b.eq .LEpilogueC16P3 // 3 pixel remains + ldr q2,[x12,x5] // A2 iter 1 + smlal v14.8h,v1.8b,v3.8b + ldr x12,[x0],#8 // x12 -> A0 iter 3 + smlal2 v15.8h,v1.16b,v3.16b + ldr q3,[x13,x5] // A3 iter 1 + saddw v24.4s,v24.4s,v12.4h + saddw2 v25.4s,v25.4s,v12.8h + ldr x13,[x9],#8 // x13 -> A1 iter 3 + saddw v26.4s,v26.4s,v13.4h + saddw2 v27.4s,v27.4s,v13.8h + saddw v28.4s,v28.4s,v14.4h + saddw2 v29.4s,v29.4s,v14.8h + saddw v30.4s,v30.4s,v15.4h + saddw2 v31.4s,v31.4s,v15.8h + subs x3,x3,2 // decrement input blocks remaining + smull v12.8h,v0.8b,v4.8b + smull2 v13.8h,v0.16b,v4.16b + ldr q4,[x10,x5] // A0 iter 2 + smull v14.8h,v0.8b,v6.8b + ldr x10,[x14],#8 // x10 -> A2 iter 2 + smull2 v15.8h,v0.16b,v6.16b + ldr q6,[x11,x5] // A1 iter 2 + ld1 {v0.16b},[x1],x4 // filter iter 2 + smlal v12.8h,v1.8b,v2.8b + ldr x11,[x15],#8 // x11 -> A3 iter 2 + smlal2 v13.8h,v1.16b,v2.16b + ldr q2,[x12,x5] // A0 iter 3 + smlal v14.8h,v1.8b,v3.8b + ldr x12,[x14],#8 // x12 -> A2 iter 3 + smlal2 v15.8h,v1.16b,v3.16b + ldr q3,[x13,x5] // A1 iter 3 + saddw v16.4s,v16.4s,v12.4h + saddw2 v17.4s,v17.4s,v12.8h + ld1 {v1.16b},[x1],x4 // filter iter 3 + saddw v18.4s,v18.4s,v13.4h + saddw2 v19.4s,v19.4s,v13.8h + ldr x13,[x15],#8 // x13 -> A3 iter 3 + saddw v20.4s,v20.4s,v14.4h + saddw2 v21.4s,v21.4s,v14.8h + saddw v22.4s,v22.4s,v15.4h + saddw2 v23.4s,v23.4s,v15.8h + b .LBlockLoopC16 + +.LEpilogueC16P2: + // + // Loop epilogue (process last 2 pixels) mixed + // with loading of dequantization params + // + smull v14.8h,v0.8b,v6.8b + smull2 v15.8h,v0.16b,v6.16b + ldr q6,[x11,x5] // A3 iter 0 + smlal v12.8h,v1.8b,v2.8b + smlal2 v13.8h,v1.16b,v2.16b + ldr q2,[x12,x5] // A2 iter 1 + smlal v14.8h,v1.8b,v3.8b + smlal2 v15.8h,v1.16b,v3.16b + ldr q3,[x13,x5] // A3 iter 1 + saddw v24.4s,v24.4s,v12.4h + saddw2 v25.4s,v25.4s,v12.8h + saddw v26.4s,v26.4s,v13.4h + saddw2 v27.4s,v27.4s,v13.8h + saddw v28.4s,v28.4s,v14.4h + saddw2 v29.4s,v29.4s,v14.8h + saddw v30.4s,v30.4s,v15.4h + saddw2 v31.4s,v31.4s,v15.8h + ldr w9,[sp,#.LConvSymDepthwiseKernelFrame_KernelFlags] + ldr x12,[x8,#.LConvSymDepthwisePostProcessParams_Scale] + smull v12.8h,v0.8b,v4.8b + smull2 v13.8h,v0.16b,v4.16b + ldr w15,[x8,#.LConvSymDepthwisePostProcessParams_ZeroPoint] + smull v14.8h,v0.8b,v6.8b + smull2 v15.8h,v0.16b,v6.16b + smlal v12.8h,v1.8b,v2.8b + smlal2 v13.8h,v1.16b,v2.16b + smlal v14.8h,v1.8b,v3.8b + smlal2 v15.8h,v1.16b,v3.16b + tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE + ld1r {v4.4s},[x12] // load scale val + b.eq .LSkipScaleVecLoad2 + ldp q4,q5,[x12],#32 // load scale vector if per channel + ldp q6,q3,[x12] +.LSkipScaleVecLoad2: + saddw v16.4s,v16.4s,v12.4h + saddw2 v17.4s,v17.4s,v12.8h + saddw v18.4s,v18.4s,v13.4h + saddw2 v19.4s,v19.4s,v13.8h + saddw v20.4s,v20.4s,v14.4h + saddw2 v21.4s,v21.4s,v14.8h + saddw v22.4s,v22.4s,v15.4h + saddw2 v23.4s,v23.4s,v15.8h + b .LDequantization + +.LProcC16P1: + // + // Channel 16 kernel size 1 + // TODO!! is this reachable at all? + // + ldr x12,[x14],#8 // x12 -> A2 + ldr x13,[x15],#8 // x13 -> A3 + mov v28.16b,v24.16b + mov v29.16b,v25.16b + ld1 {v0.16b},[x1] + mov v16.16b,v24.16b + mov v17.16b,v25.16b + ldr q4,[x10,x5] + mov v20.16b,v24.16b + mov v21.16b,v25.16b + ldr q6,[x11,x5] + mov v30.16b,v26.16b + mov v31.16b,v27.16b + ldr q2,[x12,x5] + subs x3,x3,2 // decrement input blocks remaining + mov v18.16b,v26.16b + mov v19.16b,v27.16b + ldr q3,[x13,x5] + mov v22.16b,v26.16b + mov v23.16b,v27.16b + b .LEpilogueC16P1 + +.LEpilogueC16P3: + // + // Loop epilogue (process last 2 pixels) mixed + // with loading of dequantization params + // + ldr q2,[x12,x5] // A2 iter 1 + smlal v14.8h,v1.8b,v3.8b + ldr x12,[x14],#8 // x12 -> A2 iter 2 + smlal2 v15.8h,v1.16b,v3.16b + ldr q3,[x13,x5] // A3 iter 1 + saddw v24.4s,v24.4s,v12.4h + saddw2 v25.4s,v25.4s,v12.8h + ldr x13,[x15],#8 // x13 -> A3 iter 2 + saddw v26.4s,v26.4s,v13.4h + saddw2 v27.4s,v27.4s,v13.8h + saddw v28.4s,v28.4s,v14.4h + saddw2 v29.4s,v29.4s,v14.8h + saddw v30.4s,v30.4s,v15.4h + saddw2 v31.4s,v31.4s,v15.8h + smull v12.8h,v0.8b,v4.8b + smull2 v13.8h,v0.16b,v4.16b + ldr q4,[x10,x5] // A0 iter 2 + smull v14.8h,v0.8b,v6.8b + smull2 v15.8h,v0.16b,v6.16b + ld1 {v0.16b},[x1] // filter iter 2 + ldr q6,[x11,x5] // A1 iter 2 + smlal v12.8h,v1.8b,v2.8b + smlal2 v13.8h,v1.16b,v2.16b + ldr q2,[x12,x5] // A2 iter 2 + smlal v14.8h,v1.8b,v3.8b + smlal2 v15.8h,v1.16b,v3.16b + ldr q3,[x13,x5] // A3 iter 2 + saddw v16.4s,v16.4s,v12.4h + saddw2 v17.4s,v17.4s,v12.8h + saddw v18.4s,v18.4s,v13.4h + saddw2 v19.4s,v19.4s,v13.8h + saddw v20.4s,v20.4s,v14.4h + saddw2 v21.4s,v21.4s,v14.8h + saddw v22.4s,v22.4s,v15.4h + saddw2 v23.4s,v23.4s,v15.8h + +.LEpilogueC16P1: + // + // Loop epilogue (process last single pixel) mixed with loading of dequantization params + // + ldr w9,[sp,#.LConvSymDepthwiseKernelFrame_KernelFlags] + ldr x12,[x8,#.LConvSymDepthwisePostProcessParams_Scale] + smull v12.8h,v0.8b,v4.8b + smull2 v13.8h,v0.16b,v4.16b + ldr w15,[x8,#.LConvSymDepthwisePostProcessParams_ZeroPoint] + smull v14.8h,v0.8b,v6.8b + smull2 v15.8h,v0.16b,v6.16b + saddw v24.4s,v24.4s,v12.4h + saddw2 v25.4s,v25.4s,v12.8h + saddw v26.4s,v26.4s,v13.4h + saddw2 v27.4s,v27.4s,v13.8h + saddw v28.4s,v28.4s,v14.4h + saddw2 v29.4s,v29.4s,v14.8h + saddw v30.4s,v30.4s,v15.4h + saddw2 v31.4s,v31.4s,v15.8h + smull v12.8h,v0.8b,v2.8b + smull2 v13.8h,v0.16b,v2.16b + smull v14.8h,v0.8b,v3.8b + smull2 v15.8h,v0.16b,v3.16b + tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE + ld1r {v4.4s},[x12] // load scale val + b.eq .LSkipScaleVecLoad + ldp q4,q5,[x12],#32 // load scale vector if per channel + ldp q6,q3,[x12] +.LSkipScaleVecLoad: + saddw v16.4s,v16.4s,v12.4h + saddw2 v17.4s,v17.4s,v12.8h + saddw v18.4s,v18.4s,v13.4h + saddw2 v19.4s,v19.4s,v13.8h + saddw v20.4s,v20.4s,v14.4h + saddw2 v21.4s,v21.4s,v14.8h + saddw v22.4s,v22.4s,v15.4h + saddw2 v23.4s,v23.4s,v15.8h + +.LDequantization: + scvtf v24.4s,v24.4s // convert to float + scvtf v25.4s,v25.4s + scvtf v26.4s,v26.4s + scvtf v27.4s,v27.4s + scvtf v28.4s,v28.4s + scvtf v29.4s,v29.4s + scvtf v30.4s,v30.4s + scvtf v31.4s,v31.4s + scvtf v16.4s,v16.4s + scvtf v17.4s,v17.4s + scvtf v18.4s,v18.4s + scvtf v19.4s,v19.4s + scvtf v20.4s,v20.4s + scvtf v21.4s,v21.4s + scvtf v22.4s,v22.4s + scvtf v23.4s,v23.4s + b.ne .LSkipScaleBroadcast + mov v5.16b,v4.16b // broadcast scale val if not per channel + mov v6.16b,v4.16b + mov v3.16b,v4.16b +.LSkipScaleBroadcast: + fmul v24.4s,v24.4s,v4.4s // multiply by scale + fmul v25.4s,v25.4s,v5.4s + fmul v26.4s,v26.4s,v6.4s + fmul v27.4s,v27.4s,v3.4s + fmul v28.4s,v28.4s,v4.4s + fmul v29.4s,v29.4s,v5.4s + fmul v30.4s,v30.4s,v6.4s + fmul v31.4s,v31.4s,v3.4s + fmul v16.4s,v16.4s,v4.4s + fmul v17.4s,v17.4s,v5.4s + fmul v18.4s,v18.4s,v6.4s + fmul v19.4s,v19.4s,v3.4s + fmul v20.4s,v20.4s,v4.4s + fmul v21.4s,v21.4s,v5.4s + fmul v22.4s,v22.4s,v6.4s + fmul v23.4s,v23.4s,v3.4s + fcvtns v24.4s,v24.4s // convert to int + fcvtns v25.4s,v25.4s + fcvtns v26.4s,v26.4s + fcvtns v27.4s,v27.4s + fcvtns v28.4s,v28.4s + fcvtns v29.4s,v29.4s + fcvtns v30.4s,v30.4s + fcvtns v31.4s,v31.4s + fcvtns v16.4s,v16.4s + fcvtns v17.4s,v17.4s + fcvtns v18.4s,v18.4s + fcvtns v19.4s,v19.4s + fcvtns v20.4s,v20.4s + fcvtns v21.4s,v21.4s + fcvtns v22.4s,v22.4s + fcvtns v23.4s,v23.4s + sqxtn v24.4h,v24.4s // shorten to int16 + sqxtn v26.4h,v26.4s + sqxtn2 v24.8h,v25.4s + sqxtn2 v26.8h,v27.4s + sqxtn v28.4h,v28.4s + sqxtn v30.4h,v30.4s + sqxtn2 v28.8h,v29.4s + sqxtn2 v30.8h,v31.4s + dup v0.8h,w15 + sqxtn v16.4h,v16.4s + sqxtn v18.4h,v18.4s + sqxtn2 v16.8h,v17.4s + sqxtn2 v18.8h,v19.4s + sqxtn v20.4h,v20.4s + sqxtn v22.4h,v22.4s + sqxtn2 v20.8h,v21.4s + sqxtn2 v22.8h,v23.4s + sqadd v24.8h,v24.8h,v0.8h // add zero point + sqadd v26.8h,v26.8h,v0.8h + sqadd v28.8h,v28.8h,v0.8h + sqadd v30.8h,v30.8h,v0.8h + sqadd v16.8h,v16.8h,v0.8h + sqadd v18.8h,v18.8h,v0.8h + sqadd v20.8h,v20.8h,v0.8h + sqadd v22.8h,v22.8h,v0.8h + sqxtn v24.8b,v24.8h // shorten to int8 + sqxtn2 v24.16b,v26.8h + sqxtn v28.8b,v28.8h + sqxtn2 v28.16b,v30.8h + sqxtn v16.8b,v16.8h + sqxtn2 v16.16b,v18.8h + sqxtn v20.8b,v20.8h + sqxtn2 v20.16b,v22.8h + cmp x7,2 // OutputCount < 2 ? + st1 {v24.16b},[x2],x4 + b.lo .LExitKernel // exit if OutputCount < 2 + st1 {v28.16b},[x2],x4 + b.ls .LExitKernel // exit if OutputCount <=2 + cmp x7,4 // OutputCount < 4 ? + st1 {v16.16b},[x2],x4 + b.lo .LExitKernel // exit if OutputCount < 4 + str q20,[x2] + +.LExitKernel: + ldp d14,d15,[sp,#16] + ldp d12,d13,[sp],#.LConvSymDepthwiseKernelFrame_SavedRegisters + ret + +.LProcess8Channels: + cmp x3,1 + b.eq .LProcC8P1 + + ldr x12,[x0],#8 // x12 -> A0 iter 1 + ldr x13,[x9],#8 // x13 -> A1 iter 1 + ld1 {v0.8b},[x1],x4 // filter iter 0 + ld1 {v1.8b},[x1],x4 // filter iter 1 + ldr d4,[x10,x5] // A0 iter 0 + ldr x10,[x14],#8 // x10 -> A2 iter 0 + mov v28.16b,v24.16b + ldr d6,[x11,x5] // A1 iter 0 + mov v29.16b,v25.16b + ldr x11,[x15],#8 // x11 -> A3 iter 0 + mov v16.16b,v24.16b + ldr d2,[x12,x5] // A0 iter 1 + mov v17.16b,v25.16b + ldr x12,[x14],#8 // x12 -> A2 iter 1 + subs x3,x3,2 // decrement input blocks remaining + ldr d3,[x13,x5] // A1 iter 1 + mov v20.16b,v24.16b + ldr x13,[x15],#8 // x13 -> A3 iter 1 + mov v21.16b,v25.16b + +.LBlockLoopC8: + // + // Process 2 pixels, and load next two pixels + // + smull v12.8h,v0.8b,v4.8b + ldr d4,[x10,x5] // A2 iter 0 + smull v14.8h,v0.8b,v6.8b + b.eq .LEpilogueC8P2 + ldr x10,[x0],#8 // x10 -> A0 iter 2 + ldr d6,[x11,x5] // A3 iter 0 + cmp x3,1 + smlal v12.8h,v1.8b,v2.8b + ldr x11,[x9],#8 // x11 -> A1 iter 2 + smlal v14.8h,v1.8b,v3.8b + ldr d2,[x12,x5] // A2 iter 1 + b.eq .LEpilogueC8P3 // 3 pixel remains + ldr d3,[x13,x5] // A3 iter 1 + saddw v24.4s,v24.4s,v12.4h + ldr x12,[x0],#8 // x12 -> A0 iter 3 + saddw2 v25.4s,v25.4s,v12.8h + ldr x13,[x9],#8 // x13 -> A1 iter 3 + saddw v28.4s,v28.4s,v14.4h + saddw2 v29.4s,v29.4s,v14.8h + subs x3,x3,2 // decrement input blocks remaining + smull v12.8h,v0.8b,v4.8b + ldr d4,[x10,x5] // A0 iter 2 + smull v14.8h,v0.8b,v6.8b + ldr x10,[x14],#8 // x10 -> A2 iter 2 + ldr d6,[x11,x5] // A1 iter 2 + ld1 {v0.8b},[x1],x4 // filter iter 2 + smlal v12.8h,v1.8b,v2.8b + ldr x11,[x15],#8 // x11 -> A3 iter 2 + ldr d2,[x12,x5] // A0 iter 3 + smlal v14.8h,v1.8b,v3.8b + ldr x12,[x14],#8 // x12 -> A2 iter 3 + saddw v16.4s,v16.4s,v12.4h + ldr d3,[x13,x5] // A1 iter 3 + saddw2 v17.4s,v17.4s,v12.8h + ld1 {v1.8b},[x1],x4 // filter iter 3 + saddw v20.4s,v20.4s,v14.4h + ldr x13,[x15],#8 // x13 -> A3 iter 3 + saddw2 v21.4s,v21.4s,v14.8h + b .LBlockLoopC8 + +.LEpilogueC8P2: + // + // Loop epilogue (process last 2 pixels) mixed + // with loading of dequantization params + // + ldr d6,[x11,x5] // A3 iter 0 + smlal v12.8h,v1.8b,v2.8b + ldr d2,[x12,x5] // A2 iter 1 + smlal v14.8h,v1.8b,v3.8b + ldr d3,[x13,x5] // A3 iter 1 + saddw v24.4s,v24.4s,v12.4h + saddw2 v25.4s,v25.4s,v12.8h + saddw v28.4s,v28.4s,v14.4h + saddw2 v29.4s,v29.4s,v14.8h + ldr w9,[sp,#.LConvSymDepthwiseKernelFrame_KernelFlags] + smull v12.8h,v0.8b,v4.8b + ldr x12,[x8,#.LConvSymDepthwisePostProcessParams_Scale] + smull v14.8h,v0.8b,v6.8b + ldr w15,[x8,#.LConvSymDepthwisePostProcessParams_ZeroPoint] + smlal v12.8h,v1.8b,v2.8b + smlal v14.8h,v1.8b,v3.8b + tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE + ld1r {v4.4s},[x12] // load scale val + b.eq .LSkipScaleVecLoad2C8 + ldp q4,q5,[x12],#32 // load scale vector if per channel +.LSkipScaleVecLoad2C8: + saddw v16.4s,v16.4s,v12.4h + saddw2 v17.4s,v17.4s,v12.8h + saddw v20.4s,v20.4s,v14.4h + saddw2 v21.4s,v21.4s,v14.8h + b .LDequantC8 + +.LProcC8P1: + // + // Channel 8 kernel size 1 + // TODO!! is this reachable at all? + // + ldr x12,[x14],#8 // x12 -> A2 + mov v28.16b,v24.16b + ldr x13,[x15],#8 // x13 -> A3 + mov v29.16b,v25.16b + ld1 {v0.8b},[x1] + mov v16.16b,v24.16b + ldr d4,[x10,x5] + mov v17.16b,v25.16b + ldr d6,[x11,x5] + mov v20.16b,v24.16b + ldr d2,[x12,x5] + subs x3,x3,2 // decrement input blocks remaining + ldr d3,[x13,x5] + mov v21.16b,v25.16b + b .LEpilogueC8P1 + +.LEpilogueC8P3: + // + // Loop epilogue (process 2 of last 3 pixels) + // + ldr x12,[x14],#8 // x12 -> A2 iter 2 + ldr d3,[x13,x5] // A3 iter 1 + saddw v24.4s,v24.4s,v12.4h + saddw2 v25.4s,v25.4s,v12.8h + ldr x13,[x15],#8 // x13 -> A3 iter 2 + saddw v28.4s,v28.4s,v14.4h + saddw2 v29.4s,v29.4s,v14.8h + smull v12.8h,v0.8b,v4.8b + ldr d4,[x10,x5] // A0 iter 2 + smull v14.8h,v0.8b,v6.8b + ld1 {v0.8b},[x1] // filter iter 2 + ldr d6,[x11,x5] // A1 iter 2 + smlal v12.8h,v1.8b,v2.8b + ldr d2,[x12,x5] // A2 iter 2 + smlal v14.8h,v1.8b,v3.8b + ldr d3,[x13,x5] // A3 iter 2 + saddw v16.4s,v16.4s,v12.4h + saddw2 v17.4s,v17.4s,v12.8h + saddw v20.4s,v20.4s,v14.4h + saddw2 v21.4s,v21.4s,v14.8h + +.LEpilogueC8P1: + // + // Loop epilogue (process last single pixel) mixed with loading of dequantization params + // + ldr w9,[sp,#.LConvSymDepthwiseKernelFrame_KernelFlags] + ldr x12,[x8,#.LConvSymDepthwisePostProcessParams_Scale] + smull v12.8h,v0.8b,v4.8b + ldr w15,[x8,#.LConvSymDepthwisePostProcessParams_ZeroPoint] + smull v14.8h,v0.8b,v6.8b + saddw v24.4s,v24.4s,v12.4h + saddw2 v25.4s,v25.4s,v12.8h + saddw v28.4s,v28.4s,v14.4h + saddw2 v29.4s,v29.4s,v14.8h + smull v12.8h,v0.8b,v2.8b + smull v14.8h,v0.8b,v3.8b + tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE + ld1r {v4.4s},[x12] // load scale val + b.eq .LSkipScaleVecLoadC8 + ldp q4,q5,[x12] // load scale vector if per channel +.LSkipScaleVecLoadC8: + saddw v16.4s,v16.4s,v12.4h + saddw2 v17.4s,v17.4s,v12.8h + saddw v20.4s,v20.4s,v14.4h + saddw2 v21.4s,v21.4s,v14.8h + +.LDequantC8: + scvtf v24.4s,v24.4s // convert to float + scvtf v25.4s,v25.4s + scvtf v28.4s,v28.4s + scvtf v29.4s,v29.4s + scvtf v16.4s,v16.4s + scvtf v17.4s,v17.4s + scvtf v20.4s,v20.4s + scvtf v21.4s,v21.4s + b.ne .LSkipScaleBroadcastC8 + mov v5.16b,v4.16b // broadcast scale val if not per channel +.LSkipScaleBroadcastC8: + fmul v24.4s,v24.4s,v4.4s // multiply by scale + fmul v25.4s,v25.4s,v5.4s + fmul v28.4s,v28.4s,v4.4s + fmul v29.4s,v29.4s,v5.4s + fmul v16.4s,v16.4s,v4.4s + fmul v17.4s,v17.4s,v5.4s + fmul v20.4s,v20.4s,v4.4s + fmul v21.4s,v21.4s,v5.4s + fcvtns v24.4s,v24.4s // convert to int + fcvtns v25.4s,v25.4s + fcvtns v28.4s,v28.4s + fcvtns v29.4s,v29.4s + fcvtns v16.4s,v16.4s + fcvtns v17.4s,v17.4s + fcvtns v20.4s,v20.4s + fcvtns v21.4s,v21.4s + dup v0.8h,w15 + sqxtn v24.4h,v24.4s // shorten to int16 + sqxtn2 v24.8h,v25.4s + sqxtn v28.4h,v28.4s + sqxtn2 v28.8h,v29.4s + sqxtn v16.4h,v16.4s + sqxtn2 v16.8h,v17.4s + sqxtn v20.4h,v20.4s + sqxtn2 v20.8h,v21.4s + sqadd v24.8h,v24.8h,v0.8h // add zero point + sqadd v28.8h,v28.8h,v0.8h + sqadd v16.8h,v16.8h,v0.8h + sqadd v20.8h,v20.8h,v0.8h + sqxtn v24.8b,v24.8h // shorten to int8 + sqxtn v28.8b,v28.8h + sqxtn v16.8b,v16.8h + sqxtn v20.8b,v20.8h + cmp x7,2 // OutputCount < 2 ? + st1 {v24.8b},[x2],x4 + b.lo .LExitKernel // exit if OutputCount < 2 + st1 {v28.8b},[x2],x4 + b.ls .LExitKernel // exit if OutputCount <=2 + cmp x7,4 // OutputCount < 4 ? + st1 {v16.8b},[x2],x4 + b.lo .LExitKernel // exit if OutputCount < 4 + str d20,[x2] + b .LExitKernel + + .end diff --git a/onnxruntime/core/mlas/lib/aarch64/DepthwiseConvSymKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/DepthwiseQConvSymU8KernelNeon.S similarity index 99% rename from onnxruntime/core/mlas/lib/aarch64/DepthwiseConvSymKernelNeon.S rename to onnxruntime/core/mlas/lib/aarch64/DepthwiseQConvSymU8KernelNeon.S index 379358f230476..05a51253b0c8a 100644 --- a/onnxruntime/core/mlas/lib/aarch64/DepthwiseConvSymKernelNeon.S +++ b/onnxruntime/core/mlas/lib/aarch64/DepthwiseQConvSymU8KernelNeon.S @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - DepthwiseConvSymKernelNeon.S + DepthwiseQConvSymU8KernelNeon.S Abstract: @@ -78,7 +78,7 @@ Return Value: --*/ - FUNCTION_ENTRY MlasConvSymDepthwiseKernelNeon + FUNCTION_ENTRY MlasConvSymDepthwiseU8KernelNeon stp d8,d9,[sp,#-64]! ldr x8,[sp,#.LConvSymDepthwiseKernelFrame_PostProcessParams] diff --git a/onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelDot.asm b/onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelDot.asm new file mode 100644 index 0000000000000..ddbff20cfbd28 --- /dev/null +++ b/onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelDot.asm @@ -0,0 +1,605 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + ConvSymS8KernelDot.asm + +Abstract: + + This module implements the kernels for the symmetric quantized integer + convolution operation. + +--*/ + +#include "kxarm64.h" + +#define MLAS_CONV_SYM_FLAG_INPUT_DIRECT 1 +#define MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE 2 + +// +// Stack frame layout for the symmetric convolution kernel. +// d8-d15, x19-x30 need to be preserved if used +// +#define ConvSymFrame_SavedRegisters (6 * 8) +#define ConvSymFrame_PostProcessParams 0 + ConvSymFrame_SavedRegisters +#define ConvSymFrame_KernelFlags 8 + ConvSymFrame_SavedRegisters + +#define ConvSymPostProcessParams_Bias 0 +#define ConvSymPostProcessParams_Scale 8 +#define ConvSymPostProcessParams_Min 16 +#define ConvSymPostProcessParams_Max 20 +#define ConvSymPostProcessParams_ZeroPoint 24 + + TEXTAREA + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + +Arguments: + + Input (x0) - Points to the input buffer. + + If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then the input buffer points + directly at the input tensor. + + If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is clear, then the input buffer is an + indirection buffer. Every pointer in the indirection buffer points at a + InputChannels length vector (either from the input tensor or a vector of + padding values). These are grouped in batches of length KernelSize. + These batches are then repeated OutputCount times. + + Filter (x1) - Points to the filter buffer. + + Output (x2) - Points the output buffer. + + KernelSize (x3/x9) - Size of the kernel (most commonly. 3x3=9, 5x5=25). + + If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then kernel size should be 1. + + InputChannels (x4/x7) - Number of input channels. + + OutputChannels (x5) - Number of output channels. + + ChannelCount (x6) - Number of output channels this iteration produces. + + OutputCount (x7) - Number of output elements this iteration produces. + + This implementation requires the count to be no larger than 4. + + PostProcessParams (x8) - Points to the post process parameter block. + + KernelFlags - (w10) Additional flags controlling the operation. + +Return Value: + + None. + +--*/ + NESTED_ENTRY MlasConvSymS8KernelDot + + PROLOG_SAVE_REG_PAIR d8,d9,#-ConvSymFrame_SavedRegisters! + PROLOG_NOP ldr x8,[sp,#ConvSymFrame_PostProcessParams] + PROLOG_NOP ldr w10,[sp,#ConvSymFrame_KernelFlags] + PROLOG_SAVE_REG_PAIR d10,d11,#16 + PROLOG_SAVE_REG_PAIR x19,x20,#32 + + // compute C pointers: x2, x16, x17, x5 + cmp x7,2 // OutputCount < 2 ? + add x16,x2,x5 // x16 -> C1 + lsl x3,x3,#3 // KernelSize * sizeof(int8_t*) + csel x16,x2,x16,lo // if OutputCount < 2 x16/C1 -> C0 + mov x20,x4 + add x4,x4,3 // InputChannels align to 4 + add x17,x16,x5 // x17 -> C2 + ldr x11,[x8,#ConvSymPostProcessParams_Bias] + csel x17,x16,x17,ls // if OutputCount <= 2 x17/C2 -> C1 + bic x4,x4,3 + cmp x7,4 // OutputCount < 4 ? + add x5,x17,x5 // x5 -> C3 + ldr x19,[x8,#ConvSymPostProcessParams_Scale] + csel x5,x17,x5,lo // if OutputCount < 4 x5/C3 -> C2 + + // TODO!! tiptoe around loading biases if we need to support + // output channels none divisible by 16 +OutputChannelLoop + ldp q16,q20,[x11],32 // Init accumulators with biases + mov v17.16b,v16.16b + mov v18.16b,v16.16b + ldp q24,q28,[x11],32 + mov v19.16b,v16.16b + mov v21.16b,v20.16b + mov v22.16b,v20.16b + mov v23.16b,v20.16b + mov v25.16b,v24.16b + mov v26.16b,v24.16b + mov v27.16b,v24.16b + mov v29.16b,v28.16b + mov v30.16b,v28.16b + mov v31.16b,v28.16b + mov x9,x3 // restore KernelSize * sizeof(int8_t*) + +KernelSizeLoop + tst w10,#MLAS_CONV_SYM_FLAG_INPUT_DIRECT + beq InputIndirection + +InputDirect + cmp x16,x2 + mov x12,x0 // x12 -> A0 + add x13,x0,x20 // x13 -> A1 = A0 + input channels + csel x13,x0,x13,eq + cmp x17,x16 + add x14,x0,x20,lsl#1 // x14 -> A2 + csel x14,x13,x14,eq + cmp x5,x17 + add x15,x13,x20,lsl#1 // x15 -> A3 + csel x15,x14,x15,eq + b FinishLoadAPtr + +InputIndirection + ldr x12,[x0] // x12 -> A0 + cmp x16,x2 + b.eq SkipLoadA1 // C1==C0 -> A0=A1=A2=A3 + cmp x17,x16 + lsl x14,x3,#1 + ldr x13,[x0,x3] // x13 -> A1 + b.eq SkipLoadA2 // C2==C1 -> A1=A2=A3 + cmp x5,x17 + add x15,x3,x3,lsl#1 + ldr x14,[x0,x14] // x14 -> A2 + b.eq SkipLoadA3 // C3==C2 -> A2=A3 + ldr x15,[x0,x15] // x15 -> A3 + b FinishLoadAPtr +SkipLoadA1 + mov x13,x12 +SkipLoadA2 + mov x14,x13 +SkipLoadA3 + mov x15,x14 + +// Register Usage +// B (x1) -> 4x16 +// ---------------------------------------------------------------------------- +// |v4.b[0]..v4.b[12] v5.b[0]..v5.b[12] v6.b[0]..v6.b[12] v7.b[0]..v7.b[12]| +// | ... ... ... ... ... ... ... ... | +// |v4.b[3]..v4.b[15] v5.b[3]..v5.b[15] v6.b[3]..v6.b[15] v7.b[3]..v7.b[15]| +// A 4x4 ---------------------------------------------------------------------------- +// ------------------ ---------------------------------------------------------------------------- +// x12 |v0.b[0]..v0.b[3]| |v16.s[0]_v16.s[3] v20.s[0]_v20.s[3] v24.s[0]_v24.s[3] v28.s[0]_v28.s[3]| x2 +// x13 |v1.b[0]..v1.b[3]| |v17.s[0]_v17.s[3] v21.s[0]_v21.s[3] v25.s[0]_v25.s[3] v29.s[0]_v29.s[3]| x16 +// x14 |v2.b[0]..v2.b[3]| |v18.s[0]_v18.s[3] v22.s[0]_v23.s[3] v26.s[0]_v26.s[3] v30.s[0]_v31.s[3]| x17 +// x15 |v3.b[0]..v3.b[3]| |v19.s[0]_v19.s[3] v23.s[0]_v23.s[3] v27.s[0]_v27.s[3] v31.s[0]_v31.s[3]| x5 +// ------------------ ---------------------------------------------------------------------------- + +FinishLoadAPtr + subs x7,x4,16 // Need 16 input channels for loop + add x0,x0,8 // indirect A advance to next pointer, prepare for kernel size loop + b.lo InChannels8 + + ldr d0,[x12],8 + ldr q4,[x1],16 + ldr d1,[x13],8 + subs x7,x7,16 + ldr d2,[x14],8 + ldr d3,[x15],8 + ldr q5,[x1],16 + ldr q6,[x1],16 + ldr q7,[x1],16 + b.lo InChLoopEpilogue // Need 32 input channels for main loop + +InputChannelLoop + sdot v16.4s,v4.16b,v0.4b[0] + sdot v17.4s,v4.16b,v1.4b[0] + ldr d8,[x12],8 + sdot v18.4s,v4.16b,v2.4b[0] + sdot v19.4s,v4.16b,v3.4b[0] + ldr q4,[x1],16 + sdot v20.4s,v5.16b,v0.4b[0] + sdot v21.4s,v5.16b,v1.4b[0] + ldr d9,[x13],8 + sdot v22.4s,v5.16b,v2.4b[0] + sdot v23.4s,v5.16b,v3.4b[0] + ldr q5,[x1],16 + sdot v24.4s,v6.16b,v0.4b[0] + sdot v25.4s,v6.16b,v1.4b[0] + ldr d10,[x14],8 + sdot v26.4s,v6.16b,v2.4b[0] + sdot v27.4s,v6.16b,v3.4b[0] + ldr q6,[x1],16 + sdot v28.4s,v7.16b,v0.4b[0] + sdot v29.4s,v7.16b,v1.4b[0] + ldr d11,[x15],8 + sdot v30.4s,v7.16b,v2.4b[0] + sdot v31.4s,v7.16b,v3.4b[0] + ldr q7,[x1],16 + sdot v16.4s,v4.16b,v0.4b[1] + sdot v17.4s,v4.16b,v1.4b[1] + sdot v18.4s,v4.16b,v2.4b[1] + sdot v19.4s,v4.16b,v3.4b[1] + ldr q4,[x1],16 + sdot v20.4s,v5.16b,v0.4b[1] + sdot v21.4s,v5.16b,v1.4b[1] + sdot v22.4s,v5.16b,v2.4b[1] + sdot v23.4s,v5.16b,v3.4b[1] + ldr q5,[x1],16 + sdot v24.4s,v6.16b,v0.4b[1] + sdot v25.4s,v6.16b,v1.4b[1] + sdot v26.4s,v6.16b,v2.4b[1] + sdot v27.4s,v6.16b,v3.4b[1] + ldr q6,[x1],16 + sdot v28.4s,v7.16b,v0.4b[1] + sdot v29.4s,v7.16b,v1.4b[1] + sdot v30.4s,v7.16b,v2.4b[1] + sdot v31.4s,v7.16b,v3.4b[1] + ldr q7,[x1],16 + sdot v16.4s,v4.16b,v8.4b[0] + sdot v17.4s,v4.16b,v9.4b[0] + ldr d0,[x12],8 + sdot v18.4s,v4.16b,v10.4b[0] + sdot v19.4s,v4.16b,v11.4b[0] + ldr q4,[x1],16 + sdot v20.4s,v5.16b,v8.4b[0] + sdot v21.4s,v5.16b,v9.4b[0] + ldr d1,[x13],8 + sdot v22.4s,v5.16b,v10.4b[0] + sdot v23.4s,v5.16b,v11.4b[0] + ldr q5,[x1],16 + sdot v24.4s,v6.16b,v8.4b[0] + sdot v25.4s,v6.16b,v9.4b[0] + ldr d2,[x14],8 + sdot v26.4s,v6.16b,v10.4b[0] + sdot v27.4s,v6.16b,v11.4b[0] + ldr q6,[x1],16 + sdot v28.4s,v7.16b,v8.4b[0] + sdot v29.4s,v7.16b,v9.4b[0] + ldr d3,[x15],8 + sdot v30.4s,v7.16b,v10.4b[0] + sdot v31.4s,v7.16b,v11.4b[0] + ldr q7,[x1],16 + sdot v16.4s,v4.16b,v8.4b[1] + sdot v17.4s,v4.16b,v9.4b[1] + sdot v18.4s,v4.16b,v10.4b[1] + sdot v19.4s,v4.16b,v11.4b[1] + ldr q4,[x1],16 + sdot v20.4s,v5.16b,v8.4b[1] + sdot v21.4s,v5.16b,v9.4b[1] + sdot v22.4s,v5.16b,v10.4b[1] + sdot v23.4s,v5.16b,v11.4b[1] + ldr q5,[x1],16 + sdot v24.4s,v6.16b,v8.4b[1] + sdot v25.4s,v6.16b,v9.4b[1] + sdot v26.4s,v6.16b,v10.4b[1] + sdot v27.4s,v6.16b,v11.4b[1] + ldr q6,[x1],16 + sdot v28.4s,v7.16b,v8.4b[1] + sdot v29.4s,v7.16b,v9.4b[1] + subs x7,x7,16 // InputChannels -= 16 + sdot v30.4s,v7.16b,v10.4b[1] + sdot v31.4s,v7.16b,v11.4b[1] + ldr q7,[x1],16 + b.hs InputChannelLoop + +InChLoopEpilogue + sdot v16.4s,v4.16b,v0.4b[0] + sdot v17.4s,v4.16b,v1.4b[0] + ldr d8,[x12],8 + sdot v18.4s,v4.16b,v2.4b[0] + sdot v19.4s,v4.16b,v3.4b[0] + ldr q4,[x1],16 + sdot v20.4s,v5.16b,v0.4b[0] + sdot v21.4s,v5.16b,v1.4b[0] + ldr d9,[x13],8 + sdot v22.4s,v5.16b,v2.4b[0] + sdot v23.4s,v5.16b,v3.4b[0] + ldr q5,[x1],16 + sdot v24.4s,v6.16b,v0.4b[0] + sdot v25.4s,v6.16b,v1.4b[0] + ldr d10,[x14],8 + sdot v26.4s,v6.16b,v2.4b[0] + sdot v27.4s,v6.16b,v3.4b[0] + ldr q6,[x1],16 + sdot v28.4s,v7.16b,v0.4b[0] + sdot v29.4s,v7.16b,v1.4b[0] + ldr d11,[x15],8 + sdot v30.4s,v7.16b,v2.4b[0] + sdot v31.4s,v7.16b,v3.4b[0] + ldr q7,[x1],16 + sdot v16.4s,v4.16b,v0.4b[1] + sdot v17.4s,v4.16b,v1.4b[1] + sdot v18.4s,v4.16b,v2.4b[1] + sdot v19.4s,v4.16b,v3.4b[1] + ldr q4,[x1],16 + sdot v20.4s,v5.16b,v0.4b[1] + sdot v21.4s,v5.16b,v1.4b[1] + sdot v22.4s,v5.16b,v2.4b[1] + sdot v23.4s,v5.16b,v3.4b[1] + ldr q5,[x1],16 + sdot v24.4s,v6.16b,v0.4b[1] + sdot v25.4s,v6.16b,v1.4b[1] + sdot v26.4s,v6.16b,v2.4b[1] + sdot v27.4s,v6.16b,v3.4b[1] + ldr q6,[x1],16 + sdot v28.4s,v7.16b,v0.4b[1] + sdot v29.4s,v7.16b,v1.4b[1] + sdot v30.4s,v7.16b,v2.4b[1] + sdot v31.4s,v7.16b,v3.4b[1] + ldr q7,[x1],16 + sdot v16.4s,v4.16b,v8.4b[0] + sdot v17.4s,v4.16b,v9.4b[0] + sdot v18.4s,v4.16b,v10.4b[0] + sdot v19.4s,v4.16b,v11.4b[0] + ldr q4,[x1],16 + sdot v20.4s,v5.16b,v8.4b[0] + sdot v21.4s,v5.16b,v9.4b[0] + sdot v22.4s,v5.16b,v10.4b[0] + sdot v23.4s,v5.16b,v11.4b[0] + ldr q5,[x1],16 + sdot v24.4s,v6.16b,v8.4b[0] + sdot v25.4s,v6.16b,v9.4b[0] + sdot v26.4s,v6.16b,v10.4b[0] + sdot v27.4s,v6.16b,v11.4b[0] + ldr q6,[x1],16 + sdot v28.4s,v7.16b,v8.4b[0] + sdot v29.4s,v7.16b,v9.4b[0] + sdot v30.4s,v7.16b,v10.4b[0] + sdot v31.4s,v7.16b,v11.4b[0] + ldr q7,[x1],16 + sdot v16.4s,v4.16b,v8.4b[1] + sdot v17.4s,v4.16b,v9.4b[1] + sdot v18.4s,v4.16b,v10.4b[1] + sdot v19.4s,v4.16b,v11.4b[1] + sdot v20.4s,v5.16b,v8.4b[1] + sdot v21.4s,v5.16b,v9.4b[1] + sdot v22.4s,v5.16b,v10.4b[1] + sdot v23.4s,v5.16b,v11.4b[1] + sdot v24.4s,v6.16b,v8.4b[1] + sdot v25.4s,v6.16b,v9.4b[1] + sdot v26.4s,v6.16b,v10.4b[1] + sdot v27.4s,v6.16b,v11.4b[1] + sdot v28.4s,v7.16b,v8.4b[1] + sdot v29.4s,v7.16b,v9.4b[1] + sdot v30.4s,v7.16b,v10.4b[1] + sdot v31.4s,v7.16b,v11.4b[1] + + TST x7,15 + B.NE InChannels8 // 4 ~ 12 InputChannels + + subs x9,x9,8 // KernelSize-=1 + b.hi KernelSizeLoop + +Requantize + tst w10,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE + ldr w13,[x8,#ConvSymPostProcessParams_ZeroPoint] + beq BroadcastScaleValue + ldp q0,q1,[x19],32 // load scale vector + ldp q2,q3,[x19],32 + b AccumulatorsToFloat + +BroadcastScaleValue + ld1r {v0.4s},[x19] // load scale Value + mov v1.16b, v0.16b + mov v2.16b, v0.16b + mov v3.16b, v0.16b + +AccumulatorsToFloat + scvtf v16.4s,v16.4s // convert to float + scvtf v17.4s,v17.4s + scvtf v18.4s,v18.4s + scvtf v19.4s,v19.4s + scvtf v20.4s,v20.4s + scvtf v21.4s,v21.4s + scvtf v22.4s,v22.4s + scvtf v23.4s,v23.4s + scvtf v24.4s,v24.4s + scvtf v25.4s,v25.4s + scvtf v26.4s,v26.4s + scvtf v27.4s,v27.4s + scvtf v28.4s,v28.4s + scvtf v29.4s,v29.4s + scvtf v30.4s,v30.4s + scvtf v31.4s,v31.4s + fmul v16.4s,v16.4s,v0.4s // multiply by scale + fmul v17.4s,v17.4s,v0.4s + fmul v18.4s,v18.4s,v0.4s + fmul v19.4s,v19.4s,v0.4s + fmul v20.4s,v20.4s,v1.4s + fmul v21.4s,v21.4s,v1.4s + fmul v22.4s,v22.4s,v1.4s + fmul v23.4s,v23.4s,v1.4s + fmul v24.4s,v24.4s,v2.4s + fmul v25.4s,v25.4s,v2.4s + fmul v26.4s,v26.4s,v2.4s + fmul v27.4s,v27.4s,v2.4s + fmul v28.4s,v28.4s,v3.4s + fmul v29.4s,v29.4s,v3.4s + fmul v30.4s,v30.4s,v3.4s + fmul v31.4s,v31.4s,v3.4s + fcvtns v16.4s,v16.4s // convert to int + fcvtns v17.4s,v17.4s + fcvtns v18.4s,v18.4s + fcvtns v19.4s,v19.4s + fcvtns v20.4s,v20.4s + fcvtns v21.4s,v21.4s + fcvtns v22.4s,v22.4s + fcvtns v23.4s,v23.4s + fcvtns v24.4s,v24.4s + fcvtns v25.4s,v25.4s + fcvtns v26.4s,v26.4s + fcvtns v27.4s,v27.4s + fcvtns v28.4s,v28.4s + fcvtns v29.4s,v29.4s + fcvtns v30.4s,v30.4s + fcvtns v31.4s,v31.4s + + sqxtn v16.4h,v16.4s + sqxtn v17.4h,v17.4s + sqxtn v18.4h,v18.4s + sqxtn v19.4h,v19.4s + sqxtn v24.4h,v24.4s + sqxtn v25.4h,v25.4s + sqxtn v26.4h,v26.4s + sqxtn v27.4h,v27.4s + dup v4.8h,w13 // zero point + sqxtn2 v16.8h,v20.4s + sqxtn2 v17.8h,v21.4s + sqxtn2 v18.8h,v22.4s + sqxtn2 v19.8h,v23.4s + sqxtn2 v24.8h,v28.4s + sqxtn2 v25.8h,v29.4s + sqxtn2 v26.8h,v30.4s + sqxtn2 v27.8h,v31.4s + sqadd v16.8h,v16.8h,v4.8h + sqadd v17.8h,v17.8h,v4.8h + sqadd v18.8h,v18.8h,v4.8h + sqadd v19.8h,v19.8h,v4.8h + sqadd v24.8h,v24.8h,v4.8h + sqadd v25.8h,v25.8h,v4.8h + sqadd v26.8h,v26.8h,v4.8h + sqadd v27.8h,v27.8h,v4.8h + sqxtn v0.8b,v16.8h + sqxtn v1.8b,v17.8h + sqxtn v2.8b,v18.8h + sqxtn v3.8b,v19.8h + sqxtn2 v0.16b,v24.8h + sqxtn2 v1.16b,v25.8h + subs x6,x6,16 // processed 16 output channels + sqxtn2 v2.16b,v26.8h + sqxtn2 v3.16b,v27.8h + b.lo PartialStore + + st1 {v3.16b},[x5],16 // Store full 4 x 16 + st1 {v2.16b},[x17],16 + sub x0,x0,x3 // Restore pointer to A: a -= ks + st1 {v1.16b},[x16],16 + st1 {v0.16b},[x2],16 + b.hi OutputChannelLoop + +ExitKernel + EPILOG_RESTORE_REG_PAIR x19,x20,#32 + EPILOG_RESTORE_REG_PAIR d10,d11,#16 + EPILOG_RESTORE_REG_PAIR d8,d9,#ConvSymFrame_SavedRegisters! + EPILOG_RETURN + +InChannels8 + tbz x7,3,InChannels4 + ldr d0,[x12],8 + ldr q4,[x1],16 + ldr d1,[x13],8 + ldr d2,[x14],8 + ldr d3,[x15],8 + ldr q5,[x1],16 + sdot v16.4s,v4.16b,v0.4b[0] + sdot v17.4s,v4.16b,v1.4b[0] + ldp q6,q7,[x1],32 + sdot v18.4s,v4.16b,v2.4b[0] + sdot v19.4s,v4.16b,v3.4b[0] + sdot v20.4s,v5.16b,v0.4b[0] + sdot v21.4s,v5.16b,v1.4b[0] + sdot v22.4s,v5.16b,v2.4b[0] + sdot v23.4s,v5.16b,v3.4b[0] + sdot v24.4s,v6.16b,v0.4b[0] + sdot v25.4s,v6.16b,v1.4b[0] + ldp q4,q5,[x1],32 + sdot v26.4s,v6.16b,v2.4b[0] + sdot v27.4s,v6.16b,v3.4b[0] + sdot v28.4s,v7.16b,v0.4b[0] + sdot v29.4s,v7.16b,v1.4b[0] + sdot v30.4s,v7.16b,v2.4b[0] + sdot v31.4s,v7.16b,v3.4b[0] + sdot v16.4s,v4.16b,v0.4b[1] + sdot v17.4s,v4.16b,v1.4b[1] + ldp q6,q7,[x1],32 + sdot v18.4s,v4.16b,v2.4b[1] + sdot v19.4s,v4.16b,v3.4b[1] + sdot v20.4s,v5.16b,v0.4b[1] + sdot v21.4s,v5.16b,v1.4b[1] + sdot v22.4s,v5.16b,v2.4b[1] + sdot v23.4s,v5.16b,v3.4b[1] + sdot v24.4s,v6.16b,v0.4b[1] + sdot v25.4s,v6.16b,v1.4b[1] + sdot v26.4s,v6.16b,v2.4b[1] + sdot v27.4s,v6.16b,v3.4b[1] + sdot v28.4s,v7.16b,v0.4b[1] + sdot v29.4s,v7.16b,v1.4b[1] + sdot v30.4s,v7.16b,v2.4b[1] + sdot v31.4s,v7.16b,v3.4b[1] + tbz x7,2,SkipInCh4 + +InChannels4 + ldr s0,[x12],4 + ldr q4,[x1],16 + ldr s1,[x13],4 + ldr s2,[x14],4 + ldr s3,[x15],4 + ldr q5,[x1],16 + sdot v16.4s,v4.16b,v0.4b[0] + sdot v17.4s,v4.16b,v1.4b[0] + ldp q6,q7,[x1],32 + sdot v18.4s,v4.16b,v2.4b[0] + sdot v19.4s,v4.16b,v3.4b[0] + sdot v20.4s,v5.16b,v0.4b[0] + sdot v21.4s,v5.16b,v1.4b[0] + sdot v22.4s,v5.16b,v2.4b[0] + sdot v23.4s,v5.16b,v3.4b[0] + sdot v24.4s,v6.16b,v0.4b[0] + sdot v25.4s,v6.16b,v1.4b[0] + sdot v26.4s,v6.16b,v2.4b[0] + sdot v27.4s,v6.16b,v3.4b[0] + sdot v28.4s,v7.16b,v0.4b[0] + sdot v29.4s,v7.16b,v1.4b[0] + sdot v30.4s,v7.16b,v2.4b[0] + sdot v31.4s,v7.16b,v3.4b[0] + +SkipInCh4 + subs x9,x9,8 // ks -= 1 + b.hi KernelSizeLoop + b Requantize + +PartialStore + tbz x6,3,LT8Store + str d3,[x5],8 // no less than 8 channels + str d2,[x17],8 + dup d3,v3.d[1] + dup d2,v2.d[1] + str d1,[x16],8 + str d0,[x2],8 + dup d1,v1.d[1] + dup d0,v0.d[1] +LT8Store + tbz x6,2,LT4Store + str s3,[x5],4 + str s2,[x17],4 + dup s3,v3.s[1] + dup s2,v2.s[1] + str s1,[x16],4 + str s0,[x2],4 + dup s1,v1.s[1] + dup s0,v0.s[1] +LT4Store + tbz x6,1, LT2Store + str h3,[x5],2 + str h2,[x17],2 + dup h3,v3.h[1] + dup h2,v2.h[1] + str h1,[x16],2 + str h0,[x2],2 + dup h1,v1.h[1] + dup h0,v0.h[1] +LT2Store + tbz x6,0,ExitKernel + str b3,[x5] + str b2,[x17] + str b1,[x16] + str b0,[x2] + b ExitKernel + + NESTED_END MlasConvSymS8KernelDot + + END diff --git a/onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelNeon.asm new file mode 100644 index 0000000000000..15db1b31bf013 --- /dev/null +++ b/onnxruntime/core/mlas/lib/arm64/ConvSymS8KernelNeon.asm @@ -0,0 +1,423 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + ConvSymS8KernelNeon.asm + +Abstract: + + This module implements the kernels for the symmetric quantized integer + convolution operation. + +--*/ + +#include "kxarm64.h" + +#define MLAS_CONV_SYM_FLAG_INPUT_DIRECT 1 +#define MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE 2 + +// +// Stack frame layout for the symmetric convolution kernel. +// d8-d15, x19-x30 need to be preserved if used +// +#define ConvSymFrame_SavedNeonRegisters (8 * 8) +#define ConvSymFrame_SavedRegisters ConvSymFrame_SavedNeonRegisters +#define ConvSymFrame_PostProcessParams 0 + ConvSymFrame_SavedRegisters +#define ConvSymFrame_KernelFlags 8 + ConvSymFrame_SavedRegisters + +#define ConvSymPostProcessParams_Bias 0 +#define ConvSymPostProcessParams_Scale 8 +#define ConvSymPostProcessParams_Min 16 +#define ConvSymPostProcessParams_Max 20 +#define ConvSymPostProcessParams_ZeroPoint 24 + + TEXTAREA + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a convolution for the elements + of an output row for a set of filter rows. + +Arguments: + + Input (x0) - Supplies the address of the input buffer. + + If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then the input buffer points + directly at the input tensor. + + If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is clear, then the input buffer is an + indirection buffer. Every pointer in the indirection buffer points at a + InputChannels length vector (either from the input tensor or a vector of + padding values). These are grouped in batches of length KernelSize. + These batches are then repeated OutputCount times. + + Filter (x1) - Supplies the address of the filter buffer. + + Output (x2) - Supplies the address of the output buffer. + + KernelSize (x3) - Supplies the size of the kernel. + + If MLAS_CONV_SYM_FLAG_INPUT_DIRECT is set, then kernel size should be 1. + + InputChannels (x4) - Supplies the number of input channels. + + This implementation requires the count to be a multiple of 8. + + OutputChannels (x5) - Supplies the number of output channels. + + ChannelCount (x6) - Supplies the number of channels this iteration produces. + + This implementation requires the count to be 8. + + OutputCount (x7) - Supplies the number of output elements this iteration produces. + + This implementation requires the count to be 1 or 2. + + PostProcessParams - Supplies the address of the post process parameter block. + + KernelFlags - Supplies additional flags controlling the operation. + +Return Value: + + None. + +--*/ + NESTED_ENTRY MlasConvSymS8KernelNeon + + PROLOG_SAVE_REG_PAIR d8,d9,#-64! + PROLOG_NOP ldr x8,[sp,#ConvSymFrame_PostProcessParams] + PROLOG_NOP ldrb w10,[sp,#ConvSymFrame_KernelFlags] + PROLOG_SAVE_REG_PAIR d10,d11,#16 + PROLOG_SAVE_REG_PAIR d12,d13,#32 + PROLOG_SAVE_REG_PAIR d14,d15,#48 + mov x9,x3 // save kernel size + ldr x11,[x8,#ConvSymPostProcessParams_Bias] + mov x16,x4 // save input channels + ldr x12,[x8,#ConvSymPostProcessParams_Scale] + cmp x7,2 // if OutputCount < 2 + add x5,x2,x5 // c1 = c0 + ldc + add x4,x4,7 // kc = (kc + 7) & ~7 + csel x5,x2,x5,lo // if OutputCount < 2 c1 = c0 + bic x4,x4,7 + ldp s16,s18,[x11],8 // init accumulators with bias + ldp s20,s22,[x11],8 + ldp s24,s26,[x11],8 + ldp s28,s30,[x11],8 + mov v17.16b,v16.16b + mov v19.16b,v18.16b + mov v21.16b,v20.16b + mov v23.16b,v22.16b + mov v25.16b,v24.16b + mov v27.16b,v26.16b + mov v29.16b,v28.16b + mov v31.16b,v30.16b + +// Nested loops, inner loop: input channel; outter loop: kernel size +// Each inner iteration processes 8 input channels, 2 output pixels, 8 output channels. +// +// B 8x8 +// ------------------------------------------------------------------ +// |v4.b[0] v5.b[0] v4.b[0] v5.b[0] v4.b[0] v5.b[0] v4.b[0] v5.b[0] | +// | ... ... ... ... ... ... ... ... | +// |v4.b[7] v5.b[7] v4.b[7] v5.b[7] v4.b[7] v5.b[7] v4.b[7] v5.b[7] | +// A 2x8 ------------------------------------------------------------------ +// ------------------ ------------------------------------------------------------------ +// x13-> |v0.b[0]..v0.b[7]| |v16.4s v18.4s v20.4s v22.4s v24.4s v26.4s v28.4s v30.4s | +// x15-> |v1.b[0]..v1.b[7]| |v17.4s v19.4s v21.4s v23.4s v25.4s v27.4s v29.4s v31.4s | +// ------------------ ------------------------------------------------------------------ +// When Input Channels greater than 16, unroll: +// A registers v6 v7, +// B registers v8 v9 +// + +KernelSizeLoop + + // Load next 2 A pointers + tst w10,#MLAS_CONV_SYM_FLAG_INPUT_DIRECT + ldr d4,[x1] + ldr d5,[x1,8] + beq InputIndirection + +InputDirect + mov x13,x0 // x13 -> A0 + add x15,x0,x16 // x15 -> A1 = A0 + input channels + b BlockLoopPrologue + +InputIndirection + cmp x7,2 // test if OutputCount < 2 + ldr x13,[x0] // x13 -> A0 + blo SkipLoadA1 + ldr x15,[x0,x3,lsl#3] // x15 -> A1 +SkipLoadA1 + +BlockLoopPrologue + cmp x7,2 // test if OutputCount < 2 + add x0,x0,8 // indirect A advance to next pointer, prepare for kernel size loop + csel x15,x13,x15,lo // if OutputCount < 2 x15 -> A0 + subs x14,x4,16 // input channel - 16 + blo InputChannel8 // less than 16 deep, no unroll + + ldr d0,[x13],8 + ldr d1,[x15],8 + ldr d8,[x1,64] + ldr d9,[x1,72] + ldr d6,[x13],8 + subs x14,x14,16 // input channel - 16 + ldr d7,[x15],8 + blo BlockLoopEpilogue // need 32 input channel for full unrolled loop + +Blockloop + smull v2.8h,v4.8b,v0.8b + smull v3.8h,v4.8b,v1.8b + ldr d4,[x1,16] + smull v10.8h,v5.8b,v0.8b + smull v11.8h,v5.8b,v1.8b + ldr d5,[x1,24] + smlal v2.8h,v8.8b,v6.8b + smlal v3.8h,v8.8b,v7.8b + ldr d8,[x1,80] + smlal v10.8h,v9.8b,v6.8b + smlal v11.8h,v9.8b,v7.8b + ldr d9,[x1,88] + smull v12.8h,v4.8b,v0.8b + sadalp v16.4s,v2.8h + smull v13.8h,v4.8b,v1.8b + ldr d4,[x1,32] + sadalp v17.4s,v3.8h + smull v14.8h,v5.8b,v0.8b + sadalp v18.4s,v10.8h + smull v15.8h,v5.8b,v1.8b + ldr d5,[x1,40] + sadalp v19.4s,v11.8h + smlal v12.8h,v8.8b,v6.8b + smlal v13.8h,v8.8b,v7.8b + ldr d8,[x1,96] + smlal v14.8h,v9.8b,v6.8b + smlal v15.8h,v9.8b,v7.8b + ldr d9,[x1,104] + smull v2.8h,v4.8b,v0.8b + sadalp v20.4s,v12.8h + smull v3.8h,v4.8b,v1.8b + ldr d4,[x1,48] + sadalp v21.4s,v13.8h + smull v10.8h,v5.8b,v0.8b + sadalp v22.4s,v14.8h + smull v11.8h,v5.8b,v1.8b + ldr d5,[x1,56] + sadalp v23.4s, v15.8h + smlal v2.8h,v8.8b,v6.8b + smlal v3.8h,v8.8b,v7.8b + ldr d8,[x1,112] + smlal v10.8h,v9.8b,v6.8b + smlal v11.8h,v9.8b,v7.8b + ldr d9,[x1,120] + smull v12.8h,v4.8b,v0.8b + add x1,x1,128 + sadalp v24.4s,v2.8h + smull v13.8h,v4.8b,v1.8b + ldr d4,[x1] // Read B + sadalp v25.4s,v3.8h + smull v14.8h,v5.8b,v0.8b + ldr d0,[x13],8 // Read A0 + sadalp v26.4s,v10.8h + smull v15.8h,v5.8b,v1.8b + ldr d1,[x15],8 // Read A1 + sadalp v27.4s,v11.8h + smlal v12.8h,v8.8b,v6.8b + ldr d5,[x1,8] // Read B + smlal v13.8h,v8.8b,v7.8b + ldr d8,[x1,64] // Read B + smlal v14.8h,v9.8b,v6.8b + ldr d6,[x13],8 // Read A0 + smlal v15.8h,v9.8b,v7.8b + ldr d7,[x15],8 // Read A1 + sadalp v28.4s,v12.8h + ldr d9,[x1,72] // Read B + sadalp v29.4s,v13.8h + subs x14,x14,16 + sadalp v30.4s,v14.8h + sadalp v31.4s,v15.8h + b.hs Blockloop + +BlockLoopEpilogue // remaining 16 input channels + smull v2.8h,v4.8b,v0.8b + smull v3.8h,v4.8b,v1.8b + ldr d4,[x1,16] + smull v10.8h,v5.8b,v0.8b + smull v11.8h,v5.8b,v1.8b + ldr d5,[x1,24] + smlal v2.8h,v8.8b,v6.8b + smlal v3.8h,v8.8b,v7.8b + ldr d8,[x1,80] + smlal v10.8h,v9.8b,v6.8b + smlal v11.8h,v9.8b,v7.8b + ldr d9,[x1,88] + smull v12.8h,v4.8b,v0.8b + sadalp v16.4s,v2.8h + smull v13.8h,v4.8b,v1.8b + ldr d4,[x1,32] + sadalp v17.4s,v3.8h + smull v14.8h,v5.8b,v0.8b + sadalp v18.4s,v10.8h + smull v15.8h,v5.8b,v1.8b + sadalp v19.4s,v11.8h + ldr d5,[x1,40] + smlal v12.8h,v8.8b,v6.8b + smlal v13.8h,v8.8b,v7.8b + ldr d8,[x1,96] + smlal v14.8h,v9.8b,v6.8b + smlal v15.8h,v9.8b,v7.8b + ldr d9,[x1,104] + smull v2.8h,v4.8b,v0.8b + sadalp v20.4s,v12.8h + smull v3.8h,v4.8b,v1.8b + ldr d4,[x1,48] + sadalp v21.4s,v13.8h + smull v10.8h,v5.8b,v0.8b + sadalp v22.4s,v14.8h + smull v11.8h,v5.8b,v1.8b + sadalp v23.4s,v15.8h + ldr d5,[x1,56] + smlal v2.8h,v8.8b,v6.8b + smlal v3.8h,v8.8b,v7.8b + ldr d8,[x1,112] + smlal v10.8h,v9.8b,v6.8b + smlal v11.8h,v9.8b,v7.8b + ldr d9,[x1,120] + smull v12.8h,v4.8b,v0.8b + sadalp v24.4s,v2.8h + smull v13.8h,v4.8b,v1.8b + sadalp v25.4s,v3.8h + smull v14.8h,v5.8b,v0.8b + sadalp v26.4s,v10.8h + smull v15.8h,v5.8b,v1.8b + sadalp v27.4s,v11.8h + smlal v12.8h,v8.8b,v6.8b + smlal v13.8h,v8.8b,v7.8b + smlal v14.8h,v9.8b,v6.8b + smlal v15.8h,v9.8b,v7.8b + add x1,x1,128 + + sadalp v28.4s,v12.8h + sadalp v29.4s,v13.8h + sadalp v30.4s,v14.8h + sadalp v31.4s,v15.8h + tbnz x14,3,InputChannel8 + + subs x9,x9,1 + b.hi KernelSizeLoop + +Requantize + ldr w11,[x8,#ConvSymPostProcessParams_ZeroPoint] + tst w10,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE + beq BroadcastScaleValue + ld1 {v4.4s,v5.4s},[x12] // load scale vector + b AccumulatorsToFloat + +BroadcastScaleValue + ld1r {v4.4s},[x12] // load scale Value + mov v5.16b, v4.16b + +AccumulatorsToFloat + addp v16.4s,v16.4s,v18.4s + addp v20.4s,v20.4s,v22.4s + addp v24.4s,v24.4s,v26.4s + addp v28.4s,v28.4s,v30.4s + addp v17.4s,v17.4s,v19.4s + addp v21.4s,v21.4s,v23.4s + addp v25.4s,v25.4s,v27.4s + addp v29.4s,v29.4s,v31.4s + addp v0.4s,v16.4s,v20.4s + addp v1.4s,v24.4s,v28.4s + addp v2.4s,v17.4s,v21.4s + addp v3.4s,v25.4s,v29.4s + scvtf v0.4s,v0.4s // convert to float + scvtf v1.4s,v1.4s + scvtf v2.4s,v2.4s + scvtf v3.4s,v3.4s + fmul v0.4s,v0.4s,v4.4s // multiply by scale + fmul v1.4s,v1.4s,v5.4s + fmul v2.4s,v2.4s,v4.4s + fmul v3.4s,v3.4s,v5.4s + fcvtns v0.4s,v0.4s // convert to int + fcvtns v1.4s,v1.4s + dup v9.8h,w11 + fcvtns v2.4s,v2.4s + fcvtns v3.4s,v3.4s + sqxtn v0.4h,v0.4s + sqxtn2 v0.8h,v1.4s + sqxtn v2.4h,v2.4s + sqxtn2 v2.8h,v3.4s + sqadd v0.8h,v0.8h,v9.8h + sqadd v2.8h,v2.8h,v9.8h + sqxtn v0.8b,v0.8h // shorten to int8 + sqxtn2 v0.16b,v2.8h + st1 {v0.d}[1],[x5] // full 2x8 store to c + st1 {v0.8b},[x2] + +ExitKernel + EPILOG_RESTORE_REG_PAIR d14,d15,#48 + EPILOG_RESTORE_REG_PAIR d12,d13,#32 + EPILOG_RESTORE_REG_PAIR d10,d11,#16 + EPILOG_RESTORE_REG_PAIR d8,d9,#64! + EPILOG_RETURN + +InputChannel8 + ldr d0,[x13] + ldr d1,[x15] + ldr d4,[x1] + ldr d5,[x1,8] + ldr d6,[x1,16] + ldr d7,[x1,24] + smull v2.8h,v4.8b,v0.8b + smull v3.8h,v4.8b,v1.8b + ldr d4,[x1,32] + smull v10.8h,v5.8b,v0.8b + smull v11.8h,v5.8b,v1.8b + ldr d5,[x1,40] + smull v12.8h,v6.8b,v0.8b + sadalp v16.4s,v2.8h + smull v13.8h,v6.8b,v1.8b + ldr d6,[x1,48] + sadalp v17.4s,v3.8h + smull v14.8h,v7.8b,v0.8b + sadalp v18.4s,v10.8h + smull v15.8h,v7.8b,v1.8b + ldr d7,[x1,56] + sadalp v19.4s,v11.8h + smull v2.8h,v4.8b,v0.8b + sadalp v20.4s,v12.8h + smull v3.8h,v4.8b,v1.8b + sadalp v21.4s,v13.8h + smull v10.8h,v5.8b,v0.8b + sadalp v22.4s,v14.8h + smull v11.8h,v5.8b,v1.8b + sadalp v23.4s,v15.8h + smull v12.8h,v6.8b,v0.8b + sadalp v24.4s,v2.8h + smull v13.8h,v6.8b,v1.8b + sadalp v25.4s,v3.8h + smull v14.8h,v7.8b,v0.8b + sadalp v26.4s,v10.8h + smull v15.8h,v7.8b,v1.8b + sadalp v27.4s,v11.8h + add x1,x1,64 + sadalp v28.4s,v12.8h + sadalp v29.4s,v13.8h + sadalp v30.4s,v14.8h + sadalp v31.4s,v15.8h + + // ks loop + subs x9,x9,1 + b.hi KernelSizeLoop + b Requantize + + NESTED_END MlasConvSymS8KernelNeon + + END diff --git a/onnxruntime/core/mlas/lib/arm64/ConvSymU8KernelDot.asm b/onnxruntime/core/mlas/lib/arm64/ConvSymU8KernelDot.asm index ecb6c6578b574..7b0917b95551c 100644 --- a/onnxruntime/core/mlas/lib/arm64/ConvSymU8KernelDot.asm +++ b/onnxruntime/core/mlas/lib/arm64/ConvSymU8KernelDot.asm @@ -84,7 +84,7 @@ Return Value: None. --*/ - NESTED_ENTRY MlasConvSymKernelNeonDot + NESTED_ENTRY MlasConvSymU8KernelDot PROLOG_SAVE_REG_PAIR d8,d9,#-64! PROLOG_NOP ldr x8,[sp,#ConvSymFrame_PostProcessParams] @@ -626,6 +626,6 @@ LT2Store str b0,[x2] b ExitKernel - NESTED_END MlasConvSymKernelNeonDot + NESTED_END MlasConvSymU8KernelDot END diff --git a/onnxruntime/core/mlas/lib/arm64/ConvSymU8KernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/ConvSymU8KernelNeon.asm index 5b3c4f5d9e2e1..a8a11fe6209d1 100644 --- a/onnxruntime/core/mlas/lib/arm64/ConvSymU8KernelNeon.asm +++ b/onnxruntime/core/mlas/lib/arm64/ConvSymU8KernelNeon.asm @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - ConvSymKernelNeon.s + ConvSymU8KernelNeon.asm Abstract: @@ -88,7 +88,7 @@ Return Value: None. --*/ - NESTED_ENTRY MlasConvSymKernelNeon + NESTED_ENTRY MlasConvSymU8KernelNeon PROLOG_SAVE_REG_PAIR d8,d9,#-64! PROLOG_NOP ldr x8,[sp,#ConvSymFrame_PostProcessParams] @@ -431,6 +431,6 @@ InputChannel8 b.hi KernelSizeLoop b Requantize - NESTED_END MlasConvSymKernelNeon + NESTED_END MlasConvSymU8KernelNeon END diff --git a/onnxruntime/core/mlas/lib/arm64/DepthwiseQConvSymS8KernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/DepthwiseQConvSymS8KernelNeon.asm new file mode 100644 index 0000000000000..ba565ca587e0c --- /dev/null +++ b/onnxruntime/core/mlas/lib/arm64/DepthwiseQConvSymS8KernelNeon.asm @@ -0,0 +1,693 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + DepthwiseQConvSymS8KernelNeon.asm + +Abstract: + + This module implements the kernels for the depthwise convolution + operation with symmetrically quantized integer values + +--*/ + +#include "kxarm64.h" + +// +// Stack frame layout for the depthwise conv kernel. +// d8-d15, x19-x30 need to be preserved if used +// + +#define ConvSymDepthwiseKernelFrame_SavedRegisters (4 * 8) +#define ConvSymDepthwiseKernelFrame_PostProcessParams 0 + ConvSymDepthwiseKernelFrame_SavedRegisters +#define ConvSymDepthwiseKernelFrame_KernelFlags 8 + ConvSymDepthwiseKernelFrame_SavedRegisters + +#define ConvSymDepthwisePostProcessParams_Bias 0 +#define ConvSymDepthwisePostProcessParams_Scale 8 +#define ConvSymDepthwisePostProcessParams_Min 16 +#define ConvSymDepthwisePostProcessParams_Max 20 +#define ConvSymDepthwisePostProcessParams_ZeroPoint 24 + +#define MLAS_CONV_SYM_FLAG_INPUT_DIRECT 1 +#define MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE 2 + + TEXTAREA + +/*++ + +Routine Description: + + This routine is the inner kernel to compute a depthwise convolution for the + elements of an output row for a set of filter rows. + +Arguments: + + Input (x0) - Supplies the address of the indirection buffer. + + Filter (x1) - Supplies the address of the filter buffer. + + Output (x2) - Supplies the address of the output buffer. + + KernelSize (x3) - Supplies the size of the kernel. + + Channels (x4) - Supplies the number of input and output channels. + + ChannelOffset (x5) - Supplies the byte offset from the indirection buffer base + address for this iteration. + + ChannelCount (x6) - Supplies the number of channels this iteration produces. + + This implementation requires the count to be 16 or 8 + + OutputCount (x7)- Supplies the number of output elements this iteration produces. + + This implementation requires the count to be in the range 1 to 2. + + PostProcessParams - Supplies the address of the post process parameter block. + + KernelFlags - Supplies additional flags controlling the operation. + +Return Value: + + None. + +--*/ + + NESTED_ENTRY MlasConvSymDepthwiseS8KernelNeon + + PROLOG_SAVE_REG_PAIR d12,d13,#-ConvSymDepthwiseKernelFrame_SavedRegisters! + PROLOG_NOP ldr x8,[sp,#ConvSymDepthwiseKernelFrame_PostProcessParams] + PROLOG_SAVE_REG_PAIR d14,d15,#16 + cmp x7,2 + add x9,x0,x3,lsl#3 // x9 -> &A1 + add x14,x0,x3,lsl#4 // x14 -> &A2 + add x15,x9,x3,lsl#4 // x15 -> &A3 + ldr x16,[x8,#ConvSymDepthwisePostProcessParams_Bias] + csel x9,x0,x9,lo // x9 -> &A0 if OutputCount < 2 + csel x14,x0,x14,ls // x14 -> &A0 if OutputCount <= 2 + ldr x11,[x9],#8 // x11 -> A1 iter 0 + cmp x7,4 + ldp q24,q25,[x16],#32 // init accumulators with bias + csel x15,x0,x15,lo // x15 -> &A0 if OutputCount < 4 + cmp x6,16 + ldr x10,[x0],#8 // x10 -> A0 iter 0 + b.lo Process8Channels + +// +// Process an input block of length Channels for each element of the kernel. +// +// Filter: v0, +// v1 // unroll +// Input: +// x0 -> x10 -> v4 +// -> x12 -> v2 // unroll +// x9 -> x11 -> v6 +// -> x13 -> v3 // unroll +// x14 -> x10 -> v4 +// -> x12 -> v2 // unroll +// x15 -> x11 -> v6 +// -> x13 -> v3 // unroll +// + +Process16Channels + cmp x3,1 + ldp q26,q27,[x16] + b.eq ProcC16P1 + + ldr x12,[x0],#8 // x12 -> A0 iter 1 + ldr x13,[x9],#8 // x13 -> A1 iter 1 + mov v28.16b,v24.16b + mov v29.16b,v25.16b + ld1 {v0.16b},[x1],x4 // filter iter 0 + ld1 {v1.16b},[x1],x4 // filter iter 1 + mov v16.16b,v24.16b + mov v17.16b,v25.16b + ldr q4,[x10,x5] // A0 iter 0 + mov v20.16b,v24.16b + ldr x10,[x14],#8 // x10 -> A2 iter 0 + mov v21.16b,v25.16b + ldr q6,[x11,x5] // A1 iter 0 + mov v30.16b,v26.16b + ldr x11,[x15],#8 // x11 -> A3 iter 0 + mov v31.16b,v27.16b + ldr q2,[x12,x5] // A0 iter 1 + subs x3,x3,2 // decrement input blocks remaining + mov v18.16b,v26.16b + ldr x12,[x14],#8 // x12 -> A2 iter 1 + mov v19.16b,v27.16b + ldr q3,[x13,x5] // A1 iter 1 + mov v22.16b,v26.16b + ldr x13,[x15],#8 // x13 -> A3 iter 1 + mov v23.16b,v27.16b + +BlockLoopC16 + + // + // Process 2 pixels, and load next two pixels + // + smull v12.8h,v0.8b,v4.8b + smull2 v13.8h,v0.16b,v4.16b + ldr q4,[x10,x5] // A2 iter 0 + b.eq EpilogueC16P2 + smull v14.8h,v0.8b,v6.8b + ldr x10,[x0],#8 // x10 -> A0 iter 2 + smull2 v15.8h,v0.16b,v6.16b + cmp x3,1 + ldr q6,[x11,x5] // A3 iter 0 + smlal v12.8h,v1.8b,v2.8b + ldr x11,[x9],#8 // x11 -> A1 iter 2 + smlal2 v13.8h,v1.16b,v2.16b + b.eq EpilogueC16P3 // 3 pixel remains + ldr q2,[x12,x5] // A2 iter 1 + smlal v14.8h,v1.8b,v3.8b + ldr x12,[x0],#8 // x12 -> A0 iter 3 + smlal2 v15.8h,v1.16b,v3.16b + ldr q3,[x13,x5] // A3 iter 1 + saddw v24.4s,v24.4s,v12.4h + saddw2 v25.4s,v25.4s,v12.8h + ldr x13,[x9],#8 // x13 -> A1 iter 3 + saddw v26.4s,v26.4s,v13.4h + saddw2 v27.4s,v27.4s,v13.8h + saddw v28.4s,v28.4s,v14.4h + saddw2 v29.4s,v29.4s,v14.8h + saddw v30.4s,v30.4s,v15.4h + saddw2 v31.4s,v31.4s,v15.8h + subs x3,x3,2 // decrement input blocks remaining + smull v12.8h,v0.8b,v4.8b + smull2 v13.8h,v0.16b,v4.16b + ldr q4,[x10,x5] // A0 iter 2 + smull v14.8h,v0.8b,v6.8b + ldr x10,[x14],#8 // x10 -> A2 iter 2 + smull2 v15.8h,v0.16b,v6.16b + ldr q6,[x11,x5] // A1 iter 2 + ld1 {v0.16b},[x1],x4 // filter iter 2 + smlal v12.8h,v1.8b,v2.8b + ldr x11,[x15],#8 // x11 -> A3 iter 2 + smlal2 v13.8h,v1.16b,v2.16b + ldr q2,[x12,x5] // A0 iter 3 + smlal v14.8h,v1.8b,v3.8b + ldr x12,[x14],#8 // x12 -> A2 iter 3 + smlal2 v15.8h,v1.16b,v3.16b + ldr q3,[x13,x5] // A1 iter 3 + saddw v16.4s,v16.4s,v12.4h + saddw2 v17.4s,v17.4s,v12.8h + ld1 {v1.16b},[x1],x4 // filter iter 3 + saddw v18.4s,v18.4s,v13.4h + saddw2 v19.4s,v19.4s,v13.8h + ldr x13,[x15],#8 // x13 -> A3 iter 3 + saddw v20.4s,v20.4s,v14.4h + saddw2 v21.4s,v21.4s,v14.8h + saddw v22.4s,v22.4s,v15.4h + saddw2 v23.4s,v23.4s,v15.8h + b BlockLoopC16 + +EpilogueC16P2 + // + // Loop epilogue (process last 2 pixels) mixed + // with loading of dequantization params + // + smull v14.8h,v0.8b,v6.8b + smull2 v15.8h,v0.16b,v6.16b + ldr q6,[x11,x5] // A3 iter 0 + smlal v12.8h,v1.8b,v2.8b + smlal2 v13.8h,v1.16b,v2.16b + ldr q2,[x12,x5] // A2 iter 1 + smlal v14.8h,v1.8b,v3.8b + smlal2 v15.8h,v1.16b,v3.16b + ldr q3,[x13,x5] // A3 iter 1 + saddw v24.4s,v24.4s,v12.4h + saddw2 v25.4s,v25.4s,v12.8h + saddw v26.4s,v26.4s,v13.4h + saddw2 v27.4s,v27.4s,v13.8h + saddw v28.4s,v28.4s,v14.4h + saddw2 v29.4s,v29.4s,v14.8h + saddw v30.4s,v30.4s,v15.4h + saddw2 v31.4s,v31.4s,v15.8h + ldr w9,[sp,#ConvSymDepthwiseKernelFrame_KernelFlags] + ldr x12,[x8,#ConvSymDepthwisePostProcessParams_Scale] + smull v12.8h,v0.8b,v4.8b + smull2 v13.8h,v0.16b,v4.16b + ldr w15,[x8,#ConvSymDepthwisePostProcessParams_ZeroPoint] + smull v14.8h,v0.8b,v6.8b + smull2 v15.8h,v0.16b,v6.16b + smlal v12.8h,v1.8b,v2.8b + smlal2 v13.8h,v1.16b,v2.16b + smlal v14.8h,v1.8b,v3.8b + smlal2 v15.8h,v1.16b,v3.16b + tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE + ld1r {v4.4s},[x12] // load scale val + b.eq SkipScaleVecLoad2 + ldp q4,q5,[x12],#32 // load scale vector if per channel + ldp q6,q3,[x12] +SkipScaleVecLoad2 + saddw v16.4s,v16.4s,v12.4h + saddw2 v17.4s,v17.4s,v12.8h + saddw v18.4s,v18.4s,v13.4h + saddw2 v19.4s,v19.4s,v13.8h + saddw v20.4s,v20.4s,v14.4h + saddw2 v21.4s,v21.4s,v14.8h + saddw v22.4s,v22.4s,v15.4h + saddw2 v23.4s,v23.4s,v15.8h + b Dequantization + +ProcC16P1 + // + // Channel 16 kernel size 1 + // TODO!! is this reachable at all? + // + ldr x12,[x14],#8 // x12 -> A2 + ldr x13,[x15],#8 // x13 -> A3 + mov v28.16b,v24.16b + mov v29.16b,v25.16b + ld1 {v0.16b},[x1] + mov v16.16b,v24.16b + mov v17.16b,v25.16b + ldr q4,[x10,x5] + mov v20.16b,v24.16b + mov v21.16b,v25.16b + ldr q6,[x11,x5] + mov v30.16b,v26.16b + mov v31.16b,v27.16b + ldr q2,[x12,x5] + subs x3,x3,2 // decrement input blocks remaining + mov v18.16b,v26.16b + mov v19.16b,v27.16b + ldr q3,[x13,x5] + mov v22.16b,v26.16b + mov v23.16b,v27.16b + b EpilogueC16P1 + +EpilogueC16P3 + // + // Loop epilogue (process last 2 pixels) mixed + // with loading of dequantization params + // + ldr q2,[x12,x5] // A2 iter 1 + smlal v14.8h,v1.8b,v3.8b + ldr x12,[x14],#8 // x12 -> A2 iter 2 + smlal2 v15.8h,v1.16b,v3.16b + ldr q3,[x13,x5] // A3 iter 1 + saddw v24.4s,v24.4s,v12.4h + saddw2 v25.4s,v25.4s,v12.8h + ldr x13,[x15],#8 // x13 -> A3 iter 2 + saddw v26.4s,v26.4s,v13.4h + saddw2 v27.4s,v27.4s,v13.8h + saddw v28.4s,v28.4s,v14.4h + saddw2 v29.4s,v29.4s,v14.8h + saddw v30.4s,v30.4s,v15.4h + saddw2 v31.4s,v31.4s,v15.8h + smull v12.8h,v0.8b,v4.8b + smull2 v13.8h,v0.16b,v4.16b + ldr q4,[x10,x5] // A0 iter 2 + smull v14.8h,v0.8b,v6.8b + smull2 v15.8h,v0.16b,v6.16b + ld1 {v0.16b},[x1] // filter iter 2 + ldr q6,[x11,x5] // A1 iter 2 + smlal v12.8h,v1.8b,v2.8b + smlal2 v13.8h,v1.16b,v2.16b + ldr q2,[x12,x5] // A2 iter 2 + smlal v14.8h,v1.8b,v3.8b + smlal2 v15.8h,v1.16b,v3.16b + ldr q3,[x13,x5] // A3 iter 2 + saddw v16.4s,v16.4s,v12.4h + saddw2 v17.4s,v17.4s,v12.8h + saddw v18.4s,v18.4s,v13.4h + saddw2 v19.4s,v19.4s,v13.8h + saddw v20.4s,v20.4s,v14.4h + saddw2 v21.4s,v21.4s,v14.8h + saddw v22.4s,v22.4s,v15.4h + saddw2 v23.4s,v23.4s,v15.8h + +EpilogueC16P1 + // + // Loop epilogue (process last single pixel) mixed with loading of dequantization params + // + ldr w9,[sp,#ConvSymDepthwiseKernelFrame_KernelFlags] + ldr x12,[x8,#ConvSymDepthwisePostProcessParams_Scale] + smull v12.8h,v0.8b,v4.8b + smull2 v13.8h,v0.16b,v4.16b + ldr w15,[x8,#ConvSymDepthwisePostProcessParams_ZeroPoint] + smull v14.8h,v0.8b,v6.8b + smull2 v15.8h,v0.16b,v6.16b + saddw v24.4s,v24.4s,v12.4h + saddw2 v25.4s,v25.4s,v12.8h + saddw v26.4s,v26.4s,v13.4h + saddw2 v27.4s,v27.4s,v13.8h + saddw v28.4s,v28.4s,v14.4h + saddw2 v29.4s,v29.4s,v14.8h + saddw v30.4s,v30.4s,v15.4h + saddw2 v31.4s,v31.4s,v15.8h + smull v12.8h,v0.8b,v2.8b + smull2 v13.8h,v0.16b,v2.16b + smull v14.8h,v0.8b,v3.8b + smull2 v15.8h,v0.16b,v3.16b + tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE + ld1r {v4.4s},[x12] // load scale val + b.eq SkipScaleVecLoad + ldp q4,q5,[x12],#32 // load scale vector if per channel + ldp q6,q3,[x12] +SkipScaleVecLoad + saddw v16.4s,v16.4s,v12.4h + saddw2 v17.4s,v17.4s,v12.8h + saddw v18.4s,v18.4s,v13.4h + saddw2 v19.4s,v19.4s,v13.8h + saddw v20.4s,v20.4s,v14.4h + saddw2 v21.4s,v21.4s,v14.8h + saddw v22.4s,v22.4s,v15.4h + saddw2 v23.4s,v23.4s,v15.8h + +Dequantization + scvtf v24.4s,v24.4s // convert to float + scvtf v25.4s,v25.4s + scvtf v26.4s,v26.4s + scvtf v27.4s,v27.4s + scvtf v28.4s,v28.4s + scvtf v29.4s,v29.4s + scvtf v30.4s,v30.4s + scvtf v31.4s,v31.4s + scvtf v16.4s,v16.4s + scvtf v17.4s,v17.4s + scvtf v18.4s,v18.4s + scvtf v19.4s,v19.4s + scvtf v20.4s,v20.4s + scvtf v21.4s,v21.4s + scvtf v22.4s,v22.4s + scvtf v23.4s,v23.4s + b.ne SkipScaleBroadcast + mov v5.16b,v4.16b // broadcast scale val if not per channel + mov v6.16b,v4.16b + mov v3.16b,v4.16b +SkipScaleBroadcast + fmul v24.4s,v24.4s,v4.4s // multiply by scale + fmul v25.4s,v25.4s,v5.4s + fmul v26.4s,v26.4s,v6.4s + fmul v27.4s,v27.4s,v3.4s + fmul v28.4s,v28.4s,v4.4s + fmul v29.4s,v29.4s,v5.4s + fmul v30.4s,v30.4s,v6.4s + fmul v31.4s,v31.4s,v3.4s + fmul v16.4s,v16.4s,v4.4s + fmul v17.4s,v17.4s,v5.4s + fmul v18.4s,v18.4s,v6.4s + fmul v19.4s,v19.4s,v3.4s + fmul v20.4s,v20.4s,v4.4s + fmul v21.4s,v21.4s,v5.4s + fmul v22.4s,v22.4s,v6.4s + fmul v23.4s,v23.4s,v3.4s + fcvtns v24.4s,v24.4s // convert to int + fcvtns v25.4s,v25.4s + fcvtns v26.4s,v26.4s + fcvtns v27.4s,v27.4s + fcvtns v28.4s,v28.4s + fcvtns v29.4s,v29.4s + fcvtns v30.4s,v30.4s + fcvtns v31.4s,v31.4s + fcvtns v16.4s,v16.4s + fcvtns v17.4s,v17.4s + fcvtns v18.4s,v18.4s + fcvtns v19.4s,v19.4s + fcvtns v20.4s,v20.4s + fcvtns v21.4s,v21.4s + fcvtns v22.4s,v22.4s + fcvtns v23.4s,v23.4s + sqxtn v24.4h,v24.4s // shorten to int16 + sqxtn v26.4h,v26.4s + sqxtn2 v24.8h,v25.4s + sqxtn2 v26.8h,v27.4s + sqxtn v28.4h,v28.4s + sqxtn v30.4h,v30.4s + sqxtn2 v28.8h,v29.4s + sqxtn2 v30.8h,v31.4s + dup v0.8h,w15 + sqxtn v16.4h,v16.4s + sqxtn v18.4h,v18.4s + sqxtn2 v16.8h,v17.4s + sqxtn2 v18.8h,v19.4s + sqxtn v20.4h,v20.4s + sqxtn v22.4h,v22.4s + sqxtn2 v20.8h,v21.4s + sqxtn2 v22.8h,v23.4s + sqadd v24.8h,v24.8h,v0.8h // add zero point + sqadd v26.8h,v26.8h,v0.8h + sqadd v28.8h,v28.8h,v0.8h + sqadd v30.8h,v30.8h,v0.8h + sqadd v16.8h,v16.8h,v0.8h + sqadd v18.8h,v18.8h,v0.8h + sqadd v20.8h,v20.8h,v0.8h + sqadd v22.8h,v22.8h,v0.8h + sqxtn v24.8b,v24.8h // shorten to int8 + sqxtn2 v24.16b,v26.8h + sqxtn v28.8b,v28.8h + sqxtn2 v28.16b,v30.8h + sqxtn v16.8b,v16.8h + sqxtn2 v16.16b,v18.8h + sqxtn v20.8b,v20.8h + sqxtn2 v20.16b,v22.8h + cmp x7,2 // OutputCount < 2 ? + st1 {v24.16b},[x2],x4 + b.lo ExitKernel // exit if OutputCount < 2 + st1 {v28.16b},[x2],x4 + b.ls ExitKernel // exit if OutputCount <=2 + cmp x7,4 // OutputCount < 4 ? + st1 {v16.16b},[x2],x4 + b.lo ExitKernel // exit if OutputCount < 4 + str q20,[x2] + +ExitKernel + EPILOG_RESTORE_REG_PAIR d14,d15,#16 + EPILOG_RESTORE_REG_PAIR d12,d13,#ConvSymDepthwiseKernelFrame_SavedRegisters! + EPILOG_RETURN + +Process8Channels + cmp x3,1 + b.eq ProcC8P1 + + ldr x12,[x0],#8 // x12 -> A0 iter 1 + ldr x13,[x9],#8 // x13 -> A1 iter 1 + ld1 {v0.8b},[x1],x4 // filter iter 0 + ld1 {v1.8b},[x1],x4 // filter iter 1 + ldr d4,[x10,x5] // A0 iter 0 + ldr x10,[x14],#8 // x10 -> A2 iter 0 + mov v28.16b,v24.16b + ldr d6,[x11,x5] // A1 iter 0 + mov v29.16b,v25.16b + ldr x11,[x15],#8 // x11 -> A3 iter 0 + mov v16.16b,v24.16b + ldr d2,[x12,x5] // A0 iter 1 + mov v17.16b,v25.16b + ldr x12,[x14],#8 // x12 -> A2 iter 1 + subs x3,x3,2 // decrement input blocks remaining + ldr d3,[x13,x5] // A1 iter 1 + mov v20.16b,v24.16b + ldr x13,[x15],#8 // x13 -> A3 iter 1 + mov v21.16b,v25.16b + +BlockLoopC8 + // + // Process 2 pixels, and load next two pixels + // + smull v12.8h,v0.8b,v4.8b + ldr d4,[x10,x5] // A2 iter 0 + smull v14.8h,v0.8b,v6.8b + b.eq EpilogueC8P2 + ldr x10,[x0],#8 // x10 -> A0 iter 2 + ldr d6,[x11,x5] // A3 iter 0 + cmp x3,1 + smlal v12.8h,v1.8b,v2.8b + ldr x11,[x9],#8 // x11 -> A1 iter 2 + smlal v14.8h,v1.8b,v3.8b + ldr d2,[x12,x5] // A2 iter 1 + b.eq EpilogueC8P3 // 3 pixel remains + ldr d3,[x13,x5] // A3 iter 1 + saddw v24.4s,v24.4s,v12.4h + ldr x12,[x0],#8 // x12 -> A0 iter 3 + saddw2 v25.4s,v25.4s,v12.8h + ldr x13,[x9],#8 // x13 -> A1 iter 3 + saddw v28.4s,v28.4s,v14.4h + saddw2 v29.4s,v29.4s,v14.8h + subs x3,x3,2 // decrement input blocks remaining + smull v12.8h,v0.8b,v4.8b + ldr d4,[x10,x5] // A0 iter 2 + smull v14.8h,v0.8b,v6.8b + ldr x10,[x14],#8 // x10 -> A2 iter 2 + ldr d6,[x11,x5] // A1 iter 2 + ld1 {v0.8b},[x1],x4 // filter iter 2 + smlal v12.8h,v1.8b,v2.8b + ldr x11,[x15],#8 // x11 -> A3 iter 2 + ldr d2,[x12,x5] // A0 iter 3 + smlal v14.8h,v1.8b,v3.8b + ldr x12,[x14],#8 // x12 -> A2 iter 3 + saddw v16.4s,v16.4s,v12.4h + ldr d3,[x13,x5] // A1 iter 3 + saddw2 v17.4s,v17.4s,v12.8h + ld1 {v1.8b},[x1],x4 // filter iter 3 + saddw v20.4s,v20.4s,v14.4h + ldr x13,[x15],#8 // x13 -> A3 iter 3 + saddw2 v21.4s,v21.4s,v14.8h + b BlockLoopC8 + +EpilogueC8P2 + // + // Loop epilogue (process last 2 pixels) mixed + // with loading of dequantization params + // + ldr d6,[x11,x5] // A3 iter 0 + smlal v12.8h,v1.8b,v2.8b + ldr d2,[x12,x5] // A2 iter 1 + smlal v14.8h,v1.8b,v3.8b + ldr d3,[x13,x5] // A3 iter 1 + saddw v24.4s,v24.4s,v12.4h + saddw2 v25.4s,v25.4s,v12.8h + saddw v28.4s,v28.4s,v14.4h + saddw2 v29.4s,v29.4s,v14.8h + ldr w9,[sp,#ConvSymDepthwiseKernelFrame_KernelFlags] + smull v12.8h,v0.8b,v4.8b + ldr x12,[x8,#ConvSymDepthwisePostProcessParams_Scale] + smull v14.8h,v0.8b,v6.8b + ldr w15,[x8,#ConvSymDepthwisePostProcessParams_ZeroPoint] + smlal v12.8h,v1.8b,v2.8b + smlal v14.8h,v1.8b,v3.8b + tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE + ld1r {v4.4s},[x12] // load scale val + b.eq SkipScaleVecLoad2C8 + ldp q4,q5,[x12],#32 // load scale vector if per channel +SkipScaleVecLoad2C8 + saddw v16.4s,v16.4s,v12.4h + saddw2 v17.4s,v17.4s,v12.8h + saddw v20.4s,v20.4s,v14.4h + saddw2 v21.4s,v21.4s,v14.8h + b DequantC8 + +ProcC8P1 + // + // Channel 8 kernel size 1 + // TODO!! is this reachable at all? + // + ldr x12,[x14],#8 // x12 -> A2 + mov v28.16b,v24.16b + ldr x13,[x15],#8 // x13 -> A3 + mov v29.16b,v25.16b + ld1 {v0.8b},[x1] + mov v16.16b,v24.16b + ldr d4,[x10,x5] + mov v17.16b,v25.16b + ldr d6,[x11,x5] + mov v20.16b,v24.16b + ldr d2,[x12,x5] + subs x3,x3,2 // decrement input blocks remaining + ldr d3,[x13,x5] + mov v21.16b,v25.16b + b EpilogueC8P1 + +EpilogueC8P3 + // + // Loop epilogue (process 2 of last 3 pixels) + // + ldr x12,[x14],#8 // x12 -> A2 iter 2 + ldr d3,[x13,x5] // A3 iter 1 + saddw v24.4s,v24.4s,v12.4h + saddw2 v25.4s,v25.4s,v12.8h + ldr x13,[x15],#8 // x13 -> A3 iter 2 + saddw v28.4s,v28.4s,v14.4h + saddw2 v29.4s,v29.4s,v14.8h + smull v12.8h,v0.8b,v4.8b + ldr d4,[x10,x5] // A0 iter 2 + smull v14.8h,v0.8b,v6.8b + ld1 {v0.8b},[x1] // filter iter 2 + ldr d6,[x11,x5] // A1 iter 2 + smlal v12.8h,v1.8b,v2.8b + ldr d2,[x12,x5] // A2 iter 2 + smlal v14.8h,v1.8b,v3.8b + ldr d3,[x13,x5] // A3 iter 2 + saddw v16.4s,v16.4s,v12.4h + saddw2 v17.4s,v17.4s,v12.8h + saddw v20.4s,v20.4s,v14.4h + saddw2 v21.4s,v21.4s,v14.8h + +EpilogueC8P1 + // + // Loop epilogue (process last single pixel) mixed with loading of dequantization params + // + ldr w9,[sp,#ConvSymDepthwiseKernelFrame_KernelFlags] + ldr x12,[x8,#ConvSymDepthwisePostProcessParams_Scale] + smull v12.8h,v0.8b,v4.8b + ldr w15,[x8,#ConvSymDepthwisePostProcessParams_ZeroPoint] + smull v14.8h,v0.8b,v6.8b + saddw v24.4s,v24.4s,v12.4h + saddw2 v25.4s,v25.4s,v12.8h + saddw v28.4s,v28.4s,v14.4h + saddw2 v29.4s,v29.4s,v14.8h + smull v12.8h,v0.8b,v2.8b + smull v14.8h,v0.8b,v3.8b + tst w9,#MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE + ld1r {v4.4s},[x12] // load scale val + b.eq SkipScaleVecLoadC8 + ldp q4,q5,[x12] // load scale vector if per channel +SkipScaleVecLoadC8 + saddw v16.4s,v16.4s,v12.4h + saddw2 v17.4s,v17.4s,v12.8h + saddw v20.4s,v20.4s,v14.4h + saddw2 v21.4s,v21.4s,v14.8h + +DequantC8 + scvtf v24.4s,v24.4s // convert to float + scvtf v25.4s,v25.4s + scvtf v28.4s,v28.4s + scvtf v29.4s,v29.4s + scvtf v16.4s,v16.4s + scvtf v17.4s,v17.4s + scvtf v20.4s,v20.4s + scvtf v21.4s,v21.4s + b.ne SkipScaleBroadcastC8 + mov v5.16b,v4.16b // broadcast scale val if not per channel +SkipScaleBroadcastC8 + fmul v24.4s,v24.4s,v4.4s // multiply by scale + fmul v25.4s,v25.4s,v5.4s + fmul v28.4s,v28.4s,v4.4s + fmul v29.4s,v29.4s,v5.4s + fmul v16.4s,v16.4s,v4.4s + fmul v17.4s,v17.4s,v5.4s + fmul v20.4s,v20.4s,v4.4s + fmul v21.4s,v21.4s,v5.4s + fcvtns v24.4s,v24.4s // convert to int + fcvtns v25.4s,v25.4s + fcvtns v28.4s,v28.4s + fcvtns v29.4s,v29.4s + fcvtns v16.4s,v16.4s + fcvtns v17.4s,v17.4s + fcvtns v20.4s,v20.4s + fcvtns v21.4s,v21.4s + dup v0.8h,w15 + sqxtn v24.4h,v24.4s // shorten to int16 + sqxtn2 v24.8h,v25.4s + sqxtn v28.4h,v28.4s + sqxtn2 v28.8h,v29.4s + sqxtn v16.4h,v16.4s + sqxtn2 v16.8h,v17.4s + sqxtn v20.4h,v20.4s + sqxtn2 v20.8h,v21.4s + sqadd v24.8h,v24.8h,v0.8h // add zero point + sqadd v28.8h,v28.8h,v0.8h + sqadd v16.8h,v16.8h,v0.8h + sqadd v20.8h,v20.8h,v0.8h + sqxtn v24.8b,v24.8h // shorten to int8 + sqxtn v28.8b,v28.8h + sqxtn v16.8b,v16.8h + sqxtn v20.8b,v20.8h + cmp x7,2 // OutputCount < 2 ? + st1 {v24.8b},[x2],x4 + b.lo ExitKernel // exit if OutputCount < 2 + st1 {v28.8b},[x2],x4 + b.ls ExitKernel // exit if OutputCount <=2 + cmp x7,4 // OutputCount < 4 ? + st1 {v16.8b},[x2],x4 + b.lo ExitKernel // exit if OutputCount < 4 + str d20,[x2] + b ExitKernel + NESTED_END MlasConvSymDepthwiseS8KernelNeon + + END diff --git a/onnxruntime/core/mlas/lib/arm64/DepthwiseConvSymKernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/DepthwiseQConvSymU8KernelNeon.asm similarity index 99% rename from onnxruntime/core/mlas/lib/arm64/DepthwiseConvSymKernelNeon.asm rename to onnxruntime/core/mlas/lib/arm64/DepthwiseQConvSymU8KernelNeon.asm index a18324ec97500..e9f75f1be5bfd 100644 --- a/onnxruntime/core/mlas/lib/arm64/DepthwiseConvSymKernelNeon.asm +++ b/onnxruntime/core/mlas/lib/arm64/DepthwiseQConvSymU8KernelNeon.asm @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - DepthwiseConvSymKernelNeon.asm + DepthwiseQConvSymU8KernelNeon.asm Abstract: @@ -78,7 +78,7 @@ Return Value: --*/ - NESTED_ENTRY MlasConvSymDepthwiseKernelNeon + NESTED_ENTRY MlasConvSymDepthwiseU8KernelNeon PROLOG_SAVE_REG_PAIR d8,d9,#-64! PROLOG_NOP ldr x8,[sp,#ConvSymDepthwiseKernelFrame_PostProcessParams] @@ -740,6 +740,6 @@ SkipScaleBroadcastC8 b.lo ExitKernel // exit if OutputCount < 4 str d20,[x2] b ExitKernel - NESTED_END MlasConvSymDepthwiseKernelNeon + NESTED_END MlasConvSymDepthwiseU8KernelNeon END diff --git a/onnxruntime/core/mlas/lib/convsym.cpp b/onnxruntime/core/mlas/lib/convsym.cpp index b5c82044eaecc..9fa580aea6a77 100644 --- a/onnxruntime/core/mlas/lib/convsym.cpp +++ b/onnxruntime/core/mlas/lib/convsym.cpp @@ -51,6 +51,21 @@ void unsigned KernelFlags ); +// +// Processor for common kernel sized (e.g. 3x3, 5x5) +// +typedef +void +(MLASCALL MLAS_SYMM_QCONV_DEPTHWISE_FIXFILTER_PROC)( + void const* const* InputIndirection, + int8_t const* Filter, + size_t Channels, + void* Output, + size_t OutputCount, + MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, + unsigned KernelFlags + ); + extern "C" { @@ -64,22 +79,73 @@ extern "C" { MLAS_CONV_SYM_KERNEL MlasConvSymKernelAvx512Vnni; MLAS_CONV_SYM_DEPTHWISE_KERNEL MlasConvSymDepthwiseKernelAvx512Vnni; #elif defined(MLAS_TARGET_ARM64) - MLAS_CONV_SYM_KERNEL MlasConvSymKernelNeon; - MLAS_CONV_SYM_KERNEL MlasConvSymKernelNeonDot; - MLAS_CONV_SYM_DEPTHWISE_KERNEL MlasConvSymDepthwiseKernelNeon; - MLAS_CONV_SYM_DEPTHWISE_ROUTINE_KERNELSIZE MlasConvSymDepthwiseKernelSize9Arm64; - MLAS_CONV_SYM_DEPTHWISE_ROUTINE_KERNELSIZE MlasConvSymDepthwiseKernelSize25Arm; -#endif - -} + MLAS_CONV_SYM_KERNEL MlasConvSymS8KernelNeon; + MLAS_CONV_SYM_KERNEL MlasConvSymU8KernelNeon; + MLAS_CONV_SYM_KERNEL MlasConvSymS8KernelDot; + MLAS_CONV_SYM_KERNEL MlasConvSymU8KernelDot; + MLAS_CONV_SYM_DEPTHWISE_KERNEL MlasConvSymDepthwiseU8KernelNeon; + MLAS_CONV_SYM_DEPTHWISE_KERNEL MlasConvSymDepthwiseS8KernelNeon; // +// Specialized depthwise conv kernels for 3x3 and 5x5 filters // -// + +void +MLASCALL +MlasConvSymDepthwiseKernelSize9Arm64U8S8( + void const* const* InputIndirection, + int8_t const* Filter, + size_t Channels, + void* Output, + size_t OutputCount, + MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, + unsigned KernelFlags + ); + +void +MLASCALL +MlasConvSymDepthwiseKernelSize9Arm64S8S8( + void const* const* InputIndirection, + int8_t const* Filter, + size_t Channels, + void* Output, + size_t OutputCount, + MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, + unsigned KernelFlags + ); + +void +MLASCALL +MlasConvSymDepthwiseKernelSize25ArmS8S8( + void const* const* InputIndirection, + int8_t const* Filter, + size_t Channels, + void* Output, + size_t OutputCount, + MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, + unsigned KernelFlags + ); + +void +MLASCALL +MlasConvSymDepthwiseKernelSize25ArmU8S8( + void const* const* InputIndirection, + int8_t const* Filter, + size_t Channels, + void* Output, + size_t OutputCount, + MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, + unsigned KernelFlags + ); + +#endif +} struct MLAS_CONV_SYM_DISPATCH { MLAS_CONV_SYM_KERNEL* Kernel; MLAS_CONV_SYM_DEPTHWISE_KERNEL* DepthwiseKernel; + MLAS_SYMM_QCONV_DEPTHWISE_FIXFILTER_PROC* Depthwise3x3Proc; + MLAS_SYMM_QCONV_DEPTHWISE_FIXFILTER_PROC* Depthwise5x5Proc; uint8_t FilterInputChannelPackCount; uint8_t FilterOutputChannelPackCount; uint8_t KernelChannelCount; @@ -96,6 +162,8 @@ struct MLAS_CONV_SYM_DISPATCH { const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx2 = { MlasConvSymKernelAvx2, MlasConvSymDepthwiseKernelAvx2, + nullptr, + nullptr, 4, // FilterInputChannelPackCount 16, // FilterOutputChannelPackCount 16, // KernelChannelCount @@ -110,6 +178,8 @@ const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx2 = { const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvxVnni = { MlasConvSymKernelAvxVnni, MlasConvSymDepthwiseKernelAvxVnni, + nullptr, + nullptr, 4, // FilterInputChannelPackCount 16, // FilterOutputChannelPackCount 16, // KernelChannelCount @@ -126,6 +196,8 @@ const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvxVnni = { const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx512Core = { MlasConvSymKernelAvx512Core, MlasConvSymDepthwiseKernelAvx512Core, + nullptr, + nullptr, 4, // FilterInputChannelPackCount 16, // FilterOutputChannelPackCount 64, // KernelChannelCount @@ -140,6 +212,8 @@ const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx512Core = { const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx512Vnni = { MlasConvSymKernelAvx512Vnni, MlasConvSymDepthwiseKernelAvx512Vnni, + nullptr, + nullptr, 4, // FilterInputChannelPackCount 16, // FilterOutputChannelPackCount 64, // KernelChannelCount @@ -154,9 +228,11 @@ const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx512Vnni = { #endif // ORT_MINIMAL_BUILD #elif defined(MLAS_TARGET_ARM64) -const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchNeon = { - MlasConvSymKernelNeon, - MlasConvSymDepthwiseKernelNeon, +const MLAS_CONV_SYM_DISPATCH MlasConvSymU8DispatchNeon = { + MlasConvSymU8KernelNeon, + MlasConvSymDepthwiseU8KernelNeon, + MlasConvSymDepthwiseKernelSize9Arm64U8S8, + MlasConvSymDepthwiseKernelSize25ArmU8S8, 8, // FilterInputChannelPackCount 8, // FilterOutputChannelPackCount 8, // KernelChannelCount @@ -168,20 +244,53 @@ const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchNeon = { true }; -const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchDot = { - MlasConvSymKernelNeonDot, - MlasConvSymDepthwiseKernelNeon, +const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchNeon = { + MlasConvSymS8KernelNeon, + MlasConvSymDepthwiseS8KernelNeon, + MlasConvSymDepthwiseKernelSize9Arm64S8S8, + MlasConvSymDepthwiseKernelSize25ArmS8S8, + 8, // FilterInputChannelPackCount + 8, // FilterOutputChannelPackCount + 8, // KernelChannelCount + 2, // KernelOutputCount + 8, // KernelInputChannelAlignment + 8, // KernelOutputChannelAlignment + 16, // KernelDepthwiseChannelCount + 4, // KernelDepthwiseOutputCount + false +}; + +const MLAS_CONV_SYM_DISPATCH MlasConvSymU8DispatchDot = { + MlasConvSymU8KernelDot, + MlasConvSymDepthwiseU8KernelNeon, + MlasConvSymDepthwiseKernelSize9Arm64U8S8, + MlasConvSymDepthwiseKernelSize25ArmU8S8, 4, // FilterInputChannelPackCount 16, // FilterOutputChannelPackCount 0, // KernelChannelCount 4, // KernelOutputCount 4, // KernelInputChannelAlignment - 1, // KernelOutputChannelAlignment + 16, // KernelOutputChannelAlignment 16, // KernelDepthwiseChannelCount 4, // KernelDepthwiseOutputCount true }; +const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchDot = { + MlasConvSymS8KernelDot, + MlasConvSymDepthwiseS8KernelNeon, + MlasConvSymDepthwiseKernelSize9Arm64S8S8, + MlasConvSymDepthwiseKernelSize25ArmS8S8, + 4, // FilterInputChannelPackCount + 16, // FilterOutputChannelPackCount + 0, // KernelChannelCount + 4, // KernelOutputCount + 4, // KernelInputChannelAlignment + 16, // KernelOutputChannelAlignment + 16, // KernelDepthwiseChannelCount + 4, // KernelDepthwiseOutputCount + false +}; #endif // MLAS_TARGET_AMD64 MLAS_FORCEINLINE @@ -249,7 +358,8 @@ MlasConvSymPackWSize( #ifdef MLAS_TARGET_ARM64 // TODO!! remove this for functional testing! // TODO!! is there a way to know whether this is called by tests? - if (InputChannels < 128) { + if (KernelSize <= 1) return 0; + if (InputChannels < 64) { // Shallow indirect conv runs slower. // TODO!! for DOT arch, threshold should be 32 for better perf return 0; @@ -443,27 +553,23 @@ MlasConvSymDepthwise( MlasConvSymSetOutputZeroPoint(PostProcessParams, Params.OutputZeroPoint, Params.InputIsSigned); -#if defined(MLAS_TARGET_ARM64) - - if ((Params.KernelSize == 9 || Params.KernelSize == 25) && (Params.OutputChannels & 15) == 0) { + if ((Params.OutputChannels & 15) == 0) { PostProcessParams.Bias = Params.Bias; PostProcessParams.Scale = Params.Scale; - if (Params.KernelSize == 9) { - MlasConvSymDepthwiseKernelSize9Arm64( - Params.InputIndirection, (int8_t const*)Params.Filter, Params.OutputChannels, - Params.Output, Params.OutputCount, &PostProcessParams, KernelFlags, Params.InputIsSigned - ); - } else { - MlasConvSymDepthwiseKernelSize25Arm( - Params.InputIndirection, (int8_t const*)Params.Filter, Params.OutputChannels, - Params.Output, Params.OutputCount, &PostProcessParams, KernelFlags, Params.InputIsSigned - ); + if (ConvSymDispatch->Depthwise3x3Proc && Params.KernelSize == 9) { + ConvSymDispatch->Depthwise3x3Proc(Params.InputIndirection, (int8_t const*)Params.Filter, + Params.OutputChannels, Params.Output, + Params.OutputCount, &PostProcessParams, KernelFlags); + return; + } + if (ConvSymDispatch->Depthwise5x5Proc && Params.KernelSize == 25) { + ConvSymDispatch->Depthwise5x5Proc(Params.InputIndirection, (int8_t const*)Params.Filter, + Params.OutputChannels, Params.Output, + Params.OutputCount, &PostProcessParams, KernelFlags); + return; } - return; } -#endif - const size_t KernelChannelCount = ConvSymDispatch->KernelDepthwiseChannelCount; const size_t KernelOutputCount = ConvSymDispatch->KernelDepthwiseOutputCount; diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index d7e4fd55c7f16..ca9fdef577f7b 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -714,8 +714,10 @@ extern const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx2; extern const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvxVnni; extern const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx512Core; extern const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchAvx512Vnni; -extern const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchNeon; -extern const MLAS_CONV_SYM_DISPATCH MlasConvSymDispatchDot; +extern const MLAS_CONV_SYM_DISPATCH MlasConvSymU8DispatchNeon; +extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchNeon; +extern const MLAS_CONV_SYM_DISPATCH MlasConvSymU8DispatchDot; +extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchDot; // // Quantized depthwise convolution kernels. @@ -768,19 +770,6 @@ struct MLAS_CONV_SYM_POST_PROCESS_PARAMS { int32_t OutputZeroPoint; }; -typedef -void -(MLASCALL MLAS_CONV_SYM_DEPTHWISE_ROUTINE_KERNELSIZE)( - void const* const* InputIndirection, - int8_t const* Filter, - size_t Channels, - void* Output, - size_t OutputCount, - MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, - unsigned KernelFlags, - bool InputIsSigned - ); - // // Environment information class. // diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 675e6625d5e9f..6101355b21b21 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -352,8 +352,9 @@ Return Value: #if defined(MLAS_TARGET_ARM64) this->GemmU8X8Dispatch = &MlasGemmU8X8DispatchNeon; - this->ConvSymU8S8Dispatch = &MlasConvSymDispatchNeon; this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; + this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; + this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; // // Check if the processor supports ASIMD dot product instructions. @@ -371,8 +372,9 @@ Return Value: if (HasDotProductInstructions) { this->GemmU8X8Dispatch = &MlasGemmU8X8DispatchUdot; - this->ConvSymU8S8Dispatch = &MlasConvSymDispatchDot; this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot; + this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot; + this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; } #endif // MLAS_TARGET_ARM64 diff --git a/onnxruntime/core/mlas/lib/qdwconv_kernelsize.cpp b/onnxruntime/core/mlas/lib/qdwconv_kernelsize.cpp index ce7f30a1ad275..4985f91b64f36 100644 --- a/onnxruntime/core/mlas/lib/qdwconv_kernelsize.cpp +++ b/onnxruntime/core/mlas/lib/qdwconv_kernelsize.cpp @@ -43,20 +43,24 @@ Module Name: #include "mlasi.h" +extern "C" { + #if defined(MLAS_TARGET_ARM64) -static void +MLASCALL MlasConvSymDepthwiseKernelSize25ArmU8S8( - uint8_t const* const* InputIndirection, + void const* const* InputIndirection, int8_t const* Filter, size_t Channels, - uint8_t* Output, + void* Output, size_t OutputCount, MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, unsigned KernelFlags ) { + uint8_t const* const* IndirectBuf = (uint8_t const* const*)InputIndirection; + uint8_t* OutBuf = (uint8_t*)Output; const uint8x16_t vu128 = vdupq_n_u8(128); const int16x8_t voutput_zero_point = vld1q_dup_s16((int16_t const*)&PostProcessParams->OutputZeroPoint); float32x4_t vscale_0123, vscale_4567, vscale_89AB, vscale_CDEF; @@ -64,35 +68,35 @@ MlasConvSymDepthwiseKernelSize25ArmU8S8( // Init them anyway due to some compiler will generate uninitialized warnings. vscale_0123 = vscale_4567 = vscale_89AB = vscale_CDEF = vld1q_dup_f32(PostProcessParams->Scale); while (OutputCount-- > 0) { - const uint8_t* i00 = InputIndirection[0]; - const uint8_t* i01 = InputIndirection[1]; - const uint8_t* i02 = InputIndirection[2]; - const uint8_t* i03 = InputIndirection[3]; - const uint8_t* i04 = InputIndirection[4]; - const uint8_t* i05 = InputIndirection[5]; - const uint8_t* i06 = InputIndirection[6]; - const uint8_t* i07 = InputIndirection[7]; - const uint8_t* i08 = InputIndirection[8]; - const uint8_t* i09 = InputIndirection[9]; - - const uint8_t* i10 = InputIndirection[10]; - const uint8_t* i11 = InputIndirection[11]; - const uint8_t* i12 = InputIndirection[12]; - const uint8_t* i13 = InputIndirection[13]; - const uint8_t* i14 = InputIndirection[14]; - const uint8_t* i15 = InputIndirection[15]; - const uint8_t* i16 = InputIndirection[16]; - const uint8_t* i17 = InputIndirection[17]; - const uint8_t* i18 = InputIndirection[18]; - const uint8_t* i19 = InputIndirection[19]; - - const uint8_t* i20 = InputIndirection[20]; - const uint8_t* i21 = InputIndirection[21]; - const uint8_t* i22 = InputIndirection[22]; - const uint8_t* i23 = InputIndirection[23]; - const uint8_t* i24 = InputIndirection[24]; - - InputIndirection += 25; + const uint8_t* i00 = IndirectBuf[0]; + const uint8_t* i01 = IndirectBuf[1]; + const uint8_t* i02 = IndirectBuf[2]; + const uint8_t* i03 = IndirectBuf[3]; + const uint8_t* i04 = IndirectBuf[4]; + const uint8_t* i05 = IndirectBuf[5]; + const uint8_t* i06 = IndirectBuf[6]; + const uint8_t* i07 = IndirectBuf[7]; + const uint8_t* i08 = IndirectBuf[8]; + const uint8_t* i09 = IndirectBuf[9]; + + const uint8_t* i10 = IndirectBuf[10]; + const uint8_t* i11 = IndirectBuf[11]; + const uint8_t* i12 = IndirectBuf[12]; + const uint8_t* i13 = IndirectBuf[13]; + const uint8_t* i14 = IndirectBuf[14]; + const uint8_t* i15 = IndirectBuf[15]; + const uint8_t* i16 = IndirectBuf[16]; + const uint8_t* i17 = IndirectBuf[17]; + const uint8_t* i18 = IndirectBuf[18]; + const uint8_t* i19 = IndirectBuf[19]; + + const uint8_t* i20 = IndirectBuf[20]; + const uint8_t* i21 = IndirectBuf[21]; + const uint8_t* i22 = IndirectBuf[22]; + const uint8_t* i23 = IndirectBuf[23]; + const uint8_t* i24 = IndirectBuf[24]; + + IndirectBuf += 25; int32_t const* bias = PostProcessParams->Bias; float const* scale = PostProcessParams->Scale; for (size_t c = 0; c < Channels; c += 16) { @@ -322,8 +326,8 @@ MlasConvSymDepthwiseKernelSize25ArmU8S8( const int16x8_t vacc_89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_89AB), vacc_CDEF), voutput_zero_point); uint8x16_t vout = vqmovun_high_s16(vqmovun_s16(vacc_01234567), vacc_89ABCDEF); - vst1q_u8(Output, vout); - Output += 16; + vst1q_u8(OutBuf, vout); + OutBuf += 16; } } } @@ -331,50 +335,53 @@ MlasConvSymDepthwiseKernelSize25ArmU8S8( void MLASCALL MlasConvSymDepthwiseKernelSize25ArmS8S8( - int8_t const* const* InputIndirection, + void const* const* InputIndirection, int8_t const* Filter, size_t Channels, - int8_t* Output, + void* Output, size_t OutputCount, MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, unsigned KernelFlags ) { - const int16x8_t voutput_zero_point = vld1q_dup_s16((int16_t const*)&PostProcessParams->OutputZeroPoint); + int8_t const* const* IndirectBuf = (int8_t const* const*)InputIndirection; + int8_t* OutBuf = (int8_t*)Output; + const int16x8_t voutput_zero_point = + vld1q_dup_s16((int16_t const*)&PostProcessParams->OutputZeroPoint); float32x4_t vscale_0123, vscale_4567, vscale_89AB, vscale_CDEF; const bool is_per_channel = ((KernelFlags & MLAS_CONV_SYM_FLAG_PER_CHANNEL_SCALE) != 0); // Init them anyway due to some compiler will generate uninitialized warnings. vscale_0123 = vscale_4567 = vscale_89AB = vscale_CDEF = vld1q_dup_f32(PostProcessParams->Scale); while (OutputCount-- > 0) { - const int8_t* i00 = InputIndirection[0]; - const int8_t* i01 = InputIndirection[1]; - const int8_t* i02 = InputIndirection[2]; - const int8_t* i03 = InputIndirection[3]; - const int8_t* i04 = InputIndirection[4]; - const int8_t* i05 = InputIndirection[5]; - const int8_t* i06 = InputIndirection[6]; - const int8_t* i07 = InputIndirection[7]; - const int8_t* i08 = InputIndirection[8]; - const int8_t* i09 = InputIndirection[9]; - - const int8_t* i10 = InputIndirection[10]; - const int8_t* i11 = InputIndirection[11]; - const int8_t* i12 = InputIndirection[12]; - const int8_t* i13 = InputIndirection[13]; - const int8_t* i14 = InputIndirection[14]; - const int8_t* i15 = InputIndirection[15]; - const int8_t* i16 = InputIndirection[16]; - const int8_t* i17 = InputIndirection[17]; - const int8_t* i18 = InputIndirection[18]; - const int8_t* i19 = InputIndirection[19]; - - const int8_t* i20 = InputIndirection[20]; - const int8_t* i21 = InputIndirection[21]; - const int8_t* i22 = InputIndirection[22]; - const int8_t* i23 = InputIndirection[23]; - const int8_t* i24 = InputIndirection[24]; - - InputIndirection += 25; + const int8_t* i00 = IndirectBuf[0]; + const int8_t* i01 = IndirectBuf[1]; + const int8_t* i02 = IndirectBuf[2]; + const int8_t* i03 = IndirectBuf[3]; + const int8_t* i04 = IndirectBuf[4]; + const int8_t* i05 = IndirectBuf[5]; + const int8_t* i06 = IndirectBuf[6]; + const int8_t* i07 = IndirectBuf[7]; + const int8_t* i08 = IndirectBuf[8]; + const int8_t* i09 = IndirectBuf[9]; + + const int8_t* i10 = IndirectBuf[10]; + const int8_t* i11 = IndirectBuf[11]; + const int8_t* i12 = IndirectBuf[12]; + const int8_t* i13 = IndirectBuf[13]; + const int8_t* i14 = IndirectBuf[14]; + const int8_t* i15 = IndirectBuf[15]; + const int8_t* i16 = IndirectBuf[16]; + const int8_t* i17 = IndirectBuf[17]; + const int8_t* i18 = IndirectBuf[18]; + const int8_t* i19 = IndirectBuf[19]; + + const int8_t* i20 = IndirectBuf[20]; + const int8_t* i21 = IndirectBuf[21]; + const int8_t* i22 = IndirectBuf[22]; + const int8_t* i23 = IndirectBuf[23]; + const int8_t* i24 = IndirectBuf[24]; + + IndirectBuf += 25; int32_t const* bias = PostProcessParams->Bias; float const* scale = PostProcessParams->Scale; for (size_t c = 0; c < Channels; c += 16) { @@ -604,90 +611,11 @@ MlasConvSymDepthwiseKernelSize25ArmS8S8( const int16x8_t vacc_89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc_89AB), vacc_CDEF), voutput_zero_point); int8x16_t vout = vqmovn_high_s16(vqmovn_s16(vacc_01234567), vacc_89ABCDEF); - vst1q_s8(Output, vout); - Output += 16; + vst1q_s8(OutBuf, vout); + OutBuf += 16; } } } -extern "C" { - -void -MLASCALL -MlasConvSymDepthwiseKernelSize25Arm( - void const* const* InputIndirection, - int8_t const* Filter, - size_t Channels, - void* Output, - size_t OutputCount, - MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, - unsigned KernelFlags, - bool IsInputSigned - ) -{ - if (IsInputSigned) { - MlasConvSymDepthwiseKernelSize25ArmS8S8( - (int8_t const* const*)InputIndirection, Filter, Channels, (int8_t*)Output, OutputCount, - PostProcessParams, KernelFlags - ); - } else { - MlasConvSymDepthwiseKernelSize25ArmU8S8( - (uint8_t const* const*)InputIndirection, Filter, Channels, (uint8_t*)Output, OutputCount, - PostProcessParams, KernelFlags - ); - } -} - -void -MLASCALL -MlasConvSymDepthwiseKernelSize9Arm64U8S8( - uint8_t const* const* InputIndirection, - int8_t const* Filter, - size_t Channels, - uint8_t* Output, - size_t OutputCount, - MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, - unsigned KernelFlags - ); - -void -MLASCALL -MlasConvSymDepthwiseKernelSize9Arm64S8S8( - int8_t const* const* InputIndirection, - int8_t const* Filter, - size_t Channels, - int8_t* Output, - size_t OutputCount, - MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, - unsigned KernelFlags - ); - -void -MLASCALL -MlasConvSymDepthwiseKernelSize9Arm64( - void const* const* InputIndirection, - int8_t const* Filter, - size_t Channels, - void* Output, - size_t OutputCount, - MLAS_CONV_SYM_POST_PROCESS_PARAMS const* PostProcessParams, - unsigned KernelFlags, - bool IsInputSigned - ) -{ - if (IsInputSigned) { - MlasConvSymDepthwiseKernelSize9Arm64S8S8( - (int8_t const* const*)InputIndirection, Filter, Channels, (int8_t*)Output, OutputCount, - PostProcessParams, KernelFlags - ); - } else { - MlasConvSymDepthwiseKernelSize9Arm64U8S8( - (uint8_t const* const*)InputIndirection, Filter, Channels, (uint8_t*)Output, OutputCount, - PostProcessParams, KernelFlags - ); - } -} - -} - #endif +} \ No newline at end of file diff --git a/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc b/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc index e68168efea734..01f3c0e027f65 100644 --- a/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc @@ -1039,15 +1039,6 @@ TEST(QLinearConvTest, Conv2D_S8S8_Sym_M64_C64) { test.Run(); } -TEST(QLinearConvTest, Conv2D_S8S8_Sym_M16_C4) { - QLinearConvOpTester test; - test.GenerateRandomInput({1, 4, 3, 3}, .05f, 4); - test.GenerateRandomWeights({16, 4, 3, 3}, .125f, 0); - test.GenerateRandomBias(); - test.SetPads({0, 0, 0, 0}); - test.SetOutputScaleAndZeroPoint(.55f, 54); - test.Run(); -} TEST(QLinearConvTest, Conv2D_S8S8_Sym_M16_C4_Bias) { QLinearConvOpTester test;