diff --git a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm index 1705a15fa4dc7..e65e43d93e671 100644 --- a/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm +++ b/onnxruntime/core/mlas/lib/amd64/QgemmU8X8KernelAvx2.asm @@ -41,7 +41,7 @@ GemmInt8KernelFrame STRUCT SavedXmm13 OWORD ? SavedXmm14 OWORD ? SavedXmm15 OWORD ? - Padding QWORD ? + SavedR14 QWORD ? SavedR13 QWORD ? SavedR12 QWORD ? SavedRdi QWORD ? @@ -165,6 +165,42 @@ ENDIF ENDM +; Macro Description: +; +; This macro generates the appropriate vpdp instruction based on the ASigned +; and BSigned values. +; +; Arguments: +; +; ASigned - sign of A. +; +; BSigned - sign of B. +; +; reg1 - Output register for vpdp instruction +; +; reg2 - Second input register for vpdp instruction +; +; reg3 - First input register for vpdp instruction +; + +VpdpYmmYmmYmm MACRO ASigned, BSigned, reg1, reg2, reg3 + + IF ASigned EQ 1 + IF BSigned EQ 1 + VpdpbssdYmmYmmYmm reg1, reg2, reg3 + ELSE + VpdpbsudYmmYmmYmm reg1, reg2, reg3 + ENDIF + ELSE + IF BSigned EQ 1 + VpdpbusdYmmYmmYmm reg1, reg2, reg3 + ELSE + VpdpbuudYmmYmmYmm reg1, reg2, reg3 + ENDIF + ENDIF + + ENDM + ; ; Macro Description: ; @@ -190,41 +226,21 @@ ENDIF ; ymm2 - Supplies the broadcast value loaded from matrix A. ; -MultiplyAccumulateRowAvxVnni MACRO ColumnCount, Vec1Reg, Vec2Reg, ASigned, BSigned +MultiplyAccumulateRowAvxVnni MACRO ColumnCount, ASigned, BSigned, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg -IF ASigned EQ 1 - IF BSigned EQ 1 - IF ColumnCount EQ 16 - VpdpbssdYmmYmmYmm Vec1Reg,ymm2,ymm0 - VpdpbssdYmmYmmYmm Vec2Reg,ymm2,ymm1 - ELSE - VpdpbssdYmmYmmYmm Vec2Reg,ymm2,ymm0 + IF ColumnCount EQ 32 + VpdpYmmYmmYmm ASigned, BSigned, Vec1Reg, ymm2, ymm0 + VpdpYmmYmmYmm ASigned, BSigned, Vec2Reg, ymm2, ymm1 + VpdpYmmYmmYmm ASigned, BSigned, Vec3Reg, ymm2, ymm14 + VpdpYmmYmmYmm ASigned, BSigned, Vec4Reg, ymm2, ymm15 ENDIF - ELSE IF ColumnCount EQ 16 - VpdpbsudYmmYmmYmm Vec1Reg,ymm2,ymm0 - VpdpbsudYmmYmmYmm Vec2Reg,ymm2,ymm1 - ELSE - VpdpbsudYmmYmmYmm Vec2Reg,ymm2,ymm0 + VpdpYmmYmmYmm ASigned, BSigned, Vec1Reg, ymm2, ymm0 + VpdpYmmYmmYmm ASigned, BSigned, Vec2Reg, ymm2, ymm1 ENDIF - ENDIF -ELSE - IF BSigned EQ 1 - IF ColumnCount EQ 16 - VpdpbusdYmmYmmYmm Vec1Reg,ymm2,ymm0 - VpdpbusdYmmYmmYmm Vec2Reg,ymm2,ymm1 - ELSE - VpdpbusdYmmYmmYmm Vec2Reg,ymm2,ymm0 + IF ColumnCount EQ 8 + VpdpYmmYmmYmm ASigned, BSigned, Vec2Reg, ymm2, ymm0 ENDIF - ELSE - IF ColumnCount EQ 16 - VpdpbuudYmmYmmYmm Vec1Reg,ymm2,ymm0 - VpdpbuudYmmYmmYmm Vec2Reg,ymm2,ymm1 - ELSE - VpdpbuudYmmYmmYmm Vec2Reg,ymm2,ymm0 - ENDIF - ENDIF -ENDIF ENDM @@ -261,18 +277,20 @@ ComputeBlockAvxVnni MACRO ColumnCount, RowCount, VectorOffset, BroadcastOffset, vmovdqu ymm0,YMMWORD PTR [rdx+VectorOffset] EmitIfCountGE ColumnCount, 16, + EmitIfCount2EQ ColumnCount, 32, RowCount, 1, + EmitIfCount2EQ ColumnCount, 32, RowCount, 1, EmitIfCountGE RowCount, 1, - EmitIfCountGE RowCount, 1, + EmitIfCountGE RowCount, 1, EmitIfCountGE RowCount, 2, - EmitIfCountGE RowCount, 2, + EmitIfCountGE RowCount, 2, EmitIfCountGE RowCount, 3, - EmitIfCountGE RowCount, 3, + EmitIfCountGE RowCount, 3, EmitIfCountGE RowCount, 4, - EmitIfCountGE RowCount, 4, + EmitIfCountGE RowCount, 4, EmitIfCountGE RowCount, 5, - EmitIfCountGE RowCount, 5, + EmitIfCountGE RowCount, 5, EmitIfCountGE RowCount, 6, - EmitIfCountGE RowCount, 6, + EmitIfCountGE RowCount, 6, ENDM @@ -312,7 +330,8 @@ ComputeBlockLoop MACRO Isa, ColumnCount, RowCount, ASigned, BSigned mov rsi,r9 ; reload row length remaining -IF (ColumnCount EQ 16) AND (RowCount EQ 1) +IF (ColumnCount EQ 16) OR (ColumnCount EQ 32) +IF (RowCount EQ 1) sub rsi,4*4 jb ProcessRemainingBlocks @@ -329,7 +348,8 @@ ComputeBlockBy4Loop: ProcessRemainingBlocks: add rsi,4*4 ; correct for over-subtract above jz ComputeBlockLoopExit -ENDIF +ENDIF ; RowCount == 1 +ENDIF ; ColumnCount == 16/32 ComputeBlockBy1Loop: ComputeBlock&Isa& ColumnCount, RowCount, 0, 0, ASigned, BSigned @@ -552,24 +572,44 @@ ProduceOutputBlock MACRO ColumnCount, RowCount, ASigned, BSigned EmitIfCountGE RowCount, 4, EmitIfCountGE RowCount, 5, EmitIfCountGE RowCount, 6, +IF ColumnCount EQ 32 + vmovdqu ymm0,YMMWORD PTR [r12] + vmovdqu ymm1,YMMWORD PTR [r12+32] + vmovdqu ymm14,YMMWORD PTR [r12+64] + vmovdqu ymm15,YMMWORD PTR [r12+96] + add r12,32*4 ; advance ColumnSumBuffer by 32 columns +ENDIF IF ColumnCount EQ 16 vmovdqu ymm0,YMMWORD PTR [r12] vmovdqu ymm1,YMMWORD PTR [r12+32] add r12,16*4 ; advance ColumnSumBuffer by 16 columns -ELSE +ENDIF +IF ColumnCount EQ 8 vmovdqu ymm1,YMMWORD PTR [r12] ENDIF test r13,r13 ; per column zero points? jz SkipScaleByZeroPointB +IF ColumnCount EQ 32 + vmovdqu ymm2,YMMWORD PTR [r13] + vmovdqu ymm3,YMMWORD PTR [r13+32] + vmovdqu ymm12,YMMWORD PTR [r13+64] + vmovdqu ymm13,YMMWORD PTR [r13+96] + add r13,32*4 ; advance ZeroPointB by 16 columns +ENDIF IF ColumnCount EQ 16 vmovdqu ymm2,YMMWORD PTR [r13] vmovdqu ymm3,YMMWORD PTR [r13+32] add r13,16*4 ; advance ZeroPointB by 16 columns -ELSE +ENDIF +IF ColumnCount EQ 8 vmovdqu ymm3,YMMWORD PTR [r13] ENDIF + EmitIfCount2EQ RowCount, 1, ColumnCount, 32, + EmitIfCount2EQ RowCount, 1, ColumnCount, 32, EmitIfCount2GE RowCount, 1, ColumnCount, 16, EmitIfCountGE RowCount, 1, + EmitIfCount2EQ RowCount, 1, ColumnCount, 32, + EmitIfCount2EQ RowCount, 1, ColumnCount, 32, EmitIfCount2GE RowCount, 1, ColumnCount, 16, EmitIfCountGE RowCount, 1, EmitIfCount2GE RowCount, 2, ColumnCount, 16, @@ -595,6 +635,8 @@ ENDIF jmp AccumulatorsInitialized SkipScaleByZeroPointB: + EmitIfCount2EQ RowCount, 1, ColumnCount, 32, + EmitIfCount2EQ RowCount, 1, ColumnCount, 32, EmitIfCount2GE RowCount, 1, ColumnCount, 16, EmitIfCountGE RowCount, 1, EmitIfCount2GE RowCount, 2, ColumnCount, 16, @@ -810,6 +852,177 @@ SkipAccumulateOutputMasked8xNBlock: ENDM +; +; Section Description: +; +; This macro generates code to compute matrix multiplication for a single +; row. When processing just one row, there are more ymm registers available +; for us to unroll the main kernel further to benefit from better pipelining +; the dot product instruction. +; +; Arguments: None +; +; Implicit Arguments: Same as ProcessCountM +; +; + +ProcessCount1AvxVnni MACRO RowCount, ASigned, BSigned, Fallthrough + + LOCAL LProcessNextColumnLoop32xN1 + LOCAL LSkipAccumulateOutputMasked32xNBlock1 + LOCAL LProcessNextColumnLoop16xN1 + LOCAL LSkipAccumulateOutput16xNBlock1 + LOCAL LProcessRemainingCountN1 + LOCAL LSkipAccumulateOutput8xNBlock1 + LOCAL LExitProcessCountM1 + LOCAL LOutputMasked32xNBlock1 + LOCAL LSkipAccumulateOutputMasked32xNBlock1 + LOCAL LOutputMasked24xNBlock1 + LOCAL LSkipAccumulateOutputMasked24xNBlock1 + LOCAL LOutputMasked16xNBlock1 + LOCAL LSkipAccumulateOutputMasked16xNBlock1 + LOCAL LOutputMasked8xNBlock1 + LOCAL LSkipAccumulateOutputMasked8xNBlock1 + + cmp rbp,8 + jbe LProcessRemainingCountN1 ; num of cols <= 8?: process the tail + cmp rbp,16 + jbe LProcessNextColumnLoop16xN1 ; num of cols <= 16?: process 16 at a time: + +LProcessNextColumnLoop32xN1: ; Ouptut look to process 32 cols at a time: + ProduceOutputBlock 32, 1, ASigned, BSigned + add rdx,r14 + sub rbp,32 + jb LOutputMasked32xNBlock1 ; if numcols < 32 (& > 16), use write using masked output and exit + test r10b,r10b ; ZeroMode? + jnz LSkipAccumulateOutput32xNBlock1 + vpaddd ymm4,ymm4,YMMWORD PTR [r8] + vpaddd ymm5,ymm5,YMMWORD PTR [r8+32] + vpaddd ymm6,ymm6,YMMWORD PTR [r8+64] + vpaddd ymm7,ymm7,YMMWORD PTR [r8+96] + +LSkipAccumulateOutput32xNBlock1: + vmovdqu YMMWORD PTR [r8],ymm4 + vmovdqu YMMWORD PTR [r8+32],ymm5 + vmovdqu YMMWORD PTR [r8+64],ymm6 + vmovdqu YMMWORD PTR [r8+96],ymm7 + add r8,32*4 ; advance matrix C by 32 columns + mov rcx,rdi ; reload matrix A + cmp rbp,0 + je LExitProcessCountM1 + cmp rbp,8 + jle LProcessRemainingCountN1 ; num of cols < 8 + cmp rbp,16 + ja LProcessNextColumnLoop32xN1 ; num of cols > 16?: process 32 at a time: + +LProcessNextColumnLoop16xN1: ; num of cols > 8 and <= 16 + ProduceOutputBlock 16, 1, ASigned, BSigned + sub rbp,16 + jb LOutputMasked16xNBlock1 ; if numcols < 16 (& > 8), use write using masked output and exit + test r10b,r10b ; ZeroMode? + jnz LSkipAccumulateOutput16xNBlock1 + vpaddd ymm4,ymm4,YMMWORD PTR [r8] + vpaddd ymm5,ymm5,YMMWORD PTR [r8+32] + +LSkipAccumulateOutput16xNBlock1: + vmovdqu YMMWORD PTR [r8],ymm4 + vmovdqu YMMWORD PTR [r8+32],ymm5 + add r8,16*4 ; advance matrix C by 16 columns + mov rcx,rdi ; reload matrix A + cmp rbp,0 + je LExitProcessCountM1 + cmp rbp,8 + ja LProcessNextColumnLoop16xN1 ; num of cols > 8?: process 16 at a time: + +; Loop if num of cols <= 8 +LProcessRemainingCountN1: + ProduceOutputBlock 8, 1, ASigned, BSigned + cmp rbp,8 + jb LOutputMasked8xNBlock1 ; if numcols < 8, use write using masked output and exit + test r10b,r10b ; ZeroMode? + jnz LSkipAccumulateOutput8xNBlock1 + vpaddd ymm5,ymm5,YMMWORD PTR [r8] + +LSkipAccumulateOutput8xNBlock1: + vmovdqu YMMWORD PTR [r8],ymm5 + +LExitProcessCountM1: ; num of cols = 0, we are done + mov eax, 1 + jmp ExitKernel + +;; -- Section to write final tail of C matrix and exit -- ;; +;; write <= 32 elements ;; +LOutputMasked32xNBlock1: + add rbp,32 + cmp rbp,24 + jle LOutputMasked24xNBlock1 + sub rbp,24 + neg rbp + lea rcx,MlasMaskMoveTableAvx+8*4 + vmovdqu ymm0,YMMWORD PTR [rcx+rbp*4] + test r10b,r10b ; ZeroMode? + jnz LSkipAccumulateOutputMasked32xNBlock1 + vpaddd ymm4,ymm4,YMMWORD PTR [r8] + vpaddd ymm5,ymm5,YMMWORD PTR [r8+32] + vpaddd ymm6,ymm6,YMMWORD PTR [r8+64] + vpmaskmovd ymm8,ymm0,YMMWORD PTR [r8+96] + vpaddd ymm7,ymm7,ymm8 + +; First write 16 cols using regular mov and then maskmov for the rest < 8 cols +LSkipAccumulateOutputMasked32xNBlock1: + vmovdqu YMMWORD PTR [r8],ymm4 + vmovdqu YMMWORD PTR [r8+32],ymm5 + vmovdqu YMMWORD PTR [r8+64],ymm6 + vpmaskmovd YMMWORD PTR [r8+96],ymm0,ymm7 + jmp LExitProcessCountM1 + +;; write <= 24 elements ;; +LOutputMasked24xNBlock1: + sub rbp,16 + neg rbp + lea rcx,MlasMaskMoveTableAvx+8*4 + vmovdqu ymm0,YMMWORD PTR [rcx+rbp*4] + test r10b,r10b ; ZeroMode? + jnz LSkipAccumulateOutputMasked24xNBlock1 + vpaddd ymm4,ymm4,YMMWORD PTR [r8] + vpaddd ymm5,ymm5,YMMWORD PTR [r8+32] + vpmaskmovd ymm8,ymm0,YMMWORD PTR [r8+64] + vpaddd ymm6,ymm6,ymm8 + +; First write 16 cols using regular mov and then maskmov for the rest < 8 cols +LSkipAccumulateOutputMasked24xNBlock1: + vmovdqu YMMWORD PTR [r8],ymm4 + vmovdqu YMMWORD PTR [r8+32],ymm5 + vpmaskmovd YMMWORD PTR [r8+64],ymm0,ymm6 + jmp LExitProcessCountM1 + +;; write <= 16 elements ;; +LOutputMasked16xNBlock1: + add rbp,16 + test r10b,r10b ; ZeroMode? + jnz LSkipAccumulateOutputMasked16xNBlock1 + vpaddd ymm4,ymm4,YMMWORD PTR [r8] + +LSkipAccumulateOutputMasked16xNBlock1: + vmovdqu YMMWORD PTR [r8],ymm4 + add r8,8*4 ; advance matrix C by 8 columns + sub rbp,8 + +; at this point, rbp should be the value of num elements left to write +LOutputMasked8xNBlock1: + neg rbp + lea rcx,MlasMaskMoveTableAvx+8*4 + vmovdqu ymm0,YMMWORD PTR [rcx+rbp*4] + test r10b,r10b ; ZeroMode? + jnz LSkipAccumulateOutputMasked8xNBlock1 + vpmaskmovd ymm4,ymm0,YMMWORD PTR [r8] + vpaddd ymm5,ymm5,ymm4 + +LSkipAccumulateOutputMasked8xNBlock1: + vpmaskmovd YMMWORD PTR [r8],ymm0,ymm5 + jmp LExitProcessCountM1 + + ENDM ;++ ; @@ -870,7 +1083,8 @@ MlasGemmInt8KernelAvx2 MACRO ASigned, BSigned push_reg rdi push_reg r12 push_reg r13 - alloc_stack (GemmInt8KernelFrame.SavedR13) + push_reg r14 + alloc_stack (GemmInt8KernelFrame.SavedR14) save_xmm128 xmm6,GemmInt8KernelFrame.SavedXmm6 save_xmm128 xmm7,GemmInt8KernelFrame.SavedXmm7 save_xmm128 xmm8,GemmInt8KernelFrame.SavedXmm8 @@ -897,6 +1111,8 @@ MlasGemmInt8KernelAvx2 MACRO ASigned, BSigned mov r13,GemmInt8KernelFrame.ZeroPointB[rsp] vpcmpeqw ymm12,ymm12,ymm12 ; generate 256-bit word vector [0xFFFF] vpsrlw ymm12,ymm12,15 ; generate 256-bit word vector [0x0001] + lea r14,[r9*8] + lea r14,[r14*2] cmp DWORD PTR GemmInt8KernelFrame.PreviousP1Home[rsp],0 je CheckCountM4OrMore ; U8S8 AVX2 kernel requires extra registers @@ -941,10 +1157,11 @@ ExitKernel: movaps xmm13,GemmInt8KernelFrame.SavedXmm13[rsp] movaps xmm14,GemmInt8KernelFrame.SavedXmm14[rsp] movaps xmm15,GemmInt8KernelFrame.SavedXmm15[rsp] - add rsp,(GemmInt8KernelFrame.SavedR13) + add rsp,(GemmInt8KernelFrame.SavedR14) BEGIN_EPILOGUE + pop r14 pop r13 pop r12 pop rdi @@ -954,8 +1171,13 @@ ExitKernel: ret ProcessCountM1: + cmp DWORD PTR GemmInt8KernelFrame.PreviousP1Home[rsp],-1 + je ProcessCountM1AvxVnni ProcessCountM 1, ASigned, BSigned +ProcessCountM1AvxVnni: + ProcessCount1AvxVnni 1, ASigned, BSigned + ProcessCountM3: ProcessCountM 3, ASigned, BSigned diff --git a/onnxruntime/core/mlas/lib/amd64/mlasi.inc b/onnxruntime/core/mlas/lib/amd64/mlasi.inc index 2db3147168727..a4f58c1060a8a 100644 --- a/onnxruntime/core/mlas/lib/amd64/mlasi.inc +++ b/onnxruntime/core/mlas/lib/amd64/mlasi.inc @@ -93,6 +93,15 @@ ENDIF ENDM +EmitIfCount2EQ MACRO Count1, Value1, Count2, Value2, Statement + +IF (Count1 EQ Value1) AND (Count2 EQ Value2) + Statement +ENDIF + + ENDM + + ; ; Macro Description: ; diff --git a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S index af2a475ea0c59..ef98afcbbc8e1 100644 --- a/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S +++ b/onnxruntime/core/mlas/lib/x86_64/QgemmU8X8KernelAvx2.S @@ -28,16 +28,17 @@ Abstract: // .equ .LGemmInt8KernelFrame_type, -8 - .equ .LGemmInt8KernelFrame_SavedR13, 0 - .equ .LGemmInt8KernelFrame_SavedR12, 8 - .equ .LGemmInt8KernelFrame_SavedRbx, 16 - .equ .LGemmInt8KernelFrame_SavedRbp, 24 - .equ .LGemmInt8KernelFrame_ReturnAddress, 32 - .equ .LGemmInt8KernelFrame_ldc, 40 - .equ .LGemmInt8KernelFrame_RowSumBuffer, 48 - .equ .LGemmInt8KernelFrame_ColumnSumBuffer, 56 - .equ .LGemmInt8KernelFrame_ZeroPointB, 64 - .equ .LGemmInt8KernelFrame_ZeroMode, 72 + .equ .LGemmInt8KernelFrame_SavedR14, 0 + .equ .LGemmInt8KernelFrame_SavedR13, 8 + .equ .LGemmInt8KernelFrame_SavedR12, 16 + .equ .LGemmInt8KernelFrame_SavedRbx, 24 + .equ .LGemmInt8KernelFrame_SavedRbp, 32 + .equ .LGemmInt8KernelFrame_ReturnAddress, 40 + .equ .LGemmInt8KernelFrame_ldc, 48 + .equ .LGemmInt8KernelFrame_RowSumBuffer, 56 + .equ .LGemmInt8KernelFrame_ColumnSumBuffer, 64 + .equ .LGemmInt8KernelFrame_ZeroPointB, 72 + .equ .LGemmInt8KernelFrame_ZeroMode, 80 /*++ @@ -145,6 +146,44 @@ Implicit Arguments: .endm +/*++ +Macro Description: + + This macro generates the appropriate vpdp instruction based on the ASigned + and BSigned values. + +Arguments: + + ASigned - sign of A. + + BSigned - sign of B. + + reg1 - Output register for vpdp instruction + + reg2 - Second input register for vpdp instruction + + reg3 - First input register for vpdp instruction + +--*/ + + .macro VpdpYmmYmmYmm ASigned, BSigned, reg1, reg2, reg3 + + .if \ASigned\() == 1 + .if \BSigned\() == 1 + VpdpbssdYmmYmmYmm \reg1\(),\reg2\(),\reg3\() + .else + VpdpbsudYmmYmmYmm \reg1\(),\reg2\(),\reg3\() + .endif + .else + .if \BSigned\() == 1 + VpdpbusdYmmYmmYmm \reg1\(),\reg2\(),\reg3\() + .else + VpdpbuudYmmYmmYmm \reg1\(),\reg2\(),\reg3\() + .endif + .endif + + .endm + /*++ Macro Description: @@ -171,41 +210,21 @@ Implicit Arguments: --*/ - .macro MultiplyAccumulateRowAvxVnni ColumnCount, Vec1Reg, Vec2Reg, ASigned, BSigned - -.if \ASigned\() == 1 - .if \BSigned\() == 1 - .if \ColumnCount\() == 16 - VpdpbssdYmmYmmYmm \Vec1Reg\(),ymm2,ymm0 - VpdpbssdYmmYmmYmm \Vec2Reg\(),ymm2,ymm1 - .else - VpdpbssdYmmYmmYmm \Vec2Reg\(),ymm2,ymm0 - .endif - .else - .if \ColumnCount\() == 16 - VpdpbsudYmmYmmYmm \Vec1Reg\(),ymm2,ymm0 - VpdpbsudYmmYmmYmm \Vec2Reg\(),ymm2,ymm1 - .else - VpdpbsudYmmYmmYmm \Vec2Reg\(),ymm2,ymm0 - .endif - .endif -.else - .if \BSigned\() == 1 - .if \ColumnCount\() == 16 - VpdpbusdYmmYmmYmm \Vec1Reg\(),ymm2,ymm0 - VpdpbusdYmmYmmYmm \Vec2Reg\(),ymm2,ymm1 - .else - VpdpbusdYmmYmmYmm \Vec2Reg\(),ymm2,ymm0 - .endif - .else - .if \ColumnCount\() == 16 - VpdpbuudYmmYmmYmm \Vec1Reg\(),ymm2,ymm0 - VpdpbuudYmmYmmYmm \Vec2Reg\(),ymm2,ymm1 - .else - VpdpbuudYmmYmmYmm \Vec2Reg\(),ymm2,ymm0 - .endif - .endif -.endif + .macro MultiplyAccumulateRowAvxVnni ColumnCount, ASigned, BSigned, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + .if \ColumnCount\() == 32 + VpdpYmmYmmYmm \ASigned\(), \BSigned\(), \Vec1Reg\(), ymm2, ymm0 + VpdpYmmYmmYmm \ASigned\(), \BSigned\(), \Vec2Reg\(), ymm2, ymm1 + VpdpYmmYmmYmm \ASigned\(), \BSigned\(), \Vec3Reg\(), ymm2, ymm14 + VpdpYmmYmmYmm \ASigned\(), \BSigned\(), \Vec4Reg\(), ymm2, ymm15 + .endif + .if \ColumnCount\() == 16 + VpdpYmmYmmYmm \ASigned\(), \BSigned\(), \Vec1Reg\(), ymm2, ymm0 + VpdpYmmYmmYmm \ASigned\(), \BSigned\(), \Vec2Reg\(), ymm2, ymm1 + .endif + .if \ColumnCount\() == 8 + VpdpYmmYmmYmm \ASigned\(), \BSigned\(), \Vec2Reg\(), ymm2, ymm0 + .endif .endm @@ -244,18 +263,20 @@ Implicit Arguments: vmovdqu ymm0,YMMWORD PTR [rsi+\VectorOffset\()] EmitIfCountGE \ColumnCount\(), 16, "vmovdqu ymm1,YMMWORD PTR [rsi+\VectorOffset\()+32]" + EmitIfCount2EQ \ColumnCount\(), 32, \RowCount\(), 1, "vmovdqu ymm14,YMMWORD PTR [rsi+r14+\VectorOffset\()]" + EmitIfCount2EQ \ColumnCount\(), 32, \RowCount\(), 1, "vmovdqu ymm15,YMMWORD PTR [rsi+r14+\VectorOffset\()+32]" EmitIfCountGE \RowCount\(), 1, "vpbroadcastd ymm2,DWORD PTR [rdi+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm4, ymm5, \ASigned\(), \BSigned\()" + EmitIfCountGE \RowCount\(), 1, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), \ASigned\(), \BSigned\(), ymm4, ymm5, ymm6, ymm7" EmitIfCountGE \RowCount\(), 2, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm6, ymm7, \ASigned\(), \BSigned\()" + EmitIfCountGE \RowCount\(), 2, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), \ASigned\(), \BSigned\(), ymm6, ymm7" EmitIfCountGE \RowCount\(), 3, "vpbroadcastd ymm2,DWORD PTR [rdi+rcx*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm8, ymm9, \ASigned\(), \BSigned\()" + EmitIfCountGE \RowCount\(), 3, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), \ASigned\(), \BSigned\(), ymm8, ymm9" EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm2,DWORD PTR [r8+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm10, ymm11, \ASigned\(), \BSigned\()" + EmitIfCountGE \RowCount\(), 4, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), \ASigned\(), \BSigned\(), ymm10, ymm11" EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm2,DWORD PTR [r8+rcx+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm12, ymm13, \ASigned\(), \BSigned\()" + EmitIfCountGE \RowCount\(), 5, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), \ASigned\(), \BSigned\(), ymm12, ymm13" EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm2,DWORD PTR [r8+rcx*2+\BroadcastOffset\()]" - EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), ymm14, ymm15, \ASigned\(), \BSigned\()" + EmitIfCountGE \RowCount\(), 6, "MultiplyAccumulateRowAvxVnni \ColumnCount\(), \ASigned\(), \BSigned\(), ymm14, ymm15" .endm @@ -292,7 +313,7 @@ Implicit Arguments: mov rbp,rcx # reload row length remaining -.if (\ColumnCount\() == 16) && (\RowCount\() == 1) +.if (\ColumnCount\() >= 16) && (\RowCount\() == 1) sub rbp,4*4 jb .LProcessRemainingBlocks\@ @@ -527,24 +548,42 @@ Implicit Arguments: EmitIfCountGE \RowCount\(), 4, "vpbroadcastd ymm11,DWORD PTR [r11+12]" EmitIfCountGE \RowCount\(), 5, "vpbroadcastd ymm13,DWORD PTR [r11+16]" EmitIfCountGE \RowCount\(), 6, "vpbroadcastd ymm15,DWORD PTR [r11+20]" -.if \ColumnCount\() == 16 +.if \ColumnCount\() >= 16 +.if \ColumnCount\() == 32 vmovdqu ymm0,YMMWORD PTR [r12] vmovdqu ymm1,YMMWORD PTR [r12+32] - add r12,16*4 # advance ColumnSumBuffer by 16 columns + vmovdqu ymm14,YMMWORD PTR [r12+64] + vmovdqu ymm15,YMMWORD PTR [r12+96] +.else + vmovdqu ymm0,YMMWORD PTR [r12] + vmovdqu ymm1,YMMWORD PTR [r12+32] +.endif + add r12,\ColumnCount\()*4 # advance ColumnSumBuffer by 16/32 columns .else vmovdqu ymm1,YMMWORD PTR [r12] .endif test r13,r13 # per column zero points? jz .LSkipScaleByZeroPointB\@ -.if \ColumnCount\() == 16 +.if \ColumnCount\() >= 16 +.if \ColumnCount\() == 32 vmovdqu ymm2,YMMWORD PTR [r13] vmovdqu ymm3,YMMWORD PTR [r13+32] - add r13,16*4 # advance ZeroPointB by 16 columns + vmovdqu ymm12,YMMWORD PTR [r13+64] + vmovdqu ymm13,YMMWORD PTR [r13+96] +.else + vmovdqu ymm2,YMMWORD PTR [r13] + vmovdqu ymm3,YMMWORD PTR [r13+32] +.endif + add r13,\ColumnCount\()*4 # advance ZeroPointB by 16/32 columns .else vmovdqu ymm3,YMMWORD PTR [r13] .endif + EmitIfCount2EQ \RowCount\(), 1, \ColumnCount\(), 32, "vpmulld ymm6,ymm5,ymm12" + EmitIfCount2EQ \RowCount\(), 1, \ColumnCount\(), 32, "vpmulld ymm7,ymm5,ymm13" EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpmulld ymm4,ymm5,ymm2" EmitIfCountGE \RowCount\(), 1, "vpmulld ymm5,ymm5,ymm3" + EmitIfCount2EQ \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd ymm6,ymm14,ymm6" + EmitIfCount2EQ \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd ymm7,ymm15,ymm7" EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm0,ymm4" EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm1,ymm5" EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpmulld ymm6,ymm7,ymm2" @@ -570,6 +609,8 @@ Implicit Arguments: jmp .LAccumulatorsInitialized\@ .LSkipScaleByZeroPointB\@: + EmitIfCount2EQ \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd ymm6,ymm5,ymm14" + EmitIfCount2EQ \RowCount\(), 1, \ColumnCount\(), 32, "vpaddd ymm7,ymm5,ymm15" EmitIfCount2GE \RowCount\(), 1, \ColumnCount\(), 16, "vpaddd ymm4,ymm5,ymm0" EmitIfCountGE \RowCount\(), 1, "vpaddd ymm5,ymm5,ymm1" EmitIfCount2GE \RowCount\(), 2, \ColumnCount\(), 16, "vpaddd ymm6,ymm7,ymm0" @@ -777,6 +818,159 @@ Implicit Arguments: /*++ +Section Description: + This macro generates code to compute matrix multiplication for a single + row. When processing just one row, there are more ymm registers available + for us to unroll the main kernel further to benefit from better pipelining + the dot product instruction. +Arguments: None +Implicit Arguments: Same as ProcessCountM + +--*/ + + .macro ProcessCount1AvxVnni ASigned, BSigned + cmp r9,8 + jbe .LProcessRemainingCountN1\@ # num of cols <= 8?: process the tail + cmp r9,16 + jbe .LProcessNextColumnLoop16xN1\@ # num of cols <= 16?: process 16 at a time: + +.LProcessNextColumnLoop32xN1\@: # Ouptut look to process 32 cols at a time: + ProduceOutputBlock 32, 1 \ASigned\(), \BSigned\() + add rsi,r14 + sub r9,32 + jb .LOutputMasked32xNBlock1\@ # if numcols < 32 (& > 16), use write using masked output and exit + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutput32xNBlock1\@ + vpaddd ymm4,ymm4,YMMWORD PTR [rdx] + vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32] + vpaddd ymm6,ymm6,YMMWORD PTR [rdx+64] + vpaddd ymm7,ymm7,YMMWORD PTR [rdx+96] + +.LSkipAccumulateOutput32xNBlock1\@: + vmovdqu YMMWORD PTR [rdx],ymm4 + vmovdqu YMMWORD PTR [rdx+32],ymm5 + vmovdqu YMMWORD PTR [rdx+64],ymm6 + vmovdqu YMMWORD PTR [rdx+96],ymm7 + add rdx,32*4 # advance matrix C by 32 columns + mov rdi,rbx # reload matrix A + cmp r9,0 + je .LExitProcessCountM1\@ + cmp r9,8 + jle .LProcessRemainingCountN1\@ # num of cols < 8 + cmp r9,16 + ja .LProcessNextColumnLoop32xN1\@ # num of cols > 16?: process 32 at a time: + +.LProcessNextColumnLoop16xN1\@: # num of cols > 8 and <= 16 + ProduceOutputBlock 16, 1 \ASigned\(), \BSigned\() + sub r9,16 + jb .LOutputMasked16xNBlock1\@ # if numcols < 16 (& > 8), use write using masked output and exit + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutput16xNBlock1\@ + vpaddd ymm4,ymm4,YMMWORD PTR [rdx] + vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32] + +.LSkipAccumulateOutput16xNBlock1\@: + vmovdqu YMMWORD PTR [rdx],ymm4 + vmovdqu YMMWORD PTR [rdx+32],ymm5 + add rdx,16*4 # advance matrix C by 16 columns + mov rdi,rbx # reload matrix A + cmp r9,0 + je .LExitProcessCountM1\@ + cmp r9,8 + ja .LProcessNextColumnLoop16xN1\@ # num of cols > 8?: process 16 at a time: + +# Loop if num of cols <= 8 +.LProcessRemainingCountN1\@: + ProduceOutputBlock 8, 1 \ASigned\(), \BSigned\() + cmp r9,8 + jb .LOutputMasked8xNBlock1\@ # if numcols < 8, use write using masked output and exit + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutput8xNBlock1\@ + vpaddd ymm5,ymm5,YMMWORD PTR [rdx] + +.LSkipAccumulateOutput8xNBlock1\@: + vmovdqu YMMWORD PTR [rdx],ymm5 + +.LExitProcessCountM1\@: # num of cols = 0, we are done + mov eax, 1 + jmp .LExitKernel + +## -- Section to write final tail of C matrix and exit -- ## +## write <= 32 elements ## +.LOutputMasked32xNBlock1\@: + add r9,32 + cmp r9,24 + jle .LOutputMasked24xNBlock1\@ + sub r9,24 + neg r9 + lea rdi,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] + vmovdqu ymm0,YMMWORD PTR [rdi+r9*4] + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutputMasked32xNBlock1\@ + vpaddd ymm4,ymm4,YMMWORD PTR [rdx] + vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32] + vpaddd ymm6,ymm6,YMMWORD PTR [rdx+64] + vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+96] + vpaddd ymm7,ymm7,ymm8 + +# First write 16 cols using regular mov and then maskmov for the rest < 8 cols +.LSkipAccumulateOutputMasked32xNBlock1\@: + vmovdqu YMMWORD PTR [rdx],ymm4 + vmovdqu YMMWORD PTR [rdx+32],ymm5 + vmovdqu YMMWORD PTR [rdx+64],ymm6 + vpmaskmovd YMMWORD PTR [rdx+96],ymm0,ymm7 + jmp .LExitProcessCountM1\@ + +## write <= 24 elements ## +.LOutputMasked24xNBlock1\@: + sub r9,16 + neg r9 + lea rdi,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] + vmovdqu ymm0,YMMWORD PTR [rdi+r9*4] + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutputMasked24xNBlock1\@ + vpaddd ymm4,ymm4,YMMWORD PTR [rdx] + vpaddd ymm5,ymm5,YMMWORD PTR [rdx+32] + vpmaskmovd ymm8,ymm0,YMMWORD PTR [rdx+64] + vpaddd ymm6,ymm6,ymm8 + +# First write 16 cols using regular mov and then maskmov for the rest < 8 cols +.LSkipAccumulateOutputMasked24xNBlock1\@: + vmovdqu YMMWORD PTR [rdx],ymm4 + vmovdqu YMMWORD PTR [rdx+32],ymm5 + vpmaskmovd YMMWORD PTR [rdx+64],ymm0,ymm6 + jmp .LExitProcessCountM1\@ + +## write <= 16 elements ## +.LOutputMasked16xNBlock1\@: + add r9,16 + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutputMasked16xNBlock1\@ + vpaddd ymm4,ymm4,YMMWORD PTR [rdx] + +.LSkipAccumulateOutputMasked16xNBlock1\@: + vmovdqu YMMWORD PTR [rdx],ymm4 + add rdx,8*4 # advance matrix C by 8 columns + sub r9,8 + +# at this point, r9 should be the value of num elements left to write +.LOutputMasked8xNBlock1\@: + neg r9 + lea rdi,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] + vmovdqu ymm0,YMMWORD PTR [rdi+r9*4] + test r10b,r10b # ZeroMode? + jnz .LSkipAccumulateOutputMasked8xNBlock1\@ + vpmaskmovd ymm4,ymm0,YMMWORD PTR [rdx] + vpaddd ymm5,ymm5,ymm4 + +.LSkipAccumulateOutputMasked8xNBlock1\@: + vpmaskmovd YMMWORD PTR [rdx],ymm0,ymm5 + jmp .LExitProcessCountM1\@ + + .endm + +/*++ + Routine Description: This routine is an inner kernel to compute matrix multiplication for a @@ -832,6 +1026,7 @@ Return Value: push rbx push r12 push r13 + push r14 mov DWORD PTR .LGemmInt8KernelFrame_type[rsp],eax mov rbx,rdi @@ -844,6 +1039,8 @@ Return Value: mov r13,.LGemmInt8KernelFrame_ZeroPointB[rsp] vpcmpeqw ymm12,ymm12,ymm12 # generate 256-bit word vector [0xFFFF] vpsrlw ymm12,ymm12,15 # generate 256-bit word vector [0x0001] + lea rbp,[rcx*8] + lea r14,[rbp*2] cmp DWORD PTR .LGemmInt8KernelFrame_type[rsp],0 je .LCheckCountM4OrMore\@ # U8S8 AVX2 kernel requires extra registers @@ -873,8 +1070,13 @@ Return Value: ProcessCountM 6, \ASigned\(), \BSigned\() .LProcessCountM1\@: + cmp DWORD PTR .LGemmInt8KernelFrame_type[rsp],-1 + je .LProcessCountM1AvxVnni\@ ProcessCountM 1, \ASigned\(), \BSigned\() +.LProcessCountM1AvxVnni\@: + ProcessCount1AvxVnni \ASigned\(), \BSigned\() + .LProcessCountM3\@: ProcessCountM 3, \ASigned\(), \BSigned\() @@ -890,6 +1092,7 @@ Return Value: .LExitKernel: vzeroupper + pop r14 pop r13 pop r12 pop rbx diff --git a/onnxruntime/core/mlas/lib/x86_64/asmmacro.h b/onnxruntime/core/mlas/lib/x86_64/asmmacro.h index 7d7b3079a5132..7ef836c5701f3 100644 --- a/onnxruntime/core/mlas/lib/x86_64/asmmacro.h +++ b/onnxruntime/core/mlas/lib/x86_64/asmmacro.h @@ -97,6 +97,28 @@ Macro Description: .endm + +/*++ +Macro Description: + This macro conditionally emits the statement if Count1 is equal to Value1 + and Count2 is equal to Value2. +Arguments: + Count1 - Supplies the variable used in the comparison. + Value1 - Supplies the static used in the comparison. + Count2 - Supplies the variable used in the comparison. + Value2 - Supplies the static used in the comparison. + Statement - Supplies the statement to conditionally emit. +--*/ + + .macro EmitIfCount2EQ Count1, Value1, Count2, Value2, Statement + +.if (\Count1\() == \Value1\()) && (\Count2\() == \Value2\()) + \Statement\() +.endif + + .endm + + /*++ Macro Description: