diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index cefe98930c832..dc5efeb5814c6 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -152,13 +152,32 @@ MlasJblasNBitsGemmPackB(void* PackedBuf, MLAS_COMPUTE_TYPE CompType, MLAS_THREADPOOL* ThreadPool) { + GetCPUDevice(); switch (CompType) { case CompInt8: - return JblaNBitsGemmPackB(JblasAvxVnniS4Fp32Fp32, PackedBuf, int(BlkSize), QData, Scale, - Zp, int(N), int(K), isAsym, lastCall, int(ldb), ThreadPool); + if (_cd->AVX512_VNNI()) { + return JblaNBitsGemmPackB(JblasAvx512VnniS4Fp32Fp32, PackedBuf, int(BlkSize), QData, + Scale, Zp, int(N), int(K), isAsym, lastCall, int(ldb), + ThreadPool); + } + if (_cd->AVX_VNNI()) { + return JblaNBitsGemmPackB(JblasAvxVnniS4Fp32Fp32, PackedBuf, int(BlkSize), QData, + Scale, Zp, int(N), int(K), isAsym, lastCall, int(ldb), + ThreadPool); + } + break; case CompFp32: - return JblaNBitsGemmPackB(JblasAvx512fS4Fp32Fp32, PackedBuf, int(BlkSize), QData, Scale, - Zp, int(N), int(K), isAsym, lastCall, int(ldb), ThreadPool); + if (_cd->AVX512F()) { + return JblaNBitsGemmPackB(JblasAvx512fS4Fp32Fp32, PackedBuf, int(BlkSize), QData, + Scale, Zp, int(N), int(K), isAsym, lastCall, int(ldb), + ThreadPool); + } + if (_cd->AVX2()) { + return JblaNBitsGemmPackB(JblasAvx2S4Fp32Fp32, PackedBuf, int(BlkSize), QData, + Scale, Zp, int(N), int(K), isAsym, lastCall, int(ldb), + ThreadPool); + } + break; case CompBf16: case CompFp16: default: @@ -177,18 +196,37 @@ MlasJblasQ4GemmUnPackB(float* FpData, auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(const_cast(PackedBuf)); ORTThreading orth(ThreadPool); + GetCPUDevice(); if (ptr) { if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { - auto coretype = ptr->mCoreType; - auto NTile = uint32_t(coretype) & uint32_t(JBLAS_GEMM_CORE::NTILE_MASK); - auto CType = uint32_t(coretype) & uint32_t(JBLAS_GEMM_CORE::COMP_MASK); - if (NTile == 48 && CType == uint32_t(JBLAS_GEMM_CORE::COMP_FP32)) { - JblasAvx512fS4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb), - &orth); + auto NTile = + jblas::gemm::CoreAttr::get_mask_val(ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, + jblas::gemm::CoreAttr::NTILE_SHIFT); + auto CType = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT); + if (CType == uint32_t(jblas::gemm::CompType::COMP_FP32)) { + if (NTile == 48 && _cd->AVX512F()) { + JblasAvx512fS4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb), + &orth); + return; + } + if (NTile == 24 && _cd->AVX2()) { + JblasAvx2S4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb), + &orth); + return; + } } - if (NTile == 48 && CType == uint32_t(JBLAS_GEMM_CORE::COMP_INT8_US)) { - JblasAvx512VnniS4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb), - &orth); + if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_US_INT32)) { + if (NTile == 48 && _cd->AVX512_VNNI()) { + JblasAvx512VnniS4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, + int(ldb), &orth); + return; + } + if (NTile == 24 && _cd->AVX_VNNI()) { + JblasAvxVnniS4Fp32Fp32.mProB.unpackWeight(int(N), int(K), ptr, FpData, int(ldb), + &orth); + return; + } } } delete ptr; diff --git a/onnxruntime/core/mlas/lib/q4gemm.cpp b/onnxruntime/core/mlas/lib/q4gemm.cpp index b003038d3b9fa..891c346c14466 100644 --- a/onnxruntime/core/mlas/lib/q4gemm.cpp +++ b/onnxruntime/core/mlas/lib/q4gemm.cpp @@ -155,30 +155,6 @@ jblas::ORTThreading::parallel_for(const jblas::parallel::thread_func& func) [&](ptrdiff_t tid) { func(int(tid)); }); } -template -void -GemmKBlockRun(Launch_T& launcher, - const typename Launch_T::Param& args, - parallel::IThreading* threading) -{ - device::CpuBase cb; - Parallel_T para({ - threading->num_threads(), - cb.mL2Cache, - args.M, - args.N, - args.K, - args.KBlock, - }); - threading->parallel_for([&](int tidx) { - typename Parallel_T::ThreadProblem thdp{tidx}; - para.getIndex(thdp); - if (thdp.valid) { - launcher.run(args, thdp); - } - }); -} - template void JblasQ4GemmCompF32(const int M, @@ -205,7 +181,7 @@ JblasQ4GemmCompF32(const int M, reduceA.template get(), reduceA.lda}; typename Launcher::Param args{M, N, K, B->mBlockSize, {A, K}, {B}, blkargs, {C, N}}; - GemmKBlockRun(kernel, args, th); + jblas::parallel::GemmKBlockRun(kernel, args, th); } template @@ -236,10 +212,10 @@ JblasQ4GemmCompInt8(const int M, {A, K, &quanA}, {B}, {B->template SPtr(), B->mScaT, B->mCStep, quanA.template SPtr(), - quanA.mCStep, quanA.template ZPtr(), B->template RPtr(), + quanA.mCStep, quanA.template ZPtr(), B->template RPtr(), B->mRedT, B->template ZPtr(), quanA.template RPtr(), B->mBlockSize}, {C, N}}; - GemmKBlockRun(kernel, args, th); + jblas::parallel::GemmKBlockRun(kernel, args, th); } void @@ -259,35 +235,39 @@ JblasQ4GemmBatchDriver(const size_t M, if (ptr) { if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { auto kptr = (jblas::storage::gemm::StorageWeightKBlockS4*)ptr; - auto coretype = ptr->mCoreType; - auto NTile = uint32_t(coretype) & uint32_t(JBLAS_GEMM_CORE::NTILE_MASK); - auto CType = uint32_t(coretype) & uint32_t(JBLAS_GEMM_CORE::COMP_MASK); - if (NTile == 48 && CType == uint32_t(JBLAS_GEMM_CORE::COMP_FP32)) { - if (_cd->AVX512F()) { - JblasQ4GemmCompF32( + auto coretype = ptr->mCoreId; + auto NTile = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, + jblas::gemm::CoreAttr::NTILE_SHIFT); + auto CType = jblas::gemm::CoreAttr::get_mask_val(ptr->mCoreId, + jblas::gemm::CoreAttr::COMP_MASK, + jblas::gemm::CoreAttr::COMP_SHIFT); + if (CType == uint32_t(gemm::CompType::COMP_FP32)) { + if (NTile == 48 && _cd->AVX512F()) { + JblasQ4GemmCompF32>( M, N, K, DataParams[i].A, DataParams[i].lda, (jblas::storage::gemm::StorageWeightKBlockS4*)ptr, DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); return; } - if (_cd->AVX2()) { - JblasQ4GemmCompF32( + if (NTile == 24 && _cd->AVX2()) { + JblasQ4GemmCompF32>( M, N, K, DataParams[i].A, DataParams[i].lda, (jblas::storage::gemm::StorageWeightKBlockS4*)ptr, DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); return; } } - if (NTile == 48 && CType == uint32_t(JBLAS_GEMM_CORE::COMP_INT8_US)) { - if (_cd->AVX512_VNNI()) { - JblasQ4GemmCompInt8( + if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { + if (NTile == 48 && _cd->AVX512_VNNI()) { + JblasQ4GemmCompInt8>( M, N, K, DataParams[i].A, DataParams[i].lda, (jblas::storage::gemm::StorageWeightKBlockS4*)ptr, DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); return; } - if (_cd->AVX_VNNI()) { - JblasQ4GemmCompInt8( + if (NTile == 24 && _cd->AVX_VNNI()) { + JblasQ4GemmCompInt8>( M, N, K, DataParams[i].A, DataParams[i].lda, (jblas::storage::gemm::StorageWeightKBlockS4*)ptr, DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h index d2ef3308fdccd..c2f08a0fc9770 100644 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h @@ -49,6 +49,21 @@ class JitBase : protected Xbyak::CodeGenerator { #endif } + void padto_le(const Xbyak::Reg64& _src, int padding) { + // _src=_src/padding*padding + if (padding == 1) { + return; + } + for (int i = 1; i < 16; i++) { + if ((1 << i) == padding) { + shr(_src, i); + shl(_src, i); + return; + } + } + assert(0); + } + void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Address& _total, const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) { inLocalLabel(); @@ -86,6 +101,8 @@ class JitBase : protected Xbyak::CodeGenerator { class JitAvx : protected JitBase { protected: static int constexpr VBits = 256; + static int constexpr VecBytes = VBits / 8; + static int constexpr RegCount = 16; typedef Xbyak::Ymm vreg_t; }; @@ -93,6 +110,7 @@ class JitAvx2 : protected JitAvx { protected: static int constexpr VBits = 256; typedef Xbyak::Ymm vreg_t; + void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxor(x1, x2, op); } void loadbf16_f32(const Xbyak::Ymm& dst, const Xbyak::Address& addr) { vpmovzxwd(dst, addr); @@ -103,8 +121,12 @@ class JitAvx2 : protected JitAvx { class JitAvx512f : protected JitAvx2 { protected: static int constexpr VBits = 512; + static int constexpr VecBytes = VBits / 8; + static int constexpr RegCount = 32; typedef Xbyak::Zmm vreg_t; + void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxorq(x1, x2, op); } + void interleave_2rows_4regs(Xbyak::Zmm* src_2regs, Xbyak::Zmm* tmp_2reg) { vpunpcklwd(tmp_2reg[0], src_2regs[0], src_2regs[1]); vpunpckhwd(tmp_2reg[1], src_2regs[0], src_2regs[1]); @@ -191,18 +213,20 @@ class JitAvx512f : protected JitAvx2 { } }; +class JitAvx512_bf16 : protected JitAvx512f {}; + class JitAvx512_fp16 : protected JitAvx512f {}; class JitAvx512vnni : protected JitAvx512f { protected: - void vpdpbusds_evex(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { + void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { vpdpbusds(x1, x2, op, Xbyak::EvexEncoding); } }; class JitAvxvnni : protected JitAvx2 { protected: - void vpdpbusds_vex(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { + void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { vpdpbusds(x1, x2, op, Xbyak::VexEncoding); } }; @@ -215,6 +239,15 @@ class JitAmxtile : protected JitAvx512f { uint16_t colb[16]; uint8_t rows[16]; }; + static int constexpr TileCount = 8; + + typedef long long (*configure_t)(void*); + + static void generate_config(Xbyak::CodeGenerator* g) { + Xbyak::util::StackFrame st(g, 1, 0, 0); + auto& parambase = st.p[0]; + g->ldtilecfg(g->ptr[parambase]); + } static void configure_tiles(tileconfig_t& tc, int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, int CNum) { diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h index badb571b3926b..2ecb6630d3da0 100644 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h @@ -30,6 +30,7 @@ enum JBLAS_ISA : uint32_t { JblasAMX_BF16, JblasAMX_INT8, JblasAVX512_FP16, + JblasAVX512_BF16, }; enum class JBLAS_DTYPE : uint32_t { EleBitsMask = 0xff, @@ -80,57 +81,6 @@ enum JBLAS_ELTWISEOP { LINEAR, }; -enum class JBLAS_GEMM_CORE : uint32_t { - // INT32=LSB|**8bits:NTile**||**8bits:PackRow**||**8bits:CompType**||**8bits:Reserve**| - Undef = 0, - NTILE_MASK = 0xff, - NTILE_SHIFT = 0, - NTILE_24 = 24, - NTILE_48 = 48, - NTILE_64 = 64, - NTILE_96 = 96, - PACKROW_MASK = 0xff00, - PACKROW_SHIFT = 8, - PACKROW_1 = 1 << PACKROW_SHIFT, - PACKROW_2 = 2 << PACKROW_SHIFT, - PACKROW_4 = 4 << PACKROW_SHIFT, - COMP_MASK = 0xff0000, - COMP_SHIFT = 16, - COMP_FP32 = 0 << COMP_SHIFT, - COMP_BF16 = 1 << COMP_SHIFT, - COMP_FP16 = 2 << COMP_SHIFT, - COMP_INT_START = 3 << COMP_SHIFT, - COMP_INT8_US = COMP_INT_START, - COMP_INT8_SS = 4 << COMP_SHIFT, - COMP_INT8_SU = 5 << COMP_SHIFT, - COMP_INT16_SS = 6 << COMP_SHIFT, - COMP_INT8_US_FP32 = 7 << COMP_SHIFT, - COMP_INT8_SS_FP32 = 8 << COMP_SHIFT, - COMP_INT8_SU_FP32 = 9 << COMP_SHIFT, - ISA_MASK = 0xff000000, - ISA_SHIFT = 24, - ISA_AVX2 = (uint32_t)JBLAS_ISA::JblasAVX2 << ISA_SHIFT, - ISA_AVX512F = (uint32_t)JBLAS_ISA::JblasAVX512F << ISA_SHIFT, - ISA_AVX_VNNI = (uint32_t)JBLAS_ISA::JblasAVX_VNNI << ISA_SHIFT, - ISA_AVX512_VNNI = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI << ISA_SHIFT, - ISA_AMX_INT8 = (uint32_t)JBLAS_ISA::JblasAMX_INT8 << ISA_SHIFT, - ISA_AMX_BF16 = (uint32_t)JBLAS_ISA::JblasAMX_BF16 << ISA_SHIFT, - ISA_AVX512_FP16 = (uint32_t)JBLAS_ISA::JblasAVX512_FP16 << ISA_SHIFT, - AVX2_4X24 = NTILE_24 | PACKROW_1 | COMP_FP32 | ISA_AVX2, - AVX2_2X48 = NTILE_48 | PACKROW_1 | COMP_FP32 | ISA_AVX2, - AVX512F_8x48 = NTILE_48 | PACKROW_1 | COMP_FP32 | ISA_AVX512F, - AMX_BF16_16x64 = NTILE_64 | PACKROW_2 | COMP_BF16 | ISA_AMX_BF16, - AMX_BF16_16x48 = NTILE_48 | PACKROW_2 | COMP_BF16 | ISA_AMX_BF16, - AVX512_FP16_8x64 = NTILE_64 | PACKROW_2 | COMP_FP16 | ISA_AVX512_FP16, - AVX512_FP16_8x96 = NTILE_96 | PACKROW_2 | COMP_FP16 | ISA_AVX512_FP16, - AVX_VNNI_2x48 = NTILE_48 | PACKROW_4 | COMP_INT8_US | ISA_AVX_VNNI, - AVX_VNNI_4x24 = NTILE_24 | PACKROW_4 | COMP_INT8_US | ISA_AVX_VNNI, - AVX512_VNNI_8x48 = NTILE_48 | PACKROW_4 | COMP_INT8_US | ISA_AVX512_VNNI, - AMX_INT8_16x64_US = NTILE_64 | PACKROW_4 | COMP_INT8_US | ISA_AMX_INT8, - AMX_INT8_16x64_SS = NTILE_64 | PACKROW_4 | COMP_INT8_SS | ISA_AMX_INT8, - AMX_INT8_16x48_US = NTILE_48 | PACKROW_4 | COMP_INT8_US | ISA_AMX_INT8, - AMX_INT8_16x48_SS = NTILE_48 | PACKROW_4 | COMP_INT8_SS | ISA_AMX_INT8, -}; enum class JBLAS_PROLOGUEB_IDS : uint32_t { Undef = (uint32_t)-1, Begin = 0, diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h index 0b4a21d7d9e53..bf3f45db18f62 100644 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h @@ -209,6 +209,7 @@ class CpuDevice { } } inline int getThreads() { return numthreads; } + inline int getCores() { return numcores; } inline uint32_t getL2CacheSize() { return L2Cache; } inline uint32_t getL1CacheSize() { return L1Cache; } inline bool AVX() { return mHasAVX; } @@ -261,15 +262,15 @@ class CpuDevice { #define GetCPUDevice() auto _cd = jblas::device::CpuDevice::getInstance(); - class CpuBase { public: CpuBase() { GetCPUDevice(); mL2Cache = _cd->getL2CacheSize(); + mL1Cache = _cd->getL1CacheSize(); mNumThreads = _cd->getThreads(); } - size_t mL2Cache; + size_t mL2Cache, mL1Cache; int mNumThreads; }; } // namespace device diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h index d12dbac990578..1cab348459a0c 100644 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h @@ -79,6 +79,8 @@ class CustomAccumulatorWriteBackWithEltop { template using AccumulatorWriteBackFp32 = AccumulatorWriteBack; template +using AccumulatorWriteBackInt32 = AccumulatorWriteBack; +template using AccumulatorWriteBackBf16 = AccumulatorWriteBack; template using AccumulatorWriteBackFp16 = AccumulatorWriteBack; @@ -133,7 +135,7 @@ class CompFp32BlockEpilogue { (float*)_param.scales + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, cachestep, M, N); assert(ret == JblasSuccess); if (_param.zps != nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::forward( + ret = kernel::wrapper::RemoveZeroPointBias::forward_wei( dstptr, cachestep, M, N, _param.zps + K_offset * _param.ldsb + N_offset, (float*)_param.scales + K_offset * _param.ldsb + N_offset, _param.ldra, _param.reduce + M_offset * _param.ldra + K_offset); @@ -184,7 +186,8 @@ class CompInt8BlockEpilogue { int ldsa; // optional if A asym uint8_t* zpA = nullptr; - float* reduceB = nullptr; + void* reduceB = nullptr; + JBLAS_DTYPE reduceBdtype = JBLAS_DTYPE::F32; // optional if B asym int8_t* zpB = nullptr; float* reduceA = nullptr; @@ -194,46 +197,56 @@ class CompInt8BlockEpilogue { const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, size_t cachesize) { JBLAS_CODE ret = JblasNotSupport; - if (_param.scaleBdtype == JBLAS_DTYPE::F32) { - ret = kernel::wrapper::DequanS32Fp32::template forward( - srcptr, cachestep, (float*)srcptr, cachestep, M, N, _param.scalesA + M_offset * _param.ldsa + K_offset, - _param.ldsa, (float*)_param.scalesB + N_offset + K_offset * _param.ldsb); - assert(ret == JblasSuccess); - } else if (_param.scaleBdtype == JBLAS_DTYPE::BF16) { - ret = kernel::wrapper::DequanS32Fp32::template forward( - srcptr, cachestep, (float*)srcptr, cachestep, M, N, _param.scalesA + M_offset * _param.ldsa + K_offset, - _param.ldsa, (utils::bf16*)_param.scalesB + N_offset + K_offset * _param.ldsb); + float* scab = nullptr; + size_t ScaleBTmpSize = N * sizeof(float); + size_t ReduceBTmpSize = N * sizeof(float); + assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize)); + if (_param.scaleBdtype == JBLAS_DTYPE::BF16) { + auto scache = (float*)tmpcache; + ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( + (utils::bf16*)_param.scalesB + N_offset + K_offset * _param.ldsb, scache, 1, N, N, N, false); assert(ret == JblasSuccess); + scab = scache; + } else if (_param.scaleBdtype == JBLAS_DTYPE::F32) { + scab = (float*)_param.scalesB + N_offset + K_offset * _param.ldsb; } - ret = kernel::wrapper::AccumulateFp32::template forward((float*)srcptr, cachestep, dstptr, cachestep, M, N); - if (ret != JblasSuccess) { - assert(0); - return ret; + float* redb = nullptr; + if (_param.reduceB) { + if (_param.reduceBdtype == JBLAS_DTYPE::BF16) { + auto rcache = (float*)((char*)tmpcache + ScaleBTmpSize); + ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( + (utils::bf16*)_param.reduceB + N_offset + K_offset * _param.ldsb, rcache, 1, N, N, N, false); + assert(ret == JblasSuccess); + redb = rcache; + } else if (_param.reduceBdtype == JBLAS_DTYPE::F32) { + redb = (float*)_param.reduceB + N_offset + K_offset * _param.ldsb; + } } + ret = kernel::wrapper::DequanS32Fp32::template forward(srcptr, cachestep, (float*)srcptr, cachestep, M, N, + _param.scalesA + M_offset * _param.ldsa + K_offset, + _param.ldsa, scab); + assert(ret == JblasSuccess); + ret = kernel::wrapper::AccumulateFp32::template forward((float*)srcptr, cachestep, dstptr, cachestep, M, N); + assert(ret == JblasSuccess); - if (_param.zpA == nullptr && _param.zpB == nullptr) { - return ret; - } else if (_param.zpA != nullptr && _param.zpB == nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward( - dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, - _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, - _param.reduceB + N_offset + K_offset * _param.ldsb); - - } else if (_param.zpA == nullptr && _param.zpB != nullptr) { - if (_param.scaleBdtype == JBLAS_DTYPE::F32) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward( - dstptr, cachestep, M, N, _param.zpB + N_offset + K_offset * _param.ldsb, - (float*)_param.scalesB + N_offset + K_offset * _param.ldsb, _param.ldsa, + if (_param.zpA == nullptr) { + if (_param.zpB == nullptr) { + return ret; + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( + dstptr, cachestep, M, N, _param.zpB + N_offset + K_offset * _param.ldsb, scab, _param.ldsa, _param.reduceA + M_offset * _param.ldsa + K_offset); } - } else { - if (_param.scaleBdtype == JBLAS_DTYPE::F32) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward( + if (_param.zpB == nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( + dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, + _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, redb); + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, - _param.zpB + N_offset + K_offset * _param.ldsb, _param.scalesA + M_offset * _param.ldsa + K_offset, - (float*)_param.scalesB + N_offset + K_offset * _param.ldsb, _param.ldsa, _param.K, - _param.reduceA + M_offset * _param.ldsa + K_offset, _param.reduceB + N_offset + K_offset * _param.ldsb); + _param.zpB + N_offset + K_offset * _param.ldsb, _param.scalesA + M_offset * _param.ldsa + K_offset, scab, + _param.ldsa, _param.K, _param.reduceA + M_offset * _param.ldsa + K_offset, redb); } } return ret; @@ -271,15 +284,15 @@ class ZpDequantInt32ToFp32 { if (_param.zpA == nullptr && _param.zpB == nullptr) { return ret; } else if (_param.zpA != nullptr && _param.zpB == nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward( + ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.scalesA + M_offset * _param.ldsa, _param.ldsa, _param.reduceB + N_offset); } else if (_param.zpA == nullptr && _param.zpB != nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward(cptr, _param.ldc, M, N, _param.zpB + N_offset, - _param.scalesB + N_offset, _param.ldsa, - _param.reduceA + M_offset * _param.ldsa); + ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( + cptr, _param.ldc, M, N, _param.zpB + N_offset, _param.scalesB + N_offset, _param.ldsa, + _param.reduceA + M_offset * _param.ldsa); } else { - ret = kernel::wrapper::RemoveZeroPointBias::template forward( + ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.zpB + N_offset, _param.scalesA + M_offset * _param.ldsa, _param.scalesB + N_offset, _param.ldsa, _param.K, _param.reduceA + M_offset * _param.ldsa, _param.reduceB + N_offset); diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h index 5fefbca23b477..e024dcd45643e 100644 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h @@ -19,2852 +19,2152 @@ namespace jblas { namespace gemm { +enum class CompType : uint32_t { + COMP_FP32 = 0, + COMP_BF16_FP32 = 1, + COMP_FP16_FP16 = 2, + COMP_INT_START = 3, + COMP_INT8_US_INT32 = COMP_INT_START, + COMP_INT8_UU_INT32 = 4, + COMP_INT8_SS_INT32 = 5, + COMP_INT8_SU_INT32 = 6, + COMP_INT16_SS_INT32 = 7, + COMP_INT8_US_FP32 = 8, + COMP_INT8_UU_FP32 = 9, + COMP_INT8_SS_FP32 = 10, + COMP_INT8_SU_FP32 = 11, +}; -class GemmCore_Row_NN_4x24_AVX2 { +class CoreAttr { public: - struct params { - float *matA, *matB, *matC; - int k, nsize; - int astep, bstep, cstep; - int kpos; - }; - typedef long long (*func_t)(params*); - typedef float AType; - typedef float BType; - typedef float CType; - static JBLAS_GEMM_CORE constexpr TYPE = JBLAS_GEMM_CORE::AVX2_4X24; - static JBLAS_ISA constexpr ISA = - utils::gemm_core_mask(); - static int constexpr NTILE = 24, MTILE = 4, KTILE = 4 / sizeof(BType); - static int constexpr KUNROLL = 2; - static int constexpr PACK_ROW = 1; - static int constexpr PREFERED_N = 144; - class MicroKernel : protected jblas::xbyak::JitAvx2 { - public: - MicroKernel() {} - static int constexpr VecBytes = 32; - static int constexpr VecElements = VecBytes / sizeof(CType); - int CRegCount = 12, BRegCount = 3, ARegCount = 1; - int CReg = 0, BReg = 12, AReg = 15, TmpReg = BReg; - int const NRegs = NTILE / VecElements; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; + // INT32=LSB|**8bits:NTile**||**8bits:PackRow**||**8bits:CompType**||**8bits:Reserve**| + static uint32_t constexpr NTILE_MASK = 0xff, NTILE_SHIFT = 0, PACKROW_MASK = 0xff00, PACKROW_SHIFT = 8, + COMP_MASK = 0xff0000, COMP_SHIFT = 16, ISA_MASK = 0xff000000, ISA_SHIFT = 24; - protected: - void generate_mtile(int _mtile) { - CRegCount = _mtile * NRegs; - BRegCount = NRegs; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_nsize = st.t[9]; - reg_cstep = st.t[3]; - reg_astep = st.t[5]; - reg_iterk = st.t[4]; - reg_itern = st.t[7]; - reg_tmp = st.t[6]; - reg_tmp1 = st.t[8]; - reg_tmp2 = st.t[10]; - reg_ret = rax; - - vreg_push(rsp); - - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(nsize)]); - load32(reg_astep, ptr[parambase + OFFSET(astep)]); - - xor_(reg_itern, reg_itern); - L(".nloop"); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vpxor(Xbyak::Ymm(CReg + i * NRegs + j), Xbyak::Ymm(CReg + i * NRegs + j), Xbyak::Ymm(CReg + i * NRegs + j)); - } - } - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - mov(reg_tmp1, reg_matBptr); - - xor_(reg_iterk, reg_iterk); - - mov(reg_tmp, reg_nsize); - sub(reg_tmp, reg_itern); - cmp(reg_tmp, NTILE); - jl(".n16", T_NEAR); - generate_kloop(_mtile, NRegs); - write_back(_mtile, NRegs); - load32(reg_tmp, ptr[parambase + OFFSET(bstep)]); - imul(reg_tmp, reg_tmp, NTILE); - add(reg_matBptr, reg_tmp); - add(reg_itern, NTILE); - jmp(".nend", T_NEAR); - - L(".n16"); - cmp(reg_tmp, 16); - jl(".n8", T_NEAR); - generate_kloop(_mtile, 2); - write_back(_mtile, 2); - add(reg_itern, 16); - add(reg_matBptr, 16 * sizeof(BType)); - jmp(".nend", T_NEAR); - - L(".n8"); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile, 1); - write_back(_mtile, 1); - add(reg_itern, 8); - add(reg_matBptr, 8 * sizeof(BType)); - L(".nend"); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile, int _nregs) { - inLocalLabel(); - L(".kloop"); - mov(reg_tmp, reg_ksize); - sub(reg_tmp, reg_iterk); - cmp(reg_tmp, KUNROLL * KTILE); - jl(".k1loop", T_NEAR); - generate_fma(_mtile, _nregs, KUNROLL, reg_tmp1); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_tmp1, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - jmp(".kloopend", T_NEAR); - - L(".k1loop"); - generate_fma(_mtile, _nregs, 1, reg_tmp1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_tmp1, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - L(".kloopend"); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _NRegs, int _ktile, const Xbyak::Reg64& _reg_matBptr) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp, ptr[reg_matAptr + kk * AKStepSize]); - for (int i = 0; i < _NRegs; i++) { - vmovups(Xbyak::Ymm(BReg + i), ptr[_reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(Xbyak::Ymm(AReg), ptr[reg_tmp]); - add(reg_tmp, reg_astep); - for (int i = 0; i < _NRegs; i++) { - vfmadd231ps(Xbyak::Ymm(CReg + mm * NRegs + i), Xbyak::Ymm(BReg + i), Xbyak::Ymm(AReg)); - } - } - } - } - - void write_back(int _mtile, int _NRegs) { - inLocalLabel(); - load32(reg_matCptr, ptr[parambase + OFFSET(kpos)]); - cmp(reg_matCptr, 0); - jg(".LACC", T_NEAR); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < _NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Ymm(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstep); - } - jmp(".LEND", T_NEAR); - L(".LACC"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < _NRegs; j++) { - vaddps(Xbyak::Ymm(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Ymm(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstep); - } - L(".LEND"); - nop(); - outLocalLabel(); - } - - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstep; - Xbyak::Reg64 reg_astep; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - }; + static inline uint32_t get_mask_val(uint32_t raw, uint32_t mask, uint32_t shift) { return (raw & mask) >> shift; } + static constexpr uint32_t make_core_id(uint32_t NTile, uint32_t PackRow, uint32_t CompType, uint32_t ISA) { + return (NTile << NTILE_SHIFT) | (PackRow << PACKROW_SHIFT) | (CompType << COMP_SHIFT) | (ISA << ISA_SHIFT); + } - public: - GemmCore_Row_NN_4x24_AVX2() { - for (int i = 0; i < MTILE; i++) { - mCodes[i].generate_code(i + 1); - } + static void parse_id(uint32_t id, uint32_t* vals) { + vals[0] = get_mask_val(id, NTILE_MASK, NTILE_SHIFT); + vals[1] = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); + vals[2] = get_mask_val(id, COMP_MASK, COMP_SHIFT); + vals[3] = get_mask_val(id, ISA_MASK, ISA_SHIFT); } - void forward(float* matA, float* matB, float* matC, int _m, int _n, int _k, int _astride, int _bstride, int _cstride, - int kpos, void* tmpcache, size_t cachesize) { - auto param = params{matA, matB, matC, _k, _n, _astride, _bstride, _cstride, kpos}; - if (_m <= MTILE) { - mCodes[_m - 1].mKernel(¶m); - } else { - assert(0); - } + static const char* to_str(uint32_t id) { + static char tmp[128]; + uint32_t vals[4]; + parse_id(id, vals); + sprintf(tmp, "N%d_PACK%d_COMP%d_ISA%d", vals[0], vals[1], vals[2], vals[3]); + return tmp; } - private: - std::array mCodes; + static inline size_t get_bsize(uint32_t id) { + auto packrow = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); + return size_t(4 / packrow); + } }; -class GemmCore_Row_NN_2x48_AVX2 { +namespace code { + +template +class Avx2N8P1 : protected jblas::xbyak::JitAvx2 { public: - struct params { - float *matA, *matB, *matC; - int k, nsize; - int astep, bstep, cstep; - int kpos; - }; - typedef long long (*func_t)(params*); + static int constexpr RegLen = 8, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX2; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; typedef float AType; typedef float BType; typedef float CType; - static JBLAS_GEMM_CORE constexpr TYPE = JBLAS_GEMM_CORE::AVX2_2X48; - static JBLAS_ISA constexpr ISA = - utils::gemm_core_mask(); - static int constexpr NTILE = 48, MTILE = 2, KTILE = 4 / sizeof(BType); - static int constexpr KUNROLL = 2; - static int constexpr PACK_ROW = 1; - static int constexpr PREFERED_N = 144; - class MicroKernel : protected jblas::xbyak::JitAvx2 { - public: - MicroKernel() {} - static int constexpr VecBytes = 32; - static int constexpr VecElements = VecBytes / sizeof(CType); - int CRegCount = 12, BRegCount = 1, ARegCount = 2; - int CReg = 0, BReg = 12, AReg = 13, TmpReg = BReg; - int const NRegs = NTILE / VecElements; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - void generate_mtile(int _mtile) { - CRegCount = _mtile * NRegs; - BRegCount = 1; - ARegCount = _mtile; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_nsize = st.t[9]; - reg_cstep = st.t[3]; - reg_astep = st.t[5]; - reg_iterk = st.t[4]; - reg_itern = st.t[7]; - reg_tmp = st.t[6]; - reg_tmp1 = st.t[8]; - reg_tmp2 = st.t[10]; - reg_ret = rax; - - vreg_push(rsp); - - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(nsize)]); - load32(reg_astep, ptr[parambase + OFFSET(astep)]); - - xor_(reg_itern, reg_itern); - L(".nloop"); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vpxor(Xbyak::Ymm(CReg + i * NRegs + j), Xbyak::Ymm(CReg + i * NRegs + j), Xbyak::Ymm(CReg + i * NRegs + j)); - } - } - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - mov(reg_tmp1, reg_matBptr); - - xor_(reg_iterk, reg_iterk); - - generate_kloop(_mtile, NRegs); - write_back(_mtile, NRegs); - load32(reg_tmp, ptr[parambase + OFFSET(bstep)]); - imul(reg_tmp, reg_tmp, NTILE); - add(reg_matBptr, reg_tmp); - add(reg_itern, NTILE); - - cmp(reg_itern, reg_nsize); - jb(".nloop"); - - mov(reg_ret, 0); - vreg_pop(rsp); - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile, int _nregs) { - inLocalLabel(); - L(".kloop"); - mov(reg_tmp, reg_ksize); - sub(reg_tmp, reg_iterk); - cmp(reg_tmp, KUNROLL * KTILE); - jl(".k1loop", T_NEAR); - generate_fma(_mtile, _nregs, KUNROLL, reg_tmp1); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_tmp1, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - jmp(".kloopend", T_NEAR); - - L(".k1loop"); - generate_fma(_mtile, _nregs, 1, reg_tmp1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_tmp1, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - L(".kloopend"); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _NRegs, int _ktile, const Xbyak::Reg64& _reg_matBptr) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp, ptr[reg_matAptr + kk * AKStepSize]); - for (int i = 0; i < _mtile; ++i) - vbroadcastss(Xbyak::Ymm(AReg + i), ptr[reg_matAptr + kk * AKStepSize + reg_astep * i]); - for (int j = 0; j < _NRegs; j++) { - vmovups(Xbyak::Ymm(BReg), ptr[_reg_matBptr + kk * BKStepSize + j * VecBytes]); - for (int i = 0; i < _mtile; ++i) - vfmadd231ps(Xbyak::Ymm(CReg + i * NRegs + j), Xbyak::Ymm(BReg), Xbyak::Ymm(AReg + i)); - } - } - } - - void write_back(int _mtile, int _NRegs) { - inLocalLabel(); - load32(reg_matCptr, ptr[parambase + OFFSET(kpos)]); - cmp(reg_matCptr, 0); - jg(".LACC", T_NEAR); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < _NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Ymm(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstep); - } - jmp(".LEND", T_NEAR); - L(".LACC"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < _NRegs; j++) { - vaddps(Xbyak::Ymm(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Ymm(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstep); - } - L(".LEND"); - nop(); - outLocalLabel(); - } - - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstep; - Xbyak::Reg64 reg_astep; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; }; + typedef long long (*func_t)(params*); - public: - GemmCore_Row_NN_2x48_AVX2() { - for (int i = 0; i < MTILE; i++) { - mCodes[i].generate_code(i + 1); + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; } - void forward(float* matA, float* matB, float* matC, int _m, int _n, int _k, int _astride, int _bstride, int _cstride, - int kpos, void* tmpcache, size_t cachesize) { - auto param = params{matA, matB, matC, _k, _n, _astride, _bstride, _cstride, kpos}; - if (_m <= MTILE) { - mCodes[_m - 1].mKernel(¶m); - } else { - assert(0); - } + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label } - private: - std::array mCodes; -}; + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } -class GemmCore_Row_NN_4x24_AVX_VNNI { - public: - struct params { - uint8_t* matA; - int8_t* matB; - int32_t* matC; - int k, nsize; - int astep, bstep, cstep; - int kpos; - }; - typedef long long (*func_t)(params*); - typedef uint8_t AType; - typedef int8_t BType; - typedef int32_t CType; - static JBLAS_GEMM_CORE constexpr TYPE = JBLAS_GEMM_CORE::AVX_VNNI_4x24; - static JBLAS_ISA constexpr ISA = - utils::gemm_core_mask(); - static int constexpr NTILE = 24, MTILE = 4, KTILE = 4 / sizeof(BType); - static int constexpr PACK_ROW = KTILE; - static int constexpr KUNROLL = 2; - static int constexpr PREFERED_N = 192; - - class MicroKernel : protected jblas::xbyak::JitAvxvnni { - public: - MicroKernel() {} - int CRegCount = 12, BRegCount = 1, ARegCount = 1; - int CReg = 0, BReg = 12, AReg = 13, TmpReg = 14; - int const NRegs = 3; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - static int constexpr VecBytes = 32; - - void generate_code(int _mtile) { - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - void generate_mtile(int _mtile) { - CRegCount = _mtile * NRegs; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_nsize = st.t[9]; - reg_cstep = st.t[3]; - reg_astep = st.t[5]; - reg_iterk = st.t[4]; - reg_itern = st.t[7]; - reg_tmp = st.t[6]; - reg_tmp1 = st.t[8]; - reg_tmp2 = st.t[10]; - reg_ret = rax; - - vreg_push(rsp); - - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(nsize)]); - load32(reg_astep, ptr[parambase + OFFSET(astep)]); - - xor_(reg_itern, reg_itern); - L(".nloop"); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxorps(Xbyak::Ymm(CReg + i * NRegs + j), Xbyak::Ymm(CReg + i * NRegs + j), Xbyak::Ymm(CReg + i * NRegs + j)); + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); } - } - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - mov(reg_tmp1, reg_matBptr); - - xor_(reg_iterk, reg_iterk); - - mov(reg_tmp, reg_nsize); - sub(reg_tmp, reg_itern); - cmp(reg_tmp, NTILE); - - generate_kloop(_mtile, NRegs); - write_back(_mtile, NRegs); - load32(reg_tmp, ptr[parambase + OFFSET(bstep)]); - imul(reg_tmp, reg_tmp, NTILE); - add(reg_matBptr, reg_tmp); - add(reg_itern, NTILE); - - cmp(reg_itern, reg_nsize); - jb(".nloop"); - - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile, int _nregs) { - inLocalLabel(); - L(".kloop"); - mov(reg_tmp, reg_ksize); - sub(reg_tmp, reg_iterk); - cmp(reg_tmp, KTILE * KUNROLL); - jl(".k1loop", T_NEAR); - generate_fma(_mtile, _nregs, KUNROLL, reg_tmp1); - add(reg_matAptr, AKStepSize * KUNROLL); - add(reg_tmp1, BKStepSize * KUNROLL); - add(reg_iterk, KTILE * KUNROLL); - jmp(".kloopend", T_NEAR); - - L(".k1loop"); - generate_fma(_mtile, _nregs, 1, reg_tmp1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_tmp1, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - L(".kloopend"); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _NRegs, int _kunroll, const Xbyak::Reg64& _reg_matBptr) { - for (int kk = 0; kk < _kunroll; kk++) { - lea(reg_tmp, ptr[reg_matAptr + kk * AKStepSize]); for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastd(Xbyak::Ymm(AReg), ptr[reg_tmp]); - add(reg_tmp, reg_astep); - for (int i = 0; i < _NRegs; i++) { - vpdpbusds_vex(Xbyak::Ymm(CReg + mm * NRegs + i), Xbyak::Ymm(AReg), - ptr[_reg_matBptr + kk * BKStepSize + i * VecBytes]); + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); } } - } - } - - void write_back(int _mtile, int _NRegs) { - inLocalLabel(); - load32(reg_matCptr, ptr[parambase + OFFSET(kpos)]); - cmp(reg_matCptr, 0); - jg(".LACC", T_NEAR); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < _NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Ymm(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstep); - } - jmp(".LEND", T_NEAR); - L(".LACC"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < _NRegs; j++) { - vpaddd(Xbyak::Ymm(CReg + i * NRegs + j), Xbyak::Ymm(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Ymm(CReg + i * NRegs + j)); + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } } - add(reg_matCptr, reg_cstep); + } else { + assert(0); } - L(".LEND"); - nop(); - outLocalLabel(); - } - - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstep; - Xbyak::Reg64 reg_astep; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - }; - - public: - GemmCore_Row_NN_4x24_AVX_VNNI() { - for (int i = 0; i < MTILE; i++) { - mCodes[i].generate_code(i + 1); } } - void forward(AType* matA, BType* matB, CType* matC, int _m, int _n, int _k, int _astride, int _bstride, int _cstride, - int kpos, void* tmpcache, size_t cachesize) { - auto param = params{matA, matB, matC, _k, _n, _astride, _bstride, _cstride, kpos}; - if (_m <= MTILE) { - mCodes[_m - 1].mKernel(¶m); - } else { - assert(0); - } - } - - private: - std::array mCodes; -}; - -class GemmCore_Row_NN_2x48_AVX_VNNI { - public: - struct params { - uint8_t* matA; - int8_t* matB; - int32_t* matC; - int k, nsize; - int astep, bstep, cstep; - int kpos; - }; - typedef long long (*func_t)(params*); - typedef uint8_t AType; - typedef int8_t BType; - typedef int32_t CType; - static JBLAS_GEMM_CORE constexpr TYPE = JBLAS_GEMM_CORE::AVX_VNNI_2x48; - static JBLAS_ISA constexpr ISA = - utils::gemm_core_mask(); - static int constexpr NTILE = 48, MTILE = 2, KTILE = 4 / sizeof(BType); - static int constexpr PACK_ROW = KTILE; - static int constexpr KUNROLL = 2; - static int constexpr PREFERED_N = 192; - - class MicroKernel : protected jblas::xbyak::JitAvxvnni { - public: - MicroKernel() {} - int CRegCount = 12, BRegCount = 1, ARegCount = 1; - int CReg = 0, BReg = 12, AReg = 13, TmpReg = 14; - int const NRegs = 6; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - static int constexpr VecBytes = 32; - - void generate_code(int _mtile) { - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - void generate_mtile(int _mtile) { - CRegCount = _mtile * NRegs; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_nsize = st.t[9]; - reg_cstep = st.t[3]; - reg_astep = st.t[5]; - reg_iterk = st.t[4]; - reg_itern = st.t[7]; - reg_tmp = st.t[6]; - reg_tmp1 = st.t[8]; - reg_tmp2 = st.t[10]; - reg_ret = rax; - - vreg_push(rsp); - - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(nsize)]); - load32(reg_astep, ptr[parambase + OFFSET(astep)]); - - xor_(reg_itern, reg_itern); - L(".nloop"); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxorps(Xbyak::Ymm(CReg + i * NRegs + j), Xbyak::Ymm(CReg + i * NRegs + j), Xbyak::Ymm(CReg + i * NRegs + j)); - } - } - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - mov(reg_tmp1, reg_matBptr); - - xor_(reg_iterk, reg_iterk); - - mov(reg_tmp, reg_nsize); - sub(reg_tmp, reg_itern); - cmp(reg_tmp, NTILE); - - generate_kloop(_mtile, NRegs); - write_back(_mtile, NRegs); - load32(reg_tmp, ptr[parambase + OFFSET(bstep)]); - imul(reg_tmp, reg_tmp, NTILE); - add(reg_matBptr, reg_tmp); - add(reg_itern, NTILE); - - cmp(reg_itern, reg_nsize); - jb(".nloop"); - - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile, int _nregs) { - inLocalLabel(); - L(".kloop"); - mov(reg_tmp, reg_ksize); - sub(reg_tmp, reg_iterk); - cmp(reg_tmp, KTILE * KUNROLL); - jl(".k1loop", T_NEAR); - generate_fma(_mtile, _nregs, KUNROLL, reg_tmp1); - add(reg_matAptr, AKStepSize * KUNROLL); - add(reg_tmp1, BKStepSize * KUNROLL); - add(reg_iterk, KTILE * KUNROLL); - jmp(".kloopend", T_NEAR); - - L(".k1loop"); - generate_fma(_mtile, _nregs, 1, reg_tmp1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_tmp1, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - L(".kloopend"); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _NRegs, int _kunroll, const Xbyak::Reg64& _reg_matBptr) { - for (int kk = 0; kk < _kunroll; kk++) { - lea(reg_tmp, ptr[reg_matAptr + kk * AKStepSize]); - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastd(Xbyak::Ymm(AReg), ptr[reg_tmp]); - add(reg_tmp, reg_astep); - for (int i = 0; i < _NRegs; i++) { - vpdpbusds_vex(Xbyak::Ymm(CReg + mm * NRegs + i), Xbyak::Ymm(AReg), - ptr[_reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); } } - - void write_back(int _mtile, int _NRegs) { - inLocalLabel(); - load32(reg_matCptr, ptr[parambase + OFFSET(kpos)]); - cmp(reg_matCptr, 0); - jg(".LACC", T_NEAR); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < _NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Ymm(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstep); - } - jmp(".LEND", T_NEAR); - L(".LACC"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < _NRegs; j++) { - vpaddd(Xbyak::Ymm(CReg + i * NRegs + j), Xbyak::Ymm(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Ymm(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstep); + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); } - L(".LEND"); - nop(); - outLocalLabel(); - } - - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstep; - Xbyak::Reg64 reg_astep; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - }; - - public: - GemmCore_Row_NN_2x48_AVX_VNNI() { - for (int i = 0; i < MTILE; i++) { - mCodes[i].generate_code(i + 1); + add(reg_matCptr, reg_cstride); } + L(".end"); + outLocalLabel(); } - void forward(AType* matA, BType* matB, CType* matC, int _m, int _n, int _k, int _astride, int _bstride, int _cstride, - int kpos, void* tmpcache, size_t cachesize) { - auto param = params{matA, matB, matC, _k, _n, _astride, _bstride, _cstride, kpos}; - if (_m <= MTILE) { - mCodes[_m - 1].mKernel(¶m); - } else { - assert(0); + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); } + outLocalLabel(); } - - private: - std::array mCodes; }; -class GemmCore_Row_NN_8x48_AVX512F { +template +class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { public: - struct params { - float *matA, *matB, *matC; - int k, nsize; - int astep, bstep, cstep; - int kpos; - }; - typedef long long (*func_t)(params*); + static int constexpr RegLen = 16, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; typedef float AType; typedef float BType; typedef float CType; - static JBLAS_GEMM_CORE constexpr TYPE = JBLAS_GEMM_CORE::AVX512F_8x48; - static JBLAS_ISA constexpr ISA = - utils::gemm_core_mask(); - static int constexpr NTILE = 48, MTILE = 8, KTILE = 4 / sizeof(BType); - static int constexpr KUNROLL = 2; - static int constexpr PACK_ROW = 1; - static int constexpr PREFERED_N = 144; - class MicroKernel : protected jblas::xbyak::JitAvx512f { - public: - MicroKernel() {} - int CRegCount = 24, BRegCount = 6, ARegCount = 1; - int CReg = 0, BReg = 24, AReg = 27, TmpReg = 28; - int const NRegs = 3; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - static int constexpr VecBytes = 64; - - void generate_code(int _mtile) { - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - protected: - void generate_mtile(int _mtile) { - CRegCount = _mtile * NRegs; + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { BRegCount = NRegs; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_nsize = st.t[9]; - reg_cstep = st.t[3]; - reg_astep = st.t[5]; - reg_iterk = st.t[4]; - reg_itern = st.t[7]; - reg_tmp = st.t[6]; - reg_tmp1 = st.t[8]; - reg_tmp2 = st.t[10]; - reg_ret = rax; - - vreg_push(rsp); - - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(nsize)]); - load32(reg_astep, ptr[parambase + OFFSET(astep)]); - - xor_(reg_itern, reg_itern); - L(".nloop"); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vpxorq(Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j)); - } - } - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - mov(reg_tmp1, reg_matBptr); - - xor_(reg_iterk, reg_iterk); - - mov(reg_tmp, reg_nsize); - sub(reg_tmp, reg_itern); - cmp(reg_tmp, NTILE); - jl(".n32", T_NEAR); - generate_kloop(_mtile, NRegs); - write_back(_mtile, NRegs); - load32(reg_tmp, ptr[parambase + OFFSET(bstep)]); - imul(reg_tmp, reg_tmp, NTILE); - add(reg_matBptr, reg_tmp); - add(reg_itern, NTILE); - jmp(".nend", T_NEAR); - - L(".n32"); - cmp(reg_tmp, 32); - jl(".n16", T_NEAR); - generate_kloop(_mtile, 2); - write_back(_mtile, 2); - add(reg_itern, 32); - add(reg_matBptr, 32 * sizeof(BType)); - jmp(".nend", T_NEAR); - - L(".n16"); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile, 1); - write_back(_mtile, 1); - add(reg_itern, 16); - add(reg_matBptr, 16 * sizeof(BType)); - L(".nend"); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } - void generate_kloop(int _mtile, int _nregs) { - inLocalLabel(); - L(".kloop"); - mov(reg_tmp, reg_ksize); - sub(reg_tmp, reg_iterk); - cmp(reg_tmp, KUNROLL * KTILE); - jl(".k1loop", T_NEAR); - generate_fma(_mtile, _nregs, KUNROLL, reg_tmp1); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_tmp1, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - jmp(".kloopend", T_NEAR); - - L(".k1loop"); - generate_fma(_mtile, _nregs, 1, reg_tmp1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_tmp1, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - L(".kloopend"); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - outLocalLabel(); - } + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } - void generate_fma(int _mtile, int _NRegs, int _ktile, const Xbyak::Reg64& _reg_matBptr) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp, ptr[reg_matAptr + kk * AKStepSize]); - for (int i = 0; i < _NRegs; i++) { - vmovups(Xbyak::Zmm(BReg + i), ptr[_reg_matBptr + kk * BKStepSize + i * VecBytes]); + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); } for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(Xbyak::Zmm(AReg), ptr[reg_tmp]); - add(reg_tmp, reg_astep); - for (int i = 0; i < _NRegs; i++) { - vfmadd231ps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(BReg + i), Xbyak::Zmm(AReg)); + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } } } + } else { + assert(0); } } + } - void write_back(int _mtile, int _NRegs) { - inLocalLabel(); - load32(reg_matCptr, ptr[parambase + OFFSET(kpos)]); - cmp(reg_matCptr, 0); - jg(".LACC", T_NEAR); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < _NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Zmm(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstep); - } - jmp(".LEND", T_NEAR); - L(".LACC"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < _NRegs; j++) { - vaddps(Xbyak::Zmm(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Zmm(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstep); + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); } - L(".LEND"); - nop(); - outLocalLabel(); } - - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstep; - Xbyak::Reg64 reg_astep; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - }; - - public: - GemmCore_Row_NN_8x48_AVX512F() { - for (int i = 0; i < MTILE; i++) { - mCodes[i].generate_code(i + 1); + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); } + L(".end"); + outLocalLabel(); } - void forward(float* matA, float* matB, float* matC, int _m, int _n, int _k, int _astride, int _bstride, int _cstride, - int kpos, void* tmpcache, size_t cachesize) { - auto param = params{matA, matB, matC, _k, _n, _astride, _bstride, _cstride, kpos}; - if (_m <= MTILE) { - mCodes[_m - 1].mKernel(¶m); - } else { - assert(0); + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); } + outLocalLabel(); } - - private: - std::array mCodes; }; -class GemmCore_Row_NN_8x64_AVX512_FP16 { +template +class Avx512fp16N32P1 : protected jblas::xbyak::JitAvx512_fp16 { public: + static int constexpr RegLen = 32, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_FP16; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP16_FP16; typedef utils::fp16 AType; typedef utils::fp16 BType; typedef utils::fp16 CType; + struct params { AType* matA; + int astride; BType* matB; + int bstride; CType* matC; - int k, nsize; - int astep, bstep, cstep; - int kpos; + int cstride; + int k; + int n; + int init; }; typedef long long (*func_t)(params*); - static JBLAS_GEMM_CORE constexpr TYPE = JBLAS_GEMM_CORE::AVX512_FP16_8x64; - static JBLAS_ISA constexpr ISA = - utils::gemm_core_mask(); - static int constexpr NTILE = 64, MTILE = 12, KTILE = 1; - static int constexpr KUNROLL = 2; - static int constexpr PACK_ROW = 1; - static int constexpr PREFERED_N = 128; - class MicroKernel : protected jblas::xbyak::JitAvx512_fp16 { - public: - MicroKernel() {} - int CRegCount = 24, BRegCount = 2, ARegCount = 1; - int CReg = 0, BReg = 24, AReg = 26, TmpReg = 27; - int const NRegs = 2; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - static int constexpr VecBytes = 64; - - void generate_code(int _mtile) { - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - void generate_mtile(int _mtile) { - CRegCount = _mtile * NRegs; + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { BRegCount = NRegs; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_nsize = st.t[9]; - reg_cstep = st.t[3]; - reg_astep = st.t[5]; - reg_iterk = st.t[4]; - reg_itern = st.t[7]; - reg_tmp = st.t[6]; - reg_tmp1 = st.t[8]; - reg_tmp2 = st.t[10]; - reg_ret = rax; - - vreg_push(rsp); - - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(nsize)]); - load32(reg_astep, ptr[parambase + OFFSET(astep)]); - - xor_(reg_itern, reg_itern); - L(".nloop"); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vpxorq(Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j)); - } - } - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - mov(reg_tmp1, reg_matBptr); - - xor_(reg_iterk, reg_iterk); - - mov(reg_tmp, reg_nsize); - sub(reg_tmp, reg_itern); - cmp(reg_tmp, NTILE); - jl(".n32", T_NEAR); - generate_kloop(_mtile, NRegs); - write_back(_mtile, NRegs); - load32(reg_tmp, ptr[parambase + OFFSET(bstep)]); - imul(reg_tmp, reg_tmp, NTILE); - add(reg_matBptr, reg_tmp); - add(reg_itern, NTILE); - jmp(".nend", T_NEAR); - - L(".n32"); - generate_kloop(_mtile, 1); - write_back(_mtile, 1); - add(reg_itern, 32); - add(reg_matBptr, 32 * sizeof(BType)); - - L(".nend"); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } - void generate_kloop(int _mtile, int _nregs) { - inLocalLabel(); - L(".kloop"); - mov(reg_tmp, reg_ksize); - sub(reg_tmp, reg_iterk); - cmp(reg_tmp, KUNROLL * KTILE); - jl(".k1loop", T_NEAR); - generate_fma(_mtile, _nregs, KUNROLL, reg_tmp1); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_tmp1, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - jmp(".kloopend", T_NEAR); - - L(".k1loop"); - generate_fma(_mtile, _nregs, 1, reg_tmp1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_tmp1, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - L(".kloopend"); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - outLocalLabel(); - } + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } - void generate_fma(int _mtile, int _NRegs, int _ktile, const Xbyak::Reg64& _reg_matBptr) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp, ptr[reg_matAptr + kk * AKStepSize]); - for (int i = 0; i < _NRegs; i++) { - vmovups(Xbyak::Zmm(BReg + i), ptr[_reg_matBptr + kk * BKStepSize + i * VecBytes]); + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); } for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastw(Xbyak::Zmm(AReg), ptr[reg_tmp]); - add(reg_tmp, reg_astep); - for (int i = 0; i < _NRegs; i++) { - vfmadd231ph(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(BReg + i), Xbyak::Zmm(AReg)); + vpbroadcastw(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastw(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } } } + } else { + assert(0); } } + } - void write_back(int _mtile, int _NRegs) { - inLocalLabel(); - load32(reg_matCptr, ptr[parambase + OFFSET(kpos)]); - cmp(reg_matCptr, 0); - jg(".LACC", T_NEAR); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < _NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Zmm(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstep); + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); } - jmp(".LEND", T_NEAR); - L(".LACC"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < _NRegs; j++) { - vaddph(Xbyak::Zmm(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Zmm(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstep); - } - L(".LEND"); - nop(); - outLocalLabel(); } - - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstep; - Xbyak::Reg64 reg_astep; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - }; - - public: - GemmCore_Row_NN_8x64_AVX512_FP16() { - for (int i = 0; i < MTILE; i++) { - mCodes[i].generate_code(i + 1); + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); } + L(".end"); + outLocalLabel(); } - void forward(utils::fp16* matA, utils::fp16* matB, utils::fp16* matC, int _m, int _n, int _k, int _astride, - int _bstride, int _cstride, int kpos, void* tmpcache, size_t cachesize) { - auto param = params{matA, matB, matC, _k, _n, _astride, _bstride, _cstride, kpos}; - if (_m <= MTILE) { - mCodes[_m - 1].mKernel(¶m); - } else { - assert(0); + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); } + outLocalLabel(); } - - private: - std::array mCodes; }; -class GemmCore_Row_NN_8x96_AVX512_FP16 { +template +class Avx512bf16N16P2 : protected jblas::xbyak::JitAvx512_bf16 { public: - typedef utils::fp16 AType; - typedef utils::fp16 BType; - typedef utils::fp16 CType; + static int constexpr RegLen = 16, PackRow = 2; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 2; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_BF16; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32; + typedef utils::bf16 AType; + typedef utils::bf16 BType; + typedef float CType; + struct params { AType* matA; + int astride; BType* matB; + int bstride; CType* matC; - int k, nsize; - int astep, bstep, cstep; - int kpos; + int cstride; + int k; + int n; + int init; }; typedef long long (*func_t)(params*); - static JBLAS_GEMM_CORE constexpr TYPE = JBLAS_GEMM_CORE::AVX512_FP16_8x96; - static JBLAS_ISA constexpr ISA = - utils::gemm_core_mask(); - static int constexpr NTILE = 96, MTILE = 8, KTILE = 1; - static int constexpr KUNROLL = 2; - static int constexpr PACK_ROW = 1; - static int constexpr PREFERED_N = 192; - class MicroKernel : protected jblas::xbyak::JitAvx512_fp16 { - public: - MicroKernel() {} - int CRegCount = 24, BRegCount = 3, ARegCount = 1; - int CReg = 0, BReg = 24, AReg = 27, TmpReg = 28; - int const NRegs = 3; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - static int constexpr VecBytes = 64; - - void generate_code(int _mtile) { - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - protected: - void generate_mtile(int _mtile) { - CRegCount = _mtile * NRegs; + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { BRegCount = NRegs; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_nsize = st.t[9]; - reg_cstep = st.t[3]; - reg_astep = st.t[5]; - reg_iterk = st.t[4]; - reg_itern = st.t[7]; - reg_tmp = st.t[6]; - reg_tmp1 = st.t[8]; - reg_tmp2 = st.t[10]; - reg_ret = rax; - - vreg_push(rsp); - - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(nsize)]); - load32(reg_astep, ptr[parambase + OFFSET(astep)]); - - xor_(reg_itern, reg_itern); - L(".nloop"); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vpxorq(Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j)); - } - } - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - mov(reg_tmp1, reg_matBptr); - - xor_(reg_iterk, reg_iterk); - - mov(reg_tmp, reg_nsize); - sub(reg_tmp, reg_itern); - cmp(reg_tmp, NTILE); - jl(".n64", T_NEAR); - generate_kloop(_mtile, NRegs); - write_back(_mtile, NRegs); - load32(reg_tmp, ptr[parambase + OFFSET(bstep)]); - imul(reg_tmp, reg_tmp, NTILE); - add(reg_matBptr, reg_tmp); - add(reg_itern, NTILE); - jmp(".nend", T_NEAR); - - L(".n64"); - cmp(reg_tmp, 64); - jl(".n32", T_NEAR); - generate_kloop(_mtile, 2); - write_back(_mtile, 2); - add(reg_itern, 64); - add(reg_matBptr, 64 * sizeof(BType)); - jmp(".nend", T_NEAR); - - L(".n32"); - generate_kloop(_mtile, 1); - write_back(_mtile, 1); - add(reg_itern, 32); - add(reg_matBptr, 32 * sizeof(BType)); - - L(".nend"); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } - void generate_kloop(int _mtile, int _nregs) { - inLocalLabel(); - L(".kloop"); - mov(reg_tmp, reg_ksize); - sub(reg_tmp, reg_iterk); - cmp(reg_tmp, KUNROLL * KTILE); - jl(".k1loop", T_NEAR); - generate_fma(_mtile, _nregs, KUNROLL, reg_tmp1); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_tmp1, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - jmp(".kloopend", T_NEAR); - - L(".k1loop"); - generate_fma(_mtile, _nregs, 1, reg_tmp1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_tmp1, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - L(".kloopend"); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - outLocalLabel(); - } + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } - void generate_fma(int _mtile, int _NRegs, int _ktile, const Xbyak::Reg64& _reg_matBptr) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp, ptr[reg_matAptr + kk * AKStepSize]); - for (int i = 0; i < _NRegs; i++) { - vmovups(Xbyak::Zmm(BReg + i), ptr[_reg_matBptr + kk * BKStepSize + i * VecBytes]); + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); } for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastw(Xbyak::Zmm(AReg), ptr[reg_tmp]); - add(reg_tmp, reg_astep); - for (int i = 0; i < _NRegs; i++) { - vfmadd231ph(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(BReg + i), Xbyak::Zmm(AReg)); + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); } } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); } } + } - void write_back(int _mtile, int _NRegs) { - inLocalLabel(); - load32(reg_matCptr, ptr[parambase + OFFSET(kpos)]); - cmp(reg_matCptr, 0); - jg(".LACC", T_NEAR); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < _NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Zmm(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstep); + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); } - jmp(".LEND", T_NEAR); - L(".LACC"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < _NRegs; j++) { - vaddph(Xbyak::Zmm(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Zmm(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstep); - } - L(".LEND"); - nop(); - outLocalLabel(); } - - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstep; - Xbyak::Reg64 reg_astep; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - }; - - public: - GemmCore_Row_NN_8x96_AVX512_FP16() { - for (int i = 0; i < MTILE; i++) { - mCodes[i].generate_code(i + 1); + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); } + L(".end"); + outLocalLabel(); } - void forward(utils::fp16* matA, utils::fp16* matB, utils::fp16* matC, int _m, int _n, int _k, int _astride, - int _bstride, int _cstride, int kpos, void* tmpcache, size_t cachesize) { - auto param = params{matA, matB, matC, _k, _n, _astride, _bstride, _cstride, kpos}; - if (_m <= MTILE) { - mCodes[_m - 1].mKernel(¶m); - } else { - assert(0); + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); } + outLocalLabel(); } - - private: - std::array mCodes; }; -class GemmCore_Row_NN_8x48_AVX512_VNNI { +template +class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { public: - struct params { - uint8_t* matA; - int8_t* matB; - int32_t* matC; - int k, nsize; - int astep, bstep, cstep; - int kpos; - }; - typedef long long (*func_t)(params*); + static int constexpr RegLen = 16, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32; typedef uint8_t AType; typedef int8_t BType; typedef int32_t CType; - static JBLAS_GEMM_CORE constexpr TYPE = JBLAS_GEMM_CORE::AVX512_VNNI_8x48; - static JBLAS_ISA constexpr ISA = - utils::gemm_core_mask(); - static int constexpr NTILE = 48, MTILE = 8, KTILE = 4 / sizeof(BType); - static int constexpr PACK_ROW = KTILE; - static int constexpr KUNROLL = 2; - static int constexpr PREFERED_N = 192; - - class MicroKernel : protected jblas::xbyak::JitAvx512vnni { - public: - MicroKernel() {} - int CRegCount = 24, BRegCount = 6, ARegCount = 1; - int CReg = 0, BReg = 24, AReg = 27, TmpReg = 28; - int const NRegs = 3; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - static int constexpr VecBytes = 64; - - void generate_code(int _mtile) { - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; - protected: - void generate_mtile(int _mtile) { - CRegCount = _mtile * NRegs; + private: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + + protected: + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { BRegCount = NRegs; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_nsize = st.t[9]; - reg_cstep = st.t[3]; - reg_astep = st.t[5]; - reg_iterk = st.t[4]; - reg_itern = st.t[7]; - reg_tmp = st.t[6]; - reg_tmp1 = st.t[8]; - reg_tmp2 = st.t[10]; - reg_ret = rax; - - vreg_push(rsp); - - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(nsize)]); - load32(reg_astep, ptr[parambase + OFFSET(astep)]); - - xor_(reg_itern, reg_itern); - L(".nloop"); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vpxorq(Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j)); - } - } - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - mov(reg_tmp1, reg_matBptr); - - xor_(reg_iterk, reg_iterk); - - mov(reg_tmp, reg_nsize); - sub(reg_tmp, reg_itern); - cmp(reg_tmp, NTILE); - jl(".n32", T_NEAR); - generate_kloop(_mtile, NRegs); - write_back(_mtile, NRegs); - load32(reg_tmp, ptr[parambase + OFFSET(bstep)]); - imul(reg_tmp, reg_tmp, NTILE); - add(reg_matBptr, reg_tmp); - add(reg_itern, NTILE); - jmp(".nend", T_NEAR); - - L(".n32"); - cmp(reg_tmp, 32); - jl(".n16", T_NEAR); - generate_kloop(_mtile, 2); - write_back(_mtile, 2); - add(reg_itern, 32); - add(reg_matBptr, 32 * sizeof(int8_t) * 4); - jmp(".nend", T_NEAR); - - L(".n16"); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile, 1); - write_back(_mtile, 1); - add(reg_itern, 16); - add(reg_matBptr, 16 * sizeof(int8_t) * 4); - L(".nend"); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } - void generate_kloop(int _mtile, int _nregs) { - inLocalLabel(); - L(".kloop"); - mov(reg_tmp, reg_ksize); - sub(reg_tmp, reg_iterk); - cmp(reg_tmp, KTILE * KUNROLL); - jl(".k1loop", T_NEAR); - generate_fma(_mtile, _nregs, KUNROLL, reg_tmp1); - add(reg_matAptr, AKStepSize * KUNROLL); - add(reg_tmp1, BKStepSize * KUNROLL); - add(reg_iterk, KTILE * KUNROLL); - jmp(".kloopend", T_NEAR); - - L(".k1loop"); - generate_fma(_mtile, _nregs, 1, reg_tmp1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_tmp1, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - L(".kloopend"); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - outLocalLabel(); - } + void generate_mtile(int _mtile) { + inLocalLabel(); + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } - void generate_fma(int _mtile, int _NRegs, int _kunroll, const Xbyak::Reg64& _reg_matBptr) { - for (int kk = 0; kk < _kunroll; kk++) { - lea(reg_tmp, ptr[reg_matAptr + kk * AKStepSize]); - for (int i = 0; i < _NRegs; i++) { - vmovups(Xbyak::Zmm(BReg + i), ptr[_reg_matBptr + kk * BKStepSize + i * VecBytes]); + void generate_fma(int _mtile, int _kunroll) { + for (int kk = 0; kk < _kunroll; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); } for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastd(Xbyak::Zmm(AReg), ptr[reg_tmp]); - add(reg_tmp, reg_astep); - for (int i = 0; i < _NRegs; i++) { - vpdpbusds_evex(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(AReg), Xbyak::Zmm(BReg + i)); + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); } } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); } } + } - void write_back(int _mtile, int _NRegs) { - inLocalLabel(); - load32(reg_matCptr, ptr[parambase + OFFSET(kpos)]); - cmp(reg_matCptr, 0); - jg(".LACC", T_NEAR); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < _NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Zmm(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstep); - } - jmp(".LEND", T_NEAR); - L(".LACC"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < _NRegs; j++) { - vpaddd(Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Zmm(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstep); + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); } - L(".LEND"); - nop(); - outLocalLabel(); } - - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstep; - Xbyak::Reg64 reg_astep; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - }; - - public: - GemmCore_Row_NN_8x48_AVX512_VNNI() { - for (int i = 0; i < MTILE; i++) { - mCodes[i].generate_code(i + 1); + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); } + L(".end"); + outLocalLabel(); } - void forward(AType* matA, BType* matB, CType* matC, int _m, int _n, int _k, int _astride, int _bstride, int _cstride, - int kpos, void* tmpcache, size_t cachesize) { - auto param = params{matA, matB, matC, _k, _n, _astride, _bstride, _cstride, kpos}; - if (_m <= MTILE) { - mCodes[_m - 1].mKernel(¶m); - } else { - assert(0); + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); } + outLocalLabel(); } - - private: - std::array mCodes; }; -class GemmCore_Row_NN_16x64_AMX_BF16 { +template +class AvxvnniN8P4 : protected jblas::xbyak::JitAvxvnni { public: - typedef utils::bf16 AType; - typedef utils::bf16 BType; - typedef float CType; + static int constexpr RegLen = 8, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX_VNNI; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32; + typedef uint8_t AType; + typedef int8_t BType; + typedef int32_t CType; struct params { AType* matA; + int astride; BType* matB; + int bstride; CType* matC; - int k, msize, nsize; - int astep, bstep, cstep; - int kpos; - void *workspace, *cfg; + int cstride; + int k; + int n; + int init; }; typedef long long (*func_t)(params*); - static JBLAS_GEMM_CORE constexpr TYPE = JBLAS_GEMM_CORE::AMX_BF16_16x64; - static JBLAS_ISA constexpr ISA = - utils::gemm_core_mask(); - static int constexpr NTILE = 64, MTILE = 16, KTILE = 64 / sizeof(BType); - static int constexpr PACK_ROW = 2; - static int constexpr KUNROLL = 2; - static int constexpr PREFERED_N = 256; - class MicroKernel : protected jblas::xbyak::JitAmxbf16 { - public: - friend GemmCore_Row_NN_16x64_AMX_BF16; - MicroKernel() {} - static int constexpr CReg = 0, TmpReg = 4; - static int constexpr NRegs = 4; - static int constexpr CRegCount = NRegs; - static int constexpr C_tilenum = 4, A_tilenum = 1, B_tilenum = 3; - static int constexpr CTile = 0, ATile = CTile + C_tilenum, BTile = ATile + A_tilenum; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - static int constexpr VecBytes = 64; - - void generate_code() { - reset(); - generate_mtile(); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - void generate_mtile() { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_nsize = st.t[9]; - reg_cstep = st.t[3]; - reg_astep = st.t[5]; - reg_iterk = st.t[4]; - reg_itern = st.t[7]; - reg_tmp = st.t[6]; - reg_tmp1 = st.t[8]; - reg_tmp2 = st.t[10]; - reg_ret = rax; - - vreg_push(rsp); - mov(reg_tmp, ptr[parambase + OFFSET(cfg)]); - ldtilecfg(ptr[reg_tmp]); - - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(nsize)]); - load32(reg_astep, ptr[parambase + OFFSET(astep)]); - - xor_(reg_itern, reg_itern); - L(".nloop"); - for (int i = 0; i < C_tilenum; i++) { - tilezero(Xbyak::Tmm(CTile + i)); - } - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - mov(reg_tmp1, reg_matBptr); - - xor_(reg_iterk, reg_iterk); - - mov(reg_tmp, reg_nsize); - sub(reg_tmp, reg_itern); - cmp(reg_tmp, NTILE); - jl(".n48", T_NEAR); - generate_kloop(NRegs); - write_back(MTILE, NRegs); - load32(reg_tmp, ptr[parambase + OFFSET(bstep)]); - imul(reg_tmp, reg_tmp, NTILE); - add(reg_matBptr, reg_tmp); - add(reg_itern, NTILE); - jmp(".nend", T_NEAR); - - L(".n48"); - cmp(reg_tmp, 48); - jl(".n32", T_NEAR); - generate_kloop(3); - write_back(MTILE, 3); - add(reg_itern, 48); - add(reg_matBptr, 48 * sizeof(BType)); - jmp(".nend", T_NEAR); - - L(".n32"); - cmp(reg_tmp, 32); - jl(".n16", T_NEAR); - generate_kloop(2); - write_back(MTILE, 2); - add(reg_itern, 32); - add(reg_matBptr, 32 * sizeof(BType)); - jmp(".nend", T_NEAR); - - L(".n16"); - xor_(reg_iterk, reg_iterk); - generate_kloop(1); - write_back(MTILE, 1); - add(reg_itern, 16); - add(reg_matBptr, 16 * sizeof(BType)); - L(".nend"); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; - void generate_kloop(int _nregs) { - inLocalLabel(); - L(".kloop"); - mov(reg_tmp, reg_ksize); - sub(reg_tmp, reg_iterk); - cmp(reg_tmp, KUNROLL * KTILE); - jl(".k1loop", T_NEAR); - generate_fma(_nregs, KUNROLL, reg_tmp1); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_tmp1, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - jmp(".kloopend", T_NEAR); - - L(".k1loop"); - generate_fma(_nregs, 1, reg_tmp1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_tmp1, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - L(".kloopend"); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - outLocalLabel(); + private: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + protected: + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } - void generate_fma(int _NTile, int _kunroll, const Xbyak::Reg64& _reg_matBptr) { - mov(reg_tmp, NTILE * 4); - if (_NTile <= B_tilenum) { - for (int kk = 0; kk < _kunroll; kk++) { - for (int i = 0; i < _NTile; i++) { - tileloaddt1(Xbyak::Tmm(BTile + i), ptr[_reg_matBptr + reg_tmp + kk * BKStepSize + i * 64]); - } + void generate_mtile(int _mtile) { + inLocalLabel(); + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } - for (int mm = 0; mm < 1; mm++) { - tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_matAptr + reg_astep + kk * AKStepSize]); - for (int i = 0; i < _NTile; i++) { - tdpbf16ps(Xbyak::Tmm(CTile + mm * C_tilenum + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile + i)); - } - } + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _kunroll) { + for (int kk = 0; kk < _kunroll; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); } - } else { - for (int kk = 0; kk < _kunroll; kk++) { - for (int i = 0; i < _NTile - 1; i++) { - tileloaddt1(Xbyak::Tmm(BTile + i), ptr[_reg_matBptr + reg_tmp + kk * BKStepSize + i * 64]); + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); } - - for (int mm = 0; mm < 1; mm++) { - tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_matAptr + reg_astep + kk * AKStepSize]); - for (int i = 0; i < _NTile - 1; i++) { - tdpbf16ps(Xbyak::Tmm(CTile + mm * C_tilenum + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile + i)); + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); } - tileloaddt1(Xbyak::Tmm(BTile + 0), ptr[_reg_matBptr + reg_tmp + kk * BKStepSize + (_NTile - 1) * 64]); - tdpbf16ps(Xbyak::Tmm(CTile + mm * C_tilenum + _NTile - 1), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile + 0)); } } + } else { + assert(0); } } + } - void write_back(int _mtile, int _NRegs) { - inLocalLabel(); - mov(reg_tmp, dword[parambase + OFFSET(workspace)]); - mov(reg_tmp1, NTILE * 4); - for (int mm = 0; mm < 1; mm++) { - for (int i = 0; i < _NRegs; i++) { - tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * C_tilenum + i)); - } - } - load32(reg_matCptr, ptr[parambase + OFFSET(kpos)]); - cmp(reg_matCptr, 0); - jg(".LACC", T_NEAR); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - inLocalLabel(); - xor_(reg_tmp1, reg_tmp1); - L(".mloop"); - for (int j = 0; j < _NRegs; j++) { - vmovups(Xbyak::Zmm(CReg + j), ptr[reg_tmp + j * 64]); - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Zmm(CReg + j)); + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); } - add(reg_matCptr, reg_cstep); - add(reg_tmp, NTILE * 4); - add(reg_tmp1, 1); - cmp(reg_tmp1.cvt32(), ptr[parambase + OFFSET(msize)]); - jb(".mloop"); - outLocalLabel(); - jmp(".LEND", T_NEAR); - L(".LACC"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - inLocalLabel(); - xor_(reg_tmp1, reg_tmp1); - L(".mloop"); - for (int j = 0; j < _NRegs; j++) { - vmovups(Xbyak::Zmm(CReg + j), ptr[reg_tmp + j * 64]); - vaddps(Xbyak::Zmm(CReg + j), ptr[reg_matCptr + j * VecBytes]); - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Zmm(CReg + j)); + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); } - add(reg_matCptr, reg_cstep); - add(reg_tmp, NTILE * 4); - add(reg_tmp1, 1); - cmp(reg_tmp1.cvt32(), ptr[parambase + OFFSET(msize)]); - jb(".mloop"); - outLocalLabel(); - L(".LEND"); - nop(); - outLocalLabel(); + add(reg_matCptr, reg_cstride); } + L(".end"); + outLocalLabel(); + } - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstep; - Xbyak::Reg64 reg_astep; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - }; - - public: - GemmCore_Row_NN_16x64_AMX_BF16() { mCodes.generate_code(); } - - void forward(AType* matA, BType* matB, CType* matC, int _m, int _n, int _k, int _astride, int _bstride, int _cstride, - int kpos, void* tmpcache, size_t cachesize) { - assert((NTILE * MTILE * sizeof(CType))<= cachesize); - MicroKernel::tileconfig_t mCfg; - memset(&mCfg, 0, sizeof(mCfg)); - auto param = params{matA, matB, matC, _k, _m, _n, _astride, _bstride, _cstride, kpos, tmpcache, &mCfg}; - if (_m <= MTILE) { - jblas::xbyak::JitAmxtile::configure_tiles(mCfg, _m < 16 ? _m : 16, _n < 16 ? _n : 16, _k < KTILE ? _k : KTILE, - sizeof(BType), MicroKernel::A_tilenum, MicroKernel::B_tilenum, - MicroKernel::C_tilenum); - mCodes.mKernel(¶m); - } else { - assert(0); + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); } + outLocalLabel(); } - - private: - MicroKernel mCodes; }; -class GemmCore_Row_NN_16x48_AMX_BF16 { +template +class Amxbf16N16P2 : protected jblas::xbyak::JitAmxbf16 { public: + static int constexpr RegLen = 16, PackRow = 2; + static_assert(_NTILE % RegLen == 0); + static_assert(_MTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; + static_assert(NRegs * MRegs + 2 <= TileCount); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 32; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_BF16; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32; typedef utils::bf16 AType; typedef utils::bf16 BType; typedef float CType; + struct params { AType* matA; + int astride; BType* matB; + int bstride; CType* matC; - int k, msize, nsize; - int astep, bstep, cstep; - int kpos; - void *workspace, *cfg; + int cstride; + int k; + int n; + int init; + void* workspace; }; typedef long long (*func_t)(params*); - static JBLAS_GEMM_CORE constexpr TYPE = JBLAS_GEMM_CORE::AMX_BF16_16x48; - static JBLAS_ISA constexpr ISA = - utils::gemm_core_mask(); - static int constexpr NTILE = 48, MTILE = 16, KTILE = 64 / sizeof(BType); - static int constexpr PACK_ROW = 2; - static int constexpr KUNROLL = 2; - static int constexpr PREFERED_N = 240; - class MicroKernel : protected jblas::xbyak::JitAmxbf16 { - public: - friend GemmCore_Row_NN_16x48_AMX_BF16; - MicroKernel() {} - static int constexpr CReg = 0, TmpReg = 4; - static int constexpr NRegs = 3; - static int constexpr CRegCount = NRegs; - static int constexpr C_tilenum = 3, A_tilenum = 1, B_tilenum = 3; - static int constexpr CTile = 0, ATile = CTile + C_tilenum, BTile = ATile + A_tilenum; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - static int constexpr VecBytes = 64; - - void generate_code() { - reset(); - generate_mtile(); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - void generate_mtile() { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_nsize = st.t[9]; - reg_cstep = st.t[3]; - reg_astep = st.t[5]; - reg_iterk = st.t[4]; - reg_itern = st.t[7]; - reg_tmp = st.t[6]; - reg_tmp1 = st.t[8]; - reg_tmp2 = st.t[10]; - reg_ret = rax; - - vreg_push(rsp); - mov(reg_tmp, ptr[parambase + OFFSET(cfg)]); - ldtilecfg(ptr[reg_tmp]); - - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(nsize)]); - load32(reg_astep, ptr[parambase + OFFSET(astep)]); - - xor_(reg_itern, reg_itern); - L(".nloop"); - for (int i = 0; i < C_tilenum; i++) { - tilezero(Xbyak::Tmm(CTile + i)); - } - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - mov(reg_tmp1, reg_matBptr); - - xor_(reg_iterk, reg_iterk); - - mov(reg_tmp, reg_nsize); - sub(reg_tmp, reg_itern); - cmp(reg_tmp, NTILE); - jl(".n32", T_NEAR); - generate_kloop(NRegs); - write_back(MTILE, NRegs); - load32(reg_tmp, ptr[parambase + OFFSET(bstep)]); - imul(reg_tmp, reg_tmp, NTILE); - add(reg_matBptr, reg_tmp); - add(reg_itern, NTILE); - jmp(".nend", T_NEAR); - - L(".n32"); - cmp(reg_tmp, 32); - jl(".n16", T_NEAR); - generate_kloop(2); - write_back(MTILE, 2); - add(reg_itern, 32); - add(reg_matBptr, 32 * sizeof(BType)); - jmp(".nend", T_NEAR); - - L(".n16"); - xor_(reg_iterk, reg_iterk); - generate_kloop(1); - write_back(MTILE, 1); - add(reg_itern, 16); - add(reg_matBptr, 16 * sizeof(BType)); - L(".nend"); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label + int TmpRegCount = RegCount; + int TmpReg = 0; + int CTileCount = 0, ATileCount = 0, BTileCount = 0; + int CTile = 0, ATile = 0, BTile = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CTileCount = NRegs * MRegs; + auto tile_re = TileCount - CTileCount; + if (tile_re - 1 >= NRegs) { + BTileCount = NRegs; + ATileCount = tile_re - BTileCount; + } else if (tile_re - 1 >= MRegs) { + ATileCount = MRegs; + BTileCount = tile_re - ATileCount; + } else { + ATileCount = 1; + BTileCount = tile_re - ATileCount; } + CTile = 0; + ATile = CTile + CTileCount; + BTile = ATile + ATileCount; + } - void generate_kloop(int _nregs) { - inLocalLabel(); - L(".kloop"); - mov(reg_tmp, reg_ksize); - sub(reg_tmp, reg_iterk); - cmp(reg_tmp, KUNROLL * KTILE); - jl(".k1loop", T_NEAR); - generate_fma(_nregs, KUNROLL, reg_tmp1); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_tmp1, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - jmp(".kloopend", T_NEAR); - - L(".k1loop"); - generate_fma(_nregs, 1, reg_tmp1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_tmp1, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - L(".kloopend"); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - outLocalLabel(); - } + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } - void generate_fma(int _NTile, int _kunroll, const Xbyak::Reg64& _reg_matBptr) { - mov(reg_tmp, NTILE * 4); - for (int kk = 0; kk < _kunroll; kk++) { - for (int i = 0; i < _NTile; i++) { - tileloaddt1(Xbyak::Tmm(BTile + i), ptr[_reg_matBptr + reg_tmp + kk * BKStepSize + i * 64]); - } + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int kunrll) { + auto& reg_Bstride = reg_tmp1; + mov(reg_Bstride, NTILE * 4); + int mtiles = _mtile / RegLen; - for (int mm = 0; mm < 1; mm++) { - tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_matAptr + reg_astep + kk * AKStepSize]); - for (int i = 0; i < _NTile; i++) { - tdpbf16ps(Xbyak::Tmm(CTile + mm * C_tilenum + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile + i)); + for (int kk = 0; kk < kunrll; kk++) { + auto& reg_Atmp = reg_tmp2; + if (mtiles == 1) { + reg_Atmp = reg_matAptr; + } else { + mov(reg_Atmp, reg_matAptr); + } + if (BTileCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + } + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } else { + if (ATileCount == mtiles) { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + for (int mm = 0; mm < mtiles; mm++) { + tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); + } + } + } else { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } } } } } + } - void write_back(int _mtile, int _NRegs) { - inLocalLabel(); - mov(reg_tmp, dword[parambase + OFFSET(workspace)]); - mov(reg_tmp1, NTILE * 4); - for (int mm = 0; mm < 1; mm++) { - for (int i = 0; i < _NRegs; i++) { - tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * C_tilenum + i)); - } - } - load32(reg_matCptr, ptr[parambase + OFFSET(kpos)]); - cmp(reg_matCptr, 0); - jg(".LACC", T_NEAR); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - inLocalLabel(); - xor_(reg_tmp1, reg_tmp1); - L(".mloop"); - for (int j = 0; j < _NRegs; j++) { - vmovups(Xbyak::Zmm(CReg + j), ptr[reg_tmp + j * 64]); - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Zmm(CReg + j)); + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < CTileCount; i++) { + tilezero(Xbyak::Tmm(CTile + i)); + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + int mtnum = _mtile / 16; + for (int mm = 0; mm < mtnum; mm++) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); } - add(reg_matCptr, reg_cstep); - add(reg_tmp, NTILE * 4); - add(reg_tmp1, 1); - cmp(reg_tmp1.cvt32(), ptr[parambase + OFFSET(msize)]); - jb(".mloop"); - outLocalLabel(); - jmp(".LEND", T_NEAR); - L(".LACC"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - inLocalLabel(); - xor_(reg_tmp1, reg_tmp1); - L(".mloop"); - for (int j = 0; j < _NRegs; j++) { - vmovups(Xbyak::Zmm(CReg + j), ptr[reg_tmp + j * 64]); - vaddps(Xbyak::Zmm(CReg + j), ptr[reg_matCptr + j * VecBytes]); - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Zmm(CReg + j)); + if (mm != mtnum - 1) { + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); } - add(reg_matCptr, reg_cstep); - add(reg_tmp, NTILE * 4); - add(reg_tmp1, 1); - cmp(reg_tmp1.cvt32(), ptr[parambase + OFFSET(msize)]); - jb(".mloop"); - outLocalLabel(); - L(".LEND"); - nop(); - outLocalLabel(); } + L(".end"); + outLocalLabel(); + } - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstep; - Xbyak::Reg64 reg_astep; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - }; - - public: - GemmCore_Row_NN_16x48_AMX_BF16() { mCodes.generate_code(); } - - void forward(AType* matA, BType* matB, CType* matC, int _m, int _n, int _k, int _astride, int _bstride, int _cstride, - int kpos, void* tmpcache, size_t cachesize) { - assert((NTILE * MTILE * sizeof(CType)) <= cachesize); - MicroKernel::tileconfig_t mCfg; - std::memset(&mCfg, 0, sizeof(mCfg)); - auto param = params{matA, matB, matC, _k, _m, _n, _astride, _bstride, _cstride, kpos, tmpcache, &mCfg}; - if (_m <= MTILE) { - jblas::xbyak::JitAmxtile::configure_tiles(mCfg, _m < 16 ? _m : 16, _n < 16 ? _n : 16, _k < KTILE ? _k : KTILE, - sizeof(BType), MicroKernel::A_tilenum, MicroKernel::B_tilenum, - MicroKernel::C_tilenum); - mCodes.mKernel(¶m); - } else { - assert(0); + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_tmp, dword[parambase + OFFSET(workspace)]); + mov(reg_tmp1, NTILE * 4); + for (int mm = 0; mm < MRegs; mm++) { + for (int i = 0; i < NRegs; i++) { + tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); + } + } + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + int zunroll = TmpRegCount / NRegs; + for (int i = 0; i < _mtile; i += zunroll) { + int m_re = utils::remainsize(i, _mtile, zunroll); + for (int im = 0; im < m_re; im++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } } + outLocalLabel(); } - - private: - MicroKernel mCodes; }; -template -class GemmCore_Row_NN_16x64_AMX_I8 { +template +class Amxint8N16P4 : protected jblas::xbyak::JitAmxint8 { public: - typedef T_A_ AType; - typedef T_B_ BType; + static int constexpr RegLen = 16, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static_assert(_MTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; + static_assert(NRegs * MRegs + 2 <= TileCount); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 64; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_INT8; + static uint32_t constexpr COMPUTE = + (uint32_t)(std::is_same_v + ? std::is_same_v ? CompType::COMP_INT8_SS_INT32 : CompType::COMP_INT8_SU_INT32 + : std::is_same_v ? CompType::COMP_INT8_US_INT32 + : CompType::COMP_INT8_UU_INT32); + using AType = AT; + using BType = BT; typedef int32_t CType; + struct params { AType* matA; + int astride; BType* matB; + int bstride; CType* matC; - int k, msize, nsize; - int astep, bstep, cstep; - int kpos; - void *workspace, *cfg; + int cstride; + int k; + int n; + int init; + void* workspace; }; typedef long long (*func_t)(params*); - static JBLAS_GEMM_CORE constexpr TYPE = _ID; - static JBLAS_ISA constexpr ISA = - utils::gemm_core_mask(); - static int constexpr NTILE = 64, MTILE = 16, KTILE = 64 / sizeof(BType); - static int constexpr PACK_ROW = 4; - static int constexpr KUNROLL = 2; - static int constexpr PREFERED_N = 256; - class MicroKernel : protected jblas::xbyak::JitAmxint8 { - public: - friend GemmCore_Row_NN_16x64_AMX_I8; - MicroKernel() {} - static int constexpr CReg = 0, TmpReg = 4; - static int constexpr NRegs = 4; - static int constexpr CRegCount = NRegs; - static int constexpr C_tilenum = 4, A_tilenum = 1, B_tilenum = 3; - static int constexpr CTile = 0, ATile = CTile + C_tilenum, BTile = ATile + A_tilenum; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - static int constexpr VecBytes = 64; - - void generate_code() { - reset(); - generate_mtile(); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - void generate_mtile() { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_nsize = st.t[9]; - reg_cstep = st.t[3]; - reg_astep = st.t[5]; - reg_iterk = st.t[4]; - reg_itern = st.t[7]; - reg_tmp = st.t[6]; - reg_tmp1 = st.t[8]; - reg_tmp2 = st.t[10]; - reg_ret = rax; - - vreg_push(rsp); - mov(reg_tmp, ptr[parambase + OFFSET(cfg)]); - ldtilecfg(ptr[reg_tmp]); - - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(nsize)]); - load32(reg_astep, ptr[parambase + OFFSET(astep)]); - - xor_(reg_itern, reg_itern); - L(".nloop"); - for (int i = 0; i < C_tilenum; i++) { - tilezero(Xbyak::Tmm(CTile + i)); - } - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - mov(reg_tmp1, reg_matBptr); - - xor_(reg_iterk, reg_iterk); - - mov(reg_tmp, reg_nsize); - sub(reg_tmp, reg_itern); - cmp(reg_tmp, NTILE); - jl(".n48", T_NEAR); - generate_kloop(NRegs); - write_back(MTILE, NRegs); - load32(reg_tmp, ptr[parambase + OFFSET(bstep)]); - imul(reg_tmp, reg_tmp, NTILE); - add(reg_matBptr, reg_tmp); - add(reg_itern, NTILE); - jmp(".nend", T_NEAR); - - L(".n48"); - cmp(reg_tmp, 48); - jl(".n32", T_NEAR); - generate_kloop(3); - write_back(MTILE, 3); - add(reg_itern, 48); - add(reg_matBptr, 48 * sizeof(BType)); - jmp(".nend", T_NEAR); - - L(".n32"); - cmp(reg_tmp, 32); - jl(".n16", T_NEAR); - generate_kloop(2); - write_back(MTILE, 2); - add(reg_itern, 32); - add(reg_matBptr, 32 * sizeof(BType)); - jmp(".nend", T_NEAR); - - L(".n16"); - xor_(reg_iterk, reg_iterk); - generate_kloop(1); - write_back(MTILE, 1); - add(reg_itern, 16); - add(reg_matBptr, 16 * sizeof(BType)); - L(".nend"); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label + int TmpRegCount = RegCount; + int TmpReg = 0; + int CTileCount = 0, ATileCount = 0, BTileCount = 0; + int CTile = 0, ATile = 0, BTile = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CTileCount = NRegs * MRegs; + auto tile_re = TileCount - CTileCount; + if (tile_re - 1 >= NRegs) { + BTileCount = NRegs; + ATileCount = tile_re - BTileCount; + } else if (tile_re - 1 >= MRegs) { + ATileCount = MRegs; + BTileCount = tile_re - ATileCount; + } else { + ATileCount = 1; + BTileCount = tile_re - ATileCount; } + CTile = 0; + ATile = CTile + CTileCount; + BTile = ATile + ATileCount; + } - void generate_kloop(int _nregs) { - inLocalLabel(); - L(".kloop"); - mov(reg_tmp, reg_ksize); - sub(reg_tmp, reg_iterk); - cmp(reg_tmp, KUNROLL * KTILE); - jl(".k1loop", T_NEAR); - generate_fma(_nregs, KUNROLL, reg_tmp1); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_tmp1, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - jmp(".kloopend", T_NEAR); - - L(".k1loop"); - generate_fma(_nregs, 1, reg_tmp1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_tmp1, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - L(".kloopend"); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - outLocalLabel(); - } + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } - void generate_fma(int _NTile, int _kunroll, const Xbyak::Reg64& _reg_matBptr) { - mov(reg_tmp, NTILE * 4); - if (_NTile <= B_tilenum) { - for (int kk = 0; kk < _kunroll; kk++) { - for (int i = 0; i < _NTile; i++) { - tileloaddt1(Xbyak::Tmm(BTile + i), ptr[_reg_matBptr + reg_tmp + kk * BKStepSize + i * 64]); - } + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } - for (int mm = 0; mm < 1; mm++) { - tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_matAptr + reg_astep + kk * AKStepSize]); - for (int i = 0; i < _NTile; i++) { - _tdpb(Xbyak::Tmm(CTile + mm * C_tilenum + i), Xbyak::Tmm(ATile + mm), - Xbyak::Tmm(BTile + i)); - } + void generate_fma(int _mtile, int kunrll) { + auto& reg_Bstride = reg_tmp1; + mov(reg_Bstride, NTILE * 4); + int mtiles = _mtile / RegLen; + + for (int kk = 0; kk < kunrll; kk++) { + auto& reg_Atmp = reg_tmp2; + if (mtiles == 1) { + reg_Atmp = reg_matAptr; + } else { + mov(reg_Atmp, reg_matAptr); + } + if (BTileCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + } + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); } } } else { - for (int kk = 0; kk < _kunroll; kk++) { - for (int i = 0; i < _NTile - 1; i++) { - tileloaddt1(Xbyak::Tmm(BTile + i), ptr[_reg_matBptr + reg_tmp + kk * BKStepSize + i * 64]); + if (ATileCount == mtiles) { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } } - - for (int mm = 0; mm < 1; mm++) { - tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_matAptr + reg_astep + kk * AKStepSize]); - for (int i = 0; i < _NTile - 1; i++) { - _tdpb(Xbyak::Tmm(CTile + mm * C_tilenum + i), Xbyak::Tmm(ATile + mm), - Xbyak::Tmm(BTile + i)); + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + for (int mm = 0; mm < mtiles; mm++) { + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); + } + } + } else { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); } - tileloaddt1(Xbyak::Tmm(BTile + 0), ptr[_reg_matBptr + reg_tmp + kk * BKStepSize + (_NTile - 1) * 64]); - _tdpb(Xbyak::Tmm(CTile + mm * C_tilenum + _NTile - 1), Xbyak::Tmm(ATile + mm), - Xbyak::Tmm(BTile + 0)); } } } } + } - void write_back(int _mtile, int _NRegs) { - inLocalLabel(); - mov(reg_tmp, dword[parambase + OFFSET(workspace)]); - mov(reg_tmp1, NTILE * 4); - for (int mm = 0; mm < 1; mm++) { - for (int i = 0; i < _NRegs; i++) { - tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * C_tilenum + i)); - } + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < CTileCount; i++) { + tilezero(Xbyak::Tmm(CTile + i)); + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + int mtnum = _mtile / 16; + for (int mm = 0; mm < mtnum; mm++) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); + } + if (mm != mtnum - 1) { + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); } - load32(reg_matCptr, ptr[parambase + OFFSET(kpos)]); - cmp(reg_matCptr, 0); - jg(".LACC", T_NEAR); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - inLocalLabel(); - xor_(reg_tmp1, reg_tmp1); - L(".mloop"); - for (int j = 0; j < _NRegs; j++) { - vmovups(Xbyak::Zmm(CReg + j), ptr[reg_tmp + j * 64]); - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Zmm(CReg + j)); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_tmp, dword[parambase + OFFSET(workspace)]); + mov(reg_tmp1, NTILE * 4); + for (int mm = 0; mm < MRegs; mm++) { + for (int i = 0; i < NRegs; i++) { + tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); } - add(reg_matCptr, reg_cstep); - add(reg_tmp, NTILE * 4); - add(reg_tmp1, 1); - cmp(reg_tmp1.cvt32(), ptr[parambase + OFFSET(msize)]); - jb(".mloop"); - outLocalLabel(); - jmp(".LEND", T_NEAR); - L(".LACC"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstep, ptr[parambase + OFFSET(cstep)]); - inLocalLabel(); - xor_(reg_tmp1, reg_tmp1); - L(".mloop"); - for (int j = 0; j < _NRegs; j++) { - vmovups(Xbyak::Zmm(CReg + j), ptr[reg_tmp + j * 64]); - vpaddd(Xbyak::Zmm(CReg + j), Xbyak::Zmm(CReg + j), ptr[reg_matCptr + j * VecBytes]); - vmovups(ptr[reg_matCptr + j * VecBytes], Xbyak::Zmm(CReg + j)); + } + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + int zunroll = TmpRegCount / NRegs; + for (int i = 0; i < _mtile; i += zunroll) { + int m_re = utils::remainsize(i, _mtile, zunroll); + for (int im = 0; im < m_re; im++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); + } + add(reg_matCptr, reg_cstride); } - add(reg_matCptr, reg_cstep); - add(reg_tmp, NTILE * 4); - add(reg_tmp1, 1); - cmp(reg_tmp1.cvt32(), ptr[parambase + OFFSET(msize)]); - jb(".mloop"); - outLocalLabel(); - L(".LEND"); - nop(); - outLocalLabel(); } + outLocalLabel(); + } +}; +template +using Amxint8N16P4US = Amxint8N16P4; - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstep; - Xbyak::Reg64 reg_astep; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - }; +template +using Amxint8N16P4SS = Amxint8N16P4; +class AmxConfigure : protected jblas::xbyak::JitAmxtile { public: - GemmCore_Row_NN_16x64_AMX_I8() { mCodes.generate_code(); } + typedef long long (*func_t)(tileconfig_t*); + + static void configure(int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, int CNum) { + static AmxConfigure code; + tileconfig_t cfg; + std::memset(&cfg, 0, sizeof(cfg)); + configure_tiles(cfg, TILE_M, TILE_N, TILE_K, elesize, ANum, BNum, CNum); + code.mKernel(&cfg); + } - void forward(AType* matA, BType* matB, CType* matC, int _m, int _n, int _k, int _astride, int _bstride, int _cstride, + protected: + AmxConfigure() { + generate_config(this); + mKernel = getCode(); + } + + func_t mKernel = nullptr; +}; + +} // namespace code +template