diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 37f5bee33a..89c58a5040 100755 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -173,6 +173,7 @@ set( MIOpen_Source include/miopen/rnn_util.hpp include/miopen/bz2.hpp include/miopen/comgr.hpp + include/miopen/numeric.hpp md_graph.cpp mdg_expr.cpp conv/invokers/gcn_asm_1x1u.cpp @@ -223,6 +224,7 @@ set( MIOpen_Source solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp solver/conv_hip_implicit_gemm_v4r4_gen_xdlops_fwd_fp32.cpp solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops_nchw_kcyx_nkhw.cpp + solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp ) list(APPEND MIOpen_Source tmp_dir.cpp binary_cache.cpp md5.cpp) diff --git a/src/conv/invokers/impl_gemm_dynamic.cpp b/src/conv/invokers/impl_gemm_dynamic.cpp index 76164fde5f..d7d2deed25 100644 --- a/src/conv/invokers/impl_gemm_dynamic.cpp +++ b/src/conv/invokers/impl_gemm_dynamic.cpp @@ -3,42 +3,43 @@ #include #include #include - +#include #include namespace miopen { namespace conv { -float CallImplicitGemmDynamic(const miopen::Handle& handle, - const ConvolutionContext& ctx, - ConstData_t src, - Data_t dst, - ConstData_t wei, - const std::vector& kernels) +float CallImplGemmDynamicForward(const miopen::Handle& handle, + const ProblemDescription& conv_problem, + ConstData_t src, + Data_t dst, + ConstData_t wei, + const std::vector& kernels) { float elapsed = 0.0f; auto kernel = kernels[0]; MIOPEN_LOG_I(kernel.GetName()); - bool kernel_is_1x1 = (kernel.GetName().find("igemm_v4r1_1x1_dynamic") == 0); + // clang-format off - int hi = ctx.in_height; - int wi = ctx.in_width; - int n = ctx.batch_sz; - int k = ctx.n_outputs; - int c = ctx.n_inputs; - int ho = ctx.out_height; - int wo = ctx.out_width; - int stride_h = ctx.kernel_stride_h; - int stride_w = ctx.kernel_stride_w; - int dilation_h = ctx.kernel_dilation_h; - int dilation_w = ctx.kernel_dilation_w; - int pad_h = ctx.pad_h; - int pad_w = ctx.pad_w; - int y = ctx.kernel_size_h; - int x = ctx.kernel_size_w; + int hi = conv_problem.GetInHeight(); + int wi = conv_problem.GetInWidth(); + int n = conv_problem.GetInBatchSize(); + int k = conv_problem.GetOutChannels(); + int c = conv_problem.GetInChannels(); + int ho = conv_problem.GetOutHeight(); + int wo = conv_problem.GetOutWidth(); + int stride_h = conv_problem.GetKernelStrideH(); + int stride_w = conv_problem.GetKernelStrideW(); + int dilation_h = conv_problem.GetDilationH(); + int dilation_w = conv_problem.GetDilationW(); + int pad_h = conv_problem.GetPadH(); + int pad_w = conv_problem.GetPadW(); + int y = conv_problem.GetWeightsHeight(); + int x = conv_problem.GetWeightsWidth(); int __pack0 = 0; // clang-format on + std::vector opArgs; opArgs.emplace_back(src); opArgs.emplace_back(wei); @@ -56,16 +57,10 @@ float CallImplicitGemmDynamic(const miopen::Handle& handle, opArgs.emplace_back(dilation_w); opArgs.emplace_back(pad_h); opArgs.emplace_back(pad_w); - if(kernel_is_1x1) - { - opArgs.emplace_back(__pack0); - } - else - { - opArgs.emplace_back(y); - opArgs.emplace_back(x); - opArgs.emplace_back(__pack0); - } + opArgs.emplace_back(y); + opArgs.emplace_back(x); + opArgs.emplace_back(__pack0); + kernel(opArgs); if(handle.IsProfilingEnabled()) @@ -73,45 +68,263 @@ float CallImplicitGemmDynamic(const miopen::Handle& handle, return elapsed; } -InvokerFactory MakeImplGemmDynamicDataInvokerFactory(const ConvolutionContext& ctx) +float CallImplGemmDynamicForward1x1(const miopen::Handle& handle, + const ProblemDescription& conv_problem, + ConstData_t src, + Data_t dst, + ConstData_t wei, + const std::vector& kernels) { - if(ctx.direction.IsForward()) - { - return [ctx](const std::vector& kernels) { - return [=](const Handle& handle, const boost::any& primitive_parameters) { - const auto data_ctx = boost::any_cast(primitive_parameters); - const auto& tensors = data_ctx.tensors; - auto kernel = handle.Run(kernels[0]); - if(kernel.GetName().find("igemm_v4r1_dynamic") == 0 || - kernel.GetName().find("igemm_v4r1_1x1_dynamic") == 0) - { - std::vector ks; - std::transform(kernels.begin(), - kernels.end(), - std::back_inserter(ks), - [&](const Kernel& k) { return handle.Run(k); }); - float elapsed = 0; - elapsed = CallImplicitGemmDynamic( - handle, ctx, tensors.in, tensors.out, tensors.w, ks); - if(handle.IsProfilingEnabled()) - { - handle.ResetKernelTime(); - handle.AccumKernelTime(elapsed); - } - } - else - { - MIOPEN_THROW( - "Error running dynamic implicit GEMM convolution (invalid kernel name?)"); - } - }; - }; - } - else + float elapsed = 0.0f; + + auto kernel = kernels[0]; + MIOPEN_LOG_I(kernel.GetName()); + + // clang-format off + int hi = conv_problem.GetInHeight(); + int wi = conv_problem.GetInWidth(); + int n = conv_problem.GetInBatchSize(); + int k = conv_problem.GetOutChannels(); + int c = conv_problem.GetInChannels(); + int ho = conv_problem.GetOutHeight(); + int wo = conv_problem.GetOutWidth(); + int stride_h = conv_problem.GetKernelStrideH(); + int stride_w = conv_problem.GetKernelStrideW(); + int dilation_h = conv_problem.GetDilationH(); + int dilation_w = conv_problem.GetDilationW(); + int pad_h = conv_problem.GetPadH(); + int pad_w = conv_problem.GetPadW(); + int __pack0 = 0; + // clang-format on + + std::vector opArgs; + opArgs.emplace_back(src); + opArgs.emplace_back(wei); + opArgs.emplace_back(dst); + opArgs.emplace_back(hi); + opArgs.emplace_back(wi); + opArgs.emplace_back(n); + opArgs.emplace_back(k); + opArgs.emplace_back(c); + opArgs.emplace_back(ho); + opArgs.emplace_back(wo); + opArgs.emplace_back(stride_h); + opArgs.emplace_back(stride_w); + opArgs.emplace_back(dilation_h); + opArgs.emplace_back(dilation_w); + opArgs.emplace_back(pad_h); + opArgs.emplace_back(pad_w); + opArgs.emplace_back(__pack0); + + kernel(opArgs); + + if(handle.IsProfilingEnabled()) + elapsed += handle.GetKernelTime(); + return elapsed; +} + +float CallImplGemmDynamicBackwardData(const miopen::Handle& handle, + const ProblemDescription& conv_problem, + ConstData_t src, + Data_t dst, + ConstData_t wei, + const std::vector& kernels) +{ + float elapsed = 0.0f; + + auto kernel = kernels[0]; + MIOPEN_LOG_I(kernel.GetName()); + + // clang-format off + int hi = conv_problem.GetOutHeight(); + int wi = conv_problem.GetOutWidth(); + int n = conv_problem.GetInBatchSize(); + int k = conv_problem.GetInChannels(); + int c = conv_problem.GetOutChannels(); + int ho = conv_problem.GetInHeight(); + int wo = conv_problem.GetInWidth(); + int stride_h = conv_problem.GetInHeight() > 1 ? conv_problem.GetKernelStrideH() : 1; + int stride_w = conv_problem.GetInWidth() > 1 ? conv_problem.GetKernelStrideW() : 1; + int dilation_h = conv_problem.GetWeightsHeight() > 1? conv_problem.GetDilationH() : 1; + int dilation_w = conv_problem.GetWeightsWidth() > 1? conv_problem.GetDilationW() : 1; + int pad_h = conv_problem.GetPadH(); + int pad_w = conv_problem.GetPadW(); + int y = conv_problem.GetWeightsHeight(); + int x = conv_problem.GetWeightsWidth(); + + int gcd_stride_dilation_h = gcd(stride_h, dilation_h); + int gcd_stride_dilation_w = gcd(stride_w, dilation_w); + int y_tilda = stride_h / gcd_stride_dilation_h; + int x_tilda = stride_w / gcd_stride_dilation_w; + + int y_dot = (y + y_tilda - 1) / y_tilda; + int x_dot = (x + x_tilda - 1) / x_tilda; + + int h_tilda = ho + (dilation_h * (y - 1) + stride_h - 1) / stride_h; + int w_tilda = wo + (dilation_w * (x - 1) + stride_w - 1) / stride_w; + + int h_tilda_left = std::max(0, pad_h - dilation_h * (y_tilda - 1)) / stride_h; + int w_tilda_left = std::max(0, pad_w - dilation_w * (x_tilda - 1)) / stride_w; + + int h_tilda_right = std::min(h_tilda, (pad_h + hi - 1 + stride_h - 1) / stride_h + 1); + int w_tilda_right = std::min(w_tilda, (pad_w + wi - 1 + stride_w - 1) / stride_w + 1); + + int h_tilda_slice = h_tilda_right - h_tilda_left; + int w_tilda_slice = w_tilda_right - w_tilda_left; + + int num_of_gemms = x_tilda * y_tilda; + + int dtile_iy = 0; + int dtile_ix = 0; + int dtile_dy = dilation_h / gcd_stride_dilation_h; + int dtile_dx = dilation_w / gcd_stride_dilation_w; + int dtile_y = y_tilda; + int dtile_x = x_tilda; + int dtile_h = h_tilda; + int dtile_w = w_tilda; + int dslice_y = 0; + int dslice_x = 0; + int dslice_h = h_tilda_slice; + int dslice_w = w_tilda_slice; + int dslice_h_left = h_tilda_left; + int dslice_w_left = w_tilda_left; + int __pack0 = 0; + // clang-format on + + std::vector opArgs; + opArgs.emplace_back(dst); + opArgs.emplace_back(wei); + opArgs.emplace_back(src); + opArgs.emplace_back(hi); + opArgs.emplace_back(wi); + opArgs.emplace_back(n); + opArgs.emplace_back(k); + opArgs.emplace_back(c); + opArgs.emplace_back(ho); + opArgs.emplace_back(wo); + opArgs.emplace_back(stride_h); + opArgs.emplace_back(stride_w); + opArgs.emplace_back(dilation_h); + opArgs.emplace_back(dilation_w); + opArgs.emplace_back(pad_h); + opArgs.emplace_back(pad_w); + opArgs.emplace_back(y); + opArgs.emplace_back(x); + opArgs.emplace_back(dtile_iy); + opArgs.emplace_back(dtile_ix); + opArgs.emplace_back(dtile_dy); + opArgs.emplace_back(dtile_dx); + opArgs.emplace_back(dtile_y); + opArgs.emplace_back(dtile_x); + opArgs.emplace_back(dtile_h); + opArgs.emplace_back(dtile_w); + opArgs.emplace_back(dslice_y); + opArgs.emplace_back(dslice_x); + opArgs.emplace_back(dslice_h); + opArgs.emplace_back(dslice_w); + opArgs.emplace_back(dslice_h_left); + opArgs.emplace_back(dslice_w_left); + opArgs.emplace_back(__pack0); + + for(int gemm_id = 0; gemm_id < num_of_gemms; gemm_id++) { - MIOPEN_THROW( - "Error running dynamic implicit GEMM convolution (currently only support forward)"); + int _dtile_iy = gemm_id / x_tilda; + int _dtile_ix = gemm_id % x_tilda; + int _y_dot_slice = (_dtile_iy + 1) * y_dot <= y ? y_dot : y % y_dot; + int _x_dot_slice = (_dtile_ix + 1) * x_dot <= x ? x_dot : x % x_dot; + int _gemm_k = k * _y_dot_slice * _x_dot_slice; + bool is_gemm_not_empty = _gemm_k > 0; + opArgs[18] = OpKernelArg(_dtile_iy); + opArgs[19] = OpKernelArg(_dtile_ix); + opArgs[26] = OpKernelArg(_y_dot_slice); + opArgs[27] = OpKernelArg(_x_dot_slice); + if(is_gemm_not_empty) + kernel(opArgs); } + + if(handle.IsProfilingEnabled()) + elapsed += handle.GetKernelTime(); + return elapsed; +} + +InvokerFactory MakeImplGemmDynamicForwardInvokerFactory(const ConvolutionContext& ctx) +{ + const auto& conv_problem = ctx.conv_problem; + return [conv_problem](const std::vector& kernels) { + return [=](const Handle& handle, const boost::any& primitive_parameters) { + const auto data_ctx = boost::any_cast(primitive_parameters); + const auto& tensors = data_ctx.tensors; + auto kernel = handle.Run(kernels[0]); + + std::vector ks; + std::transform(kernels.begin(), + kernels.end(), + std::back_inserter(ks), + [&](const Kernel& k) { return handle.Run(k); }); + float elapsed = 0; + elapsed = CallImplGemmDynamicForward( + handle, conv_problem, tensors.in, tensors.out, tensors.w, ks); + if(handle.IsProfilingEnabled()) + { + handle.ResetKernelTime(); + handle.AccumKernelTime(elapsed); + } + }; + }; +} + +InvokerFactory MakeImplGemmDynamicForward1x1InvokerFactory(const ConvolutionContext& ctx) +{ + const auto& conv_problem = ctx.conv_problem; + return [conv_problem](const std::vector& kernels) { + return [=](const Handle& handle, const boost::any& primitive_parameters) { + const auto data_ctx = boost::any_cast(primitive_parameters); + const auto& tensors = data_ctx.tensors; + auto kernel = handle.Run(kernels[0]); + + std::vector ks; + std::transform(kernels.begin(), + kernels.end(), + std::back_inserter(ks), + [&](const Kernel& k) { return handle.Run(k); }); + float elapsed = 0; + elapsed = CallImplGemmDynamicForward1x1( + handle, conv_problem, tensors.in, tensors.out, tensors.w, ks); + if(handle.IsProfilingEnabled()) + { + handle.ResetKernelTime(); + handle.AccumKernelTime(elapsed); + } + }; + }; +} + +InvokerFactory MakeImplGemmDynamicBackwardDataInvokerFactory(const ConvolutionContext& ctx) +{ + const auto& conv_problem = ctx.conv_problem; + return [conv_problem](const std::vector& kernels) { + return [=](const Handle& handle, const boost::any& primitive_parameters) { + const auto data_ctx = boost::any_cast(primitive_parameters); + const auto& tensors = data_ctx.tensors; + auto kernel = handle.Run(kernels[0]); + + std::vector ks; + std::transform(kernels.begin(), + kernels.end(), + std::back_inserter(ks), + [&](const Kernel& k) { return handle.Run(k); }); + float elapsed = 0; + + elapsed = CallImplGemmDynamicBackwardData( + handle, conv_problem, tensors.in, tensors.out, tensors.w, ks); + + if(handle.IsProfilingEnabled()) + { + handle.ResetKernelTime(); + handle.AccumKernelTime(elapsed); + } + }; + }; } } // namespace conv diff --git a/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp b/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp index 28937c666e..619be17a59 100644 --- a/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp +++ b/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp @@ -35,17 +35,28 @@ namespace miopen { namespace conv { -// Beside used in invoker, currently this function is only called in RunAndMeasure() of dynamic -// igemm solver -// Remove this in the future when invoker is fully re-factored. -float CallImplicitGemmDynamic(const miopen::Handle& handle, - const ConvolutionContext& ctx, - ConstData_t src, - Data_t dst, - ConstData_t wei, - const std::vector& kernels); +float CallImplGemmDynamicForward(const miopen::Handle& handle, + const ProblemDescription& conv_problem, + ConstData_t src, + Data_t dst, + ConstData_t wei, + const std::vector& kernels); +float CallImplGemmDynamicForward1x1(const miopen::Handle& handle, + const ProblemDescription& conv_problem, + ConstData_t src, + Data_t dst, + ConstData_t wei, + const std::vector& kernels); +float CallImplGemmDynamicBackwardData(const miopen::Handle& handle, + const ProblemDescription& conv_problem, + ConstData_t src, + Data_t dst, + ConstData_t wei, + const std::vector& kernels); -InvokerFactory MakeImplGemmDynamicDataInvokerFactory(const ConvolutionContext& ctx); +InvokerFactory MakeImplGemmDynamicForwardInvokerFactory(const ConvolutionContext& ctx); +InvokerFactory MakeImplGemmDynamicForward1x1InvokerFactory(const ConvolutionContext& ctx); +InvokerFactory MakeImplGemmDynamicBackwardDataInvokerFactory(const ConvolutionContext& ctx); } // namespace conv } // namespace miopen diff --git a/src/include/miopen/numeric.hpp b/src/include/miopen/numeric.hpp new file mode 100644 index 0000000000..d64119cdea --- /dev/null +++ b/src/include/miopen/numeric.hpp @@ -0,0 +1,64 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_MLOPEN_NUMERIC_HPP +#define GUARD_MLOPEN_NUMERIC_HPP + +#include + +namespace miopen { + +template +T gcd(T x, T y) +{ + assert(!(x == 0 && y == 0)); + + if(x == y || x == 0) + { + return y; + } + else if(y == 0) + { + return x; + } + else if(x > y) + { + return gcd(x - y, y); + } + else + { + return gcd(x, y - x); + } +} + +template +T gcd(T x, Ys... ys) +{ + return gcd(x, gcd(ys...)); +} + +} // namespace miopen + +#endif diff --git a/src/include/miopen/solver.hpp b/src/include/miopen/solver.hpp index 0a04249fdd..422ddf1a4f 100644 --- a/src/include/miopen/solver.hpp +++ b/src/include/miopen/solver.hpp @@ -1409,6 +1409,12 @@ struct ConvAsmImplicitGemmV4R1DynamicFwd_1x1 : SolverBase float& elapsed_time) const; }; +struct ConvAsmImplicitGemmV4R1DynamicBwd : SolverBase +{ + bool IsApplicable(const ConvolutionContext&) const; + ConvSolution GetSolution(const ConvolutionContext&) const; +}; + /// Holds common member functions for the Solvers which share the same /// "legacy exhaustive search" machinery. struct ConvOclDirectFwdLegacyExhaustiveSearch : SolverBase diff --git a/src/kernels/dynamic_igemm/igemm_bwd_gtc_dynamic.s b/src/kernels/dynamic_igemm/igemm_bwd_gtc_dynamic.s new file mode 100644 index 0000000000..f573d97baa --- /dev/null +++ b/src/kernels/dynamic_igemm/igemm_bwd_gtc_dynamic.s @@ -0,0 +1,4746 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +; generated by igemm_codegen.py +; +.macro .v_u32_div v_q, v_n, v_d, v_tmp4, s_tmp4 + v_cvt_f32_u32 v[\v_tmp4+0], v[\v_d] + v_rcp_f32 v[\v_tmp4+0], v[\v_tmp4+0] + v_mul_f32 v[\v_tmp4+0], 0x4f800000, v[\v_tmp4+0] + v_cvt_u32_f32 v[\v_tmp4+0], v[\v_tmp4+0] + v_mul_lo_u32 v[\v_tmp4+1], v[\v_d], v[\v_tmp4+0] + v_mul_hi_u32 v[\v_tmp4+2], v[\v_d], v[\v_tmp4+0] + v_sub_co_u32 v[\v_tmp4+3], vcc, 0, v[\v_tmp4+1] + v_cmp_ne_i32 s[\s_tmp4:\s_tmp4+1], 0, v[\v_tmp4+2] + v_cndmask_b32 v[\v_tmp4+1], v[\v_tmp4+3], v[\v_tmp4+1], s[\s_tmp4:\s_tmp4+1] + v_mul_hi_u32 v[\v_tmp4+1], v[\v_tmp4+1], v[\v_tmp4+0] + v_sub_co_u32 v[\v_tmp4+2], vcc, v[\v_tmp4+0], v[\v_tmp4+1] + v_add_co_u32 v[\v_tmp4+0], vcc, v[\v_tmp4+0], v[\v_tmp4+1] + v_cndmask_b32 v[\v_tmp4+0], v[\v_tmp4+0], v[\v_tmp4+2], s[\s_tmp4:\s_tmp4+1] + v_mul_hi_u32 v[\v_tmp4+0], v[\v_tmp4+0], v[\v_n] + v_mul_lo_u32 v[\v_tmp4+1], v[\v_tmp4+0], v[\v_d] + v_sub_co_u32 v[\v_tmp4+2], vcc, v[\v_n], v[\v_tmp4+1] + v_cmp_ge_u32 s[\s_tmp4:\s_tmp4+1], v[\v_n], v[\v_tmp4+1] + v_cmp_ge_u32 s[\s_tmp4+2:\s_tmp4+3], v[\v_tmp4+2], v[\v_d] + v_add_co_u32 v[\v_tmp4+2], vcc, 1, v[\v_tmp4+0] + s_and_b64 s[\s_tmp4+2:\s_tmp4+3], s[\s_tmp4:\s_tmp4+1], s[\s_tmp4+2:\s_tmp4+3] + v_add_co_u32 v[\v_tmp4+1], vcc, -1, v[\v_tmp4+0] + v_cndmask_b32 v[\v_tmp4+2], v[\v_tmp4+0], v[\v_tmp4+2], s[\s_tmp4+2:\s_tmp4+3] + v_cndmask_b32 v[\v_tmp4+2], v[\v_tmp4+1], v[\v_tmp4+2], s[\s_tmp4:\s_tmp4+1] + v_cmp_ne_i32 vcc, 0, v[\v_d] + v_cndmask_b32 v[\v_q], -1, v[\v_tmp4+2], vcc +.endm + +.macro .v_u32_div_vs v_q, v_n, s_d, v_tmp4, s_tmp4 + v_cvt_f32_u32 v[\v_tmp4+0], s[\s_d] + v_rcp_f32 v[\v_tmp4+0], v[\v_tmp4+0] + v_mul_f32 v[\v_tmp4+0], 0x4f800000, v[\v_tmp4+0] + v_cvt_u32_f32 v[\v_tmp4+0], v[\v_tmp4+0] + v_mul_lo_u32 v[\v_tmp4+1], s[\s_d], v[\v_tmp4+0] + v_mul_hi_u32 v[\v_tmp4+2], s[\s_d], v[\v_tmp4+0] + v_sub_co_u32 v[\v_tmp4+3], vcc, 0, v[\v_tmp4+1] + v_cmp_ne_i32 s[\s_tmp4:\s_tmp4+1], 0, v[\v_tmp4+2] + v_cndmask_b32 v[\v_tmp4+1], v[\v_tmp4+3], v[\v_tmp4+1], s[\s_tmp4:\s_tmp4+1] + v_mul_hi_u32 v[\v_tmp4+1], v[\v_tmp4+1], v[\v_tmp4+0] + v_sub_co_u32 v[\v_tmp4+2], vcc, v[\v_tmp4+0], v[\v_tmp4+1] + v_add_co_u32 v[\v_tmp4+0], vcc, v[\v_tmp4+0], v[\v_tmp4+1] + v_cndmask_b32 v[\v_tmp4+0], v[\v_tmp4+0], v[\v_tmp4+2], s[\s_tmp4:\s_tmp4+1] + v_mul_hi_u32 v[\v_tmp4+0], v[\v_tmp4+0], v[\v_n] + v_mul_lo_u32 v[\v_tmp4+1], s[\s_d], v[\v_tmp4+0] + v_sub_co_u32 v[\v_tmp4+2], vcc, v[\v_n], v[\v_tmp4+1] + v_cmp_ge_u32 s[\s_tmp4:\s_tmp4+1], v[\v_n], v[\v_tmp4+1] + v_cmp_le_u32 s[\s_tmp4+2:\s_tmp4+3], s[\s_d], v[\v_tmp4+2] + v_add_co_u32 v[\v_tmp4+2], vcc, 1, v[\v_tmp4+0] + s_and_b64 s[\s_tmp4+2:\s_tmp4+3], s[\s_tmp4:\s_tmp4+1], s[\s_tmp4+2:\s_tmp4+3] + v_add_co_u32 v[\v_tmp4+1], vcc, -1, v[\v_tmp4+0] + v_cndmask_b32 v[\v_tmp4+2], v[\v_tmp4+0], v[\v_tmp4+2], s[\s_tmp4+2:\s_tmp4+3] + v_cndmask_b32 v[\v_tmp4+2], v[\v_tmp4+1], v[\v_tmp4+2], s[\s_tmp4:\s_tmp4+1] + v_cmp_ne_i32 vcc, s[\s_d], 0 + v_cndmask_b32 v[\v_q], -1, v[\v_tmp4+2], vcc +.endm + +.macro .v_u32_div_ss v_q, s_n, s_d, v_tmp4, s_tmp4 + v_cvt_f32_u32 v[\v_tmp4+0], s[\s_d] + v_rcp_f32 v[\v_tmp4+0], v[\v_tmp4+0] + v_mul_f32 v[\v_tmp4+0], 0x4f800000, v[\v_tmp4+0] + v_cvt_u32_f32 v[\v_tmp4+0], v[\v_tmp4+0] + v_mul_lo_u32 v[\v_tmp4+1], s[\s_d], v[\v_tmp4+0] + v_mul_hi_u32 v[\v_tmp4+2], s[\s_d], v[\v_tmp4+0] + v_sub_co_u32 v[\v_tmp4+3], vcc, 0, v[\v_tmp4+1] + v_cmp_ne_i32 s[\s_tmp4:\s_tmp4+1], 0, v[\v_tmp4+2] + v_cndmask_b32 v[\v_tmp4+1], v[\v_tmp4+3], v[\v_tmp4+1], s[\s_tmp4:\s_tmp4+1] + v_mul_hi_u32 v[\v_tmp4+1], v[\v_tmp4+1], v[\v_tmp4+0] + v_sub_co_u32 v[\v_tmp4+2], vcc, v[\v_tmp4+0], v[\v_tmp4+1] + v_add_co_u32 v[\v_tmp4+0], vcc, v[\v_tmp4+0], v[\v_tmp4+1] + v_cndmask_b32 v[\v_tmp4+0], v[\v_tmp4+0], v[\v_tmp4+2], s[\s_tmp4:\s_tmp4+1] + v_mul_hi_u32 v[\v_tmp4+0], s[\s_n], v[\v_tmp4+0] + v_mul_lo_u32 v[\v_tmp4+1], s[\s_d], v[\v_tmp4+0] + v_sub_co_u32 v[\v_tmp4+2], vcc, s[\s_n], v[\v_tmp4+1] + v_cmp_ge_u32 s[\s_tmp4:\s_tmp4+1], s[\s_n], v[\v_tmp4+1] + v_cmp_le_u32 s[\s_tmp4+2:\s_tmp4+3], s[\s_d], v[\v_tmp4+2] + v_add_co_u32 v[\v_tmp4+2], vcc, 1, v[\v_tmp4+0] + s_and_b64 s[\s_tmp4+2:\s_tmp4+3], s[\s_tmp4:\s_tmp4+1], s[\s_tmp4+2:\s_tmp4+3] + v_add_co_u32 v[\v_tmp4+1], vcc, -1, v[\v_tmp4+0] + v_cndmask_b32 v[\v_tmp4+2], v[\v_tmp4+0], v[\v_tmp4+2], s[\s_tmp4+2:\s_tmp4+3] + v_cndmask_b32 v[\v_tmp4+2], v[\v_tmp4+1], v[\v_tmp4+2], s[\s_tmp4:\s_tmp4+1] + v_cmp_ne_i32 vcc, s[\s_d], 0 + v_cndmask_b32 v[\v_q], -1, v[\v_tmp4+2], vcc +.endm + +.macro .v_clear_nc vid, num + _v = \vid + .rept \num + v_mov_b32 v[_v], 0 + _v = _v + 1 + .endr +.endm + +.macro .v_fma_4x4_s8 c, a, b + v_mac_f32 v[\c], v[\a], v[\b] + v_mac_f32 v[\c+1], v[\a], v[\b+1] + v_mac_f32 v[\c+2], v[\a], v[\b+2] + v_mac_f32 v[\c+3], v[\a], v[\b+3] + v_mac_f32 v[\c+8], v[\a+1], v[\b] + v_mac_f32 v[\c+9], v[\a+1], v[\b+1] + v_mac_f32 v[\c+10], v[\a+1], v[\b+2] + v_mac_f32 v[\c+11], v[\a+1], v[\b+3] + v_mac_f32 v[\c+16], v[\a+2], v[\b] + v_mac_f32 v[\c+17], v[\a+2], v[\b+1] + v_mac_f32 v[\c+18], v[\a+2], v[\b+2] + v_mac_f32 v[\c+19], v[\a+2], v[\b+3] + v_mac_f32 v[\c+24], v[\a+3], v[\b] + v_mac_f32 v[\c+25], v[\a+3], v[\b+1] + v_mac_f32 v[\c+26], v[\a+3], v[\b+2] + v_mac_f32 v[\c+27], v[\a+3], v[\b+3] +.endm + +.macro .v_fma_2x4_s8 c, a, b + v_mac_f32 v[\c], v[\a], v[\b] + v_mac_f32 v[\c+1], v[\a], v[\b+1] + v_mac_f32 v[\c+2], v[\a], v[\b+2] + v_mac_f32 v[\c+3], v[\a], v[\b+3] + v_mac_f32 v[\c+8], v[\a+1], v[\b] + v_mac_f32 v[\c+9], v[\a+1], v[\b+1] + v_mac_f32 v[\c+10], v[\a+1], v[\b+2] + v_mac_f32 v[\c+11], v[\a+1], v[\b+3] +.endm + +.macro .v_fma_4x2_s4 c, a, b + v_mac_f32 v[\c], v[\a], v[\b] + v_mac_f32 v[\c+1], v[\a], v[\b+1] + v_mac_f32 v[\c+4], v[\a+1], v[\b] + v_mac_f32 v[\c+5], v[\a+1], v[\b+1] + v_mac_f32 v[\c+8], v[\a+2], v[\b] + v_mac_f32 v[\c+9], v[\a+2], v[\b+1] + v_mac_f32 v[\c+12], v[\a+3], v[\b] + v_mac_f32 v[\c+13], v[\a+3], v[\b+1] +.endm + +.macro .v_fma_2x2_s4 c, a, b + v_mac_f32 v[\c], v[\a], v[\b] + v_mac_f32 v[\c+1], v[\a], v[\b+1] + v_mac_f32 v[\c+4], v[\a+1], v[\b] + v_mac_f32 v[\c+5], v[\a+1], v[\b+1] +.endm + +; update v_out_flag for output +.macro .v_out_set_flag v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo, s_tmp2 + ; flag: 0<= * n*dslice_h*dslice_w, dslice_h,dslice_y -> hip, dslice_w,dslicw_x -> wip +.macro .v_in_transform_gemm_n v_in_in, v_in_ihi, v_in_iwi, v_in_gemm_in, v_in_dslice_h, v_in_dslice_w, s_stride_dslice_hw, s_dslice_w, s_dtile_iy, s_dtile_ix, s_dslice_h_left, s_dslice_w_left, s_dilation_h, s_dilation_w, s_stride_h, s_stride_w, s_pad_h, s_pad_w, v_tmp4, s_tmp4 + ; n -> n*dslice_h*dslice_w + .v_u32_div_vs \v_in_in, \v_in_gemm_in, \s_stride_dslice_hw, \v_tmp4, \s_tmp4 + v_mul_lo_u32 v[\v_tmp4+1], s[\s_stride_dslice_hw], v[\v_in_in] + v_sub_u32 v[\v_in_gemm_in], v[\v_in_gemm_in], v[\v_tmp4+1] + .v_u32_div_vs \v_in_dslice_h, \v_in_gemm_in, \s_dslice_w, \v_tmp4, \s_tmp4 + v_mul_lo_u32 v[\v_tmp4+1], s[\s_dslice_w], v[\v_in_dslice_h] + v_sub_u32 v[\v_in_dslice_w], v[\v_in_gemm_in], v[\v_tmp4+1] + + ; + v_add_u32 v[\v_in_dslice_h], s[\s_dslice_h_left], v[\v_in_dslice_h] + v_add_u32 v[\v_in_dslice_w], s[\s_dslice_w_left], v[\v_in_dslice_w] + + ; dslice_h,dslice_y -> hip, dslice_w,dslicw_x -> wip + s_mul_i32 s[\s_tmp4], s[\s_dtile_iy], s[\s_dilation_h] + v_mul_lo_u32 v[\v_tmp4], s[\s_stride_h], v[\v_in_dslice_h] + v_add_u32 v[\v_tmp4], s[\s_tmp4], v[\v_tmp4] + s_mul_i32 s[\s_tmp4+1], s[\s_dtile_ix], s[\s_dilation_w] + v_mul_lo_u32 v[\v_tmp4+1], s[\s_stride_w], v[\v_in_dslice_w] + v_add_u32 v[\v_tmp4+1], s[\s_tmp4+1], v[\v_tmp4+1] + ; v_tmp4: hip, v_tmp4+1: wip + + ; hip->h, wip->w + v_sub_i32 v[\v_in_ihi], v[\v_tmp4], s[\s_pad_h] + v_sub_i32 v[\v_in_iwi], v[\v_tmp4+1], s[\s_pad_w] +.endm + +.macro .v_in_move_step_n1 v_in_in, v_in_ihi, v_in_iwi, v_in_dslice_h, v_in_dslice_w, v_dtile_iy_x_dilation_h, v_dtile_ix_x_dilation_w, s_in_stride_n, s_dslice_h_left, s_dslice_w_left, s_dslice_h_shifted, s_dslice_w_shifted, s_dtile_iy, s_dtile_ix, s_dilation_h, s_dilation_w, s_stride_h, s_stride_w, s_pad_h, s_pad_w, v_tmp2, s_tmp2 + ; n -> n*dslice_h*dslice_w + v_add_u32 v[\v_in_dslice_w], 1, v[\v_in_dslice_w] + + v_cmpx_le_u32 vcc, s[\s_dslice_w_shifted], v[\v_in_dslice_w] + v_mov_b32 v[\v_in_dslice_w], s[\s_dslice_w_left] + v_add_u32 v[\v_in_dslice_h], 1, v[\v_in_dslice_h] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 vcc, s[\s_dslice_h_shifted], v[\v_in_dslice_h] + v_mov_b32 v[\v_in_dslice_h], s[\s_dslice_h_left] + v_add_u32 v[\v_in_in], s[\s_in_stride_n], v[\v_in_in] + s_mov_b64 exec, -1 + + ; dslice_h,dslice_y -> hip, dslice_w,dslicw_x -> wip + ;s_mul_i32 s[\s_tmp2], s[\s_dtile_iy], s[\s_dilation_h] + ;v_mul_lo_u32 v[\v_tmp2], s[\s_stride_h], v[\v_in_dslice_h] + ;v_add_u32 v[\v_tmp2], s[\s_tmp2], v[\v_tmp2] + ;s_mul_i32 s[\s_tmp2+1], s[\s_dtile_ix], s[\s_dilation_w] + ;v_mul_lo_u32 v[\v_tmp2+1], s[\s_stride_w], v[\v_in_dslice_w] + ;v_add_u32 v[\v_tmp2+1], s[\s_tmp2+1], v[\v_tmp2+1] + v_mad_u32_u24 v[\v_tmp2], s[\s_stride_h], v[\v_in_dslice_h], v[\v_dtile_iy_x_dilation_h] + v_mad_u32_u24 v[\v_tmp2+1], s[\s_stride_w], v[\v_in_dslice_w], v[\v_dtile_ix_x_dilation_w] + ; v_tmp2: hip, v_tmp2+1: wip + + ; hip->h, wip->w + v_sub_i32 v[\v_in_ihi], v[\v_tmp2], s[\s_pad_h] + v_sub_i32 v[\v_in_iwi], v[\v_tmp2+1], s[\s_pad_w] +.endm + +.macro .v_in_move_slice_window v_in_in, v_in_ihi, v_in_iwi, v_in_dslice_h, v_in_dslice_w, v_dtile_iy_x_dilation_h, v_dtile_ix_x_dilation_w, s_in_stride_n, s_move_slice_in_in, s_move_slice_in_dslice_h, s_move_slice_in_dslice_w, s_dslice_h_shifted, s_dslice_w_shifted, s_dslice_h, s_dslice_w, s_dtile_iy, s_dtile_ix, s_dilation_h, s_dilation_w, s_stride_h, s_stride_w, s_pad_h, s_pad_w, v_tmp2, s_tmp2 + ; n -> n*dslice_h*dslice_w + v_add_u32 v[\v_in_dslice_w], s[\s_move_slice_in_dslice_w], v[\v_in_dslice_w] + + v_cmpx_le_u32 vcc, s[\s_dslice_w_shifted], v[\v_in_dslice_w] + v_subrev_u32 v[\v_in_dslice_w], s[\s_dslice_w], v[\v_in_dslice_w] + v_add_u32 v[\v_in_dslice_h], 1, v[\v_in_dslice_h] + s_mov_b64 exec, -1 + + v_add_u32 v[\v_in_dslice_h], s[\s_move_slice_in_dslice_h], v[\v_in_dslice_h] + v_cmpx_le_u32 vcc, s[\s_dslice_h_shifted], v[\v_in_dslice_h] + v_subrev_u32 v[\v_in_dslice_h], s[\s_dslice_h], v[\v_in_dslice_h] + v_add_u32 v[\v_in_in], s[\s_in_stride_n], v[\v_in_in] + s_mov_b64 exec, -1 + + v_add_u32 v[\v_in_in], s[\s_move_slice_in_in], v[\v_in_in] + + ; dslice_h,dslice_y -> hip, dslice_w,dslicw_x -> wip + ; s_mul_i32 s[\s_tmp2], s[\s_dtile_iy], s[\s_dilation_h] + ; v_mul_lo_u32 v[\v_tmp2], s[\s_stride_h], v[\v_in_dslice_h] + ; v_add_u32 v[\v_tmp2], s[\s_tmp2], v[\v_tmp2] + ; s_mul_i32 s[\s_tmp2+1], s[\s_dtile_ix], s[\s_dilation_w] + ; v_mul_lo_u32 v[\v_tmp2+1], s[\s_stride_w], v[\v_in_dslice_w] + ; v_add_u32 v[\v_tmp2+1], s[\s_tmp2+1], v[\v_tmp2+1] + v_mad_u32_u24 v[\v_tmp2], s[\s_stride_h], v[\v_in_dslice_h], v[\v_dtile_iy_x_dilation_h] + v_mad_u32_u24 v[\v_tmp2+1], s[\s_stride_w], v[\v_in_dslice_w], v[\v_dtile_ix_x_dilation_w] + ; v_tmp2: hip, v_tmp2+1: wip + + ; hip->h, wip->w + v_sub_i32 v[\v_in_ihi], v[\v_tmp2], s[\s_pad_h] + v_sub_i32 v[\v_in_iwi], v[\v_tmp2+1], s[\s_pad_w] +.endm + +.macro .v_in_write_m0_m1_n0_n1_step v_src, s_p_in, v_in_os, v_in_flag, v_in_ihi_itr, v_in_iwi_itr, v_in_in_itr, v_in_dslice_h_itr, v_in_dslice_w_itr, v_in_ic_itr, v_in_in, v_in_ic, v_in_ihi, v_in_iwi, v_in_dslice_h, v_in_dslice_w, v_dtile_iy_x_dilation_h, v_dtile_ix_x_dilation_w, s_move_slice_in_in, s_move_slice_in_dslice_h, s_move_slice_in_dslice_w, s_dslice_h_left, s_dslice_w_left, s_dslice_h_shifted, s_dslice_w_shifted, s_dslice_h, s_dslice_w, s_dtile_iy, s_dtile_ix, s_dilation_h, s_dilation_w, s_stride_h, s_stride_w, s_pad_h, s_pad_w, s_in_stride_n, s_in_stride_c, s_in_stride_c_m0, s_in_stride_hi, s_hi, s_wi, v_tmp2, s_tmp2, k_m0, k_m1, k_n0, k_n1 + v_mov_b32 v[\v_in_in_itr], v[\v_in_in] + v_mov_b32 v[\v_in_dslice_h_itr], v[\v_in_dslice_h] + v_mov_b32 v[\v_in_dslice_w_itr], v[\v_in_dslice_w] + v_mov_b32 v[\v_in_ic_itr], v[\v_in_ic] + v_mov_b32 v[\v_in_ihi_itr], v[\v_in_ihi] + v_mov_b32 v[\v_in_iwi_itr], v[\v_in_iwi] + .itr_m0 = 0 + .rept \k_m0 + + .itr_m1 = 0 + .rept \k_m1 + + .itr_n0 = 0 + .rept \k_n0 + + .itr_n1 = 0 + .rept \k_n1 + v_cmpx_eq_u32 vcc, 1, v[\v_in_flag] + buffer_store_dword v[\v_src + (.itr_m0 * 32 + .itr_m1 * 8 + .itr_n0 * 4 + .itr_n1) ], v[\v_in_os], s[\s_p_in:\s_p_in+3], 0 offen + s_mov_b64 exec, -1 + + .if .itr_n1 != \k_n1 - 1 + .v_in_move_step_n1 \v_in_in_itr, \v_in_ihi_itr, \v_in_iwi_itr, \v_in_dslice_h_itr, \v_in_dslice_w_itr, \v_dtile_iy_x_dilation_h, \v_dtile_ix_x_dilation_w, \s_in_stride_n, \s_dslice_h_left, \s_dslice_w_left, \s_dslice_h_shifted, \s_dslice_w_shifted, \s_dtile_iy, \s_dtile_ix, \s_dilation_h, \s_dilation_w, \s_stride_h, \s_stride_w, \s_pad_h, \s_pad_w, \v_tmp2, \s_tmp2 + .v_in_calculate_os \v_in_os, \v_in_in_itr, \v_in_ic_itr, \v_in_ihi_itr, \v_in_iwi_itr, \s_in_stride_n, \s_in_stride_c, \s_in_stride_hi, \v_tmp2 + .v_in_set_flag \v_in_flag, \v_in_ihi_itr, \v_in_iwi_itr, \s_hi, \s_wi, \s_tmp2 + .endif + .itr_n1 = .itr_n1 + 1 + .endr ; n1 + .if .itr_n0 != \k_n0 - 1 + v_mov_b32 v[\v_in_in_itr], v[\v_in_in] + v_mov_b32 v[\v_in_dslice_h_itr], v[\v_in_dslice_h] + v_mov_b32 v[\v_in_dslice_w_itr], v[\v_in_dslice_w] + + .v_in_move_slice_window \v_in_in_itr, \v_in_ihi_itr, \v_in_iwi_itr, \v_in_dslice_h_itr, \v_in_dslice_w_itr, \v_dtile_iy_x_dilation_h, \v_dtile_ix_x_dilation_w, \s_in_stride_n, \s_move_slice_in_in, \s_move_slice_in_dslice_h, \s_move_slice_in_dslice_w, \s_dslice_h_shifted, \s_dslice_w_shifted, \s_dslice_h, \s_dslice_w, \s_dtile_iy, \s_dtile_ix, \s_dilation_h, \s_dilation_w, \s_stride_h, \s_stride_w, \s_pad_h, \s_pad_w, \v_tmp2, \s_tmp2 + .v_in_calculate_os \v_in_os, \v_in_in_itr, \v_in_ic_itr, \v_in_ihi_itr, \v_in_iwi_itr, \s_in_stride_n, \s_in_stride_c, \s_in_stride_hi, \v_tmp2 + .v_in_set_flag \v_in_flag, \v_in_ihi_itr, \v_in_iwi_itr, \s_hi, \s_wi, \s_tmp2 + .endif + .itr_n0 = .itr_n0 + 1 + .endr ; n0 + + .if .itr_m1 != \k_m1 - 1 + v_mov_b32 v[\v_in_in_itr], v[\v_in_in] + v_mov_b32 v[\v_in_dslice_h_itr], v[\v_in_dslice_h] + v_mov_b32 v[\v_in_dslice_w_itr], v[\v_in_dslice_w] + v_mov_b32 v[\v_in_ihi_itr], v[\v_in_ihi] + v_mov_b32 v[\v_in_iwi_itr], v[\v_in_iwi] + v_add_u32 v[\v_in_ic_itr], s[\s_in_stride_c], v[\v_in_ic_itr] + .v_in_calculate_os \v_in_os, \v_in_in_itr, \v_in_ic_itr, \v_in_ihi_itr, \v_in_iwi_itr, \s_in_stride_n, \s_in_stride_c, \s_in_stride_hi, \v_tmp2 + .v_in_set_flag \v_in_flag, \v_in_ihi_itr, \v_in_iwi_itr, \s_hi, \s_wi, \s_tmp2 + .endif + .itr_m1 = .itr_m1 + 1 + .endr ; m1 + + .if .itr_m0 != \k_m0 - 1 + v_mov_b32 v[\v_in_in_itr], v[\v_in_in] + v_mov_b32 v[\v_in_dslice_h_itr], v[\v_in_dslice_h] + v_mov_b32 v[\v_in_dslice_w_itr], v[\v_in_dslice_w] + v_mov_b32 v[\v_in_ihi_itr], v[\v_in_ihi] + v_mov_b32 v[\v_in_iwi_itr], v[\v_in_iwi] + v_add_u32 v[\v_in_ic_itr], s[\s_in_stride_c_m0], v[\v_in_ic] + .v_in_calculate_os \v_in_os, \v_in_in_itr, \v_in_ic_itr, \v_in_ihi_itr, \v_in_iwi_itr, \s_in_stride_n, \s_in_stride_c, \s_in_stride_hi, \v_tmp2 + .v_in_set_flag \v_in_flag, \v_in_ihi_itr, \v_in_iwi_itr, \s_hi, \s_wi, \s_tmp2 + .endif + .itr_m0 = .itr_m0 + 1 + .endr ; m0 +.endm + + +; store input to LDS. {k, n}:{8, 1} +.macro .v_out_sst_k_n_8_1 v_src, v_sst_os + ds_write2st64_b32 v[\v_sst_os], v[\v_src+0], v[\v_src+1] offset0:0 offset1:2 + ds_write2st64_b32 v[\v_sst_os], v[\v_src+2], v[\v_src+3] offset0:4 offset1:6 + ds_write2st64_b32 v[\v_sst_os], v[\v_src+4], v[\v_src+5] offset0:8 offset1:10 + ds_write2st64_b32 v[\v_sst_os], v[\v_src+6], v[\v_src+7] offset0:12 offset1:14 +.endm + +; store weight to LDS. {k, n}:{8, 1} +.macro .v_wei_sst_k_n_8_1 v_src, v_sst_os + ds_write2st64_b32 v[\v_sst_os], v[\v_src+0], v[\v_src+1] offset0:0 offset1:2 + ds_write2st64_b32 v[\v_sst_os], v[\v_src+2], v[\v_src+3] offset0:4 offset1:6 + ds_write2st64_b32 v[\v_sst_os], v[\v_src+4], v[\v_src+5] offset0:8 offset1:10 + ds_write2st64_b32 v[\v_sst_os], v[\v_src+6], v[\v_src+7] offset0:12 offset1:14 +.endm + +.macro .s_bwd_gtc_move_slice_window_k_dsy_dsx s_move_slice_k_ik, s_move_slice_k_idsy, s_move_slice_k_idsx, s_gemm_k_num_dsy, s_gemm_k_num_dsx, k_step_k, k_step_dsy, k_step_dsx, v_out_os_base, v_wei_os_base, s_out_move_slice_stride_k, s_wei_move_slice_stride_k + ; order always x->y->k first + ; if k=k0*k1, always iterate k then compute back to k0, k1 + ; update offset if k need step, for both output, weight + s_add_u32 s[\s_move_slice_k_idsx], \k_step_dsx, s[\s_move_slice_k_idsx] + s_cmp_lt_u32 s[\s_move_slice_k_idsx], s[\s_gemm_k_num_dsx] + s_cbranch_scc1 s_bwd_gtc_move_slice_window_k_dsy_dsx_L0_\@ + s_mov_b32 s[\s_move_slice_k_idsx], 0 + s_add_u32 s[\s_move_slice_k_idsy], \k_step_dsy, s[\s_move_slice_k_idsy] + s_bwd_gtc_move_slice_window_k_dsy_dsx_L0_\@: + s_cmp_lt_u32 s[\s_move_slice_k_idsy], s[\s_gemm_k_num_dsy] + s_cbranch_scc1 s_bwd_gtc_move_slice_window_k_dsy_dsx_L1_\@ + v_add_u32 v[\v_out_os_base], s[\s_out_move_slice_stride_k], v[\v_out_os_base] + s_mov_b32 s[\s_move_slice_k_idsy], 0 + v_add_u32 v[\v_wei_os_base], s[\s_wei_move_slice_stride_k], v[\v_wei_os_base] + s_add_u32 s[\s_move_slice_k_ik], \k_step_k, s[\s_move_slice_k_ik] + s_bwd_gtc_move_slice_window_k_dsy_dsx_L1_\@: +.endm + +.macro .v_bwd_gtc_out_update_hw v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, s_out_dslice_iy, s_out_dslice_ix, s_dtile_dy, s_dtile_dx, s_tmp2 + ; dslice_y,dslice_h -> oh, dslice_x,dslice_w -> ow + s_mul_i32 s[\s_tmp2], s[\s_dtile_dy], s[\s_out_dslice_iy] + v_subrev_u32 v[\v_out_iho], s[\s_tmp2], v[\v_out_dslice_ih] + s_mul_i32 s[\s_tmp2+1], s[\s_dtile_dx], s[\s_out_dslice_ix] + v_subrev_u32 v[\v_out_iwo], s[\s_tmp2+1], v[\v_out_dslice_iw] +.endm + +.macro .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + ; from ho, wo, os_base, compute final offset + v_mad_u32_u24 v[\v_tmp], s[\s_wo], v[\v_out_iho], v[\v_out_iwo] + v_lshl_add_u32 v[\v_out_os], v[\v_tmp], 2, v[\v_out_os_base] +.endm + +.macro .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + ; from y, x, os_base, compute final offset + v_mad_u32_u24 v[\v_tmp], v[\v_wei_iy], s[\s_x], v[\v_wei_ix] + v_lshl_add_u32 v[\v_wei_os], v[\v_tmp], 2, v[\v_wei_os_base] +.endm + +.macro .v_bwd_gtc_wei_update_yx v_wei_iy, v_wei_ix, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_y, s_dtile_x, v_dtile_iy, v_dtile_ix, s_tmp2 + s_mul_i32 s[\s_tmp2], s[\s_dtile_y], s[\s_move_slice_k_idsy] + v_add_u32 v[\v_wei_iy], s[\s_tmp2], v[\v_dtile_iy] + s_mul_i32 s[\s_tmp2+1], s[\s_dtile_x], s[\s_move_slice_k_idsx] + v_add_u32 v[\v_wei_ix], s[\s_tmp2+1], v[\v_dtile_ix] +.endm + +.macro .v_gld_1x2_b32_v1 v_dst, s_ptr, v_os, s_stride_d0, s_stride_d1, s_tmp2 + .v_clear_nc \v_dst, 2 + buffer_load_dword v[\v_dst+0], v[\v_os], s[\s_ptr:\s_ptr+3], 0 offen offset:0 + s_mov_b32 s[\s_tmp2], s[\s_stride_d1] + buffer_load_dword v[\v_dst+1], v[\v_os], s[\s_ptr:\s_ptr+3], s[\s_tmp2] offen offset:0 +.endm + +.macro .v_gld_1x4_b32_v1 v_dst, s_ptr, v_os, s_stride_d0, s_stride_d1, s_tmp2 + .v_clear_nc \v_dst, 4 + buffer_load_dword v[\v_dst+0], v[\v_os], s[\s_ptr:\s_ptr+3], 0 offen offset:0 + s_mov_b32 s[\s_tmp2], s[\s_stride_d1] + buffer_load_dword v[\v_dst+1], v[\v_os], s[\s_ptr:\s_ptr+3], s[\s_tmp2] offen offset:0 + s_add_u32 s[\s_tmp2], s[\s_tmp2], s[\s_stride_d1] + buffer_load_dword v[\v_dst+2], v[\v_os], s[\s_ptr:\s_ptr+3], s[\s_tmp2] offen offset:0 + s_add_u32 s[\s_tmp2], s[\s_tmp2], s[\s_stride_d1] + buffer_load_dword v[\v_dst+3], v[\v_os], s[\s_ptr:\s_ptr+3], s[\s_tmp2] offen offset:0 +.endm + +.macro .v_gld_2x4_b32_v1 v_dst, s_ptr, v_os, s_stride_d0, s_stride_d1, s_tmp2 + .v_clear_nc \v_dst, 8 + buffer_load_dword v[\v_dst+0], v[\v_os], s[\s_ptr:\s_ptr+3], 0 offen offset:0 + s_mov_b32 s[\s_tmp2], s[\s_stride_d1] + buffer_load_dword v[\v_dst+1], v[\v_os], s[\s_ptr:\s_ptr+3], s[\s_tmp2] offen offset:0 + s_add_u32 s[\s_tmp2], s[\s_tmp2], s[\s_stride_d1] + buffer_load_dword v[\v_dst+2], v[\v_os], s[\s_ptr:\s_ptr+3], s[\s_tmp2] offen offset:0 + s_add_u32 s[\s_tmp2], s[\s_tmp2], s[\s_stride_d1] + buffer_load_dword v[\v_dst+3], v[\v_os], s[\s_ptr:\s_ptr+3], s[\s_tmp2] offen offset:0 + s_mov_b32 s[\s_tmp2+1], s[\s_stride_d0] + buffer_load_dword v[\v_dst+4], v[\v_os], s[\s_ptr:\s_ptr+3], s[\s_tmp2+1] offen offset:0 + s_add_u32 s[\s_tmp2], s[\s_tmp2+1], s[\s_stride_d1] + buffer_load_dword v[\v_dst+5], v[\v_os], s[\s_ptr:\s_ptr+3], s[\s_tmp2] offen offset:0 + s_add_u32 s[\s_tmp2], s[\s_tmp2], s[\s_stride_d1] + buffer_load_dword v[\v_dst+6], v[\v_os], s[\s_ptr:\s_ptr+3], s[\s_tmp2] offen offset:0 + s_add_u32 s[\s_tmp2], s[\s_tmp2], s[\s_stride_d1] + buffer_load_dword v[\v_dst+7], v[\v_os], s[\s_ptr:\s_ptr+3], s[\s_tmp2] offen offset:0 +.endm + +.macro .v_set_flag_hw v_flag, v_ih, v_iw, s_h, s_w + v_cmp_gt_u32 vcc, s[\s_h], v[\v_ih] + v_cndmask_b32 v[\v_flag], 0, 1, vcc + v_cmp_gt_u32 vcc, s[\s_w], v[\v_iw] + v_cndmask_b32 v[\v_flag], 0, v[\v_flag], vcc +.endm + +.macro .v_sst_so0_1x2_b32_v2 v_src, v_sst_os + ds_write_b64 v[\v_sst_os], v[\v_src+0:\v_src+0+1] +.endm + +.macro .v_sst_so0_1x4_b32_v4 v_src, v_sst_os + ds_write_b128 v[\v_sst_os], v[\v_src+0:\v_src+0+3] +.endm + +.macro .v_sst_so0_2x4_b32_v4_st16 v_src, v_sst_os + ds_write_b128 v[\v_sst_os], v[\v_src+0:\v_src+0+3] + ds_write_b128 v[\v_sst_os], v[\v_src+4:\v_src+4+3] offset:16 +.endm + +.macro .v_sst_so0_2x4_b32_v4_st256 v_src, v_sst_os + ds_write_b128 v[\v_sst_os], v[\v_src+0:\v_src+0+3] + ds_write_b128 v[\v_sst_os], v[\v_src+4:\v_src+4+3] offset:256 +.endm + +.macro .v_sst_so0_2x4_b32_v4_st512 v_src, v_sst_os + ds_write_b128 v[\v_sst_os], v[\v_src+0:\v_src+0+3] + ds_write_b128 v[\v_sst_os], v[\v_src+4:\v_src+4+3] offset:512 +.endm + +; write 1d tensor to global with stride +.macro .v_write1d_strided v_src, s_p_buf_dst, v_dst_os, s_dst_diff, s_dst_os, t_dim_1d, vec_1d=1 + .itr_1d = 0 + .if \t_dim_1d % \vec_1d != 0 + .error "\t_dim_1d can not evenly divided by \vec_1d" + .end + .endif + .t_dim_1d_v = \t_dim_1d / \vec_1d + .rept .t_dim_1d_v + .if \vec_1d == 1 + buffer_store_dword v[\v_src+.itr_1d], v[\v_dst_os], s[\s_p_buf_dst:\s_p_buf_dst+3], s[\s_dst_os] offen + .elseif \vec_1d == 2 + buffer_store_dwordx2 v[\v_src+.itr_1d:\v_src+.itr_1d+1], v[\v_dst_os], s[\s_p_buf_dst:\s_p_buf_dst+3], s[\s_dst_os] offen + .elseif \vec_1d == 4 + buffer_store_dwordx4 v[\v_src+.itr_1d:\v_src+.itr_1d+3], v[\v_dst_os], s[\s_p_buf_dst:\s_p_buf_dst+3], s[\s_dst_os] offen + .endif + .if .itr_1d != .t_dim_1d_v - 1 + s_add_u32 s[\s_dst_os], s[\s_dst_os], s[\s_dst_diff] + .endif + .itr_1d = .itr_1d + \vec_1d + .endr +.endm + +; write 2d tensor to global with stride +.macro .v_write2d_strided v_src, s_p_dst, v_dst_os, s_dst_diff1d, s_dst_diff2d, s_dst_os_2, t_dim_1d, t_dim_2d, vec_1d=1 + .itr_2d = 0 + .rept \t_dim_2d + .v_write1d_strided (\v_src + .itr_2d * \t_dim_1d), \s_p_dst, \v_dst_os, \s_dst_diff1d, \s_dst_os_2, \t_dim_1d + .if .itr_2d != \t_dim_2d - 1 + s_add_u32 s[\s_dst_os_2+1], s[\s_dst_os_2+1], s[\s_dst_diff2d] + s_mov_b32 s[\s_dst_os_2], s[\s_dst_os_2+1] + .endif + .itr_2d = .itr_2d + 1 + .endr +.endm + +; write 3d tensor to global with stride +.macro .v_write3d_strided v_src, s_p_dst, v_dst_os, s_dst_diff1d, s_dst_diff2d, s_dst_diff3d, s_dst_os_3, t_dim_1d, t_dim_2d, t_dim_3d, vec_1d=1 + .itr_3d = 0 + .rept \t_dim_3d + .v_write2d_strided (\v_src+ .itr_3d * \t_dim_1d * \t_dim_2d), \s_p_dst, \v_dst_os, \s_dst_diff1d, \s_dst_diff2d, \s_dst_os_3, \t_dim_1d, \t_dim_2d + .if .itr_3d != \t_dim_3d - 1 + s_add_u32 s[\s_dst_os_3+2], s[\s_dst_os_3+2], s[\s_dst_diff3d] + s_mov_b32 s[\s_dst_os_3+1], s[\s_dst_os_3+2] + s_mov_b32 s[\s_dst_os_3], s[\s_dst_os_3+1] + .endif + .itr_3d = .itr_3d + 1 + .endr +.endm + +; write 4d tensor to global with stride +.macro .v_write4d_strided v_src, s_p_dst, v_dst_os, s_dst_diff1d, s_dst_diff2d, s_dst_diff3d, s_dst_diff4d, s_dst_os_4, t_dim_1d, t_dim_2d, t_dim_3d, t_dim_4d, vec_1d=1 + .itr_4d = 0 + .rept \t_dim_4d + .v_write3d_strided (\v_src+ .itr_4d * \t_dim_1d * \t_dim_2d * \t_dim_3d), \s_p_dst, \v_dst_os, \s_dst_diff1d, \s_dst_diff2d, \s_dst_diff3d, \s_dst_os_4, \t_dim_1d, \t_dim_2d, \t_dim_3d + .if .itr_4d != \t_dim_4d - 1 + s_add_u32 s[\s_dst_os_4+3], s[\s_dst_os_4+3], s[\s_dst_diff4d] + s_mov_b32 s[\s_dst_os_4+2], s[\s_dst_os_4+3] + s_mov_b32 s[\s_dst_os_4+1], s[\s_dst_os_4+2] + s_mov_b32 s[\s_dst_os_4], s[\s_dst_os_4+1] + .endif + .itr_4d = .itr_4d + 1 + .endr +.endm +;---------------------------------------------------------- +; starting of kernel igemm_bwd_gtc +; block_size : 256 +; thread_tile : 8x8 +; +; GemmMPerBlock : 128 +; GemmNPerBlock : 128 +; GemmKPerBlock : 16 +; GemmMPerThread : 4 (x2) +; GemmNPerThread : 4 (x2) +; GemmKPerThread, : 1 +; GemmMLevel0Cluster, ; 4 +; GemmNLevel0Cluster, ; 4 +; GemmMLevel1Cluster, ; 4 +; GemmNLevel1Cluster, ; 4 +; GemmThreadGemmDataPerReadM, ; 4 +; GemmThreadGemmDataPerReadN, ; 4 +; GemmABlockCopyThreadSliceLengths_GemmK_GemmM : 8, 1 +; GemmABlockCopyThreadClusterLengths_GemmK_GemmM : 2, 128 +; GemmABlockCopySrcDataPerRead_GemmM : 1 +; GemmABlockCopyDstDataPerWrite_GemmM : 1 +; GemmBBlockCopyThreadSliceLengths_GemmK_GemmN : 8, 1 +; GemmBBlockCopyThreadClusterLengths_GemmK_GemmN : 2, 128 +; GemmBBlockCopySrcDataPerRead_GemmN : 1 +; GemmBBlockCopyDstDataPerWrite_GemmN : 1 +; GemmCThreadCopyDstDataPerWrite_GemmN1 : 1 +; kernarg offset +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_dtile_iy, 84 +.set k_dtile_ix, 88 +.set k_dtile_dy, 92 ; ConvDilationH / GcdStrideDilationH +.set k_dtile_dx, 96 ; ConvDilationW / GcdStrideDilationW +.set k_dtile_y, 100 ; ConvStrideH / GcdStrideDilationH +.set k_dtile_x, 104 ; ConvStrideW / GcdStrideDilationW +.set k_dtile_h, 108 ; Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); +.set k_dtile_w, 112 ; Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); +.set k_dslice_y, 116 ; YDotSlice +.set k_dslice_x, 120 ; XDotSlice +.set k_dslice_h, 124 ; HTildaSlice (iHTildaLeft->0) +.set k_dslice_w, 128 ; WTildaSlice (iWTildaLeft->0) +.set k_dslice_h_left, 132 ; HTildaLeft +.set k_dslice_w_left, 136 ; WTildaLeft +.set k_pack0, 140 +.set k_end, 144 + +; sgpr +.set s_ka, 0 +.set s_bx, 2 +.set s_p_in, 4 +.set s_p_wei, 8 +.set s_p_out, 12 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_dtile_iy, 31 +.set s_dtile_ix, 32 +.set s_dtile_dy, 33 +.set s_dtile_dx, 34 +.set s_dtile_y, 35 +.set s_dtile_x, 36 +.set s_dtile_h, 37 +.set s_dtile_w, 38 +.set s_dslice_y, 39 +.set s_dslice_x, 40 +.set s_dslice_y_neg, 2 +.set s_dslice_x_neg, 3 +.set s_dslice_h, 41 +.set s_dslice_w, 42 +.set s_dslice_h_left, 43 +.set s_dslice_w_left, 44 +.set s_dslice_h_shifted, 45 +.set s_dslice_w_shifted, 46 + +.set s_out_stride_k, 47 +.set s_out_stride_n, 48 +.set s_out_stride_ho, 49 +.set s_in_stride_c, 50 +.set s_in_stride_n, 51 +.set s_in_stride_hi, 52 +.set s_wei_stride_c, 53 +.set s_wei_stride_k, 54 +.set s_wei_stride_y, 55 +.set s_stride_dslice_hw, 56 +.set s_stride_dslice_yx, 57 +.set s_block_gemm_im, 58 +.set s_block_gemm_in, 59 +.set s_move_slice_ik, 60 +.set s_move_slice_ik_wei, 61 +.set s_move_slice_ik_out, 62 +.set s_move_slice_idslice_y, 63 +.set s_move_slice_idslice_x, 64 +.set s_move_slice_in_in, 65 +.set s_move_slice_in_dslice_h, 66 +.set s_move_slice_in_dslice_w, 67 +.set s_kitr, 0 +.set s_tmp, 68 +.set s_end, 72 + +; vgpr +.set v_c, 0 +.set v_a, 64 +.set v_b, 72 +.set v_gld_a, 80 +.set v_gld_b, 88 +.set v_out_os, 96 +.set v_wei_os, 97 +.set v_sst_a_os, 98 +.set v_sst_b_os, 99 +.set v_sld_a_os, 100 +.set v_sld_b_os, 101 +.set v_out_flag, 102 + +.set v_out_gemm_ik, 63 +.set v_out_gemm_in, 62 +.set v_wei_gemm_ik, 61 +.set v_wei_gemm_im, 60 +.set v_gemm_in, 59 +.set v_gemm_im, 58 + +.set v_out_ik, 103 +.set v_out_in, 104 ; ! This is indeed the offset from in_n. only valid if each thread only load different K direction +.set v_out_iho, 105 +.set v_out_iwo, 106 +.set v_out_dslice_ih, 107 +.set v_out_dslice_iw, 108 +.set v_out_dslice_iy, 109 +.set v_out_dslice_ix, 110 + +.set v_wei_ik, 111 +.set v_wei_ic, 112 +.set v_wei_iy, 113 +.set v_wei_ix, 114 +.set v_wei_dslice_iy, 115 +.set v_wei_dslice_ix, 116 + +.set v_in_gemm_in0, 117 +.set v_in_gemm_in1, 118 +.set v_in_ic, 119 +.set v_in_in, v_gld_a+0 +.set v_in_ihi, v_gld_a+1 +.set v_in_iwi, v_gld_a+2 +.set v_in_dslice_h, v_gld_a+3 +.set v_in_dslice_w, v_gld_a+4 +.set v_in_gemm_in, v_gld_a+5 + +.set v_in_dslice_h_itr, v_gld_a+6 +.set v_in_dslice_w_itr, v_gld_a+7 +.set v_in_in_itr, v_gld_a+8 +.set v_in_ic_itr, v_gld_a+9 +.set v_in_os, v_gld_a+10 +.set v_in_flag, v_gld_a+11 +.set v_in_ihi_itr, v_gld_a+12 +.set v_in_iwi_itr, v_gld_a+13 +.set v_dtile_iy_x_dilation_h, v_gld_a+14 +.set v_dtile_ix_x_dilation_w, v_gld_a+15 +.set v_tmp, 120 +;.set v_tid, 127 +.set v_dtile_iy, 126 +.set v_dtile_ix, 127 +.set v_end, 128 +.set v_in_gemm_in0_itr, v_in_dslice_h_itr +.set v_in_gemm_in1_itr, v_in_dslice_w_itr + +.text +.globl igemm_bwd_gtc +.p2align 8 +.type igemm_bwd_gtc,@function +igemm_bwd_gtc: + s_load_dwordx2 s[s_p_in:s_p_in+1], s[s_ka:s_ka+1], 0+k_p_in + s_load_dwordx2 s[s_p_wei:s_p_wei+1], s[s_ka:s_ka+1], 0+k_p_wei + s_load_dwordx2 s[s_p_out:s_p_out+1], s[s_ka:s_ka+1], 0+k_p_out + s_load_dwordx16 s[s_hi:s_hi+15], s[s_ka:s_ka+1], 0+k_hi + s_load_dwordx8 s[s_dtile_ix:s_dtile_ix+7], s[s_ka:s_ka+1], 0+k_dtile_ix + s_load_dwordx4 s[s_dslice_x:s_dslice_x+3], s[s_ka:s_ka+1], 0+k_dslice_x + s_load_dword s[s_dslice_w_left], s[s_ka:s_ka+1], 0+k_dslice_w_left + + ; GemmBBlockCopyThreadClusterLengths_GemmK_GemmN:{2,128}, slice:{8,1} + v_and_b32 v[v_out_gemm_in], 127, v0 + v_lshrrev_b32 v[v_tmp], 7, v0 + v_lshlrev_b32 v[v_out_gemm_ik], 3, v[v_tmp] + + ; GemmABlockCopyThreadClusterLengths_GemmK_GemmM:{2,128}, slice:{8,1} + v_and_b32 v[v_wei_gemm_im], 127, v0 + v_lshrrev_b32 v[v_tmp], 7, v0 + v_lshlrev_b32 v[v_wei_gemm_ik], 3, v[v_tmp] + + ; v_mov_b32 v[v_tid], v0 + + s_mov_b32 s[s_p_in + 2], 0xffffffff + s_mov_b32 s[s_p_in + 3], 0x27000 + s_mov_b32 s[s_p_wei + 2], 0xffffffff + s_mov_b32 s[s_p_wei + 3], 0x27000 + s_mov_b32 s[s_p_out + 2], 0xffffffff + s_mov_b32 s[s_p_out + 3], 0x27000 + s_waitcnt lgkmcnt(0) + + ; calculate index + s_mul_i32 s[s_out_stride_k], s[s_ho], s[s_wo] + s_mul_i32 s[s_out_stride_n], s[s_k], s[s_out_stride_k] + s_mul_i32 s[s_in_stride_c], s[s_hi], s[s_wi] + s_mul_i32 s[s_in_stride_n], s[s_c], s[s_in_stride_c] + s_mul_i32 s[s_wei_stride_c], s[s_y], s[s_x] + s_mul_i32 s[s_wei_stride_k], s[s_c], s[s_wei_stride_c] + s_mul_i32 s[s_stride_dslice_hw], s[s_dslice_h], s[s_dslice_w] + s_mul_i32 s[s_stride_dslice_yx], s[s_dslice_y], s[s_dslice_x] + + ; block gemm_m, gemm_n index on global + s_mul_i32 s[s_tmp], s[s_stride_dslice_hw], s[s_n] + s_lshr_b32 s[0], s[s_tmp], 7 ; gemm_n:128 + .v_u32_div_ss v_tmp+5, s_bx, 0, v_tmp, s_tmp + v_readfirstlane_b32 s[s_tmp], v[v_tmp+5] + s_mul_i32 s[s_tmp+2], s[s_tmp], s[0] + s_sub_i32 s[s_tmp+1], s[s_bx], s[s_tmp+2] + s_lshl_b32 s[s_block_gemm_im], s[s_tmp], 7 + s_lshl_b32 s[s_block_gemm_in], s[s_tmp+1], 7 + + ; calculate output transform + ; gemm_n -> n*dslice_h*dslice_w + v_add_u32 v[v_tmp+4], s[s_block_gemm_in], v[v_out_gemm_in] + .v_u32_div_vs v_out_in, v_tmp+4, s_stride_dslice_hw, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp], s[s_stride_dslice_hw], v[v_out_in] + v_sub_u32 v[v_tmp+4], v[v_tmp+4], v[v_tmp] + .v_u32_div_vs v_out_dslice_ih, v_tmp+4, s_dslice_w, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp], s[s_dslice_w], v[v_out_dslice_ih] + v_sub_u32 v[v_out_dslice_iw], v[v_tmp+4], v[v_tmp] + + ; iHTildaLeft, iWTildaLeft + v_add_u32 v[v_out_dslice_ih], s[s_dslice_h_left], v[v_out_dslice_ih] + v_add_u32 v[v_out_dslice_iw], s[s_dslice_w_left], v[v_out_dslice_iw] + + ; gemm_k -> k*dslice_y*dslice_x + .v_u32_div_vs v_out_ik, v_out_gemm_ik, s_stride_dslice_yx, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp], s[s_stride_dslice_yx], v[v_out_ik] + v_sub_u32 v[v_tmp+4], v[v_out_gemm_ik], v[v_tmp] + .v_u32_div_vs v_out_dslice_iy, v_tmp+4, s_dslice_x, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp], s[s_dslice_x], v[v_out_dslice_iy] + v_sub_u32 v[v_out_dslice_ix], v[v_tmp+4], v[v_tmp] + + ; dslice_y,dslice_h -> oh, dslice_x,dslice_w -> ow + v_mul_lo_u32 v[v_tmp+1], s[s_dtile_dy], v[v_out_dslice_iy] + v_sub_i32 v[v_out_iho], v[v_out_dslice_ih], v[v_tmp+1] + v_mul_lo_u32 v[v_tmp+1], s[s_dtile_dx], v[v_out_dslice_ix] + v_sub_i32 v[v_out_iwo], v[v_out_dslice_iw], v[v_tmp+1] + + s_mul_i32 s[s_dslice_x_neg], -1, s[s_dslice_x] + s_mul_i32 s[s_dslice_y_neg], -1, s[s_dslice_y] + + v_mul_i32_i24 v[v_out_dslice_iy], -1, v[v_out_dslice_iy] + v_mul_i32_i24 v[v_out_dslice_ix], -1, v[v_out_dslice_ix] + + ; update out flag + .v_out_set_flag v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo, s_tmp + + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 2 + s_lshl_b32 s[s_out_stride_k], s[s_out_stride_k], 2 + s_lshl_b32 s[s_out_stride_ho], s[s_wo], 2 + v_mul_lo_u32 v[v_out_in], s[s_out_stride_n], v[v_out_in] + v_mul_lo_u32 v[v_out_ik], s[s_out_stride_k], v[v_out_ik] + .v_out_calculate_os v_out_os, v_out_in, v_out_ik, v_out_iho, v_out_iwo, s_out_stride_n, s_out_stride_k, s_out_stride_ho, v_tmp + ; load output + + .v_out_load_k_n_8_1 v_gld_b, s_p_out, v_out_os, v_out_flag, v_out_ik, v_out_in, v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, v_out_dslice_iy, v_out_dslice_ix, s_dslice_y_neg, s_dslice_x_neg, s_dtile_dy, s_dtile_dx, s_out_stride_n, s_out_stride_k, s_out_stride_ho, s_ho, s_wo, v_tmp, s_tmp + + v_mov_b32 v[v_dtile_iy], s[s_dtile_iy] + v_mov_b32 v[v_dtile_ix], s[s_dtile_ix] + + ; move slice window + s_mov_b32 s[1], 9 ; unroll 16, but iterate 8 in each load + .v_u32_div_ss v_tmp+4, 1, s_stride_dslice_yx, v_tmp, s_tmp + v_readfirstlane_b32 s[s_move_slice_ik], v[v_tmp+4] + s_mul_i32 s[s_tmp], s[s_stride_dslice_yx], s[s_move_slice_ik] + s_mul_i32 s[s_move_slice_ik_out], s[s_out_stride_k], s[s_move_slice_ik] + s_sub_i32 s[1], s[1], s[s_tmp] + .v_u32_div_ss v_tmp+4, 1, s_dslice_x, v_tmp, s_tmp + v_readfirstlane_b32 s[s_move_slice_idslice_y], v[v_tmp+4] + s_mul_i32 s[s_tmp], s[s_dslice_x], s[s_move_slice_idslice_y] + s_sub_i32 s[s_move_slice_idslice_x], s[1], s[s_tmp] + + .v_u32_div_vs v_wei_ik, v_wei_gemm_ik, s_stride_dslice_yx, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp], s[s_stride_dslice_yx], v[v_wei_ik] + v_sub_u32 v[v_tmp+4], v[v_wei_gemm_ik], v[v_tmp] + .v_u32_div_vs v_wei_dslice_iy, v_tmp+4, s_dslice_x, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp], s[s_dslice_x], v[v_wei_dslice_iy] + v_sub_u32 v[v_wei_dslice_ix], v[v_tmp+4], v[v_tmp] + + ; gemm_m-> wei_ic + v_add_u32 v[v_wei_ic], s[s_block_gemm_im], v[v_wei_gemm_im] + + ; dslice_y -> y, dslice_x -> x + v_mad_u32_u24 v[v_wei_iy], v[v_wei_dslice_iy], s[s_dtile_y], v[v_dtile_iy] + v_mad_u32_u24 v[v_wei_ix], v[v_wei_dslice_ix], s[s_dtile_x], v[v_dtile_ix] + + ; calculate wei offset + s_lshl_b32 s[s_wei_stride_c], s[s_wei_stride_c], 2 + s_lshl_b32 s[s_wei_stride_k], s[s_wei_stride_k], 2 + s_lshl_b32 s[s_wei_stride_y], s[s_x], 2 + s_mul_i32 s[s_move_slice_ik_wei], s[s_wei_stride_k], s[s_move_slice_ik] + v_mul_lo_u32 v[v_wei_ic], s[s_wei_stride_c], v[v_wei_ic] + v_mul_lo_u32 v[v_wei_ik], s[s_wei_stride_k], v[v_wei_ik] + .v_wei_calculate_os v_wei_os, v_wei_ik, v_wei_ic, v_wei_iy, v_wei_ix, s_wei_stride_k, s_wei_stride_c, s_wei_stride_y, v_tmp + .v_wei_load_k_n_8_1 v_gld_a, s_p_wei, v_wei_os, v_wei_ik, v_wei_ic, v_wei_iy, v_wei_ix, v_wei_dslice_iy, v_wei_dslice_ix, v_dtile_iy, v_dtile_ix, s_dslice_y, s_dslice_x, s_dtile_y, s_dtile_x, s_wei_stride_c, s_wei_stride_k, s_wei_stride_y, v_tmp, s_tmp + + ; c thread mapping + v_and_b32 v[v_tmp+4], 15, v0 + v_and_b32 v[v_tmp], 3, v[v_tmp+4] + v_lshrrev_b32 v[v_tmp+1], 2, v[v_tmp+4] + + v_lshrrev_b32 v[v_tmp+4], 4, v0 + v_and_b32 v[v_tmp+2], 3, v[v_tmp+4] + v_lshrrev_b32 v[v_tmp+3], 2, v[v_tmp+4] + + v_lshl_or_b32 v[v_gemm_in], v[v_tmp+2], 2, v[v_tmp] ; in + v_lshl_or_b32 v[v_gemm_im], v[v_tmp+3], 2, v[v_tmp+1] ; im + + v_lshlrev_b32 v[v_sld_b_os], 4, v[v_gemm_in] + v_lshlrev_b32 v[v_sld_a_os], 4, v[v_gemm_im] + v_add_u32 v[v_sld_a_os], 8192, v[v_sld_a_os] + + ; calculate input index, m0, m1, n0, n1 + + v_lshlrev_b32 v[v_tmp+1], 2, v[v_gemm_im] + v_add_u32 v[v_in_ic], s[s_block_gemm_im], v[v_tmp+1] + + v_lshlrev_b32 v[v_tmp+1], 2, v[v_gemm_in] + v_add_u32 v[v_tmp+5], s[s_block_gemm_in], v[v_tmp+1] + v_lshrrev_b32 v[v_in_gemm_in0], 6, v[v_tmp+5] + v_and_b32 v[v_in_gemm_in1], 63, v[v_tmp+5] + + s_lshl_b32 s[s_in_stride_hi], s[s_wi], 2 + s_lshl_b32 s[s_in_stride_c], s[s_in_stride_c], 2 + s_lshl_b32 s[s_in_stride_n], s[s_in_stride_n], 2 + + s_add_u32 s[s_dslice_h_shifted], s[s_dslice_h], s[s_dslice_h_left] + s_add_u32 s[s_dslice_w_shifted], s[s_dslice_w], s[s_dslice_w_left] + + ; input move n0 + s_mov_b32 s[0], 64 ; n0 -> 64 + .v_u32_div_ss v_tmp+4, 0, s_stride_dslice_hw, v_tmp, s_tmp + v_readfirstlane_b32 s[s_move_slice_in_in], v[v_tmp+4] + s_mul_i32 s[s_tmp], s[s_stride_dslice_hw], s[s_move_slice_in_in] + s_sub_i32 s[0], s[0], s[s_tmp] + .v_u32_div_ss v_tmp+4, 0, s_dslice_w, v_tmp, s_tmp + v_readfirstlane_b32 s[s_move_slice_in_dslice_h], v[v_tmp+4] + s_mul_i32 s[s_tmp], s[s_dslice_w], s[s_move_slice_in_dslice_h] + s_sub_i32 s[s_move_slice_in_dslice_w], s[0], s[s_tmp] + + s_mul_i32 s[s_move_slice_in_in], s[s_in_stride_n], s[s_move_slice_in_in] + + ; out lds offset block k_n + v_lshlrev_b32 v[v_tmp], 7, v[v_out_gemm_ik] + v_or_b32 v[v_tmp+1], v[v_tmp], v[v_out_gemm_in] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_tmp+1] + + ; wei lds offset block k_m + v_lshlrev_b32 v[v_tmp], 7, v[v_wei_gemm_ik] + v_or_b32 v[v_tmp+1], v[v_tmp], v[v_wei_gemm_im] + v_lshlrev_b32 v[v_sst_a_os], 2, v[v_tmp+1] + v_add_u32 v[v_sst_a_os], 8192, v[v_sst_a_os] + + .v_clear_nc v_c, 64 + + ; start FMA loop, 8x8 thread tile with 4x4 sub-tile + s_waitcnt vmcnt(8) + .v_out_sst_k_n_8_1 v_gld_b, v_sst_b_os + + s_waitcnt vmcnt(0) + .v_wei_sst_k_n_8_1 v_gld_a, v_sst_a_os + + + ; gemm_k -> k*dslice_y*dslice_x + s_mul_i32 s[s_tmp], s[s_stride_dslice_yx], s[s_k] + s_sub_i32 s[s_kitr], s[s_tmp], 16 + s_cmp_gt_i32 s[s_kitr], 0 + + + s_cbranch_scc0 L_igemm_v4r1_bwd_dynamic_end + .v_out_move_slice_window_v4r1_bwd v_out_ik, v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, v_out_dslice_iy, v_out_dslice_ix, s_out_stride_k, s_move_slice_ik_out, s_move_slice_idslice_y, s_move_slice_idslice_x, s_dslice_y_neg, s_dslice_x_neg, s_dtile_dy, s_dtile_dx, v_tmp, s_tmp + .v_out_calculate_os v_out_os, v_out_in, v_out_ik, v_out_iho, v_out_iwo, s_out_stride_n, s_out_stride_k, s_out_stride_ho, v_tmp + .v_out_set_flag v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo, s_tmp + .v_wei_move_slice_window_v4r1_bwd v_wei_ik, v_wei_ic, v_wei_iy, v_wei_ix, v_wei_dslice_iy, v_wei_dslice_ix, v_dtile_iy, v_dtile_ix, s_wei_stride_k, s_move_slice_ik_wei, s_move_slice_idslice_y, s_move_slice_idslice_x, s_dslice_y, s_dslice_x, s_dtile_y, s_dtile_x, v_tmp, s_tmp + .v_wei_calculate_os v_wei_os, v_wei_ik, v_wei_ic, v_wei_iy, v_wei_ix, s_wei_stride_k, s_wei_stride_c, s_wei_stride_y, v_tmp + + v_xor_b32 v[v_sst_b_os], 0x4000, v[v_sst_b_os] ; switch double buffer b store + v_xor_b32 v[v_sst_a_os], 0x4000, v[v_sst_a_os] ; switch double buffer a store + s_waitcnt lgkmcnt(0) + s_barrier + + .v_out_load_k_n_8_1 v_gld_b, s_p_out, v_out_os, v_out_flag, v_out_ik, v_out_in, v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, v_out_dslice_iy, v_out_dslice_ix, s_dslice_y_neg, s_dslice_x_neg, s_dtile_dy, s_dtile_dx, s_out_stride_n, s_out_stride_k, s_out_stride_ho, s_ho, s_wo, v_tmp, s_tmp + .v_wei_load_k_n_8_1 v_gld_a, s_p_wei, v_wei_os, v_wei_ik, v_wei_ic, v_wei_iy, v_wei_ix, v_wei_dslice_iy, v_wei_dslice_ix, v_dtile_iy, v_dtile_ix, s_dslice_y, s_dslice_x, s_dtile_y, s_dtile_x, s_wei_stride_c, s_wei_stride_k, s_wei_stride_y, v_tmp, s_tmp + +L_igemm_v4r1_bwd_dynamic_fma_body: + ; do fma accumulate with unroll 16 + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:256 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:256 + .itr_k = 0 + .rept 15 + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] offset:0+(.itr_k+1)*512 + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512 + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512+256 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:0+(.itr_k+1)*512+256 + .itr_k = .itr_k + 1 + .endr + + ; last unroll + v_xor_b32 v[v_sld_b_os], 0x4000, v[v_sld_b_os] ; switch double buffer b load + v_xor_b32 v[v_sld_a_os], 0x4000, v[v_sld_a_os] ; switch double buffer a load + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + + s_waitcnt vmcnt(8) + .v_out_sst_k_n_8_1 v_gld_b, v_sst_b_os + s_waitcnt vmcnt(0) + .v_wei_sst_k_n_8_1 v_gld_a, v_sst_a_os + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_v4r1_bwd_dynamic_fma_finishing + .v_out_move_slice_window_v4r1_bwd v_out_ik, v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, v_out_dslice_iy, v_out_dslice_ix, s_out_stride_k, s_move_slice_ik_out, s_move_slice_idslice_y, s_move_slice_idslice_x, s_dslice_y_neg, s_dslice_x_neg, s_dtile_dy, s_dtile_dx, v_tmp, s_tmp + .v_out_calculate_os v_out_os, v_out_in, v_out_ik, v_out_iho, v_out_iwo, s_out_stride_n, s_out_stride_k, s_out_stride_ho, v_tmp + .v_out_set_flag v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo, s_tmp + .v_wei_move_slice_window_v4r1_bwd v_wei_ik, v_wei_ic, v_wei_iy, v_wei_ix, v_wei_dslice_iy, v_wei_dslice_ix, v_dtile_iy, v_dtile_ix, s_wei_stride_k, s_move_slice_ik_wei, s_move_slice_idslice_y, s_move_slice_idslice_x, s_dslice_y, s_dslice_x, s_dtile_y, s_dtile_x, v_tmp, s_tmp + .v_wei_calculate_os v_wei_os, v_wei_ik, v_wei_ic, v_wei_iy, v_wei_ix, s_wei_stride_k, s_wei_stride_c, s_wei_stride_y, v_tmp + + s_waitcnt lgkmcnt(8) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + + v_xor_b32 v[v_sst_b_os], 0x4000, v[v_sst_b_os] ; switch double buffer b store + v_xor_b32 v[v_sst_a_os], 0x4000, v[v_sst_a_os] ; switch double buffer a store + s_waitcnt lgkmcnt(0) + s_barrier + .v_out_load_k_n_8_1 v_gld_b, s_p_out, v_out_os, v_out_flag, v_out_ik, v_out_in, v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, v_out_dslice_iy, v_out_dslice_ix, s_dslice_y_neg, s_dslice_x_neg, s_dtile_dy, s_dtile_dx, s_out_stride_n, s_out_stride_k, s_out_stride_ho, s_ho, s_wo, v_tmp, s_tmp + .v_wei_load_k_n_8_1 v_gld_a, s_p_wei, v_wei_os, v_wei_ik, v_wei_ic, v_wei_iy, v_wei_ix, v_wei_dslice_iy, v_wei_dslice_ix, v_dtile_iy, v_dtile_ix, s_dslice_y, s_dslice_x, s_dtile_y, s_dtile_x, s_wei_stride_c, s_wei_stride_k, s_wei_stride_y, v_tmp, s_tmp + + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + s_branch L_igemm_v4r1_bwd_dynamic_fma_body +L_igemm_v4r1_bwd_dynamic_fma_finishing: + s_waitcnt lgkmcnt(8) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 +L_igemm_v4r1_bwd_dynamic_end: + + s_waitcnt lgkmcnt(0) + s_barrier + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:256 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:256 + .itr_k = 0 + .rept 15 + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] offset:0+(.itr_k+1)*512 + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512 + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512+256 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:0+(.itr_k+1)*512+256 + .itr_k = .itr_k + 1 + .endr + + ; last unroll + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + + s_waitcnt lgkmcnt(0) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + s_mul_i32 s[s_tmp], s[s_dtile_iy], s[s_dilation_h] + v_mov_b32 v[v_dtile_iy_x_dilation_h], s[s_tmp] + s_mul_i32 s[s_tmp+1], s[s_dtile_ix], s[s_dilation_w] + v_mov_b32 v[v_dtile_ix_x_dilation_w], s[s_tmp+1] + + v_lshl_or_b32 v[v_in_gemm_in], v[v_in_gemm_in0], 6, v[v_in_gemm_in1] + .v_in_transform_gemm_n v_in_in, v_in_ihi, v_in_iwi, v_in_gemm_in, v_in_dslice_h, v_in_dslice_w, s_stride_dslice_hw, s_dslice_w, s_dtile_iy, s_dtile_ix, s_dslice_h_left, s_dslice_w_left, s_dilation_h, s_dilation_w, s_stride_h, s_stride_w, s_pad_h, s_pad_w, v_tmp, s_tmp + + v_mul_lo_u32 v[v_in_in], s[s_in_stride_n], v[v_in_in] + v_mul_lo_u32 v[v_in_ic], s[s_in_stride_c], v[v_in_ic] + .v_in_calculate_os v_in_os, v_in_in, v_in_ic, v_in_ihi, v_in_iwi, s_in_stride_n, s_in_stride_c, s_in_stride_hi, v_tmp + + .v_in_set_flag v_in_flag, v_in_ihi, v_in_iwi, s_hi, s_wi, s_tmp + + s_lshl_b32 s[s_block_gemm_in], s[s_in_stride_c], 6 + .v_in_write_m0_m1_n0_n1_step v_c, s_p_in, v_in_os, v_in_flag, v_in_ihi_itr, v_in_iwi_itr, v_in_in_itr, v_in_dslice_h_itr, v_in_dslice_w_itr, v_in_ic_itr, v_in_in, v_in_ic, v_in_ihi, v_in_iwi, v_in_dslice_h, v_in_dslice_w, v_dtile_iy_x_dilation_h, v_dtile_ix_x_dilation_w, s_move_slice_in_in, s_move_slice_in_dslice_h, s_move_slice_in_dslice_w, s_dslice_h_left, s_dslice_w_left, s_dslice_h_shifted, s_dslice_w_shifted, s_dslice_h, s_dslice_w, s_dtile_iy, s_dtile_ix, s_dilation_h, s_dilation_w, s_stride_h, s_stride_w, s_pad_h, s_pad_w, s_in_stride_n, s_in_stride_c, s_block_gemm_in, s_in_stride_hi, s_hi, s_wi, v_tmp, s_tmp, 2, 4, 2, 4 + s_endpgm +.rodata +.p2align 6 +.amdhsa_kernel igemm_bwd_gtc + .amdhsa_group_segment_fixed_size 32768 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 128 + .amdhsa_next_free_sgpr 72 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 +.end_amdhsa_kernel +;---------------------------------------------------------- +; starting of kernel igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x2x4_16x1x1x16x1_tb1x1x1x2x4x1x1_16x1x1x16x1x1x1 +; gemm_m_per_block : 128 +; gemm_n_per_block : 128 +; gemm_k_per_block : 16 +; gemm_m_per_thread : 4 +; gemm_m_level0_cluster : 4 +; gemm_m_level1_cluster : 4 +; gemm_n_per_thread : 4 +; gemm_n_level0_cluster : 4 +; gemm_n_level1_cluster : 4 +; tensor_a_thread_lengths : [1, 1, 1, 2, 4] +; tensor_a_cluster_lengths : [16, 1, 1, 16, 1] +; tensor_b_thread_lengths : [1, 1, 1, 2, 4, 1, 1] +; tensor_b_cluster_lengths : [16, 1, 1, 16, 1, 1, 1] +; direction : bwd +; precision : fp32 +; opt_1x1 : 0 +; +; block_size : 256 +; thread_tile : 8x8 +; lds_total : 32768 +; +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_dtile_iy, 84 +.set k_dtile_ix, 88 +.set k_dtile_dy, 92 +.set k_dtile_dx, 96 +.set k_dtile_y, 100 +.set k_dtile_x, 104 +.set k_dtile_h, 108 +.set k_dtile_w, 112 +.set k_dslice_y, 116 +.set k_dslice_x, 120 +.set k_dslice_h, 124 +.set k_dslice_w, 128 +.set k_dslice_h_left, 132 +.set k_dslice_w_left, 136 +.set k_pack0, 140 +.set k_end, 144 + +.set s_ka, 0 +.set s_bx, 2 +.set s_p_in, 4 +.set s_p_wei, 8 +.set s_p_out, 12 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_dtile_iy, 31 +.set s_dtile_ix, 32 +.set s_dtile_dy, 33 +.set s_dtile_dx, 34 +.set s_dtile_y, 35 +.set s_dtile_x, 36 +.set s_dtile_h, 37 +.set s_dtile_w, 38 +.set s_dslice_y, 39 +.set s_dslice_x, 40 +.set s_dslice_h, 41 +.set s_dslice_w, 42 +.set s_dslice_h_left, 43 +.set s_dslice_w_left, 44 +.set s_out_stride_k, 45 +.set s_out_stride_k0, 46 +.set s_out_stride_n, 47 +.set s_out_stride_n0, 48 +.set s_out_stride_b0, 49 +.set s_out_move_slice_stride_k, 50 +.set s_in_stride_c, 51 +.set s_in_stride_cr, 52 +.set s_in_stride_n, 53 +.set s_in_stride_nr, 54 +.set s_wei_stride_c, 55 +.set s_wei_stride_c0, 56 +.set s_wei_stride_k, 57 +.set s_wei_stride_k0, 58 +.set s_wei_move_slice_stride_k, 59 +.set s_stride_dslice_hw, 60 +.set s_stride_dslice_yx, 61 +.set s_block_gtc_ib, 62 +.set s_block_gtc_ic, 63 +.set s_block_gtc_in, 64 +.set s_knum, 19 +.set s_move_slice_k_idsy, 0 +.set s_move_slice_k_idsx, 1 +.set s_move_slice_k_ik, 2 +.set s_kitr, 3 +.set s_tmp, 66 +.set s_end, 72 + +.set v_c, 0 +.set v_a, 64 +.set v_b, 72 +.set v_gld_a, 80 +.set v_gld_b, 88 +.set v_sst_a_os, 96 +.set v_sst_b_os, 97 +.set v_sld_a_os, 98 +.set v_sld_b_os, 99 +.set v_out_iho, 100 +.set v_out_iwo, 101 +.set v_out_dslice_ih, 102 +.set v_out_dslice_iw, 103 +.set v_out_os, 104 +.set v_out_os_base, 105 +.set v_wei_iy, 106 +.set v_wei_ix, 107 +.set v_dtile_iy, 108 +.set v_dtile_ix, 109 +.set v_wei_os, 110 +.set v_wei_os_base, 111 +.set v_out_flag, 112 +.set v_in_flag, 113 +.set v_in_os, 114 +.set v_gtc_ic0, 63 +.set v_gtc_ic1, 62 +.set v_gtc_in0, 61 +.set v_gtc_in1, 60 +.set v_gtc_ib0, 59 +.set v_gtc_ib1, 58 +.set v_gtc_ik0, 57 +.set v_gtc_ik1, 56 +.set v_gtc_ie, 55 +.set v_gemm_in, 54 +.set v_gemm_im, 53 +.set v_in_in0, 52 +.set v_in_in1, 51 +.set v_in_ib0, 50 +.set v_in_ib1, 49 +.set v_in_ic0, 48 +.set v_in_ic1, 47 +.set v_in_ihi, 46 +.set v_in_iwi, 45 +.set v_in_dslice_ih, 44 +.set v_in_dslice_iw, 43 +.set v_tmp, 116 +.set v_end, 122 + +.text +.globl igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x2x4_16x1x1x16x1_tb1x1x1x2x4x1x1_16x1x1x16x1x1x1 +.p2align 8 +.type igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x2x4_16x1x1x16x1_tb1x1x1x2x4x1x1_16x1x1x16x1x1x1,@function +igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x2x4_16x1x1x16x1_tb1x1x1x2x4x1x1_16x1x1x16x1x1x1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx2 s[s_p_wei+0:s_p_wei+1], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx2 s[s_p_out+0:s_p_out+1], s[s_ka+0:s_ka+1], 0+k_p_out + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx8 s[s_dtile_ix+0:s_dtile_ix+7], s[s_ka+0:s_ka+1], 0+k_dtile_ix + s_load_dwordx4 s[s_dslice_x+0:s_dslice_x+3], s[s_ka+0:s_ka+1], 0+k_dslice_x + s_load_dword s[s_dslice_w_left], s[s_ka+0:s_ka+1], 0+k_dslice_w_left + v_mov_b32 v[v_tmp], v0 + ; output: K0xK1xExN0xN1xB0xB1: 1x1x1x2x4x1x1, slice 16x1x1x16x1x1x1 + v_and_b32 v[v_gtc_in0], 15, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 4, v[v_tmp] + v_lshlrev_b32 v[v_gtc_in0], 1, v[v_gtc_in0] + v_and_b32 v[v_gtc_ik0], 15, v[v_tmp] + + v_mov_b32 v[v_tmp], v0 + ; wei: K0xK1xExC0xC1: 1x1x1x2x4, slice 16x1x1x16x1 + v_and_b32 v[v_gtc_ic0], 15, v[v_tmp] + v_lshlrev_b32 v[v_gtc_ic0], 1, v[v_gtc_ic0] + + s_mov_b32 s[s_p_in + 2], 0xffffffff + s_mov_b32 s[s_p_in + 3], 0x27000 + s_mov_b32 s[s_p_wei + 2], 0xffffffff + s_mov_b32 s[s_p_wei + 3], 0x27000 + s_mov_b32 s[s_p_out + 2], 0xffffffff + s_mov_b32 s[s_p_out + 3], 0x27000 + s_waitcnt lgkmcnt(0) + + ; calculate index + s_mul_i32 s[s_out_stride_k], s[s_ho], s[s_wo] + s_mul_i32 s[s_out_stride_n], s[s_k], s[s_out_stride_k] + s_mul_i32 s[s_in_stride_c], s[s_hi], s[s_wi] + s_mul_i32 s[s_in_stride_n], s[s_c], s[s_in_stride_c] + s_mul_i32 s[s_wei_stride_c], s[s_y], s[s_x] + s_mul_i32 s[s_wei_stride_k], s[s_c], s[s_wei_stride_c] + s_mul_i32 s[s_stride_dslice_hw], s[s_dslice_h], s[s_dslice_w] + s_mul_i32 s[s_stride_dslice_yx], s[s_dslice_y], s[s_dslice_x] + s_lshl_b32 s[s_out_stride_n0], s[s_out_stride_n], 2 + s_lshl_b32 s[s_wei_stride_c0], s[s_wei_stride_c], 2 + + ; N0xN1xB0xB1, gemm_m_per_block:128, gemm_n_per_block:128 + s_mul_i32 s[s_tmp], s[s_stride_dslice_hw], s[s_n] + s_lshr_b32 s[0], s[s_tmp], 7 + .v_u32_div_ss v_tmp+5, s_bx, 0, v_tmp, s_tmp + v_readfirstlane_b32 s[s_tmp+4], v[v_tmp+5] ; gemm_m, + s_mul_i32 s[s_tmp+2], s[s_tmp+4], s[0] + s_sub_i32 s[s_tmp+5], s[s_bx], s[s_tmp+2] ; gemm_n, cnt + s_lshl_b32 s[s_block_gtc_ic], s[s_tmp+4], 7 + s_mov_b32 s[0], s[s_stride_dslice_hw] ; B:1 per block, total num of B + .v_u32_div_ss v_tmp+5, s_tmp+5, 0, v_tmp, s_tmp + v_readfirstlane_b32 s[s_tmp], v[v_tmp+5] ; => N + s_mul_i32 s[s_tmp+2], s[s_tmp], s[0] + s_sub_i32 s[s_tmp+1], s[s_tmp+5], s[s_tmp+2] ; => B + s_lshl_b32 s[s_block_gtc_in], s[s_tmp], 7 + s_mov_b32 s[s_block_gtc_ib], s[s_tmp+1] + + v_mov_b32 v[v_gtc_ib1], 0 + v_lshlrev_b32 v[v_gtc_in1], 2, v[v_gtc_in0] + v_mov_b32 v[v_gtc_ik1], v[v_gtc_ik0] + + v_add_u32 v[v_tmp+5], s[s_block_gtc_ib], v[v_gtc_ib1] + ; calculate output transform, B -> dslice_h*dslice_w. always use ib1 + .v_u32_div_vs v_out_dslice_ih, v_tmp+5, s_dslice_w, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp], s[s_dslice_w], v[v_out_dslice_ih] + v_sub_u32 v[v_out_dslice_iw], v[v_tmp+5], v[v_tmp] + ; iHTildaLeft, iWTildaLeft + v_add_u32 v[v_out_dslice_ih], s[s_dslice_h_left], v[v_out_dslice_ih] + v_add_u32 v[v_out_dslice_iw], s[s_dslice_w_left], v[v_out_dslice_iw] + ; dslice_y,dslice_h -> oh, dslice_x,dslice_w -> ow + v_mov_b32 v[v_out_iho], v[v_out_dslice_ih] + v_mov_b32 v[v_out_iwo], v[v_out_dslice_iw] + v_mul_lo_u32 v[v_tmp], s[s_out_stride_k], v[v_gtc_ik1] + v_mul_lo_u32 v[v_tmp+1], s[s_out_stride_n], v[v_gtc_in1] + v_add_lshl_u32 v[v_out_os_base], v[v_tmp], v[v_tmp+1], 2 + ; n to staticly accumulate into base pointer + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_in], 2 + s_mul_i32 s[s_tmp], s[s_out_stride_n], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_out_stride_n], s[s_tmp+3] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], s[s_tmp+1] + + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + + s_lshl_b32 s[s_out_stride_n0], s[s_out_stride_n0], 2 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 2 + + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_2x4_b32_v1 v_gld_b, s_p_out, v_out_os, s_out_stride_n0, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + + v_lshlrev_b32 v[v_gtc_ic1], 2, v[v_gtc_ic0] + v_add_u32 v[v_tmp+5], s[s_block_gtc_ic], v[v_gtc_ic1] + v_mov_b32 v[v_dtile_iy], s[s_dtile_iy] + v_mov_b32 v[v_dtile_ix], s[s_dtile_ix] + v_mov_b32 v[v_wei_iy], s[s_dtile_iy] + v_mov_b32 v[v_wei_ix], s[s_dtile_ix] + + ; calculate wei offset + v_mul_lo_u32 v[v_tmp], s[s_wei_stride_c], v[v_tmp+5] + v_mul_lo_u32 v[v_tmp+1], s[s_wei_stride_k], v[v_gtc_ik1] + v_add_lshl_u32 v[v_wei_os_base], v[v_tmp], v[v_tmp+1], 2 + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + + s_lshl_b32 s[s_wei_stride_c0], s[s_wei_stride_c0], 2 + s_lshl_b32 s[s_wei_stride_c], s[s_wei_stride_c], 2 + + ; load weight + .v_gld_2x4_b32_v1 v_gld_a, s_p_wei, v_wei_os, s_wei_stride_c0, s_wei_stride_c, s_tmp + + ; c thread mapping + ; -> MR x NR x ML1 x NL1 x ML0 x NL0 + ; cluster 1 x 1 x 4 x 4 x 4 x 4 + ; perthrd 2 x 2 x 1 x 1 x 4 x 4 + v_and_b32 v[v_tmp], 3, v0 + v_lshlrev_b32 v[v_tmp], 2, v[v_tmp] ; => iNL0 + v_lshrrev_b32 v[v_tmp+5], 2, v0 + v_and_b32 v[v_tmp+1], 3, v[v_tmp+5] + v_lshlrev_b32 v[v_tmp+1], 2, v[v_tmp+1] ; => iML0 + v_lshrrev_b32 v[v_tmp+5], 2, v[v_tmp+5] + v_and_b32 v[v_tmp+2], 3, v[v_tmp+5] ; => iNL1 + v_lshrrev_b32 v[v_tmp+5], 2, v[v_tmp+5] + v_and_b32 v[v_tmp+3], 3, v[v_tmp+5] ; => iML1 + v_lshl_or_b32 v[v_gemm_in], v[v_tmp+2], 4, v[v_tmp] ; in (without repeat) + v_lshl_or_b32 v[v_gemm_im], v[v_tmp+3], 4, v[v_tmp+1] ; im (without repeat) + + ; remapping, gemm_im => C0xC1:32x4, gemm_in => N0xB0xB1xN1:32x1x1x4 + v_and_b32 v[v_in_in1], 3, v[v_gemm_in] + v_lshrrev_b32 v[v_tmp], 2, v[v_gemm_in] + v_mov_b32 v[v_in_ib1], 0 + v_mov_b32 v[v_in_ib0], 0 + v_and_b32 v[v_in_in0], 31, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 5, v[v_tmp] + v_and_b32 v[v_in_ic1], 3, v[v_gemm_im] + v_lshrrev_b32 v[v_tmp], 2, v[v_gemm_im] + v_and_b32 v[v_in_ic0], 31, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 5, v[v_tmp] + + v_lshl_or_b32 v[v_in_ib1], v[v_in_ib0], 0, v[v_in_ib1] + v_lshl_or_b32 v[v_in_in1], v[v_in_in0], 2, v[v_in_in1] + v_lshl_or_b32 v[v_in_ic1], v[v_in_ic0], 2, v[v_in_ic1] + v_add_u32 v[v_in_ib1], s[s_block_gtc_ib], v[v_in_ib1] + + .v_u32_div_vs v_in_dslice_ih, v_in_ib1, s_dslice_w, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp+1], s[s_dslice_w], v[v_in_dslice_ih] + v_sub_u32 v[v_in_dslice_iw], v[v_in_ib1], v[v_tmp+1] + v_add_u32 v[v_in_dslice_ih], s[s_dslice_h_left], v[v_in_dslice_ih] + v_add_u32 v[v_in_dslice_iw], s[s_dslice_w_left], v[v_in_dslice_iw] + + ; dslice_h,dslice_y -> hip, dslice_w,dslicw_x -> wip + s_mul_i32 s[s_tmp], s[s_dtile_iy], s[s_dilation_h] + v_mul_lo_u32 v[v_tmp], s[s_stride_h], v[v_in_dslice_ih] + v_add_u32 v[v_tmp], s[s_tmp], v[v_tmp] + s_mul_i32 s[s_tmp+1], s[s_dtile_ix], s[s_dilation_w] + v_mul_lo_u32 v[v_tmp+1], s[s_stride_w], v[v_in_dslice_iw] + v_add_u32 v[v_tmp+1], s[s_tmp+1], v[v_tmp+1] + ; v_tmp: hip, v_tmp+1: wip + + ; hip->h, wip->w + v_sub_i32 v[v_in_ihi], v[v_tmp], s[s_pad_h] + v_sub_i32 v[v_in_iwi], v[v_tmp+1], s[s_pad_w] + + .v_set_flag_hw v_in_flag, v_in_ihi, v_in_iwi, s_hi, s_wi + + ; input offset + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_in], 2 + s_mul_i32 s[s_tmp], s[s_in_stride_n], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_in_stride_n], s[s_tmp+3] + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], s[s_tmp+1] + + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_ic], 2 + s_mul_i32 s[s_tmp], s[s_in_stride_c], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_in_stride_c], s[s_tmp+3] + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], s[s_tmp+1] + v_mul_lo_u32 v[v_tmp], s[s_in_stride_n], v[v_in_in1] + v_mul_lo_u32 v[v_tmp+1], s[s_in_stride_c], v[v_in_ic1] + v_add_u32 v[v_tmp], v[v_tmp], v[v_tmp+1] + v_mul_lo_u32 v[v_tmp+1], s[s_wi], v[v_in_ihi] + v_add3_u32 v[v_in_os], v[v_tmp], v[v_tmp+1], v[v_in_iwi] + v_lshlrev_b32 v[v_in_os], 2, v[v_in_os] + + s_lshl_b32 s[s_in_stride_n], s[s_in_stride_n], 2 + s_lshl_b32 s[s_in_stride_nr], s[s_in_stride_n], 6 + + s_lshl_b32 s[s_in_stride_c], s[s_in_stride_c], 2 + s_lshl_b32 s[s_in_stride_cr], s[s_in_stride_c], 6 + + v_lshrrev_b32 v[v_gtc_in0], 2, v[v_gtc_in1] + ; LDS store, order out: K0xK1xExN0xB0xB1xN1: 1x1x1x2x1x1x4, 16x1x1x16x1x1x1 + v_lshlrev_b32 v[v_tmp], 2, v[v_gtc_ib1] + v_lshl_or_b32 v[v_tmp], v[v_gtc_in0], 2, v[v_tmp] + v_lshl_or_b32 v[v_tmp+1], v[v_gtc_ik1], 7, v[v_tmp] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_tmp+1] + + ; LDS store, order wei not shuffled, wei: K0xK1xExC0xC1: 1x1x1x2x4, 16x1x1x16x1 + v_lshl_or_b32 v[v_tmp], v[v_gtc_ik1], 7, v[v_gtc_ic1] + v_lshlrev_b32 v[v_sst_a_os], 2, v[v_tmp] + v_add_u32 v[v_sst_b_os], 8192, v[v_sst_b_os] + + ; LDS load + v_lshlrev_b32 v[v_sld_b_os], 2, v[v_gemm_in] + v_lshlrev_b32 v[v_sld_a_os], 2, v[v_gemm_im] + v_add_u32 v[v_sld_b_os], 8192, v[v_sld_b_os] + + s_mul_i32 s[s_knum], s[s_stride_dslice_yx], s[s_k] + s_lshl_b32 s[s_out_move_slice_stride_k], s[s_out_stride_k], 6 ; 16x + s_lshl_b32 s[s_wei_move_slice_stride_k], s[s_wei_stride_k], 6 ; 16x + s_mov_b32 s[s_move_slice_k_ik], 0 + s_mov_b32 s[s_move_slice_k_idsy], 0 + s_mov_b32 s[s_move_slice_k_idsx], 0 + + .v_clear_nc v_c, 64 + ; start FMA loop, 8x8 thread tile with 4x4 sub-tile + s_waitcnt vmcnt(8) + .v_sst_so0_2x4_b32_v4_st16 v_gld_b, v_sst_b_os + + s_waitcnt vmcnt(0) + .v_sst_so0_2x4_b32_v4_st16 v_gld_a, v_sst_a_os + + s_sub_i32 s[s_kitr], s[s_knum], 16 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x2x4_16x1x1x16x1_tb1x1x1x2x4x1x1_16x1x1x16x1x1x1_end + + .s_bwd_gtc_move_slice_window_k_dsy_dsx s_move_slice_k_ik, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dslice_y, s_dslice_x, 16, 1, 1, v_out_os_base, v_wei_os_base, s_out_move_slice_stride_k, s_wei_move_slice_stride_k + .v_bwd_gtc_out_update_hw v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_dy, s_dtile_dx, s_tmp + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + .v_bwd_gtc_wei_update_yx v_wei_iy, v_wei_ix, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_y, s_dtile_x, v_dtile_iy, v_dtile_ix, s_tmp + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + v_xor_b32 v[v_sst_b_os], 0x4000, v[v_sst_b_os] ; switch double buffer b store + v_xor_b32 v[v_sst_a_os], 0x4000, v[v_sst_a_os] ; switch double buffer a store + s_waitcnt lgkmcnt(0) + s_barrier + + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_2x4_b32_v1 v_gld_b, s_p_out, v_out_os, s_out_stride_n0, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + ; load weight + .v_gld_2x4_b32_v1 v_gld_a, s_p_wei, v_wei_os, s_wei_stride_c0, s_wei_stride_c, s_tmp + +L_igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x2x4_16x1x1x16x1_tb1x1x1x2x4x1x1_16x1x1x16x1x1x1_fma_body: + ; do fma accumulate with unroll 16 + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:256 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:256 + .itr_k = 0 + .rept 15 + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] offset:0+(.itr_k+1)*512 + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512 + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512+256 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:0+(.itr_k+1)*512+256 + .itr_k = .itr_k + 1 + .endr + + ; last unroll + v_xor_b32 v[v_sld_b_os], 16384, v[v_sld_b_os] ; switch double buffer b load + v_xor_b32 v[v_sld_a_os], 16384, v[v_sld_a_os] ; switch double buffer a load + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + s_waitcnt vmcnt(8) + .v_sst_so0_2x4_b32_v4_st16 v_gld_b, v_sst_b_os + s_waitcnt vmcnt(0) + .v_sst_so0_2x4_b32_v4_st16 v_gld_a, v_sst_a_os + s_sub_i32 s[s_kitr], s[s_kitr], 16 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x2x4_16x1x1x16x1_tb1x1x1x2x4x1x1_16x1x1x16x1x1x1_fma_finishing + .s_bwd_gtc_move_slice_window_k_dsy_dsx s_move_slice_k_ik, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dslice_y, s_dslice_x, 16, 1, 1, v_out_os_base, v_wei_os_base, s_out_move_slice_stride_k, s_wei_move_slice_stride_k + .v_bwd_gtc_out_update_hw v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_dy, s_dtile_dx, s_tmp + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + .v_bwd_gtc_wei_update_yx v_wei_iy, v_wei_ix, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_y, s_dtile_x, v_dtile_iy, v_dtile_ix, s_tmp + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + s_waitcnt lgkmcnt(4) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + v_xor_b32 v[v_sst_b_os], 16384, v[v_sst_b_os] ; switch double buffer b store + v_xor_b32 v[v_sst_a_os], 16384, v[v_sst_a_os] ; switch double buffer a store + s_waitcnt lgkmcnt(0) + s_barrier + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_2x4_b32_v1 v_gld_b, s_p_out, v_out_os, s_out_stride_n0, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + ; load weight + .v_gld_2x4_b32_v1 v_gld_a, s_p_wei, v_wei_os, s_wei_stride_c0, s_wei_stride_c, s_tmp + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + s_branch L_igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x2x4_16x1x1x16x1_tb1x1x1x2x4x1x1_16x1x1x16x1x1x1_fma_body +L_igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x2x4_16x1x1x16x1_tb1x1x1x2x4x1x1_16x1x1x16x1x1x1_fma_finishing: + s_waitcnt lgkmcnt(4) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 +L_igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x2x4_16x1x1x16x1_tb1x1x1x2x4x1x1_16x1x1x16x1x1x1_end: + s_waitcnt lgkmcnt(0) + s_barrier + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:256 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:256 + .itr_k = 0 + .rept 15 + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] offset:0+(.itr_k+1)*512 + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512 + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512+256 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:0+(.itr_k+1)*512+256 + .itr_k = .itr_k + 1 + .endr + + ; last unroll + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + s_waitcnt lgkmcnt(0) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + v_cmpx_eq_u32 vcc, 1, v[v_in_flag] + s_cbranch_execz L_igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x2x4_16x1x1x16x1_tb1x1x1x2x4x1x1_16x1x1x16x1x1x1_out + s_mov_b32 s[s_tmp], 0 + s_mov_b32 s[s_tmp+1], 0 + s_mov_b32 s[s_tmp+2], 0 + s_mov_b32 s[s_tmp+3], 0 + .v_write4d_strided v_c,s_p_in,v_in_os,s_in_stride_n,s_in_stride_nr,s_in_stride_c,s_in_stride_cr,s_tmp,4,2,4,2 +L_igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x2x4_16x1x1x16x1_tb1x1x1x2x4x1x1_16x1x1x16x1x1x1_out: + s_endpgm +.rodata +.p2align 6 +.amdhsa_kernel igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x2x4_16x1x1x16x1_tb1x1x1x2x4x1x1_16x1x1x16x1x1x1 + .amdhsa_group_segment_fixed_size 32768 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 122 + .amdhsa_next_free_sgpr 78 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 +.end_amdhsa_kernel + +;---------------------------------------------------------- +; starting of kernel igemm_bwd_gtc_bt64x64x8_tt8x8_gm2x4x2_gn2x4x2_ta1x2x1x1x4_4x1x1x16x1_tb1x2x1x1x4x1x1_4x1x1x16x1x1x1 +; gemm_m_per_block : 64 +; gemm_n_per_block : 64 +; gemm_k_per_block : 8 +; gemm_m_per_thread : 4 +; gemm_m_level0_cluster : 4 +; gemm_m_level1_cluster : 2 +; gemm_n_per_thread : 4 +; gemm_n_level0_cluster : 4 +; gemm_n_level1_cluster : 2 +; tensor_a_thread_lengths : [1, 2, 1, 1, 4] +; tensor_a_cluster_lengths : [4, 1, 1, 16, 1] +; tensor_b_thread_lengths : [1, 2, 1, 1, 4, 1, 1] +; tensor_b_cluster_lengths : [4, 1, 1, 16, 1, 1, 1] +; direction : bwd +; precision : fp32 +; opt_1x1 : 0 +; +; block_size : 64 +; thread_tile : 8x8 +; lds_total : 8192 +; +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_dtile_iy, 84 +.set k_dtile_ix, 88 +.set k_dtile_dy, 92 +.set k_dtile_dx, 96 +.set k_dtile_y, 100 +.set k_dtile_x, 104 +.set k_dtile_h, 108 +.set k_dtile_w, 112 +.set k_dslice_y, 116 +.set k_dslice_x, 120 +.set k_dslice_h, 124 +.set k_dslice_w, 128 +.set k_dslice_h_left, 132 +.set k_dslice_w_left, 136 +.set k_pack0, 140 +.set k_end, 144 + +.set s_ka, 0 +.set s_bx, 2 +.set s_p_in, 4 +.set s_p_wei, 8 +.set s_p_out, 12 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_dtile_iy, 31 +.set s_dtile_ix, 32 +.set s_dtile_dy, 33 +.set s_dtile_dx, 34 +.set s_dtile_y, 35 +.set s_dtile_x, 36 +.set s_dtile_h, 37 +.set s_dtile_w, 38 +.set s_dslice_y, 39 +.set s_dslice_x, 40 +.set s_dslice_h, 41 +.set s_dslice_w, 42 +.set s_dslice_h_left, 43 +.set s_dslice_w_left, 44 +.set s_out_stride_k, 45 +.set s_out_stride_k0, 46 +.set s_out_stride_n, 47 +.set s_out_stride_n0, 48 +.set s_out_stride_b0, 49 +.set s_out_move_slice_stride_k, 50 +.set s_in_stride_c, 51 +.set s_in_stride_cr, 52 +.set s_in_stride_n, 53 +.set s_in_stride_nr, 54 +.set s_wei_stride_c, 55 +.set s_wei_stride_c0, 56 +.set s_wei_stride_k, 57 +.set s_wei_stride_k0, 58 +.set s_wei_move_slice_stride_k, 59 +.set s_stride_dslice_hw, 60 +.set s_stride_dslice_yx, 61 +.set s_block_gtc_ib, 62 +.set s_block_gtc_ic, 63 +.set s_block_gtc_in, 64 +.set s_knum, 19 +.set s_move_slice_k_idsy, 0 +.set s_move_slice_k_idsx, 1 +.set s_move_slice_k_ik, 2 +.set s_kitr, 3 +.set s_tmp, 66 +.set s_end, 72 + +.set v_c, 0 +.set v_a, 64 +.set v_b, 72 +.set v_gld_a, 80 +.set v_gld_b, 88 +.set v_sst_a_os, 96 +.set v_sst_b_os, 97 +.set v_sld_a_os, 98 +.set v_sld_b_os, 99 +.set v_out_iho, 100 +.set v_out_iwo, 101 +.set v_out_dslice_ih, 102 +.set v_out_dslice_iw, 103 +.set v_out_os, 104 +.set v_out_os_base, 105 +.set v_wei_iy, 106 +.set v_wei_ix, 107 +.set v_dtile_iy, 108 +.set v_dtile_ix, 109 +.set v_wei_os, 110 +.set v_wei_os_base, 111 +.set v_out_flag, 112 +.set v_in_flag, 113 +.set v_in_os, 114 +.set v_gtc_ic0, 63 +.set v_gtc_ic1, 62 +.set v_gtc_in0, 61 +.set v_gtc_in1, 60 +.set v_gtc_ib0, 59 +.set v_gtc_ib1, 58 +.set v_gtc_ik0, 57 +.set v_gtc_ik1, 56 +.set v_gtc_ie, 55 +.set v_gemm_in, 54 +.set v_gemm_im, 53 +.set v_in_in0, 52 +.set v_in_in1, 51 +.set v_in_ib0, 50 +.set v_in_ib1, 49 +.set v_in_ic0, 48 +.set v_in_ic1, 47 +.set v_in_ihi, 46 +.set v_in_iwi, 45 +.set v_in_dslice_ih, 44 +.set v_in_dslice_iw, 43 +.set v_tmp, 116 +.set v_end, 122 + +.text +.globl igemm_bwd_gtc_bt64x64x8_tt8x8_gm2x4x2_gn2x4x2_ta1x2x1x1x4_4x1x1x16x1_tb1x2x1x1x4x1x1_4x1x1x16x1x1x1 +.p2align 8 +.type igemm_bwd_gtc_bt64x64x8_tt8x8_gm2x4x2_gn2x4x2_ta1x2x1x1x4_4x1x1x16x1_tb1x2x1x1x4x1x1_4x1x1x16x1x1x1,@function +igemm_bwd_gtc_bt64x64x8_tt8x8_gm2x4x2_gn2x4x2_ta1x2x1x1x4_4x1x1x16x1_tb1x2x1x1x4x1x1_4x1x1x16x1x1x1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx2 s[s_p_wei+0:s_p_wei+1], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx2 s[s_p_out+0:s_p_out+1], s[s_ka+0:s_ka+1], 0+k_p_out + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx8 s[s_dtile_ix+0:s_dtile_ix+7], s[s_ka+0:s_ka+1], 0+k_dtile_ix + s_load_dwordx4 s[s_dslice_x+0:s_dslice_x+3], s[s_ka+0:s_ka+1], 0+k_dslice_x + s_load_dword s[s_dslice_w_left], s[s_ka+0:s_ka+1], 0+k_dslice_w_left + v_mov_b32 v[v_tmp], v0 + ; output: K0xK1xExN0xN1xB0xB1: 1x2x1x1x4x1x1, slice 4x1x1x16x1x1x1 + v_and_b32 v[v_gtc_in0], 15, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 4, v[v_tmp] + v_and_b32 v[v_gtc_ik0], 3, v[v_tmp] + + v_mov_b32 v[v_tmp], v0 + ; wei: K0xK1xExC0xC1: 1x2x1x1x4, slice 4x1x1x16x1 + v_and_b32 v[v_gtc_ic0], 15, v[v_tmp] + + s_mov_b32 s[s_p_in + 2], 0xffffffff + s_mov_b32 s[s_p_in + 3], 0x27000 + s_mov_b32 s[s_p_wei + 2], 0xffffffff + s_mov_b32 s[s_p_wei + 3], 0x27000 + s_mov_b32 s[s_p_out + 2], 0xffffffff + s_mov_b32 s[s_p_out + 3], 0x27000 + s_waitcnt lgkmcnt(0) + + ; calculate index + s_mul_i32 s[s_out_stride_k], s[s_ho], s[s_wo] + s_mul_i32 s[s_out_stride_n], s[s_k], s[s_out_stride_k] + s_mul_i32 s[s_in_stride_c], s[s_hi], s[s_wi] + s_mul_i32 s[s_in_stride_n], s[s_c], s[s_in_stride_c] + s_mul_i32 s[s_wei_stride_c], s[s_y], s[s_x] + s_mul_i32 s[s_wei_stride_k], s[s_c], s[s_wei_stride_c] + s_mul_i32 s[s_stride_dslice_hw], s[s_dslice_h], s[s_dslice_w] + s_mul_i32 s[s_stride_dslice_yx], s[s_dslice_y], s[s_dslice_x] + + ; N0xN1xB0xB1, gemm_m_per_block:64, gemm_n_per_block:64 + s_mul_i32 s[s_tmp], s[s_stride_dslice_hw], s[s_n] + s_lshr_b32 s[0], s[s_tmp], 6 + .v_u32_div_ss v_tmp+5, s_bx, 0, v_tmp, s_tmp + v_readfirstlane_b32 s[s_tmp+4], v[v_tmp+5] ; gemm_m, + s_mul_i32 s[s_tmp+2], s[s_tmp+4], s[0] + s_sub_i32 s[s_tmp+5], s[s_bx], s[s_tmp+2] ; gemm_n, cnt + s_lshl_b32 s[s_block_gtc_ic], s[s_tmp+4], 6 + s_mov_b32 s[0], s[s_stride_dslice_hw] ; B:1 per block, total num of B + .v_u32_div_ss v_tmp+5, s_tmp+5, 0, v_tmp, s_tmp + v_readfirstlane_b32 s[s_tmp], v[v_tmp+5] ; => N + s_mul_i32 s[s_tmp+2], s[s_tmp], s[0] + s_sub_i32 s[s_tmp+1], s[s_tmp+5], s[s_tmp+2] ; => B + s_lshl_b32 s[s_block_gtc_in], s[s_tmp], 6 + s_mov_b32 s[s_block_gtc_ib], s[s_tmp+1] + + v_mov_b32 v[v_gtc_ib1], 0 + v_lshlrev_b32 v[v_gtc_in1], 2, v[v_gtc_in0] + v_lshlrev_b32 v[v_gtc_ik1], 1, v[v_gtc_ik0] + + v_add_u32 v[v_tmp+5], s[s_block_gtc_ib], v[v_gtc_ib1] + ; calculate output transform, B -> dslice_h*dslice_w. always use ib1 + .v_u32_div_vs v_out_dslice_ih, v_tmp+5, s_dslice_w, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp], s[s_dslice_w], v[v_out_dslice_ih] + v_sub_u32 v[v_out_dslice_iw], v[v_tmp+5], v[v_tmp] + ; iHTildaLeft, iWTildaLeft + v_add_u32 v[v_out_dslice_ih], s[s_dslice_h_left], v[v_out_dslice_ih] + v_add_u32 v[v_out_dslice_iw], s[s_dslice_w_left], v[v_out_dslice_iw] + ; dslice_y,dslice_h -> oh, dslice_x,dslice_w -> ow + v_mov_b32 v[v_out_iho], v[v_out_dslice_ih] + v_mov_b32 v[v_out_iwo], v[v_out_dslice_iw] + v_mul_lo_u32 v[v_tmp], s[s_out_stride_k], v[v_gtc_ik1] + v_mul_lo_u32 v[v_tmp+1], s[s_out_stride_n], v[v_gtc_in1] + v_add_lshl_u32 v[v_out_os_base], v[v_tmp], v[v_tmp+1], 2 + ; n to staticly accumulate into base pointer + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_in], 2 + s_mul_i32 s[s_tmp], s[s_out_stride_n], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_out_stride_n], s[s_tmp+3] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], s[s_tmp+1] + + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + + s_lshl_b32 s[s_out_stride_k], s[s_out_stride_k], 2 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 2 + + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_2x4_b32_v1 v_gld_b, s_p_out, v_out_os, s_out_stride_k, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + + v_lshlrev_b32 v[v_gtc_ic1], 2, v[v_gtc_ic0] + v_add_u32 v[v_tmp+5], s[s_block_gtc_ic], v[v_gtc_ic1] + v_mov_b32 v[v_dtile_iy], s[s_dtile_iy] + v_mov_b32 v[v_dtile_ix], s[s_dtile_ix] + v_mov_b32 v[v_wei_iy], s[s_dtile_iy] + v_mov_b32 v[v_wei_ix], s[s_dtile_ix] + + ; calculate wei offset + v_mul_lo_u32 v[v_tmp], s[s_wei_stride_c], v[v_tmp+5] + v_mul_lo_u32 v[v_tmp+1], s[s_wei_stride_k], v[v_gtc_ik1] + v_add_lshl_u32 v[v_wei_os_base], v[v_tmp], v[v_tmp+1], 2 + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + + s_lshl_b32 s[s_wei_stride_k], s[s_wei_stride_k], 2 + s_lshl_b32 s[s_wei_stride_c], s[s_wei_stride_c], 2 + + ; load weight + .v_gld_2x4_b32_v1 v_gld_a, s_p_wei, v_wei_os, s_wei_stride_k, s_wei_stride_c, s_tmp + + ; c thread mapping + ; -> MR x NR x ML1 x NL1 x ML0 x NL0 + ; cluster 1 x 1 x 2 x 2 x 4 x 4 + ; perthrd 2 x 2 x 1 x 1 x 4 x 4 + v_and_b32 v[v_tmp], 3, v0 + v_lshlrev_b32 v[v_tmp], 2, v[v_tmp] ; => iNL0 + v_lshrrev_b32 v[v_tmp+5], 2, v0 + v_and_b32 v[v_tmp+1], 3, v[v_tmp+5] + v_lshlrev_b32 v[v_tmp+1], 2, v[v_tmp+1] ; => iML0 + v_lshrrev_b32 v[v_tmp+5], 2, v[v_tmp+5] + v_and_b32 v[v_tmp+2], 1, v[v_tmp+5] ; => iNL1 + v_lshrrev_b32 v[v_tmp+5], 1, v[v_tmp+5] + v_and_b32 v[v_tmp+3], 1, v[v_tmp+5] ; => iML1 + v_lshl_or_b32 v[v_gemm_in], v[v_tmp+2], 4, v[v_tmp] ; in (without repeat) + v_lshl_or_b32 v[v_gemm_im], v[v_tmp+3], 4, v[v_tmp+1] ; im (without repeat) + + ; remapping, gemm_im => C0xC1:16x4, gemm_in => N0xB0xB1xN1:16x1x1x4 + v_and_b32 v[v_in_in1], 3, v[v_gemm_in] + v_lshrrev_b32 v[v_tmp], 2, v[v_gemm_in] + v_mov_b32 v[v_in_ib1], 0 + v_mov_b32 v[v_in_ib0], 0 + v_and_b32 v[v_in_in0], 15, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 4, v[v_tmp] + v_and_b32 v[v_in_ic1], 3, v[v_gemm_im] + v_lshrrev_b32 v[v_tmp], 2, v[v_gemm_im] + v_and_b32 v[v_in_ic0], 15, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 4, v[v_tmp] + + v_lshl_or_b32 v[v_in_ib1], v[v_in_ib0], 0, v[v_in_ib1] + v_lshl_or_b32 v[v_in_in1], v[v_in_in0], 2, v[v_in_in1] + v_lshl_or_b32 v[v_in_ic1], v[v_in_ic0], 2, v[v_in_ic1] + v_add_u32 v[v_in_ib1], s[s_block_gtc_ib], v[v_in_ib1] + + .v_u32_div_vs v_in_dslice_ih, v_in_ib1, s_dslice_w, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp+1], s[s_dslice_w], v[v_in_dslice_ih] + v_sub_u32 v[v_in_dslice_iw], v[v_in_ib1], v[v_tmp+1] + v_add_u32 v[v_in_dslice_ih], s[s_dslice_h_left], v[v_in_dslice_ih] + v_add_u32 v[v_in_dslice_iw], s[s_dslice_w_left], v[v_in_dslice_iw] + + ; dslice_h,dslice_y -> hip, dslice_w,dslicw_x -> wip + s_mul_i32 s[s_tmp], s[s_dtile_iy], s[s_dilation_h] + v_mul_lo_u32 v[v_tmp], s[s_stride_h], v[v_in_dslice_ih] + v_add_u32 v[v_tmp], s[s_tmp], v[v_tmp] + s_mul_i32 s[s_tmp+1], s[s_dtile_ix], s[s_dilation_w] + v_mul_lo_u32 v[v_tmp+1], s[s_stride_w], v[v_in_dslice_iw] + v_add_u32 v[v_tmp+1], s[s_tmp+1], v[v_tmp+1] + ; v_tmp: hip, v_tmp+1: wip + + ; hip->h, wip->w + v_sub_i32 v[v_in_ihi], v[v_tmp], s[s_pad_h] + v_sub_i32 v[v_in_iwi], v[v_tmp+1], s[s_pad_w] + + .v_set_flag_hw v_in_flag, v_in_ihi, v_in_iwi, s_hi, s_wi + + ; input offset + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_in], 2 + s_mul_i32 s[s_tmp], s[s_in_stride_n], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_in_stride_n], s[s_tmp+3] + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], s[s_tmp+1] + + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_ic], 2 + s_mul_i32 s[s_tmp], s[s_in_stride_c], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_in_stride_c], s[s_tmp+3] + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], s[s_tmp+1] + v_mul_lo_u32 v[v_tmp], s[s_in_stride_n], v[v_in_in1] + v_mul_lo_u32 v[v_tmp+1], s[s_in_stride_c], v[v_in_ic1] + v_add_u32 v[v_tmp], v[v_tmp], v[v_tmp+1] + v_mul_lo_u32 v[v_tmp+1], s[s_wi], v[v_in_ihi] + v_add3_u32 v[v_in_os], v[v_tmp], v[v_tmp+1], v[v_in_iwi] + v_lshlrev_b32 v[v_in_os], 2, v[v_in_os] + + s_lshl_b32 s[s_in_stride_n], s[s_in_stride_n], 2 + s_lshl_b32 s[s_in_stride_nr], s[s_in_stride_n], 5 + + s_lshl_b32 s[s_in_stride_c], s[s_in_stride_c], 2 + s_lshl_b32 s[s_in_stride_cr], s[s_in_stride_c], 5 + + v_lshrrev_b32 v[v_gtc_in0], 2, v[v_gtc_in1] + ; LDS store, order out: K0xK1xExN0xB0xB1xN1: 1x2x1x1x1x1x4, 4x1x1x16x1x1x1 + v_lshlrev_b32 v[v_tmp], 2, v[v_gtc_ib1] + v_lshl_or_b32 v[v_tmp], v[v_gtc_in0], 2, v[v_tmp] + v_lshl_or_b32 v[v_tmp+1], v[v_gtc_ik1], 6, v[v_tmp] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_tmp+1] + + ; LDS store, order wei not shuffled, wei: K0xK1xExC0xC1: 1x2x1x1x4, 4x1x1x16x1 + v_lshl_or_b32 v[v_tmp], v[v_gtc_ik1], 6, v[v_gtc_ic1] + v_lshlrev_b32 v[v_sst_a_os], 2, v[v_tmp] + v_add_u32 v[v_sst_b_os], 2048, v[v_sst_b_os] + + ; LDS load + v_lshlrev_b32 v[v_sld_b_os], 2, v[v_gemm_in] + v_lshlrev_b32 v[v_sld_a_os], 2, v[v_gemm_im] + v_add_u32 v[v_sld_b_os], 2048, v[v_sld_b_os] + + s_mul_i32 s[s_knum], s[s_stride_dslice_yx], s[s_k] + s_lshl_b32 s[s_out_move_slice_stride_k], s[s_out_stride_k], 3 ; 8x + s_lshl_b32 s[s_wei_move_slice_stride_k], s[s_wei_stride_k], 3 ; 8x + s_mov_b32 s[s_move_slice_k_ik], 0 + s_mov_b32 s[s_move_slice_k_idsy], 0 + s_mov_b32 s[s_move_slice_k_idsx], 0 + + .v_clear_nc v_c, 64 + ; start FMA loop, 8x8 thread tile with 4x4 sub-tile + s_waitcnt vmcnt(8) + .v_sst_so0_2x4_b32_v4_st256 v_gld_b, v_sst_b_os + + s_waitcnt vmcnt(0) + .v_sst_so0_2x4_b32_v4_st256 v_gld_a, v_sst_a_os + + s_sub_i32 s[s_kitr], s[s_knum], 8 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_bwd_gtc_bt64x64x8_tt8x8_gm2x4x2_gn2x4x2_ta1x2x1x1x4_4x1x1x16x1_tb1x2x1x1x4x1x1_4x1x1x16x1x1x1_end + + .s_bwd_gtc_move_slice_window_k_dsy_dsx s_move_slice_k_ik, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dslice_y, s_dslice_x, 8, 1, 1, v_out_os_base, v_wei_os_base, s_out_move_slice_stride_k, s_wei_move_slice_stride_k + .v_bwd_gtc_out_update_hw v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_dy, s_dtile_dx, s_tmp + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + .v_bwd_gtc_wei_update_yx v_wei_iy, v_wei_ix, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_y, s_dtile_x, v_dtile_iy, v_dtile_ix, s_tmp + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + v_xor_b32 v[v_sst_b_os], 0x1000, v[v_sst_b_os] ; switch double buffer b store + v_xor_b32 v[v_sst_a_os], 0x1000, v[v_sst_a_os] ; switch double buffer a store + s_waitcnt lgkmcnt(0) + s_barrier + + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_2x4_b32_v1 v_gld_b, s_p_out, v_out_os, s_out_stride_k, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + ; load weight + .v_gld_2x4_b32_v1 v_gld_a, s_p_wei, v_wei_os, s_wei_stride_k, s_wei_stride_c, s_tmp + +L_igemm_bwd_gtc_bt64x64x8_tt8x8_gm2x4x2_gn2x4x2_ta1x2x1x1x4_4x1x1x16x1_tb1x2x1x1x4x1x1_4x1x1x16x1x1x1_fma_body: + ; do fma accumulate with unroll 8 + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:128 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:128 + .itr_k = 0 + .rept 7 + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] offset:0+(.itr_k+1)*256 + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] offset:0+(.itr_k+1)*256 + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:0+(.itr_k+1)*256+128 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:0+(.itr_k+1)*256+128 + .itr_k = .itr_k + 1 + .endr + + ; last unroll + v_xor_b32 v[v_sld_b_os], 4096, v[v_sld_b_os] ; switch double buffer b load + v_xor_b32 v[v_sld_a_os], 4096, v[v_sld_a_os] ; switch double buffer a load + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + s_waitcnt vmcnt(8) + .v_sst_so0_2x4_b32_v4_st256 v_gld_b, v_sst_b_os + s_waitcnt vmcnt(0) + .v_sst_so0_2x4_b32_v4_st256 v_gld_a, v_sst_a_os + s_sub_i32 s[s_kitr], s[s_kitr], 8 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_bwd_gtc_bt64x64x8_tt8x8_gm2x4x2_gn2x4x2_ta1x2x1x1x4_4x1x1x16x1_tb1x2x1x1x4x1x1_4x1x1x16x1x1x1_fma_finishing + .s_bwd_gtc_move_slice_window_k_dsy_dsx s_move_slice_k_ik, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dslice_y, s_dslice_x, 8, 1, 1, v_out_os_base, v_wei_os_base, s_out_move_slice_stride_k, s_wei_move_slice_stride_k + .v_bwd_gtc_out_update_hw v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_dy, s_dtile_dx, s_tmp + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + .v_bwd_gtc_wei_update_yx v_wei_iy, v_wei_ix, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_y, s_dtile_x, v_dtile_iy, v_dtile_ix, s_tmp + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + s_waitcnt lgkmcnt(4) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + v_xor_b32 v[v_sst_b_os], 4096, v[v_sst_b_os] ; switch double buffer b store + v_xor_b32 v[v_sst_a_os], 4096, v[v_sst_a_os] ; switch double buffer a store + s_waitcnt lgkmcnt(0) + s_barrier + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_2x4_b32_v1 v_gld_b, s_p_out, v_out_os, s_out_stride_k, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + ; load weight + .v_gld_2x4_b32_v1 v_gld_a, s_p_wei, v_wei_os, s_wei_stride_k, s_wei_stride_c, s_tmp + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + s_branch L_igemm_bwd_gtc_bt64x64x8_tt8x8_gm2x4x2_gn2x4x2_ta1x2x1x1x4_4x1x1x16x1_tb1x2x1x1x4x1x1_4x1x1x16x1x1x1_fma_body +L_igemm_bwd_gtc_bt64x64x8_tt8x8_gm2x4x2_gn2x4x2_ta1x2x1x1x4_4x1x1x16x1_tb1x2x1x1x4x1x1_4x1x1x16x1x1x1_fma_finishing: + s_waitcnt lgkmcnt(4) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 +L_igemm_bwd_gtc_bt64x64x8_tt8x8_gm2x4x2_gn2x4x2_ta1x2x1x1x4_4x1x1x16x1_tb1x2x1x1x4x1x1_4x1x1x16x1x1x1_end: + s_waitcnt lgkmcnt(0) + s_barrier + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:128 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:128 + .itr_k = 0 + .rept 7 + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] offset:0+(.itr_k+1)*256 + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] offset:0+(.itr_k+1)*256 + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:0+(.itr_k+1)*256+128 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:0+(.itr_k+1)*256+128 + .itr_k = .itr_k + 1 + .endr + + ; last unroll + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + s_waitcnt lgkmcnt(0) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + v_cmpx_eq_u32 vcc, 1, v[v_in_flag] + s_cbranch_execz L_igemm_bwd_gtc_bt64x64x8_tt8x8_gm2x4x2_gn2x4x2_ta1x2x1x1x4_4x1x1x16x1_tb1x2x1x1x4x1x1_4x1x1x16x1x1x1_out + s_mov_b32 s[s_tmp], 0 + s_mov_b32 s[s_tmp+1], 0 + s_mov_b32 s[s_tmp+2], 0 + s_mov_b32 s[s_tmp+3], 0 + .v_write4d_strided v_c,s_p_in,v_in_os,s_in_stride_n,s_in_stride_nr,s_in_stride_c,s_in_stride_cr,s_tmp,4,2,4,2 +L_igemm_bwd_gtc_bt64x64x8_tt8x8_gm2x4x2_gn2x4x2_ta1x2x1x1x4_4x1x1x16x1_tb1x2x1x1x4x1x1_4x1x1x16x1x1x1_out: + s_endpgm +.rodata +.p2align 6 +.amdhsa_kernel igemm_bwd_gtc_bt64x64x8_tt8x8_gm2x4x2_gn2x4x2_ta1x2x1x1x4_4x1x1x16x1_tb1x2x1x1x4x1x1_4x1x1x16x1x1x1 + .amdhsa_group_segment_fixed_size 8192 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 122 + .amdhsa_next_free_sgpr 78 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 +.end_amdhsa_kernel + +;---------------------------------------------------------- +; starting of kernel igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x2x1x1x4_8x1x1x32x1_tb1x2x1x1x4x1x1_8x1x1x2x1x1x16 +; gemm_m_per_block : 128 +; gemm_n_per_block : 128 +; gemm_k_per_block : 16 +; gemm_m_per_thread : 4 +; gemm_m_level0_cluster : 4 +; gemm_m_level1_cluster : 4 +; gemm_n_per_thread : 4 +; gemm_n_level0_cluster : 4 +; gemm_n_level1_cluster : 4 +; tensor_a_thread_lengths : [1, 2, 1, 1, 4] +; tensor_a_cluster_lengths : [8, 1, 1, 32, 1] +; tensor_b_thread_lengths : [1, 2, 1, 1, 4, 1, 1] +; tensor_b_cluster_lengths : [8, 1, 1, 2, 1, 1, 16] +; direction : bwd +; precision : fp32 +; opt_1x1 : 0 +; +; block_size : 256 +; thread_tile : 8x8 +; lds_total : 32768 +; +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_dtile_iy, 84 +.set k_dtile_ix, 88 +.set k_dtile_dy, 92 +.set k_dtile_dx, 96 +.set k_dtile_y, 100 +.set k_dtile_x, 104 +.set k_dtile_h, 108 +.set k_dtile_w, 112 +.set k_dslice_y, 116 +.set k_dslice_x, 120 +.set k_dslice_h, 124 +.set k_dslice_w, 128 +.set k_dslice_h_left, 132 +.set k_dslice_w_left, 136 +.set k_pack0, 140 +.set k_end, 144 + +.set s_ka, 0 +.set s_bx, 2 +.set s_p_in, 4 +.set s_p_wei, 8 +.set s_p_out, 12 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_dtile_iy, 31 +.set s_dtile_ix, 32 +.set s_dtile_dy, 33 +.set s_dtile_dx, 34 +.set s_dtile_y, 35 +.set s_dtile_x, 36 +.set s_dtile_h, 37 +.set s_dtile_w, 38 +.set s_dslice_y, 39 +.set s_dslice_x, 40 +.set s_dslice_h, 41 +.set s_dslice_w, 42 +.set s_dslice_h_left, 43 +.set s_dslice_w_left, 44 +.set s_out_stride_k, 45 +.set s_out_stride_k0, 46 +.set s_out_stride_n, 47 +.set s_out_stride_n0, 48 +.set s_out_stride_b0, 49 +.set s_out_move_slice_stride_k, 50 +.set s_in_stride_c, 51 +.set s_in_stride_cr, 52 +.set s_in_stride_n, 53 +.set s_in_stride_nr, 54 +.set s_wei_stride_c, 55 +.set s_wei_stride_c0, 56 +.set s_wei_stride_k, 57 +.set s_wei_stride_k0, 58 +.set s_wei_move_slice_stride_k, 59 +.set s_stride_dslice_hw, 60 +.set s_stride_dslice_yx, 61 +.set s_block_gtc_ib, 62 +.set s_block_gtc_ic, 63 +.set s_block_gtc_in, 64 +.set s_knum, 19 +.set s_move_slice_k_idsy, 0 +.set s_move_slice_k_idsx, 1 +.set s_move_slice_k_ik, 2 +.set s_kitr, 3 +.set s_tmp, 66 +.set s_end, 72 + +.set v_c, 0 +.set v_a, 64 +.set v_b, 72 +.set v_gld_a, 80 +.set v_gld_b, 88 +.set v_sst_a_os, 96 +.set v_sst_b_os, 97 +.set v_sld_a_os, 98 +.set v_sld_b_os, 99 +.set v_out_iho, 100 +.set v_out_iwo, 101 +.set v_out_dslice_ih, 102 +.set v_out_dslice_iw, 103 +.set v_out_os, 104 +.set v_out_os_base, 105 +.set v_wei_iy, 106 +.set v_wei_ix, 107 +.set v_dtile_iy, 108 +.set v_dtile_ix, 109 +.set v_wei_os, 110 +.set v_wei_os_base, 111 +.set v_out_flag, 112 +.set v_in_flag, 113 +.set v_in_os, 114 +.set v_gtc_ic0, 63 +.set v_gtc_ic1, 62 +.set v_gtc_in0, 61 +.set v_gtc_in1, 60 +.set v_gtc_ib0, 59 +.set v_gtc_ib1, 58 +.set v_gtc_ik0, 57 +.set v_gtc_ik1, 56 +.set v_gtc_ie, 55 +.set v_gemm_in, 54 +.set v_gemm_im, 53 +.set v_in_in0, 52 +.set v_in_in1, 51 +.set v_in_ib0, 50 +.set v_in_ib1, 49 +.set v_in_ic0, 48 +.set v_in_ic1, 47 +.set v_in_ihi, 46 +.set v_in_iwi, 45 +.set v_in_dslice_ih, 44 +.set v_in_dslice_iw, 43 +.set v_tmp, 116 +.set v_end, 122 + +.text +.globl igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x2x1x1x4_8x1x1x32x1_tb1x2x1x1x4x1x1_8x1x1x2x1x1x16 +.p2align 8 +.type igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x2x1x1x4_8x1x1x32x1_tb1x2x1x1x4x1x1_8x1x1x2x1x1x16,@function +igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x2x1x1x4_8x1x1x32x1_tb1x2x1x1x4x1x1_8x1x1x2x1x1x16: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx2 s[s_p_wei+0:s_p_wei+1], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx2 s[s_p_out+0:s_p_out+1], s[s_ka+0:s_ka+1], 0+k_p_out + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx8 s[s_dtile_ix+0:s_dtile_ix+7], s[s_ka+0:s_ka+1], 0+k_dtile_ix + s_load_dwordx4 s[s_dslice_x+0:s_dslice_x+3], s[s_ka+0:s_ka+1], 0+k_dslice_x + s_load_dword s[s_dslice_w_left], s[s_ka+0:s_ka+1], 0+k_dslice_w_left + v_mov_b32 v[v_tmp], v0 + ; output: K0xK1xExN0xN1xB0xB1: 1x2x1x1x4x1x1, slice 8x1x1x2x1x1x16 + v_and_b32 v[v_gtc_ib1], 15, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 4, v[v_tmp] + v_and_b32 v[v_gtc_in0], 1, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 1, v[v_tmp] + v_and_b32 v[v_gtc_ik0], 7, v[v_tmp] + + v_mov_b32 v[v_tmp], v0 + ; wei: K0xK1xExC0xC1: 1x2x1x1x4, slice 8x1x1x32x1 + v_and_b32 v[v_gtc_ic0], 31, v[v_tmp] + + s_mov_b32 s[s_p_in + 2], 0xffffffff + s_mov_b32 s[s_p_in + 3], 0x27000 + s_mov_b32 s[s_p_wei + 2], 0xffffffff + s_mov_b32 s[s_p_wei + 3], 0x27000 + s_mov_b32 s[s_p_out + 2], 0xffffffff + s_mov_b32 s[s_p_out + 3], 0x27000 + s_waitcnt lgkmcnt(0) + + ; calculate index + s_mul_i32 s[s_out_stride_k], s[s_ho], s[s_wo] + s_mul_i32 s[s_out_stride_n], s[s_k], s[s_out_stride_k] + s_mul_i32 s[s_in_stride_c], s[s_hi], s[s_wi] + s_mul_i32 s[s_in_stride_n], s[s_c], s[s_in_stride_c] + s_mul_i32 s[s_wei_stride_c], s[s_y], s[s_x] + s_mul_i32 s[s_wei_stride_k], s[s_c], s[s_wei_stride_c] + s_mul_i32 s[s_stride_dslice_hw], s[s_dslice_h], s[s_dslice_w] + s_mul_i32 s[s_stride_dslice_yx], s[s_dslice_y], s[s_dslice_x] + + ; N0xN1xB0xB1, gemm_m_per_block:128, gemm_n_per_block:128 + s_mul_i32 s[s_tmp], s[s_stride_dslice_hw], s[s_n] + s_lshr_b32 s[0], s[s_tmp], 7 + .v_u32_div_ss v_tmp+5, s_bx, 0, v_tmp, s_tmp + v_readfirstlane_b32 s[s_tmp+4], v[v_tmp+5] ; gemm_m, + s_mul_i32 s[s_tmp+2], s[s_tmp+4], s[0] + s_sub_i32 s[s_tmp+5], s[s_bx], s[s_tmp+2] ; gemm_n, cnt + s_lshl_b32 s[s_block_gtc_ic], s[s_tmp+4], 7 + s_lshr_b32 s[0], s[s_stride_dslice_hw], 4 ; B:16 per block, total num of B + .v_u32_div_ss v_tmp+5, s_tmp+5, 0, v_tmp, s_tmp + v_readfirstlane_b32 s[s_tmp], v[v_tmp+5] ; => N + s_mul_i32 s[s_tmp+2], s[s_tmp], s[0] + s_sub_i32 s[s_tmp+1], s[s_tmp+5], s[s_tmp+2] ; => B + s_lshl_b32 s[s_block_gtc_in], s[s_tmp], 3 + s_lshl_b32 s[s_block_gtc_ib], s[s_tmp+1], 4 + + v_lshlrev_b32 v[v_gtc_in1], 2, v[v_gtc_in0] + v_lshlrev_b32 v[v_gtc_ik1], 1, v[v_gtc_ik0] + + v_add_u32 v[v_tmp+5], s[s_block_gtc_ib], v[v_gtc_ib1] + ; calculate output transform, B -> dslice_h*dslice_w. always use ib1 + .v_u32_div_vs v_out_dslice_ih, v_tmp+5, s_dslice_w, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp], s[s_dslice_w], v[v_out_dslice_ih] + v_sub_u32 v[v_out_dslice_iw], v[v_tmp+5], v[v_tmp] + ; iHTildaLeft, iWTildaLeft + v_add_u32 v[v_out_dslice_ih], s[s_dslice_h_left], v[v_out_dslice_ih] + v_add_u32 v[v_out_dslice_iw], s[s_dslice_w_left], v[v_out_dslice_iw] + ; dslice_y,dslice_h -> oh, dslice_x,dslice_w -> ow + v_mov_b32 v[v_out_iho], v[v_out_dslice_ih] + v_mov_b32 v[v_out_iwo], v[v_out_dslice_iw] + v_mul_lo_u32 v[v_tmp], s[s_out_stride_k], v[v_gtc_ik1] + v_mul_lo_u32 v[v_tmp+1], s[s_out_stride_n], v[v_gtc_in1] + v_add_lshl_u32 v[v_out_os_base], v[v_tmp], v[v_tmp+1], 2 + ; n to staticly accumulate into base pointer + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_in], 2 + s_mul_i32 s[s_tmp], s[s_out_stride_n], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_out_stride_n], s[s_tmp+3] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], s[s_tmp+1] + + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + + s_lshl_b32 s[s_out_stride_k], s[s_out_stride_k], 2 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 2 + + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_2x4_b32_v1 v_gld_b, s_p_out, v_out_os, s_out_stride_k, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + + v_lshlrev_b32 v[v_gtc_ic1], 2, v[v_gtc_ic0] + v_add_u32 v[v_tmp+5], s[s_block_gtc_ic], v[v_gtc_ic1] + v_mov_b32 v[v_dtile_iy], s[s_dtile_iy] + v_mov_b32 v[v_dtile_ix], s[s_dtile_ix] + v_mov_b32 v[v_wei_iy], s[s_dtile_iy] + v_mov_b32 v[v_wei_ix], s[s_dtile_ix] + + ; calculate wei offset + v_mul_lo_u32 v[v_tmp], s[s_wei_stride_c], v[v_tmp+5] + v_mul_lo_u32 v[v_tmp+1], s[s_wei_stride_k], v[v_gtc_ik1] + v_add_lshl_u32 v[v_wei_os_base], v[v_tmp], v[v_tmp+1], 2 + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + + s_lshl_b32 s[s_wei_stride_k], s[s_wei_stride_k], 2 + s_lshl_b32 s[s_wei_stride_c], s[s_wei_stride_c], 2 + + ; load weight + .v_gld_2x4_b32_v1 v_gld_a, s_p_wei, v_wei_os, s_wei_stride_k, s_wei_stride_c, s_tmp + + ; c thread mapping + ; -> MR x NR x ML1 x NL1 x ML0 x NL0 + ; cluster 1 x 1 x 4 x 4 x 4 x 4 + ; perthrd 2 x 2 x 1 x 1 x 4 x 4 + v_and_b32 v[v_tmp], 3, v0 + v_lshlrev_b32 v[v_tmp], 2, v[v_tmp] ; => iNL0 + v_lshrrev_b32 v[v_tmp+5], 2, v0 + v_and_b32 v[v_tmp+1], 3, v[v_tmp+5] + v_lshlrev_b32 v[v_tmp+1], 2, v[v_tmp+1] ; => iML0 + v_lshrrev_b32 v[v_tmp+5], 2, v[v_tmp+5] + v_and_b32 v[v_tmp+2], 3, v[v_tmp+5] ; => iNL1 + v_lshrrev_b32 v[v_tmp+5], 2, v[v_tmp+5] + v_and_b32 v[v_tmp+3], 3, v[v_tmp+5] ; => iML1 + v_lshl_or_b32 v[v_gemm_in], v[v_tmp+2], 4, v[v_tmp] ; in (without repeat) + v_lshl_or_b32 v[v_gemm_im], v[v_tmp+3], 4, v[v_tmp+1] ; im (without repeat) + + ; remapping, gemm_im => C0xC1:32x4, gemm_in => N0xB0xB1xN1:2x1x16x4 + v_and_b32 v[v_in_in1], 3, v[v_gemm_in] + v_lshrrev_b32 v[v_tmp], 2, v[v_gemm_in] + v_and_b32 v[v_in_ib1], 15, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 4, v[v_tmp] + v_mov_b32 v[v_in_ib0], 0 + v_and_b32 v[v_in_in0], 1, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 1, v[v_tmp] + v_and_b32 v[v_in_ic1], 3, v[v_gemm_im] + v_lshrrev_b32 v[v_tmp], 2, v[v_gemm_im] + v_and_b32 v[v_in_ic0], 31, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 5, v[v_tmp] + + v_lshl_or_b32 v[v_in_ib1], v[v_in_ib0], 4, v[v_in_ib1] + v_lshl_or_b32 v[v_in_in1], v[v_in_in0], 2, v[v_in_in1] + v_lshl_or_b32 v[v_in_ic1], v[v_in_ic0], 2, v[v_in_ic1] + v_add_u32 v[v_in_ib1], s[s_block_gtc_ib], v[v_in_ib1] + + .v_u32_div_vs v_in_dslice_ih, v_in_ib1, s_dslice_w, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp+1], s[s_dslice_w], v[v_in_dslice_ih] + v_sub_u32 v[v_in_dslice_iw], v[v_in_ib1], v[v_tmp+1] + v_add_u32 v[v_in_dslice_ih], s[s_dslice_h_left], v[v_in_dslice_ih] + v_add_u32 v[v_in_dslice_iw], s[s_dslice_w_left], v[v_in_dslice_iw] + + ; dslice_h,dslice_y -> hip, dslice_w,dslicw_x -> wip + s_mul_i32 s[s_tmp], s[s_dtile_iy], s[s_dilation_h] + v_mul_lo_u32 v[v_tmp], s[s_stride_h], v[v_in_dslice_ih] + v_add_u32 v[v_tmp], s[s_tmp], v[v_tmp] + s_mul_i32 s[s_tmp+1], s[s_dtile_ix], s[s_dilation_w] + v_mul_lo_u32 v[v_tmp+1], s[s_stride_w], v[v_in_dslice_iw] + v_add_u32 v[v_tmp+1], s[s_tmp+1], v[v_tmp+1] + ; v_tmp: hip, v_tmp+1: wip + + ; hip->h, wip->w + v_sub_i32 v[v_in_ihi], v[v_tmp], s[s_pad_h] + v_sub_i32 v[v_in_iwi], v[v_tmp+1], s[s_pad_w] + + .v_set_flag_hw v_in_flag, v_in_ihi, v_in_iwi, s_hi, s_wi + + ; input offset + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_in], 2 + s_mul_i32 s[s_tmp], s[s_in_stride_n], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_in_stride_n], s[s_tmp+3] + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], s[s_tmp+1] + + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_ic], 2 + s_mul_i32 s[s_tmp], s[s_in_stride_c], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_in_stride_c], s[s_tmp+3] + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], s[s_tmp+1] + v_mul_lo_u32 v[v_tmp], s[s_in_stride_n], v[v_in_in1] + v_mul_lo_u32 v[v_tmp+1], s[s_in_stride_c], v[v_in_ic1] + v_add_u32 v[v_tmp], v[v_tmp], v[v_tmp+1] + v_mul_lo_u32 v[v_tmp+1], s[s_wi], v[v_in_ihi] + v_add3_u32 v[v_in_os], v[v_tmp], v[v_tmp+1], v[v_in_iwi] + v_lshlrev_b32 v[v_in_os], 2, v[v_in_os] + + s_lshl_b32 s[s_in_stride_n], s[s_in_stride_n], 2 + s_lshl_b32 s[s_in_stride_nr], s[s_in_stride_n], 2 + + s_lshl_b32 s[s_in_stride_c], s[s_in_stride_c], 2 + s_lshl_b32 s[s_in_stride_cr], s[s_in_stride_c], 6 + + v_lshrrev_b32 v[v_gtc_in0], 2, v[v_gtc_in1] + ; LDS store, order out: K0xK1xExN0xB0xB1xN1: 1x2x1x1x1x1x4, 8x1x1x2x1x16x1 + v_lshlrev_b32 v[v_tmp], 2, v[v_gtc_ib1] + v_lshl_or_b32 v[v_tmp], v[v_gtc_in0], 6, v[v_tmp] + v_lshl_or_b32 v[v_tmp+1], v[v_gtc_ik1], 7, v[v_tmp] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_tmp+1] + + ; LDS store, order wei not shuffled, wei: K0xK1xExC0xC1: 1x2x1x1x4, 8x1x1x32x1 + v_lshl_or_b32 v[v_tmp], v[v_gtc_ik1], 7, v[v_gtc_ic1] + v_lshlrev_b32 v[v_sst_a_os], 2, v[v_tmp] + v_add_u32 v[v_sst_b_os], 8192, v[v_sst_b_os] + + ; LDS load + v_lshlrev_b32 v[v_sld_b_os], 2, v[v_gemm_in] + v_lshlrev_b32 v[v_sld_a_os], 2, v[v_gemm_im] + v_add_u32 v[v_sld_b_os], 8192, v[v_sld_b_os] + + s_mul_i32 s[s_knum], s[s_stride_dslice_yx], s[s_k] + s_lshl_b32 s[s_out_move_slice_stride_k], s[s_out_stride_k], 4 ; 16x + s_lshl_b32 s[s_wei_move_slice_stride_k], s[s_wei_stride_k], 4 ; 16x + s_mov_b32 s[s_move_slice_k_ik], 0 + s_mov_b32 s[s_move_slice_k_idsy], 0 + s_mov_b32 s[s_move_slice_k_idsx], 0 + + .v_clear_nc v_c, 64 + ; start FMA loop, 8x8 thread tile with 4x4 sub-tile + s_waitcnt vmcnt(8) + .v_sst_so0_2x4_b32_v4_st512 v_gld_b, v_sst_b_os + + s_waitcnt vmcnt(0) + .v_sst_so0_2x4_b32_v4_st512 v_gld_a, v_sst_a_os + + s_sub_i32 s[s_kitr], s[s_knum], 16 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x2x1x1x4_8x1x1x32x1_tb1x2x1x1x4x1x1_8x1x1x2x1x1x16_end + + .s_bwd_gtc_move_slice_window_k_dsy_dsx s_move_slice_k_ik, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dslice_y, s_dslice_x, 16, 1, 1, v_out_os_base, v_wei_os_base, s_out_move_slice_stride_k, s_wei_move_slice_stride_k + .v_bwd_gtc_out_update_hw v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_dy, s_dtile_dx, s_tmp + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + .v_bwd_gtc_wei_update_yx v_wei_iy, v_wei_ix, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_y, s_dtile_x, v_dtile_iy, v_dtile_ix, s_tmp + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + v_xor_b32 v[v_sst_b_os], 0x4000, v[v_sst_b_os] ; switch double buffer b store + v_xor_b32 v[v_sst_a_os], 0x4000, v[v_sst_a_os] ; switch double buffer a store + s_waitcnt lgkmcnt(0) + s_barrier + + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_2x4_b32_v1 v_gld_b, s_p_out, v_out_os, s_out_stride_k, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + ; load weight + .v_gld_2x4_b32_v1 v_gld_a, s_p_wei, v_wei_os, s_wei_stride_k, s_wei_stride_c, s_tmp + +L_igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x2x1x1x4_8x1x1x32x1_tb1x2x1x1x4x1x1_8x1x1x2x1x1x16_fma_body: + ; do fma accumulate with unroll 16 + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:256 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:256 + .itr_k = 0 + .rept 15 + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] offset:0+(.itr_k+1)*512 + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512 + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512+256 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:0+(.itr_k+1)*512+256 + .itr_k = .itr_k + 1 + .endr + + ; last unroll + v_xor_b32 v[v_sld_b_os], 16384, v[v_sld_b_os] ; switch double buffer b load + v_xor_b32 v[v_sld_a_os], 16384, v[v_sld_a_os] ; switch double buffer a load + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + s_waitcnt vmcnt(8) + .v_sst_so0_2x4_b32_v4_st512 v_gld_b, v_sst_b_os + s_waitcnt vmcnt(0) + .v_sst_so0_2x4_b32_v4_st512 v_gld_a, v_sst_a_os + s_sub_i32 s[s_kitr], s[s_kitr], 16 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x2x1x1x4_8x1x1x32x1_tb1x2x1x1x4x1x1_8x1x1x2x1x1x16_fma_finishing + .s_bwd_gtc_move_slice_window_k_dsy_dsx s_move_slice_k_ik, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dslice_y, s_dslice_x, 16, 1, 1, v_out_os_base, v_wei_os_base, s_out_move_slice_stride_k, s_wei_move_slice_stride_k + .v_bwd_gtc_out_update_hw v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_dy, s_dtile_dx, s_tmp + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + .v_bwd_gtc_wei_update_yx v_wei_iy, v_wei_ix, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_y, s_dtile_x, v_dtile_iy, v_dtile_ix, s_tmp + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + s_waitcnt lgkmcnt(4) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + v_xor_b32 v[v_sst_b_os], 16384, v[v_sst_b_os] ; switch double buffer b store + v_xor_b32 v[v_sst_a_os], 16384, v[v_sst_a_os] ; switch double buffer a store + s_waitcnt lgkmcnt(0) + s_barrier + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_2x4_b32_v1 v_gld_b, s_p_out, v_out_os, s_out_stride_k, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + ; load weight + .v_gld_2x4_b32_v1 v_gld_a, s_p_wei, v_wei_os, s_wei_stride_k, s_wei_stride_c, s_tmp + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + s_branch L_igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x2x1x1x4_8x1x1x32x1_tb1x2x1x1x4x1x1_8x1x1x2x1x1x16_fma_body +L_igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x2x1x1x4_8x1x1x32x1_tb1x2x1x1x4x1x1_8x1x1x2x1x1x16_fma_finishing: + s_waitcnt lgkmcnt(4) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 +L_igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x2x1x1x4_8x1x1x32x1_tb1x2x1x1x4x1x1_8x1x1x2x1x1x16_end: + s_waitcnt lgkmcnt(0) + s_barrier + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:256 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:256 + .itr_k = 0 + .rept 15 + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] offset:0+(.itr_k+1)*512 + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512 + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512+256 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:0+(.itr_k+1)*512+256 + .itr_k = .itr_k + 1 + .endr + + ; last unroll + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + s_waitcnt lgkmcnt(0) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + v_cmpx_eq_u32 vcc, 1, v[v_in_flag] + s_cbranch_execz L_igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x2x1x1x4_8x1x1x32x1_tb1x2x1x1x4x1x1_8x1x1x2x1x1x16_out + s_mov_b32 s[s_tmp], 0 + s_mov_b32 s[s_tmp+1], 0 + s_mov_b32 s[s_tmp+2], 0 + s_mov_b32 s[s_tmp+3], 0 + .v_write4d_strided v_c,s_p_in,v_in_os,s_in_stride_n,s_in_stride_nr,s_in_stride_c,s_in_stride_cr,s_tmp,4,2,4,2 +L_igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x2x1x1x4_8x1x1x32x1_tb1x2x1x1x4x1x1_8x1x1x2x1x1x16_out: + s_endpgm +.rodata +.p2align 6 +.amdhsa_kernel igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x2x1x1x4_8x1x1x32x1_tb1x2x1x1x4x1x1_8x1x1x2x1x1x16 + .amdhsa_group_segment_fixed_size 32768 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 122 + .amdhsa_next_free_sgpr 78 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 +.end_amdhsa_kernel + +;---------------------------------------------------------- +; starting of kernel igemm_bwd_gtc_bt128x128x8_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x1x4_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16 +; gemm_m_per_block : 128 +; gemm_n_per_block : 128 +; gemm_k_per_block : 8 +; gemm_m_per_thread : 4 +; gemm_m_level0_cluster : 4 +; gemm_m_level1_cluster : 4 +; gemm_n_per_thread : 4 +; gemm_n_level0_cluster : 4 +; gemm_n_level1_cluster : 4 +; tensor_a_thread_lengths : [1, 1, 1, 1, 4] +; tensor_a_cluster_lengths : [8, 1, 1, 32, 1] +; tensor_b_thread_lengths : [1, 1, 1, 1, 4, 1, 1] +; tensor_b_cluster_lengths : [8, 1, 1, 2, 1, 1, 16] +; direction : bwd +; precision : fp32 +; opt_1x1 : 0 +; +; block_size : 256 +; thread_tile : 8x8 +; lds_total : 16384 +; +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_dtile_iy, 84 +.set k_dtile_ix, 88 +.set k_dtile_dy, 92 +.set k_dtile_dx, 96 +.set k_dtile_y, 100 +.set k_dtile_x, 104 +.set k_dtile_h, 108 +.set k_dtile_w, 112 +.set k_dslice_y, 116 +.set k_dslice_x, 120 +.set k_dslice_h, 124 +.set k_dslice_w, 128 +.set k_dslice_h_left, 132 +.set k_dslice_w_left, 136 +.set k_pack0, 140 +.set k_end, 144 + +.set s_ka, 0 +.set s_bx, 2 +.set s_p_in, 4 +.set s_p_wei, 8 +.set s_p_out, 12 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_dtile_iy, 31 +.set s_dtile_ix, 32 +.set s_dtile_dy, 33 +.set s_dtile_dx, 34 +.set s_dtile_y, 35 +.set s_dtile_x, 36 +.set s_dtile_h, 37 +.set s_dtile_w, 38 +.set s_dslice_y, 39 +.set s_dslice_x, 40 +.set s_dslice_h, 41 +.set s_dslice_w, 42 +.set s_dslice_h_left, 43 +.set s_dslice_w_left, 44 +.set s_out_stride_k, 45 +.set s_out_stride_k0, 46 +.set s_out_stride_n, 47 +.set s_out_stride_n0, 48 +.set s_out_stride_b0, 49 +.set s_out_move_slice_stride_k, 50 +.set s_in_stride_c, 51 +.set s_in_stride_cr, 52 +.set s_in_stride_n, 53 +.set s_in_stride_nr, 54 +.set s_wei_stride_c, 55 +.set s_wei_stride_c0, 56 +.set s_wei_stride_k, 57 +.set s_wei_stride_k0, 58 +.set s_wei_move_slice_stride_k, 59 +.set s_stride_dslice_hw, 60 +.set s_stride_dslice_yx, 61 +.set s_block_gtc_ib, 62 +.set s_block_gtc_ic, 63 +.set s_block_gtc_in, 64 +.set s_knum, 19 +.set s_move_slice_k_idsy, 0 +.set s_move_slice_k_idsx, 1 +.set s_move_slice_k_ik, 2 +.set s_kitr, 3 +.set s_tmp, 66 +.set s_end, 72 + +.set v_c, 0 +.set v_a, 64 +.set v_b, 72 +.set v_gld_a, 80 +.set v_gld_b, 84 +.set v_sst_a_os, 88 +.set v_sst_b_os, 89 +.set v_sld_a_os, 90 +.set v_sld_b_os, 91 +.set v_out_iho, 92 +.set v_out_iwo, 93 +.set v_out_dslice_ih, 94 +.set v_out_dslice_iw, 95 +.set v_out_os, 96 +.set v_out_os_base, 97 +.set v_wei_iy, 98 +.set v_wei_ix, 99 +.set v_dtile_iy, 100 +.set v_dtile_ix, 101 +.set v_wei_os, 102 +.set v_wei_os_base, 103 +.set v_out_flag, 104 +.set v_in_flag, 105 +.set v_in_os, 106 +.set v_gtc_ic0, 63 +.set v_gtc_ic1, 62 +.set v_gtc_in0, 61 +.set v_gtc_in1, 60 +.set v_gtc_ib0, 59 +.set v_gtc_ib1, 58 +.set v_gtc_ik0, 57 +.set v_gtc_ik1, 56 +.set v_gtc_ie, 55 +.set v_gemm_in, 54 +.set v_gemm_im, 53 +.set v_in_in0, 52 +.set v_in_in1, 51 +.set v_in_ib0, 50 +.set v_in_ib1, 49 +.set v_in_ic0, 48 +.set v_in_ic1, 47 +.set v_in_ihi, 46 +.set v_in_iwi, 45 +.set v_in_dslice_ih, 44 +.set v_in_dslice_iw, 43 +.set v_tmp, 108 +.set v_end, 114 + +.text +.globl igemm_bwd_gtc_bt128x128x8_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x1x4_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16 +.p2align 8 +.type igemm_bwd_gtc_bt128x128x8_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x1x4_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16,@function +igemm_bwd_gtc_bt128x128x8_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x1x4_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx2 s[s_p_wei+0:s_p_wei+1], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx2 s[s_p_out+0:s_p_out+1], s[s_ka+0:s_ka+1], 0+k_p_out + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx8 s[s_dtile_ix+0:s_dtile_ix+7], s[s_ka+0:s_ka+1], 0+k_dtile_ix + s_load_dwordx4 s[s_dslice_x+0:s_dslice_x+3], s[s_ka+0:s_ka+1], 0+k_dslice_x + s_load_dword s[s_dslice_w_left], s[s_ka+0:s_ka+1], 0+k_dslice_w_left + v_mov_b32 v[v_tmp], v0 + ; output: K0xK1xExN0xN1xB0xB1: 1x1x1x1x4x1x1, slice 8x1x1x2x1x1x16 + v_and_b32 v[v_gtc_ib1], 15, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 4, v[v_tmp] + v_and_b32 v[v_gtc_in0], 1, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 1, v[v_tmp] + v_and_b32 v[v_gtc_ik0], 7, v[v_tmp] + + v_mov_b32 v[v_tmp], v0 + ; wei: K0xK1xExC0xC1: 1x1x1x1x4, slice 8x1x1x32x1 + v_and_b32 v[v_gtc_ic0], 31, v[v_tmp] + + s_mov_b32 s[s_p_in + 2], 0xffffffff + s_mov_b32 s[s_p_in + 3], 0x27000 + s_mov_b32 s[s_p_wei + 2], 0xffffffff + s_mov_b32 s[s_p_wei + 3], 0x27000 + s_mov_b32 s[s_p_out + 2], 0xffffffff + s_mov_b32 s[s_p_out + 3], 0x27000 + s_waitcnt lgkmcnt(0) + + ; calculate index + s_mul_i32 s[s_out_stride_k], s[s_ho], s[s_wo] + s_mul_i32 s[s_out_stride_n], s[s_k], s[s_out_stride_k] + s_mul_i32 s[s_in_stride_c], s[s_hi], s[s_wi] + s_mul_i32 s[s_in_stride_n], s[s_c], s[s_in_stride_c] + s_mul_i32 s[s_wei_stride_c], s[s_y], s[s_x] + s_mul_i32 s[s_wei_stride_k], s[s_c], s[s_wei_stride_c] + s_mul_i32 s[s_stride_dslice_hw], s[s_dslice_h], s[s_dslice_w] + s_mul_i32 s[s_stride_dslice_yx], s[s_dslice_y], s[s_dslice_x] + + ; N0xN1xB0xB1, gemm_m_per_block:128, gemm_n_per_block:128 + s_mul_i32 s[s_tmp], s[s_stride_dslice_hw], s[s_n] + s_lshr_b32 s[0], s[s_tmp], 7 + .v_u32_div_ss v_tmp+5, s_bx, 0, v_tmp, s_tmp + v_readfirstlane_b32 s[s_tmp+4], v[v_tmp+5] ; gemm_m, + s_mul_i32 s[s_tmp+2], s[s_tmp+4], s[0] + s_sub_i32 s[s_tmp+5], s[s_bx], s[s_tmp+2] ; gemm_n, cnt + s_lshl_b32 s[s_block_gtc_ic], s[s_tmp+4], 7 + s_lshr_b32 s[0], s[s_stride_dslice_hw], 4 ; B:16 per block, total num of B + .v_u32_div_ss v_tmp+5, s_tmp+5, 0, v_tmp, s_tmp + v_readfirstlane_b32 s[s_tmp], v[v_tmp+5] ; => N + s_mul_i32 s[s_tmp+2], s[s_tmp], s[0] + s_sub_i32 s[s_tmp+1], s[s_tmp+5], s[s_tmp+2] ; => B + s_lshl_b32 s[s_block_gtc_in], s[s_tmp], 3 + s_lshl_b32 s[s_block_gtc_ib], s[s_tmp+1], 4 + + v_lshlrev_b32 v[v_gtc_in1], 2, v[v_gtc_in0] + v_mov_b32 v[v_gtc_ik1], v[v_gtc_ik0] + + v_add_u32 v[v_tmp+5], s[s_block_gtc_ib], v[v_gtc_ib1] + ; calculate output transform, B -> dslice_h*dslice_w. always use ib1 + .v_u32_div_vs v_out_dslice_ih, v_tmp+5, s_dslice_w, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp], s[s_dslice_w], v[v_out_dslice_ih] + v_sub_u32 v[v_out_dslice_iw], v[v_tmp+5], v[v_tmp] + ; iHTildaLeft, iWTildaLeft + v_add_u32 v[v_out_dslice_ih], s[s_dslice_h_left], v[v_out_dslice_ih] + v_add_u32 v[v_out_dslice_iw], s[s_dslice_w_left], v[v_out_dslice_iw] + ; dslice_y,dslice_h -> oh, dslice_x,dslice_w -> ow + v_mov_b32 v[v_out_iho], v[v_out_dslice_ih] + v_mov_b32 v[v_out_iwo], v[v_out_dslice_iw] + v_mul_lo_u32 v[v_tmp], s[s_out_stride_k], v[v_gtc_ik1] + v_mul_lo_u32 v[v_tmp+1], s[s_out_stride_n], v[v_gtc_in1] + v_add_lshl_u32 v[v_out_os_base], v[v_tmp], v[v_tmp+1], 2 + ; n to staticly accumulate into base pointer + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_in], 2 + s_mul_i32 s[s_tmp], s[s_out_stride_n], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_out_stride_n], s[s_tmp+3] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], s[s_tmp+1] + + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 2 + + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_1x4_b32_v1 v_gld_b, s_p_out, v_out_os, 0, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + + v_lshlrev_b32 v[v_gtc_ic1], 2, v[v_gtc_ic0] + v_add_u32 v[v_tmp+5], s[s_block_gtc_ic], v[v_gtc_ic1] + v_mov_b32 v[v_dtile_iy], s[s_dtile_iy] + v_mov_b32 v[v_dtile_ix], s[s_dtile_ix] + v_mov_b32 v[v_wei_iy], s[s_dtile_iy] + v_mov_b32 v[v_wei_ix], s[s_dtile_ix] + + ; calculate wei offset + v_mul_lo_u32 v[v_tmp], s[s_wei_stride_c], v[v_tmp+5] + v_mul_lo_u32 v[v_tmp+1], s[s_wei_stride_k], v[v_gtc_ik1] + v_add_lshl_u32 v[v_wei_os_base], v[v_tmp], v[v_tmp+1], 2 + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + + s_lshl_b32 s[s_wei_stride_c], s[s_wei_stride_c], 2 + + ; load weight + .v_gld_1x4_b32_v1 v_gld_a, s_p_wei, v_wei_os, 0, s_wei_stride_c, s_tmp + + ; c thread mapping + ; -> MR x NR x ML1 x NL1 x ML0 x NL0 + ; cluster 1 x 1 x 4 x 4 x 4 x 4 + ; perthrd 2 x 2 x 1 x 1 x 4 x 4 + v_and_b32 v[v_tmp], 3, v0 + v_lshlrev_b32 v[v_tmp], 2, v[v_tmp] ; => iNL0 + v_lshrrev_b32 v[v_tmp+5], 2, v0 + v_and_b32 v[v_tmp+1], 3, v[v_tmp+5] + v_lshlrev_b32 v[v_tmp+1], 2, v[v_tmp+1] ; => iML0 + v_lshrrev_b32 v[v_tmp+5], 2, v[v_tmp+5] + v_and_b32 v[v_tmp+2], 3, v[v_tmp+5] ; => iNL1 + v_lshrrev_b32 v[v_tmp+5], 2, v[v_tmp+5] + v_and_b32 v[v_tmp+3], 3, v[v_tmp+5] ; => iML1 + v_lshl_or_b32 v[v_gemm_in], v[v_tmp+2], 4, v[v_tmp] ; in (without repeat) + v_lshl_or_b32 v[v_gemm_im], v[v_tmp+3], 4, v[v_tmp+1] ; im (without repeat) + + ; remapping, gemm_im => C0xC1:32x4, gemm_in => N0xB0xB1xN1:2x1x16x4 + v_and_b32 v[v_in_in1], 3, v[v_gemm_in] + v_lshrrev_b32 v[v_tmp], 2, v[v_gemm_in] + v_and_b32 v[v_in_ib1], 15, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 4, v[v_tmp] + v_mov_b32 v[v_in_ib0], 0 + v_and_b32 v[v_in_in0], 1, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 1, v[v_tmp] + v_and_b32 v[v_in_ic1], 3, v[v_gemm_im] + v_lshrrev_b32 v[v_tmp], 2, v[v_gemm_im] + v_and_b32 v[v_in_ic0], 31, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 5, v[v_tmp] + + v_lshl_or_b32 v[v_in_ib1], v[v_in_ib0], 4, v[v_in_ib1] + v_lshl_or_b32 v[v_in_in1], v[v_in_in0], 2, v[v_in_in1] + v_lshl_or_b32 v[v_in_ic1], v[v_in_ic0], 2, v[v_in_ic1] + v_add_u32 v[v_in_ib1], s[s_block_gtc_ib], v[v_in_ib1] + + .v_u32_div_vs v_in_dslice_ih, v_in_ib1, s_dslice_w, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp+1], s[s_dslice_w], v[v_in_dslice_ih] + v_sub_u32 v[v_in_dslice_iw], v[v_in_ib1], v[v_tmp+1] + v_add_u32 v[v_in_dslice_ih], s[s_dslice_h_left], v[v_in_dslice_ih] + v_add_u32 v[v_in_dslice_iw], s[s_dslice_w_left], v[v_in_dslice_iw] + + ; dslice_h,dslice_y -> hip, dslice_w,dslicw_x -> wip + s_mul_i32 s[s_tmp], s[s_dtile_iy], s[s_dilation_h] + v_mul_lo_u32 v[v_tmp], s[s_stride_h], v[v_in_dslice_ih] + v_add_u32 v[v_tmp], s[s_tmp], v[v_tmp] + s_mul_i32 s[s_tmp+1], s[s_dtile_ix], s[s_dilation_w] + v_mul_lo_u32 v[v_tmp+1], s[s_stride_w], v[v_in_dslice_iw] + v_add_u32 v[v_tmp+1], s[s_tmp+1], v[v_tmp+1] + ; v_tmp: hip, v_tmp+1: wip + + ; hip->h, wip->w + v_sub_i32 v[v_in_ihi], v[v_tmp], s[s_pad_h] + v_sub_i32 v[v_in_iwi], v[v_tmp+1], s[s_pad_w] + + .v_set_flag_hw v_in_flag, v_in_ihi, v_in_iwi, s_hi, s_wi + + ; input offset + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_in], 2 + s_mul_i32 s[s_tmp], s[s_in_stride_n], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_in_stride_n], s[s_tmp+3] + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], s[s_tmp+1] + + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_ic], 2 + s_mul_i32 s[s_tmp], s[s_in_stride_c], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_in_stride_c], s[s_tmp+3] + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], s[s_tmp+1] + v_mul_lo_u32 v[v_tmp], s[s_in_stride_n], v[v_in_in1] + v_mul_lo_u32 v[v_tmp+1], s[s_in_stride_c], v[v_in_ic1] + v_add_u32 v[v_tmp], v[v_tmp], v[v_tmp+1] + v_mul_lo_u32 v[v_tmp+1], s[s_wi], v[v_in_ihi] + v_add3_u32 v[v_in_os], v[v_tmp], v[v_tmp+1], v[v_in_iwi] + v_lshlrev_b32 v[v_in_os], 2, v[v_in_os] + + s_lshl_b32 s[s_in_stride_n], s[s_in_stride_n], 2 + s_lshl_b32 s[s_in_stride_nr], s[s_in_stride_n], 2 + + s_lshl_b32 s[s_in_stride_c], s[s_in_stride_c], 2 + s_lshl_b32 s[s_in_stride_cr], s[s_in_stride_c], 6 + + v_lshrrev_b32 v[v_gtc_in0], 2, v[v_gtc_in1] + ; LDS store, order out: K0xK1xExN0xB0xB1xN1: 1x1x1x1x1x1x4, 8x1x1x2x1x16x1 + v_lshlrev_b32 v[v_tmp], 2, v[v_gtc_ib1] + v_lshl_or_b32 v[v_tmp], v[v_gtc_in0], 6, v[v_tmp] + v_lshl_or_b32 v[v_tmp+1], v[v_gtc_ik1], 7, v[v_tmp] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_tmp+1] + + ; LDS store, order wei not shuffled, wei: K0xK1xExC0xC1: 1x1x1x1x4, 8x1x1x32x1 + v_lshl_or_b32 v[v_tmp], v[v_gtc_ik1], 7, v[v_gtc_ic1] + v_lshlrev_b32 v[v_sst_a_os], 2, v[v_tmp] + v_add_u32 v[v_sst_b_os], 4096, v[v_sst_b_os] + + ; LDS load + v_lshlrev_b32 v[v_sld_b_os], 2, v[v_gemm_in] + v_lshlrev_b32 v[v_sld_a_os], 2, v[v_gemm_im] + v_add_u32 v[v_sld_b_os], 4096, v[v_sld_b_os] + + s_mul_i32 s[s_knum], s[s_stride_dslice_yx], s[s_k] + s_lshl_b32 s[s_out_move_slice_stride_k], s[s_out_stride_k], 5 ; 8x + s_lshl_b32 s[s_wei_move_slice_stride_k], s[s_wei_stride_k], 5 ; 8x + s_mov_b32 s[s_move_slice_k_ik], 0 + s_mov_b32 s[s_move_slice_k_idsy], 0 + s_mov_b32 s[s_move_slice_k_idsx], 0 + + .v_clear_nc v_c, 64 + ; start FMA loop, 8x8 thread tile with 4x4 sub-tile + s_waitcnt vmcnt(4) + .v_sst_so0_1x4_b32_v4 v_gld_b, v_sst_b_os + + s_waitcnt vmcnt(0) + .v_sst_so0_1x4_b32_v4 v_gld_a, v_sst_a_os + + s_sub_i32 s[s_kitr], s[s_knum], 8 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_bwd_gtc_bt128x128x8_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x1x4_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16_end + + .s_bwd_gtc_move_slice_window_k_dsy_dsx s_move_slice_k_ik, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dslice_y, s_dslice_x, 8, 1, 1, v_out_os_base, v_wei_os_base, s_out_move_slice_stride_k, s_wei_move_slice_stride_k + .v_bwd_gtc_out_update_hw v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_dy, s_dtile_dx, s_tmp + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + .v_bwd_gtc_wei_update_yx v_wei_iy, v_wei_ix, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_y, s_dtile_x, v_dtile_iy, v_dtile_ix, s_tmp + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + v_xor_b32 v[v_sst_b_os], 0x2000, v[v_sst_b_os] ; switch double buffer b store + v_xor_b32 v[v_sst_a_os], 0x2000, v[v_sst_a_os] ; switch double buffer a store + s_waitcnt lgkmcnt(0) + s_barrier + + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_1x4_b32_v1 v_gld_b, s_p_out, v_out_os, 0, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + ; load weight + .v_gld_1x4_b32_v1 v_gld_a, s_p_wei, v_wei_os, 0, s_wei_stride_c, s_tmp + +L_igemm_bwd_gtc_bt128x128x8_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x1x4_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16_fma_body: + ; do fma accumulate with unroll 8 + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:256 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:256 + .itr_k = 0 + .rept 7 + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] offset:0+(.itr_k+1)*512 + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512 + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512+256 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:0+(.itr_k+1)*512+256 + .itr_k = .itr_k + 1 + .endr + + ; last unroll + v_xor_b32 v[v_sld_b_os], 8192, v[v_sld_b_os] ; switch double buffer b load + v_xor_b32 v[v_sld_a_os], 8192, v[v_sld_a_os] ; switch double buffer a load + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + s_waitcnt vmcnt(4) + .v_sst_so0_1x4_b32_v4 v_gld_b, v_sst_b_os + s_waitcnt vmcnt(0) + .v_sst_so0_1x4_b32_v4 v_gld_a, v_sst_a_os + s_sub_i32 s[s_kitr], s[s_kitr], 8 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_bwd_gtc_bt128x128x8_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x1x4_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16_fma_finishing + .s_bwd_gtc_move_slice_window_k_dsy_dsx s_move_slice_k_ik, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dslice_y, s_dslice_x, 8, 1, 1, v_out_os_base, v_wei_os_base, s_out_move_slice_stride_k, s_wei_move_slice_stride_k + .v_bwd_gtc_out_update_hw v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_dy, s_dtile_dx, s_tmp + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + .v_bwd_gtc_wei_update_yx v_wei_iy, v_wei_ix, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_y, s_dtile_x, v_dtile_iy, v_dtile_ix, s_tmp + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + v_xor_b32 v[v_sst_b_os], 8192, v[v_sst_b_os] ; switch double buffer b store + v_xor_b32 v[v_sst_a_os], 8192, v[v_sst_a_os] ; switch double buffer a store + s_waitcnt lgkmcnt(0) + s_barrier + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_1x4_b32_v1 v_gld_b, s_p_out, v_out_os, 0, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + ; load weight + .v_gld_1x4_b32_v1 v_gld_a, s_p_wei, v_wei_os, 0, s_wei_stride_c, s_tmp + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + s_branch L_igemm_bwd_gtc_bt128x128x8_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x1x4_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16_fma_body +L_igemm_bwd_gtc_bt128x128x8_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x1x4_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16_fma_finishing: + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 +L_igemm_bwd_gtc_bt128x128x8_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x1x4_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16_end: + s_waitcnt lgkmcnt(0) + s_barrier + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:256 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:256 + .itr_k = 0 + .rept 7 + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + ds_read_b128 v[v_a:v_a+3], v[v_sld_a_os] offset:0+(.itr_k+1)*512 + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512 + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512+256 + ds_read_b128 v[v_a+4:v_a+4+3], v[v_sld_a_os] offset:0+(.itr_k+1)*512+256 + .itr_k = .itr_k + 1 + .endr + + ; last unroll + s_waitcnt lgkmcnt(2) + .v_fma_4x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_4x4_s8 v_c+4,v_a,v_b+4 + s_waitcnt lgkmcnt(0) + .v_fma_4x4_s8 v_c+32,v_a+4,v_b + .v_fma_4x4_s8 v_c+36,v_a+4,v_b+4 + + v_cmpx_eq_u32 vcc, 1, v[v_in_flag] + s_cbranch_execz L_igemm_bwd_gtc_bt128x128x8_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x1x4_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16_out + s_mov_b32 s[s_tmp], 0 + s_mov_b32 s[s_tmp+1], 0 + s_mov_b32 s[s_tmp+2], 0 + s_mov_b32 s[s_tmp+3], 0 + .v_write4d_strided v_c,s_p_in,v_in_os,s_in_stride_n,s_in_stride_nr,s_in_stride_c,s_in_stride_cr,s_tmp,4,2,4,2 +L_igemm_bwd_gtc_bt128x128x8_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x1x4_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16_out: + s_endpgm +.rodata +.p2align 6 +.amdhsa_kernel igemm_bwd_gtc_bt128x128x8_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x1x4_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16 + .amdhsa_group_segment_fixed_size 16384 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 114 + .amdhsa_next_free_sgpr 78 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 +.end_amdhsa_kernel + +;---------------------------------------------------------- +; starting of kernel igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16 +; gemm_m_per_block : 64 +; gemm_n_per_block : 128 +; gemm_k_per_block : 8 +; gemm_m_per_thread : 2 +; gemm_m_level0_cluster : 4 +; gemm_m_level1_cluster : 4 +; gemm_n_per_thread : 4 +; gemm_n_level0_cluster : 4 +; gemm_n_level1_cluster : 4 +; tensor_a_thread_lengths : [1, 1, 1, 1, 2] +; tensor_a_cluster_lengths : [8, 1, 1, 32, 1] +; tensor_b_thread_lengths : [1, 1, 1, 1, 4, 1, 1] +; tensor_b_cluster_lengths : [8, 1, 1, 2, 1, 1, 16] +; direction : bwd +; precision : fp32 +; opt_1x1 : 0 +; +; block_size : 256 +; thread_tile : 4x8 +; lds_total : 16384 +; +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_dtile_iy, 84 +.set k_dtile_ix, 88 +.set k_dtile_dy, 92 +.set k_dtile_dx, 96 +.set k_dtile_y, 100 +.set k_dtile_x, 104 +.set k_dtile_h, 108 +.set k_dtile_w, 112 +.set k_dslice_y, 116 +.set k_dslice_x, 120 +.set k_dslice_h, 124 +.set k_dslice_w, 128 +.set k_dslice_h_left, 132 +.set k_dslice_w_left, 136 +.set k_pack0, 140 +.set k_end, 144 + +.set s_ka, 0 +.set s_bx, 2 +.set s_p_in, 4 +.set s_p_wei, 8 +.set s_p_out, 12 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_dtile_iy, 31 +.set s_dtile_ix, 32 +.set s_dtile_dy, 33 +.set s_dtile_dx, 34 +.set s_dtile_y, 35 +.set s_dtile_x, 36 +.set s_dtile_h, 37 +.set s_dtile_w, 38 +.set s_dslice_y, 39 +.set s_dslice_x, 40 +.set s_dslice_h, 41 +.set s_dslice_w, 42 +.set s_dslice_h_left, 43 +.set s_dslice_w_left, 44 +.set s_out_stride_k, 45 +.set s_out_stride_k0, 46 +.set s_out_stride_n, 47 +.set s_out_stride_n0, 48 +.set s_out_stride_b0, 49 +.set s_out_move_slice_stride_k, 50 +.set s_in_stride_c, 51 +.set s_in_stride_cr, 52 +.set s_in_stride_n, 53 +.set s_in_stride_nr, 54 +.set s_wei_stride_c, 55 +.set s_wei_stride_c0, 56 +.set s_wei_stride_k, 57 +.set s_wei_stride_k0, 58 +.set s_wei_move_slice_stride_k, 59 +.set s_stride_dslice_hw, 60 +.set s_stride_dslice_yx, 61 +.set s_block_gtc_ib, 62 +.set s_block_gtc_ic, 63 +.set s_block_gtc_in, 64 +.set s_knum, 19 +.set s_move_slice_k_idsy, 0 +.set s_move_slice_k_idsx, 1 +.set s_move_slice_k_ik, 2 +.set s_kitr, 3 +.set s_tmp, 66 +.set s_end, 72 + +.set v_c, 0 +.set v_a, 32 +.set v_b, 36 +.set v_gld_a, 44 +.set v_gld_b, 46 +.set v_sst_a_os, 50 +.set v_sst_b_os, 51 +.set v_sld_a_os, 52 +.set v_sld_b_os, 53 +.set v_out_iho, 54 +.set v_out_iwo, 55 +.set v_out_dslice_ih, 56 +.set v_out_dslice_iw, 57 +.set v_out_os, 58 +.set v_out_os_base, 59 +.set v_wei_iy, 60 +.set v_wei_ix, 61 +.set v_dtile_iy, 62 +.set v_dtile_ix, 63 +.set v_wei_os, 64 +.set v_wei_os_base, 65 +.set v_out_flag, 66 +.set v_in_flag, 67 +.set v_in_os, 68 +.set v_gtc_ic0, 31 +.set v_gtc_ic1, 30 +.set v_gtc_in0, 29 +.set v_gtc_in1, 28 +.set v_gtc_ib0, 27 +.set v_gtc_ib1, 26 +.set v_gtc_ik0, 25 +.set v_gtc_ik1, 24 +.set v_gtc_ie, 23 +.set v_gemm_in, 22 +.set v_gemm_im, 21 +.set v_in_in0, 20 +.set v_in_in1, 19 +.set v_in_ib0, 18 +.set v_in_ib1, 17 +.set v_in_ic0, 16 +.set v_in_ic1, 15 +.set v_in_ihi, 14 +.set v_in_iwi, 13 +.set v_in_dslice_ih, 12 +.set v_in_dslice_iw, 11 +.set v_tmp, 70 +.set v_end, 76 + +.text +.globl igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16 +.p2align 8 +.type igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16,@function +igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx2 s[s_p_wei+0:s_p_wei+1], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx2 s[s_p_out+0:s_p_out+1], s[s_ka+0:s_ka+1], 0+k_p_out + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx8 s[s_dtile_ix+0:s_dtile_ix+7], s[s_ka+0:s_ka+1], 0+k_dtile_ix + s_load_dwordx4 s[s_dslice_x+0:s_dslice_x+3], s[s_ka+0:s_ka+1], 0+k_dslice_x + s_load_dword s[s_dslice_w_left], s[s_ka+0:s_ka+1], 0+k_dslice_w_left + v_mov_b32 v[v_tmp], v0 + ; output: K0xK1xExN0xN1xB0xB1: 1x1x1x1x4x1x1, slice 8x1x1x2x1x1x16 + v_and_b32 v[v_gtc_ib1], 15, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 4, v[v_tmp] + v_and_b32 v[v_gtc_in0], 1, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 1, v[v_tmp] + v_and_b32 v[v_gtc_ik0], 7, v[v_tmp] + + v_mov_b32 v[v_tmp], v0 + ; wei: K0xK1xExC0xC1: 1x1x1x1x2, slice 8x1x1x32x1 + v_and_b32 v[v_gtc_ic0], 31, v[v_tmp] + + s_mov_b32 s[s_p_in + 2], 0xffffffff + s_mov_b32 s[s_p_in + 3], 0x27000 + s_mov_b32 s[s_p_wei + 2], 0xffffffff + s_mov_b32 s[s_p_wei + 3], 0x27000 + s_mov_b32 s[s_p_out + 2], 0xffffffff + s_mov_b32 s[s_p_out + 3], 0x27000 + s_waitcnt lgkmcnt(0) + + ; calculate index + s_mul_i32 s[s_out_stride_k], s[s_ho], s[s_wo] + s_mul_i32 s[s_out_stride_n], s[s_k], s[s_out_stride_k] + s_mul_i32 s[s_in_stride_c], s[s_hi], s[s_wi] + s_mul_i32 s[s_in_stride_n], s[s_c], s[s_in_stride_c] + s_mul_i32 s[s_wei_stride_c], s[s_y], s[s_x] + s_mul_i32 s[s_wei_stride_k], s[s_c], s[s_wei_stride_c] + s_mul_i32 s[s_stride_dslice_hw], s[s_dslice_h], s[s_dslice_w] + s_mul_i32 s[s_stride_dslice_yx], s[s_dslice_y], s[s_dslice_x] + + ; N0xN1xB0xB1, gemm_m_per_block:64, gemm_n_per_block:128 + s_mul_i32 s[s_tmp], s[s_stride_dslice_hw], s[s_n] + s_lshr_b32 s[0], s[s_tmp], 7 + .v_u32_div_ss v_tmp+5, s_bx, 0, v_tmp, s_tmp + v_readfirstlane_b32 s[s_tmp+4], v[v_tmp+5] ; gemm_m, + s_mul_i32 s[s_tmp+2], s[s_tmp+4], s[0] + s_sub_i32 s[s_tmp+5], s[s_bx], s[s_tmp+2] ; gemm_n, cnt + s_lshl_b32 s[s_block_gtc_ic], s[s_tmp+4], 6 + s_lshr_b32 s[0], s[s_stride_dslice_hw], 4 ; B:16 per block, total num of B + .v_u32_div_ss v_tmp+5, s_tmp+5, 0, v_tmp, s_tmp + v_readfirstlane_b32 s[s_tmp], v[v_tmp+5] ; => N + s_mul_i32 s[s_tmp+2], s[s_tmp], s[0] + s_sub_i32 s[s_tmp+1], s[s_tmp+5], s[s_tmp+2] ; => B + s_lshl_b32 s[s_block_gtc_in], s[s_tmp], 3 + s_lshl_b32 s[s_block_gtc_ib], s[s_tmp+1], 4 + + v_lshlrev_b32 v[v_gtc_in1], 2, v[v_gtc_in0] + v_mov_b32 v[v_gtc_ik1], v[v_gtc_ik0] + + v_add_u32 v[v_tmp+5], s[s_block_gtc_ib], v[v_gtc_ib1] + ; calculate output transform, B -> dslice_h*dslice_w. always use ib1 + .v_u32_div_vs v_out_dslice_ih, v_tmp+5, s_dslice_w, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp], s[s_dslice_w], v[v_out_dslice_ih] + v_sub_u32 v[v_out_dslice_iw], v[v_tmp+5], v[v_tmp] + ; iHTildaLeft, iWTildaLeft + v_add_u32 v[v_out_dslice_ih], s[s_dslice_h_left], v[v_out_dslice_ih] + v_add_u32 v[v_out_dslice_iw], s[s_dslice_w_left], v[v_out_dslice_iw] + ; dslice_y,dslice_h -> oh, dslice_x,dslice_w -> ow + v_mov_b32 v[v_out_iho], v[v_out_dslice_ih] + v_mov_b32 v[v_out_iwo], v[v_out_dslice_iw] + v_mul_lo_u32 v[v_tmp], s[s_out_stride_k], v[v_gtc_ik1] + v_mul_lo_u32 v[v_tmp+1], s[s_out_stride_n], v[v_gtc_in1] + v_add_lshl_u32 v[v_out_os_base], v[v_tmp], v[v_tmp+1], 2 + ; n to staticly accumulate into base pointer + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_in], 2 + s_mul_i32 s[s_tmp], s[s_out_stride_n], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_out_stride_n], s[s_tmp+3] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], s[s_tmp+1] + + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 2 + + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_1x4_b32_v1 v_gld_b, s_p_out, v_out_os, 0, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + + v_lshlrev_b32 v[v_gtc_ic1], 1, v[v_gtc_ic0] + v_add_u32 v[v_tmp+5], s[s_block_gtc_ic], v[v_gtc_ic1] + v_mov_b32 v[v_dtile_iy], s[s_dtile_iy] + v_mov_b32 v[v_dtile_ix], s[s_dtile_ix] + v_mov_b32 v[v_wei_iy], s[s_dtile_iy] + v_mov_b32 v[v_wei_ix], s[s_dtile_ix] + + ; calculate wei offset + v_mul_lo_u32 v[v_tmp], s[s_wei_stride_c], v[v_tmp+5] + v_mul_lo_u32 v[v_tmp+1], s[s_wei_stride_k], v[v_gtc_ik1] + v_add_lshl_u32 v[v_wei_os_base], v[v_tmp], v[v_tmp+1], 2 + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + + s_lshl_b32 s[s_wei_stride_c], s[s_wei_stride_c], 2 + + ; load weight + .v_gld_1x2_b32_v1 v_gld_a, s_p_wei, v_wei_os, 0, s_wei_stride_c, s_tmp + + ; c thread mapping + ; -> MR x NR x ML1 x NL1 x ML0 x NL0 + ; cluster 1 x 1 x 4 x 4 x 4 x 4 + ; perthrd 2 x 2 x 1 x 1 x 2 x 4 + v_and_b32 v[v_tmp], 3, v0 + v_lshlrev_b32 v[v_tmp], 2, v[v_tmp] ; => iNL0 + v_lshrrev_b32 v[v_tmp+5], 2, v0 + v_and_b32 v[v_tmp+1], 3, v[v_tmp+5] + v_lshlrev_b32 v[v_tmp+1], 1, v[v_tmp+1] ; => iML0 + v_lshrrev_b32 v[v_tmp+5], 2, v[v_tmp+5] + v_and_b32 v[v_tmp+2], 3, v[v_tmp+5] ; => iNL1 + v_lshrrev_b32 v[v_tmp+5], 2, v[v_tmp+5] + v_and_b32 v[v_tmp+3], 3, v[v_tmp+5] ; => iML1 + v_lshl_or_b32 v[v_gemm_in], v[v_tmp+2], 4, v[v_tmp] ; in (without repeat) + v_lshl_or_b32 v[v_gemm_im], v[v_tmp+3], 3, v[v_tmp+1] ; im (without repeat) + + ; remapping, gemm_im => C0xC1:32x2, gemm_in => N0xB0xB1xN1:2x1x16x4 + v_and_b32 v[v_in_in1], 3, v[v_gemm_in] + v_lshrrev_b32 v[v_tmp], 2, v[v_gemm_in] + v_and_b32 v[v_in_ib1], 15, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 4, v[v_tmp] + v_mov_b32 v[v_in_ib0], 0 + v_and_b32 v[v_in_in0], 1, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 1, v[v_tmp] + v_and_b32 v[v_in_ic1], 1, v[v_gemm_im] + v_lshrrev_b32 v[v_tmp], 1, v[v_gemm_im] + v_and_b32 v[v_in_ic0], 31, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 5, v[v_tmp] + + v_lshl_or_b32 v[v_in_ib1], v[v_in_ib0], 4, v[v_in_ib1] + v_lshl_or_b32 v[v_in_in1], v[v_in_in0], 2, v[v_in_in1] + v_lshl_or_b32 v[v_in_ic1], v[v_in_ic0], 1, v[v_in_ic1] + v_add_u32 v[v_in_ib1], s[s_block_gtc_ib], v[v_in_ib1] + + .v_u32_div_vs v_in_dslice_ih, v_in_ib1, s_dslice_w, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp+1], s[s_dslice_w], v[v_in_dslice_ih] + v_sub_u32 v[v_in_dslice_iw], v[v_in_ib1], v[v_tmp+1] + v_add_u32 v[v_in_dslice_ih], s[s_dslice_h_left], v[v_in_dslice_ih] + v_add_u32 v[v_in_dslice_iw], s[s_dslice_w_left], v[v_in_dslice_iw] + + ; dslice_h,dslice_y -> hip, dslice_w,dslicw_x -> wip + s_mul_i32 s[s_tmp], s[s_dtile_iy], s[s_dilation_h] + v_mul_lo_u32 v[v_tmp], s[s_stride_h], v[v_in_dslice_ih] + v_add_u32 v[v_tmp], s[s_tmp], v[v_tmp] + s_mul_i32 s[s_tmp+1], s[s_dtile_ix], s[s_dilation_w] + v_mul_lo_u32 v[v_tmp+1], s[s_stride_w], v[v_in_dslice_iw] + v_add_u32 v[v_tmp+1], s[s_tmp+1], v[v_tmp+1] + ; v_tmp: hip, v_tmp+1: wip + + ; hip->h, wip->w + v_sub_i32 v[v_in_ihi], v[v_tmp], s[s_pad_h] + v_sub_i32 v[v_in_iwi], v[v_tmp+1], s[s_pad_w] + + .v_set_flag_hw v_in_flag, v_in_ihi, v_in_iwi, s_hi, s_wi + + ; input offset + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_in], 2 + s_mul_i32 s[s_tmp], s[s_in_stride_n], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_in_stride_n], s[s_tmp+3] + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], s[s_tmp+1] + + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_ic], 2 + s_mul_i32 s[s_tmp], s[s_in_stride_c], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_in_stride_c], s[s_tmp+3] + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], s[s_tmp+1] + v_mul_lo_u32 v[v_tmp], s[s_in_stride_n], v[v_in_in1] + v_mul_lo_u32 v[v_tmp+1], s[s_in_stride_c], v[v_in_ic1] + v_add_u32 v[v_tmp], v[v_tmp], v[v_tmp+1] + v_mul_lo_u32 v[v_tmp+1], s[s_wi], v[v_in_ihi] + v_add3_u32 v[v_in_os], v[v_tmp], v[v_tmp+1], v[v_in_iwi] + v_lshlrev_b32 v[v_in_os], 2, v[v_in_os] + + s_lshl_b32 s[s_in_stride_n], s[s_in_stride_n], 2 + s_lshl_b32 s[s_in_stride_nr], s[s_in_stride_n], 2 + + s_lshl_b32 s[s_in_stride_c], s[s_in_stride_c], 2 + s_lshl_b32 s[s_in_stride_cr], s[s_in_stride_c], 5 + + v_lshrrev_b32 v[v_gtc_in0], 2, v[v_gtc_in1] + ; LDS store, order out: K0xK1xExN0xB0xB1xN1: 1x1x1x1x1x1x4, 8x1x1x2x1x16x1 + v_lshlrev_b32 v[v_tmp], 2, v[v_gtc_ib1] + v_lshl_or_b32 v[v_tmp], v[v_gtc_in0], 6, v[v_tmp] + v_lshl_or_b32 v[v_tmp+1], v[v_gtc_ik1], 7, v[v_tmp] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_tmp+1] + + ; LDS store, order wei not shuffled, wei: K0xK1xExC0xC1: 1x1x1x1x2, 8x1x1x32x1 + v_lshl_or_b32 v[v_tmp], v[v_gtc_ik1], 6, v[v_gtc_ic1] + v_lshlrev_b32 v[v_sst_a_os], 2, v[v_tmp] + v_add_u32 v[v_sst_b_os], 2048, v[v_sst_b_os] + + ; LDS load + v_lshlrev_b32 v[v_sld_b_os], 2, v[v_gemm_in] + v_lshlrev_b32 v[v_sld_a_os], 2, v[v_gemm_im] + v_add_u32 v[v_sld_b_os], 2048, v[v_sld_b_os] + + s_mul_i32 s[s_knum], s[s_stride_dslice_yx], s[s_k] + s_lshl_b32 s[s_out_move_slice_stride_k], s[s_out_stride_k], 5 ; 8x + s_lshl_b32 s[s_wei_move_slice_stride_k], s[s_wei_stride_k], 5 ; 8x + s_mov_b32 s[s_move_slice_k_ik], 0 + s_mov_b32 s[s_move_slice_k_idsy], 0 + s_mov_b32 s[s_move_slice_k_idsx], 0 + + .v_clear_nc v_c, 32 + ; start FMA loop, 4x8 thread tile with 2x4 sub-tile + s_waitcnt vmcnt(2) + .v_sst_so0_1x4_b32_v4 v_gld_b, v_sst_b_os + + s_waitcnt vmcnt(0) + .v_sst_so0_1x2_b32_v2 v_gld_a, v_sst_a_os + + s_sub_i32 s[s_kitr], s[s_knum], 8 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16_end + + .s_bwd_gtc_move_slice_window_k_dsy_dsx s_move_slice_k_ik, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dslice_y, s_dslice_x, 8, 1, 1, v_out_os_base, v_wei_os_base, s_out_move_slice_stride_k, s_wei_move_slice_stride_k + .v_bwd_gtc_out_update_hw v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_dy, s_dtile_dx, s_tmp + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + .v_bwd_gtc_wei_update_yx v_wei_iy, v_wei_ix, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_y, s_dtile_x, v_dtile_iy, v_dtile_ix, s_tmp + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + v_xor_b32 v[v_sst_b_os], 0x2000, v[v_sst_b_os] ; switch double buffer b store + v_xor_b32 v[v_sst_a_os], 0x2000, v[v_sst_a_os] ; switch double buffer a store + s_waitcnt lgkmcnt(0) + s_barrier + + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_1x4_b32_v1 v_gld_b, s_p_out, v_out_os, 0, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + ; load weight + .v_gld_1x2_b32_v1 v_gld_a, s_p_wei, v_wei_os, 0, s_wei_stride_c, s_tmp + +L_igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16_fma_body: + ; do fma accumulate with unroll 8 + ds_read_b64 v[v_a:v_a+1], v[v_sld_a_os] + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:256 + ds_read_b64 v[v_a+2:v_a+2+1], v[v_sld_a_os] offset:128 + .itr_k = 0 + .rept 7 + s_waitcnt lgkmcnt(2) + .v_fma_2x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_2x4_s8 v_c+4,v_a,v_b+4 + ds_read_b64 v[v_a:v_a+1], v[v_sld_a_os] offset:0+(.itr_k+1)*256 + s_waitcnt lgkmcnt(1) + .v_fma_2x4_s8 v_c+16,v_a+2,v_b + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512 + .v_fma_2x4_s8 v_c+20,v_a+2,v_b+4 + + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512+256 + ds_read_b64 v[v_a+2:v_a+2+1], v[v_sld_a_os] offset:0+(.itr_k+1)*256+128 + .itr_k = .itr_k + 1 + .endr + + ; last unroll + v_xor_b32 v[v_sld_b_os], 8192, v[v_sld_b_os] ; switch double buffer b load + v_xor_b32 v[v_sld_a_os], 8192, v[v_sld_a_os] ; switch double buffer a load + s_waitcnt lgkmcnt(2) + .v_fma_2x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_2x4_s8 v_c+4,v_a,v_b+4 + s_waitcnt vmcnt(2) + .v_sst_so0_1x4_b32_v4 v_gld_b, v_sst_b_os + s_waitcnt vmcnt(0) + .v_sst_so0_1x2_b32_v2 v_gld_a, v_sst_a_os + s_sub_i32 s[s_kitr], s[s_kitr], 8 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16_fma_finishing + .s_bwd_gtc_move_slice_window_k_dsy_dsx s_move_slice_k_ik, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dslice_y, s_dslice_x, 8, 1, 1, v_out_os_base, v_wei_os_base, s_out_move_slice_stride_k, s_wei_move_slice_stride_k + .v_bwd_gtc_out_update_hw v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_dy, s_dtile_dx, s_tmp + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + .v_bwd_gtc_wei_update_yx v_wei_iy, v_wei_ix, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_y, s_dtile_x, v_dtile_iy, v_dtile_ix, s_tmp + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + s_waitcnt lgkmcnt(2) + .v_fma_2x4_s8 v_c+16,v_a+2,v_b + v_xor_b32 v[v_sst_b_os], 8192, v[v_sst_b_os] ; switch double buffer b store + v_xor_b32 v[v_sst_a_os], 8192, v[v_sst_a_os] ; switch double buffer a store + s_waitcnt lgkmcnt(0) + s_barrier + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_1x4_b32_v1 v_gld_b, s_p_out, v_out_os, 0, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + ; load weight + .v_gld_1x2_b32_v1 v_gld_a, s_p_wei, v_wei_os, 0, s_wei_stride_c, s_tmp + .v_fma_2x4_s8 v_c+20,v_a+2,v_b+4 + + s_branch L_igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16_fma_body +L_igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16_fma_finishing: + s_waitcnt lgkmcnt(2) + .v_fma_2x4_s8 v_c+16,v_a+2,v_b + .v_fma_2x4_s8 v_c+20,v_a+2,v_b+4 +L_igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16_end: + s_waitcnt lgkmcnt(0) + s_barrier + ds_read_b64 v[v_a:v_a+1], v[v_sld_a_os] + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:256 + ds_read_b64 v[v_a+2:v_a+2+1], v[v_sld_a_os] offset:128 + .itr_k = 0 + .rept 7 + s_waitcnt lgkmcnt(2) + .v_fma_2x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_2x4_s8 v_c+4,v_a,v_b+4 + ds_read_b64 v[v_a:v_a+1], v[v_sld_a_os] offset:0+(.itr_k+1)*256 + s_waitcnt lgkmcnt(1) + .v_fma_2x4_s8 v_c+16,v_a+2,v_b + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512 + .v_fma_2x4_s8 v_c+20,v_a+2,v_b+4 + + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512+256 + ds_read_b64 v[v_a+2:v_a+2+1], v[v_sld_a_os] offset:0+(.itr_k+1)*256+128 + .itr_k = .itr_k + 1 + .endr + + ; last unroll + s_waitcnt lgkmcnt(2) + .v_fma_2x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_2x4_s8 v_c+4,v_a,v_b+4 + s_waitcnt lgkmcnt(0) + .v_fma_2x4_s8 v_c+16,v_a+2,v_b + .v_fma_2x4_s8 v_c+20,v_a+2,v_b+4 + + v_cmpx_eq_u32 vcc, 1, v[v_in_flag] + s_cbranch_execz L_igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16_out + s_mov_b32 s[s_tmp], 0 + s_mov_b32 s[s_tmp+1], 0 + s_mov_b32 s[s_tmp+2], 0 + s_mov_b32 s[s_tmp+3], 0 + .v_write4d_strided v_c,s_p_in,v_in_os,s_in_stride_n,s_in_stride_nr,s_in_stride_c,s_in_stride_cr,s_tmp,4,2,2,2 +L_igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16_out: + s_endpgm +.rodata +.p2align 6 +.amdhsa_kernel igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16 + .amdhsa_group_segment_fixed_size 16384 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 76 + .amdhsa_next_free_sgpr 78 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 +.end_amdhsa_kernel + +;---------------------------------------------------------- +; starting of kernel igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x1x1x1x32 +; gemm_m_per_block : 64 +; gemm_n_per_block : 128 +; gemm_k_per_block : 8 +; gemm_m_per_thread : 2 +; gemm_m_level0_cluster : 4 +; gemm_m_level1_cluster : 4 +; gemm_n_per_thread : 4 +; gemm_n_level0_cluster : 4 +; gemm_n_level1_cluster : 4 +; tensor_a_thread_lengths : [1, 1, 1, 1, 2] +; tensor_a_cluster_lengths : [8, 1, 1, 32, 1] +; tensor_b_thread_lengths : [1, 1, 1, 1, 4, 1, 1] +; tensor_b_cluster_lengths : [8, 1, 1, 1, 1, 1, 32] +; direction : bwd +; precision : fp32 +; opt_1x1 : 0 +; +; block_size : 256 +; thread_tile : 4x8 +; lds_total : 16384 +; +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_dtile_iy, 84 +.set k_dtile_ix, 88 +.set k_dtile_dy, 92 +.set k_dtile_dx, 96 +.set k_dtile_y, 100 +.set k_dtile_x, 104 +.set k_dtile_h, 108 +.set k_dtile_w, 112 +.set k_dslice_y, 116 +.set k_dslice_x, 120 +.set k_dslice_h, 124 +.set k_dslice_w, 128 +.set k_dslice_h_left, 132 +.set k_dslice_w_left, 136 +.set k_pack0, 140 +.set k_end, 144 + +.set s_ka, 0 +.set s_bx, 2 +.set s_p_in, 4 +.set s_p_wei, 8 +.set s_p_out, 12 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_dtile_iy, 31 +.set s_dtile_ix, 32 +.set s_dtile_dy, 33 +.set s_dtile_dx, 34 +.set s_dtile_y, 35 +.set s_dtile_x, 36 +.set s_dtile_h, 37 +.set s_dtile_w, 38 +.set s_dslice_y, 39 +.set s_dslice_x, 40 +.set s_dslice_h, 41 +.set s_dslice_w, 42 +.set s_dslice_h_left, 43 +.set s_dslice_w_left, 44 +.set s_out_stride_k, 45 +.set s_out_stride_k0, 46 +.set s_out_stride_n, 47 +.set s_out_stride_n0, 48 +.set s_out_stride_b0, 49 +.set s_out_move_slice_stride_k, 50 +.set s_in_stride_c, 51 +.set s_in_stride_cr, 52 +.set s_in_stride_n, 53 +.set s_in_stride_nr, 54 +.set s_wei_stride_c, 55 +.set s_wei_stride_c0, 56 +.set s_wei_stride_k, 57 +.set s_wei_stride_k0, 58 +.set s_wei_move_slice_stride_k, 59 +.set s_stride_dslice_hw, 60 +.set s_stride_dslice_yx, 61 +.set s_block_gtc_ib, 62 +.set s_block_gtc_ic, 63 +.set s_block_gtc_in, 64 +.set s_knum, 19 +.set s_move_slice_k_idsy, 0 +.set s_move_slice_k_idsx, 1 +.set s_move_slice_k_ik, 2 +.set s_kitr, 3 +.set s_tmp, 66 +.set s_end, 72 + +.set v_c, 0 +.set v_a, 32 +.set v_b, 36 +.set v_gld_a, 44 +.set v_gld_b, 46 +.set v_sst_a_os, 50 +.set v_sst_b_os, 51 +.set v_sld_a_os, 52 +.set v_sld_b_os, 53 +.set v_out_iho, 54 +.set v_out_iwo, 55 +.set v_out_dslice_ih, 56 +.set v_out_dslice_iw, 57 +.set v_out_os, 58 +.set v_out_os_base, 59 +.set v_wei_iy, 60 +.set v_wei_ix, 61 +.set v_dtile_iy, 62 +.set v_dtile_ix, 63 +.set v_wei_os, 64 +.set v_wei_os_base, 65 +.set v_out_flag, 66 +.set v_in_flag, 67 +.set v_in_os, 68 +.set v_gtc_ic0, 31 +.set v_gtc_ic1, 30 +.set v_gtc_in0, 29 +.set v_gtc_in1, 28 +.set v_gtc_ib0, 27 +.set v_gtc_ib1, 26 +.set v_gtc_ik0, 25 +.set v_gtc_ik1, 24 +.set v_gtc_ie, 23 +.set v_gemm_in, 22 +.set v_gemm_im, 21 +.set v_in_in0, 20 +.set v_in_in1, 19 +.set v_in_ib0, 18 +.set v_in_ib1, 17 +.set v_in_ic0, 16 +.set v_in_ic1, 15 +.set v_in_ihi, 14 +.set v_in_iwi, 13 +.set v_in_dslice_ih, 12 +.set v_in_dslice_iw, 11 +.set v_tmp, 70 +.set v_end, 76 + +.text +.globl igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x1x1x1x32 +.p2align 8 +.type igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x1x1x1x32,@function +igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x1x1x1x32: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx2 s[s_p_wei+0:s_p_wei+1], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx2 s[s_p_out+0:s_p_out+1], s[s_ka+0:s_ka+1], 0+k_p_out + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx8 s[s_dtile_ix+0:s_dtile_ix+7], s[s_ka+0:s_ka+1], 0+k_dtile_ix + s_load_dwordx4 s[s_dslice_x+0:s_dslice_x+3], s[s_ka+0:s_ka+1], 0+k_dslice_x + s_load_dword s[s_dslice_w_left], s[s_ka+0:s_ka+1], 0+k_dslice_w_left + v_mov_b32 v[v_tmp], v0 + ; output: K0xK1xExN0xN1xB0xB1: 1x1x1x1x4x1x1, slice 8x1x1x1x1x1x32 + v_and_b32 v[v_gtc_ib1], 31, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 5, v[v_tmp] + v_and_b32 v[v_gtc_ik0], 7, v[v_tmp] + + v_mov_b32 v[v_tmp], v0 + ; wei: K0xK1xExC0xC1: 1x1x1x1x2, slice 8x1x1x32x1 + v_and_b32 v[v_gtc_ic0], 31, v[v_tmp] + + s_mov_b32 s[s_p_in + 2], 0xffffffff + s_mov_b32 s[s_p_in + 3], 0x27000 + s_mov_b32 s[s_p_wei + 2], 0xffffffff + s_mov_b32 s[s_p_wei + 3], 0x27000 + s_mov_b32 s[s_p_out + 2], 0xffffffff + s_mov_b32 s[s_p_out + 3], 0x27000 + s_waitcnt lgkmcnt(0) + + ; calculate index + s_mul_i32 s[s_out_stride_k], s[s_ho], s[s_wo] + s_mul_i32 s[s_out_stride_n], s[s_k], s[s_out_stride_k] + s_mul_i32 s[s_in_stride_c], s[s_hi], s[s_wi] + s_mul_i32 s[s_in_stride_n], s[s_c], s[s_in_stride_c] + s_mul_i32 s[s_wei_stride_c], s[s_y], s[s_x] + s_mul_i32 s[s_wei_stride_k], s[s_c], s[s_wei_stride_c] + s_mul_i32 s[s_stride_dslice_hw], s[s_dslice_h], s[s_dslice_w] + s_mul_i32 s[s_stride_dslice_yx], s[s_dslice_y], s[s_dslice_x] + + ; N0xN1xB0xB1, gemm_m_per_block:64, gemm_n_per_block:128 + s_mul_i32 s[s_tmp], s[s_stride_dslice_hw], s[s_n] + s_lshr_b32 s[0], s[s_tmp], 7 + .v_u32_div_ss v_tmp+5, s_bx, 0, v_tmp, s_tmp + v_readfirstlane_b32 s[s_tmp+4], v[v_tmp+5] ; gemm_m, + s_mul_i32 s[s_tmp+2], s[s_tmp+4], s[0] + s_sub_i32 s[s_tmp+5], s[s_bx], s[s_tmp+2] ; gemm_n, cnt + s_lshl_b32 s[s_block_gtc_ic], s[s_tmp+4], 6 + s_lshr_b32 s[0], s[s_stride_dslice_hw], 5 ; B:32 per block, total num of B + .v_u32_div_ss v_tmp+5, s_tmp+5, 0, v_tmp, s_tmp + v_readfirstlane_b32 s[s_tmp], v[v_tmp+5] ; => N + s_mul_i32 s[s_tmp+2], s[s_tmp], s[0] + s_sub_i32 s[s_tmp+1], s[s_tmp+5], s[s_tmp+2] ; => B + s_lshl_b32 s[s_block_gtc_in], s[s_tmp], 2 + s_lshl_b32 s[s_block_gtc_ib], s[s_tmp+1], 5 + + v_mov_b32 v[v_gtc_in1], 0 + v_mov_b32 v[v_gtc_ik1], v[v_gtc_ik0] + + v_add_u32 v[v_tmp+5], s[s_block_gtc_ib], v[v_gtc_ib1] + ; calculate output transform, B -> dslice_h*dslice_w. always use ib1 + .v_u32_div_vs v_out_dslice_ih, v_tmp+5, s_dslice_w, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp], s[s_dslice_w], v[v_out_dslice_ih] + v_sub_u32 v[v_out_dslice_iw], v[v_tmp+5], v[v_tmp] + ; iHTildaLeft, iWTildaLeft + v_add_u32 v[v_out_dslice_ih], s[s_dslice_h_left], v[v_out_dslice_ih] + v_add_u32 v[v_out_dslice_iw], s[s_dslice_w_left], v[v_out_dslice_iw] + ; dslice_y,dslice_h -> oh, dslice_x,dslice_w -> ow + v_mov_b32 v[v_out_iho], v[v_out_dslice_ih] + v_mov_b32 v[v_out_iwo], v[v_out_dslice_iw] + v_mul_lo_u32 v[v_tmp], s[s_out_stride_k], v[v_gtc_ik1] + v_mul_lo_u32 v[v_tmp+1], s[s_out_stride_n], v[v_gtc_in1] + v_add_lshl_u32 v[v_out_os_base], v[v_tmp], v[v_tmp+1], 2 + ; n to staticly accumulate into base pointer + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_in], 2 + s_mul_i32 s[s_tmp], s[s_out_stride_n], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_out_stride_n], s[s_tmp+3] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], s[s_tmp+1] + + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 2 + + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_1x4_b32_v1 v_gld_b, s_p_out, v_out_os, 0, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + + v_lshlrev_b32 v[v_gtc_ic1], 1, v[v_gtc_ic0] + v_add_u32 v[v_tmp+5], s[s_block_gtc_ic], v[v_gtc_ic1] + v_mov_b32 v[v_dtile_iy], s[s_dtile_iy] + v_mov_b32 v[v_dtile_ix], s[s_dtile_ix] + v_mov_b32 v[v_wei_iy], s[s_dtile_iy] + v_mov_b32 v[v_wei_ix], s[s_dtile_ix] + + ; calculate wei offset + v_mul_lo_u32 v[v_tmp], s[s_wei_stride_c], v[v_tmp+5] + v_mul_lo_u32 v[v_tmp+1], s[s_wei_stride_k], v[v_gtc_ik1] + v_add_lshl_u32 v[v_wei_os_base], v[v_tmp], v[v_tmp+1], 2 + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + + s_lshl_b32 s[s_wei_stride_c], s[s_wei_stride_c], 2 + + ; load weight + .v_gld_1x2_b32_v1 v_gld_a, s_p_wei, v_wei_os, 0, s_wei_stride_c, s_tmp + + ; c thread mapping + ; -> MR x NR x ML1 x NL1 x ML0 x NL0 + ; cluster 1 x 1 x 4 x 4 x 4 x 4 + ; perthrd 2 x 2 x 1 x 1 x 2 x 4 + v_and_b32 v[v_tmp], 3, v0 + v_lshlrev_b32 v[v_tmp], 2, v[v_tmp] ; => iNL0 + v_lshrrev_b32 v[v_tmp+5], 2, v0 + v_and_b32 v[v_tmp+1], 3, v[v_tmp+5] + v_lshlrev_b32 v[v_tmp+1], 1, v[v_tmp+1] ; => iML0 + v_lshrrev_b32 v[v_tmp+5], 2, v[v_tmp+5] + v_and_b32 v[v_tmp+2], 3, v[v_tmp+5] ; => iNL1 + v_lshrrev_b32 v[v_tmp+5], 2, v[v_tmp+5] + v_and_b32 v[v_tmp+3], 3, v[v_tmp+5] ; => iML1 + v_lshl_or_b32 v[v_gemm_in], v[v_tmp+2], 4, v[v_tmp] ; in (without repeat) + v_lshl_or_b32 v[v_gemm_im], v[v_tmp+3], 3, v[v_tmp+1] ; im (without repeat) + + ; remapping, gemm_im => C0xC1:32x2, gemm_in => N0xB0xB1xN1:1x1x32x4 + v_and_b32 v[v_in_in1], 3, v[v_gemm_in] + v_lshrrev_b32 v[v_tmp], 2, v[v_gemm_in] + v_and_b32 v[v_in_ib1], 31, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 5, v[v_tmp] + v_mov_b32 v[v_in_ib0], 0 + v_mov_b32 v[v_in_in0], 0 + v_and_b32 v[v_in_ic1], 1, v[v_gemm_im] + v_lshrrev_b32 v[v_tmp], 1, v[v_gemm_im] + v_and_b32 v[v_in_ic0], 31, v[v_tmp] + v_lshrrev_b32 v[v_tmp], 5, v[v_tmp] + + v_lshl_or_b32 v[v_in_ib1], v[v_in_ib0], 5, v[v_in_ib1] + v_lshl_or_b32 v[v_in_in1], v[v_in_in0], 2, v[v_in_in1] + v_lshl_or_b32 v[v_in_ic1], v[v_in_ic0], 1, v[v_in_ic1] + v_add_u32 v[v_in_ib1], s[s_block_gtc_ib], v[v_in_ib1] + + .v_u32_div_vs v_in_dslice_ih, v_in_ib1, s_dslice_w, v_tmp, s_tmp + v_mul_lo_u32 v[v_tmp+1], s[s_dslice_w], v[v_in_dslice_ih] + v_sub_u32 v[v_in_dslice_iw], v[v_in_ib1], v[v_tmp+1] + v_add_u32 v[v_in_dslice_ih], s[s_dslice_h_left], v[v_in_dslice_ih] + v_add_u32 v[v_in_dslice_iw], s[s_dslice_w_left], v[v_in_dslice_iw] + + ; dslice_h,dslice_y -> hip, dslice_w,dslicw_x -> wip + s_mul_i32 s[s_tmp], s[s_dtile_iy], s[s_dilation_h] + v_mul_lo_u32 v[v_tmp], s[s_stride_h], v[v_in_dslice_ih] + v_add_u32 v[v_tmp], s[s_tmp], v[v_tmp] + s_mul_i32 s[s_tmp+1], s[s_dtile_ix], s[s_dilation_w] + v_mul_lo_u32 v[v_tmp+1], s[s_stride_w], v[v_in_dslice_iw] + v_add_u32 v[v_tmp+1], s[s_tmp+1], v[v_tmp+1] + ; v_tmp: hip, v_tmp+1: wip + + ; hip->h, wip->w + v_sub_i32 v[v_in_ihi], v[v_tmp], s[s_pad_h] + v_sub_i32 v[v_in_iwi], v[v_tmp+1], s[s_pad_w] + + .v_set_flag_hw v_in_flag, v_in_ihi, v_in_iwi, s_hi, s_wi + + ; input offset + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_in], 2 + s_mul_i32 s[s_tmp], s[s_in_stride_n], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_in_stride_n], s[s_tmp+3] + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], s[s_tmp+1] + + s_lshl_b32 s[s_tmp+3], s[s_block_gtc_ic], 2 + s_mul_i32 s[s_tmp], s[s_in_stride_c], s[s_tmp+3] + s_mul_hi_u32 s[s_tmp+1], s[s_in_stride_c], s[s_tmp+3] + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], s[s_tmp+1] + v_mul_lo_u32 v[v_tmp], s[s_in_stride_n], v[v_in_in1] + v_mul_lo_u32 v[v_tmp+1], s[s_in_stride_c], v[v_in_ic1] + v_add_u32 v[v_tmp], v[v_tmp], v[v_tmp+1] + v_mul_lo_u32 v[v_tmp+1], s[s_wi], v[v_in_ihi] + v_add3_u32 v[v_in_os], v[v_tmp], v[v_tmp+1], v[v_in_iwi] + v_lshlrev_b32 v[v_in_os], 2, v[v_in_os] + + s_lshl_b32 s[s_in_stride_n], s[s_in_stride_n], 2 + s_mov_b32 s[s_in_stride_nr], 64 + + s_lshl_b32 s[s_in_stride_c], s[s_in_stride_c], 2 + s_lshl_b32 s[s_in_stride_cr], s[s_in_stride_c], 5 + + v_lshrrev_b32 v[v_gtc_in0], 2, v[v_gtc_in1] + ; LDS store, order out: K0xK1xExN0xB0xB1xN1: 1x1x1x1x1x1x4, 8x1x1x1x1x32x1 + v_lshlrev_b32 v[v_tmp], 2, v[v_gtc_ib1] + v_lshl_or_b32 v[v_tmp], v[v_gtc_in0], 7, v[v_tmp] + v_lshl_or_b32 v[v_tmp+1], v[v_gtc_ik1], 7, v[v_tmp] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_tmp+1] + + ; LDS store, order wei not shuffled, wei: K0xK1xExC0xC1: 1x1x1x1x2, 8x1x1x32x1 + v_lshl_or_b32 v[v_tmp], v[v_gtc_ik1], 6, v[v_gtc_ic1] + v_lshlrev_b32 v[v_sst_a_os], 2, v[v_tmp] + v_add_u32 v[v_sst_b_os], 2048, v[v_sst_b_os] + + ; LDS load + v_lshlrev_b32 v[v_sld_b_os], 2, v[v_gemm_in] + v_lshlrev_b32 v[v_sld_a_os], 2, v[v_gemm_im] + v_add_u32 v[v_sld_b_os], 2048, v[v_sld_b_os] + + s_mul_i32 s[s_knum], s[s_stride_dslice_yx], s[s_k] + s_lshl_b32 s[s_out_move_slice_stride_k], s[s_out_stride_k], 5 ; 8x + s_lshl_b32 s[s_wei_move_slice_stride_k], s[s_wei_stride_k], 5 ; 8x + s_mov_b32 s[s_move_slice_k_ik], 0 + s_mov_b32 s[s_move_slice_k_idsy], 0 + s_mov_b32 s[s_move_slice_k_idsx], 0 + + .v_clear_nc v_c, 32 + ; start FMA loop, 4x8 thread tile with 2x4 sub-tile + s_waitcnt vmcnt(2) + .v_sst_so0_1x4_b32_v4 v_gld_b, v_sst_b_os + + s_waitcnt vmcnt(0) + .v_sst_so0_1x2_b32_v2 v_gld_a, v_sst_a_os + + s_sub_i32 s[s_kitr], s[s_knum], 8 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x1x1x1x32_end + + .s_bwd_gtc_move_slice_window_k_dsy_dsx s_move_slice_k_ik, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dslice_y, s_dslice_x, 8, 1, 1, v_out_os_base, v_wei_os_base, s_out_move_slice_stride_k, s_wei_move_slice_stride_k + .v_bwd_gtc_out_update_hw v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_dy, s_dtile_dx, s_tmp + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + .v_bwd_gtc_wei_update_yx v_wei_iy, v_wei_ix, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_y, s_dtile_x, v_dtile_iy, v_dtile_ix, s_tmp + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + v_xor_b32 v[v_sst_b_os], 0x2000, v[v_sst_b_os] ; switch double buffer b store + v_xor_b32 v[v_sst_a_os], 0x2000, v[v_sst_a_os] ; switch double buffer a store + s_waitcnt lgkmcnt(0) + s_barrier + + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_1x4_b32_v1 v_gld_b, s_p_out, v_out_os, 0, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + ; load weight + .v_gld_1x2_b32_v1 v_gld_a, s_p_wei, v_wei_os, 0, s_wei_stride_c, s_tmp + +L_igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x1x1x1x32_fma_body: + ; do fma accumulate with unroll 8 + ds_read_b64 v[v_a:v_a+1], v[v_sld_a_os] + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:256 + ds_read_b64 v[v_a+2:v_a+2+1], v[v_sld_a_os] offset:128 + .itr_k = 0 + .rept 7 + s_waitcnt lgkmcnt(2) + .v_fma_2x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_2x4_s8 v_c+4,v_a,v_b+4 + ds_read_b64 v[v_a:v_a+1], v[v_sld_a_os] offset:0+(.itr_k+1)*256 + s_waitcnt lgkmcnt(1) + .v_fma_2x4_s8 v_c+16,v_a+2,v_b + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512 + .v_fma_2x4_s8 v_c+20,v_a+2,v_b+4 + + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512+256 + ds_read_b64 v[v_a+2:v_a+2+1], v[v_sld_a_os] offset:0+(.itr_k+1)*256+128 + .itr_k = .itr_k + 1 + .endr + + ; last unroll + v_xor_b32 v[v_sld_b_os], 8192, v[v_sld_b_os] ; switch double buffer b load + v_xor_b32 v[v_sld_a_os], 8192, v[v_sld_a_os] ; switch double buffer a load + s_waitcnt lgkmcnt(2) + .v_fma_2x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_2x4_s8 v_c+4,v_a,v_b+4 + s_waitcnt vmcnt(2) + .v_sst_so0_1x4_b32_v4 v_gld_b, v_sst_b_os + s_waitcnt vmcnt(0) + .v_sst_so0_1x2_b32_v2 v_gld_a, v_sst_a_os + s_sub_i32 s[s_kitr], s[s_kitr], 8 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x1x1x1x32_fma_finishing + .s_bwd_gtc_move_slice_window_k_dsy_dsx s_move_slice_k_ik, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dslice_y, s_dslice_x, 8, 1, 1, v_out_os_base, v_wei_os_base, s_out_move_slice_stride_k, s_wei_move_slice_stride_k + .v_bwd_gtc_out_update_hw v_out_iho, v_out_iwo, v_out_dslice_ih, v_out_dslice_iw, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_dy, s_dtile_dx, s_tmp + .v_bwd_gtc_out_update_os v_out_os, v_out_os_base, v_out_iho, v_out_iwo, s_wo, v_tmp + .v_set_flag_hw v_out_flag, v_out_iho, v_out_iwo, s_ho, s_wo + .v_bwd_gtc_wei_update_yx v_wei_iy, v_wei_ix, s_move_slice_k_idsy, s_move_slice_k_idsx, s_dtile_y, s_dtile_x, v_dtile_iy, v_dtile_ix, s_tmp + .v_bwd_gtc_wei_update_os v_wei_os, v_wei_os_base, v_wei_iy, v_wei_ix, s_x, v_tmp + s_waitcnt lgkmcnt(2) + .v_fma_2x4_s8 v_c+16,v_a+2,v_b + v_xor_b32 v[v_sst_b_os], 8192, v[v_sst_b_os] ; switch double buffer b store + v_xor_b32 v[v_sst_a_os], 8192, v[v_sst_a_os] ; switch double buffer a store + s_waitcnt lgkmcnt(0) + s_barrier + ; load output + v_cmp_eq_u32 vcc, 1, v[v_out_flag] + s_and_saveexec_b64 s[s_tmp+4:s_tmp+5], vcc + .v_gld_1x4_b32_v1 v_gld_b, s_p_out, v_out_os, 0, s_out_stride_n, s_tmp + s_or_b64 exec, exec, s[s_tmp+4:s_tmp+5] + ; load weight + .v_gld_1x2_b32_v1 v_gld_a, s_p_wei, v_wei_os, 0, s_wei_stride_c, s_tmp + .v_fma_2x4_s8 v_c+20,v_a+2,v_b+4 + + s_branch L_igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x1x1x1x32_fma_body +L_igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x1x1x1x32_fma_finishing: + s_waitcnt lgkmcnt(2) + .v_fma_2x4_s8 v_c+16,v_a+2,v_b + .v_fma_2x4_s8 v_c+20,v_a+2,v_b+4 +L_igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x1x1x1x32_end: + s_waitcnt lgkmcnt(0) + s_barrier + ds_read_b64 v[v_a:v_a+1], v[v_sld_a_os] + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:256 + ds_read_b64 v[v_a+2:v_a+2+1], v[v_sld_a_os] offset:128 + .itr_k = 0 + .rept 7 + s_waitcnt lgkmcnt(2) + .v_fma_2x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_2x4_s8 v_c+4,v_a,v_b+4 + ds_read_b64 v[v_a:v_a+1], v[v_sld_a_os] offset:0+(.itr_k+1)*256 + s_waitcnt lgkmcnt(1) + .v_fma_2x4_s8 v_c+16,v_a+2,v_b + ds_read_b128 v[v_b:v_b+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512 + .v_fma_2x4_s8 v_c+20,v_a+2,v_b+4 + + ds_read_b128 v[v_b+4:v_b+4+3], v[v_sld_b_os] offset:0+(.itr_k+1)*512+256 + ds_read_b64 v[v_a+2:v_a+2+1], v[v_sld_a_os] offset:0+(.itr_k+1)*256+128 + .itr_k = .itr_k + 1 + .endr + + ; last unroll + s_waitcnt lgkmcnt(2) + .v_fma_2x4_s8 v_c,v_a,v_b + s_waitcnt lgkmcnt(1) + .v_fma_2x4_s8 v_c+4,v_a,v_b+4 + s_waitcnt lgkmcnt(0) + .v_fma_2x4_s8 v_c+16,v_a+2,v_b + .v_fma_2x4_s8 v_c+20,v_a+2,v_b+4 + + v_cmpx_eq_u32 vcc, 1, v[v_in_flag] + s_cbranch_execz L_igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x1x1x1x32_out + s_mov_b32 s[s_tmp], 0 + s_mov_b32 s[s_tmp+1], 0 + s_mov_b32 s[s_tmp+2], 0 + s_mov_b32 s[s_tmp+3], 0 + .v_write4d_strided v_c,s_p_in,v_in_os,s_in_stride_n,s_in_stride_nr,s_in_stride_c,s_in_stride_cr,s_tmp,4,2,2,2 +L_igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x1x1x1x32_out: + s_endpgm +.rodata +.p2align 6 +.amdhsa_kernel igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x1x1x1x32 + .amdhsa_group_segment_fixed_size 16384 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 76 + .amdhsa_next_free_sgpr 78 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 +.end_amdhsa_kernel + +.amdgpu_metadata +--- +amdhsa.version: [ 1, 0 ] +amdhsa.kernels: + - .name: igemm_bwd_gtc + .symbol: igemm_bwd_gtc.kd + .sgpr_count: 72 + .vgpr_count: 128 + .kernarg_segment_align: 8 + .kernarg_segment_size: 144 + .group_segment_fixed_size: 32768 + .private_segment_fixed_size: 0 + .wavefront_size: 64 + .reqd_workgroup_size : [256, 1, 1] + .max_flat_workgroup_size: 256 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: dtile_iy , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: dtile_ix , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: dtile_dy , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: dtile_dx , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: dtile_y , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: dtile_x , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: dtile_h , .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - { .name: dtile_w , .size: 4, .offset: 112, .value_kind: by_value, .value_type: i32} + - { .name: dslice_y , .size: 4, .offset: 116, .value_kind: by_value, .value_type: i32} + - { .name: dslice_x , .size: 4, .offset: 120, .value_kind: by_value, .value_type: i32} + - { .name: dslice_h , .size: 4, .offset: 124, .value_kind: by_value, .value_type: i32} + - { .name: dslice_w , .size: 4, .offset: 128, .value_kind: by_value, .value_type: i32} + - { .name: dslice_h_left , .size: 4, .offset: 132, .value_kind: by_value, .value_type: i32} + - { .name: dslice_w_left , .size: 4, .offset: 136, .value_kind: by_value, .value_type: i32} + - { .name: __pack0 , .size: 4, .offset: 140, .value_kind: by_value, .value_type: i32} + - .name: igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x2x4_16x1x1x16x1_tb1x1x1x2x4x1x1_16x1x1x16x1x1x1 + .symbol: igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x2x4_16x1x1x16x1_tb1x1x1x2x4x1x1_16x1x1x16x1x1x1.kd + .sgpr_count: 78 + .vgpr_count: 122 + .kernarg_segment_align: 8 + .kernarg_segment_size: 144 + .group_segment_fixed_size: 32768 + .private_segment_fixed_size: 0 + .wavefront_size: 64 + .reqd_workgroup_size : [256, 1, 1] + .max_flat_workgroup_size: 256 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: dtile_iy , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: dtile_ix , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: dtile_dy , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: dtile_dx , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: dtile_y , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: dtile_x , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: dtile_h , .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - { .name: dtile_w , .size: 4, .offset: 112, .value_kind: by_value, .value_type: i32} + - { .name: dslice_y , .size: 4, .offset: 116, .value_kind: by_value, .value_type: i32} + - { .name: dslice_x , .size: 4, .offset: 120, .value_kind: by_value, .value_type: i32} + - { .name: dslice_h , .size: 4, .offset: 124, .value_kind: by_value, .value_type: i32} + - { .name: dslice_w , .size: 4, .offset: 128, .value_kind: by_value, .value_type: i32} + - { .name: dslice_h_left, .size: 4, .offset: 132, .value_kind: by_value, .value_type: i32} + - { .name: dslice_w_left, .size: 4, .offset: 136, .value_kind: by_value, .value_type: i32} + - { .name: __pack0 , .size: 4, .offset: 140, .value_kind: by_value, .value_type: i32} + - .name: igemm_bwd_gtc_bt64x64x8_tt8x8_gm2x4x2_gn2x4x2_ta1x2x1x1x4_4x1x1x16x1_tb1x2x1x1x4x1x1_4x1x1x16x1x1x1 + .symbol: igemm_bwd_gtc_bt64x64x8_tt8x8_gm2x4x2_gn2x4x2_ta1x2x1x1x4_4x1x1x16x1_tb1x2x1x1x4x1x1_4x1x1x16x1x1x1.kd + .sgpr_count: 78 + .vgpr_count: 122 + .kernarg_segment_align: 8 + .kernarg_segment_size: 144 + .group_segment_fixed_size: 8192 + .private_segment_fixed_size: 0 + .wavefront_size: 64 + .reqd_workgroup_size : [64, 1, 1] + .max_flat_workgroup_size: 64 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: dtile_iy , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: dtile_ix , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: dtile_dy , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: dtile_dx , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: dtile_y , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: dtile_x , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: dtile_h , .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - { .name: dtile_w , .size: 4, .offset: 112, .value_kind: by_value, .value_type: i32} + - { .name: dslice_y , .size: 4, .offset: 116, .value_kind: by_value, .value_type: i32} + - { .name: dslice_x , .size: 4, .offset: 120, .value_kind: by_value, .value_type: i32} + - { .name: dslice_h , .size: 4, .offset: 124, .value_kind: by_value, .value_type: i32} + - { .name: dslice_w , .size: 4, .offset: 128, .value_kind: by_value, .value_type: i32} + - { .name: dslice_h_left, .size: 4, .offset: 132, .value_kind: by_value, .value_type: i32} + - { .name: dslice_w_left, .size: 4, .offset: 136, .value_kind: by_value, .value_type: i32} + - { .name: __pack0 , .size: 4, .offset: 140, .value_kind: by_value, .value_type: i32} + - .name: igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x2x1x1x4_8x1x1x32x1_tb1x2x1x1x4x1x1_8x1x1x2x1x1x16 + .symbol: igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x2x1x1x4_8x1x1x32x1_tb1x2x1x1x4x1x1_8x1x1x2x1x1x16.kd + .sgpr_count: 78 + .vgpr_count: 122 + .kernarg_segment_align: 8 + .kernarg_segment_size: 144 + .group_segment_fixed_size: 32768 + .private_segment_fixed_size: 0 + .wavefront_size: 64 + .reqd_workgroup_size : [256, 1, 1] + .max_flat_workgroup_size: 256 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: dtile_iy , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: dtile_ix , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: dtile_dy , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: dtile_dx , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: dtile_y , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: dtile_x , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: dtile_h , .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - { .name: dtile_w , .size: 4, .offset: 112, .value_kind: by_value, .value_type: i32} + - { .name: dslice_y , .size: 4, .offset: 116, .value_kind: by_value, .value_type: i32} + - { .name: dslice_x , .size: 4, .offset: 120, .value_kind: by_value, .value_type: i32} + - { .name: dslice_h , .size: 4, .offset: 124, .value_kind: by_value, .value_type: i32} + - { .name: dslice_w , .size: 4, .offset: 128, .value_kind: by_value, .value_type: i32} + - { .name: dslice_h_left, .size: 4, .offset: 132, .value_kind: by_value, .value_type: i32} + - { .name: dslice_w_left, .size: 4, .offset: 136, .value_kind: by_value, .value_type: i32} + - { .name: __pack0 , .size: 4, .offset: 140, .value_kind: by_value, .value_type: i32} + - .name: igemm_bwd_gtc_bt128x128x8_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x1x4_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16 + .symbol: igemm_bwd_gtc_bt128x128x8_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x1x4_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16.kd + .sgpr_count: 78 + .vgpr_count: 114 + .kernarg_segment_align: 8 + .kernarg_segment_size: 144 + .group_segment_fixed_size: 16384 + .private_segment_fixed_size: 0 + .wavefront_size: 64 + .reqd_workgroup_size : [256, 1, 1] + .max_flat_workgroup_size: 256 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: dtile_iy , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: dtile_ix , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: dtile_dy , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: dtile_dx , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: dtile_y , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: dtile_x , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: dtile_h , .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - { .name: dtile_w , .size: 4, .offset: 112, .value_kind: by_value, .value_type: i32} + - { .name: dslice_y , .size: 4, .offset: 116, .value_kind: by_value, .value_type: i32} + - { .name: dslice_x , .size: 4, .offset: 120, .value_kind: by_value, .value_type: i32} + - { .name: dslice_h , .size: 4, .offset: 124, .value_kind: by_value, .value_type: i32} + - { .name: dslice_w , .size: 4, .offset: 128, .value_kind: by_value, .value_type: i32} + - { .name: dslice_h_left, .size: 4, .offset: 132, .value_kind: by_value, .value_type: i32} + - { .name: dslice_w_left, .size: 4, .offset: 136, .value_kind: by_value, .value_type: i32} + - { .name: __pack0 , .size: 4, .offset: 140, .value_kind: by_value, .value_type: i32} + - .name: igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16 + .symbol: igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16.kd + .sgpr_count: 78 + .vgpr_count: 76 + .kernarg_segment_align: 8 + .kernarg_segment_size: 144 + .group_segment_fixed_size: 16384 + .private_segment_fixed_size: 0 + .wavefront_size: 64 + .reqd_workgroup_size : [256, 1, 1] + .max_flat_workgroup_size: 256 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: dtile_iy , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: dtile_ix , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: dtile_dy , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: dtile_dx , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: dtile_y , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: dtile_x , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: dtile_h , .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - { .name: dtile_w , .size: 4, .offset: 112, .value_kind: by_value, .value_type: i32} + - { .name: dslice_y , .size: 4, .offset: 116, .value_kind: by_value, .value_type: i32} + - { .name: dslice_x , .size: 4, .offset: 120, .value_kind: by_value, .value_type: i32} + - { .name: dslice_h , .size: 4, .offset: 124, .value_kind: by_value, .value_type: i32} + - { .name: dslice_w , .size: 4, .offset: 128, .value_kind: by_value, .value_type: i32} + - { .name: dslice_h_left, .size: 4, .offset: 132, .value_kind: by_value, .value_type: i32} + - { .name: dslice_w_left, .size: 4, .offset: 136, .value_kind: by_value, .value_type: i32} + - { .name: __pack0 , .size: 4, .offset: 140, .value_kind: by_value, .value_type: i32} + - .name: igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x1x1x1x32 + .symbol: igemm_bwd_gtc_bt64x128x8_tt4x8_gm2x4x4_gn2x4x4_ta1x1x1x1x2_8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x1x1x1x32.kd + .sgpr_count: 78 + .vgpr_count: 76 + .kernarg_segment_align: 8 + .kernarg_segment_size: 144 + .group_segment_fixed_size: 16384 + .private_segment_fixed_size: 0 + .wavefront_size: 64 + .reqd_workgroup_size : [256, 1, 1] + .max_flat_workgroup_size: 256 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: dtile_iy , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: dtile_ix , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: dtile_dy , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: dtile_dx , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: dtile_y , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: dtile_x , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: dtile_h , .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - { .name: dtile_w , .size: 4, .offset: 112, .value_kind: by_value, .value_type: i32} + - { .name: dslice_y , .size: 4, .offset: 116, .value_kind: by_value, .value_type: i32} + - { .name: dslice_x , .size: 4, .offset: 120, .value_kind: by_value, .value_type: i32} + - { .name: dslice_h , .size: 4, .offset: 124, .value_kind: by_value, .value_type: i32} + - { .name: dslice_w , .size: 4, .offset: 128, .value_kind: by_value, .value_type: i32} + - { .name: dslice_h_left, .size: 4, .offset: 132, .value_kind: by_value, .value_type: i32} + - { .name: dslice_w_left, .size: 4, .offset: 136, .value_kind: by_value, .value_type: i32} + - { .name: __pack0 , .size: 4, .offset: 140, .value_kind: by_value, .value_type: i32} +... +.end_amdgpu_metadata diff --git a/src/mlo_dir_conv.cpp b/src/mlo_dir_conv.cpp index 2b4bdcac69..9f4550b740 100644 --- a/src/mlo_dir_conv.cpp +++ b/src/mlo_dir_conv.cpp @@ -140,7 +140,8 @@ static auto GetImplicitGemmSolvers() miopen::solver::ConvHipImplicitGemmBwdDataV1R1, miopen::solver::ConvHipImplicitGemmBwdDataV4R1, miopen::solver::ConvAsmImplicitGemmV4R1DynamicFwd_1x1, - miopen::solver::ConvAsmImplicitGemmV4R1DynamicFwd>{}; + miopen::solver::ConvAsmImplicitGemmV4R1DynamicFwd, + miopen::solver::ConvAsmImplicitGemmV4R1DynamicBwd>{}; } static auto GetWindogradSolvers() diff --git a/src/solver.cpp b/src/solver.cpp index c406681f7c..52759e4ce5 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -323,6 +323,9 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) RegisterWithSolver( registry, ++id, ConvHipImplicitGemmForwardV4R4Xdlops{}, miopenConvolutionAlgoImplicitGEMM); + + RegisterWithSolver( + registry, ++id, ConvAsmImplicitGemmV4R1DynamicBwd{}, miopenConvolutionAlgoImplicitGEMM); } } // namespace solver diff --git a/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp new file mode 100644 index 0000000000..959467eae6 --- /dev/null +++ b/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp @@ -0,0 +1,195 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "miopen/solver.hpp" +#include "miopen/handle.hpp" +#include +#include +#include +#include +#include +#include "implicitgemm_util.hpp" + +namespace miopen { +namespace solver { + +static inline bool FindImplicitGemmDynamicKernelBwd(const ConvolutionContext& ctx, + std::string& kernel_name, + int& block_size, + int& grid_size) +{ + // TODO: add more dynamic kernel to expand support range, and update this function + // clang-format off + // refer to ConvolutionContextInterpreter, in bwd most dimension is reversed + int hi = ctx.out_height; + int wi = ctx.out_width; + int n = ctx.batch_sz; + int k = ctx.n_inputs; + int c = ctx.n_outputs; + int ho = ctx.in_height; + int wo = ctx.in_width; + int stride_h = ctx.in_height > 1 ? ctx.kernel_stride_h : 1; + int stride_w = ctx.in_width > 1 ? ctx.kernel_stride_w : 1; + int dilation_h = ctx.kernel_size_h > 1? ctx.kernel_dilation_h : 1; + int dilation_w = ctx.kernel_size_w > 1? ctx.kernel_dilation_w : 1; + int pad_h = ctx.pad_h; + int pad_w = ctx.pad_w; + int y = ctx.kernel_size_h; + int x = ctx.kernel_size_w; + + int gcd_stride_dilation_h = gcd(stride_h, dilation_h); + int gcd_stride_dilation_w = gcd(stride_w, dilation_w); + int y_tilda = stride_h / gcd_stride_dilation_h; + int x_tilda = stride_w / gcd_stride_dilation_w; + + int h_tilda = ho + (dilation_h * (y - 1) + stride_h - 1) / stride_h; + int w_tilda = wo + (dilation_w * (x - 1) + stride_w - 1) / stride_w; + + int h_tilda_left = std::max(0, pad_h - dilation_h * (y_tilda - 1)) / stride_h; + int w_tilda_left = std::max(0, pad_w - dilation_w * (x_tilda - 1)) / stride_w; + + int h_tilda_right = std::min(h_tilda, (pad_h + hi - 1 + stride_h - 1) / stride_h + 1); + int w_tilda_right = std::min(w_tilda, (pad_w + wi - 1 + stride_w - 1) / stride_w + 1); + + int h_tilda_slice = h_tilda_right - h_tilda_left; + int w_tilda_slice = w_tilda_right - w_tilda_left; + // clang-format on + int gemm_m = c; + int gemm_n = n * h_tilda_slice * w_tilda_slice; + // int gemm_k; since k dimension is merged, we only check k + + // TODO: this is too simple, need more kernels and more optimal logic to select kernel + if((gemm_m % 128 == 0) && (gemm_n % 128 == 0) && (k % 16 == 0)) + { + if((y == 1) && (x == 1) && (stride_h == 1) && (stride_w == 1) && (dilation_h == 1) && + (dilation_w == 1) && (pad_h == 0) && (pad_w == 0) && (n % 128 == 0)) + { + grid_size = (gemm_m >> 7) * (gemm_n >> 7); + block_size = 256; + kernel_name = "igemm_bwd_gtc_bt128x128x16_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x2x4_" + "16x1x1x16x1_tb1x1x1x2x4x1x1_16x1x1x16x1x1x1"; + return true; + } + else + { + grid_size = (gemm_m >> 7) * (gemm_n >> 7); + block_size = 256; + kernel_name = "igemm_bwd_gtc"; + return true; + } + } + else + { + if((y == 1) && (x == 1) && (stride_h == 1) && (stride_w == 1) && (dilation_h == 1) && + (dilation_w == 1) && (pad_h == 0) && (pad_w == 0)) + { + if((gemm_m % 128 == 0) && (gemm_n % 128 == 0) && (k % 8 == 0) && ((ho * wo) % 16 == 0)) + { + grid_size = (gemm_m >> 7) * (gemm_n >> 7); + block_size = 256; + kernel_name = "igemm_bwd_gtc_bt128x128x8_tt8x8_gm2x4x4_gn2x4x4_ta1x1x1x1x4_" + "8x1x1x32x1_tb1x1x1x1x4x1x1_8x1x1x2x1x1x16"; + return true; + } + else if((gemm_m % 64 == 0) && (gemm_n % 64 == 0) && (k % 8 == 0) && (n % 64 == 0)) + { + grid_size = (gemm_m >> 6) * (gemm_n >> 6); + block_size = 64; + kernel_name = "igemm_bwd_gtc_bt64x64x8_tt8x8_gm2x4x2_gn2x4x2_ta1x2x1x1x4_" + "4x1x1x16x1_tb1x2x1x1x4x1x1_4x1x1x16x1x1x1"; + return true; + } + } + } + return false; +} + +bool ConvAsmImplicitGemmV4R1DynamicBwd::IsApplicable(const ConvolutionContext& ctx) const +{ + const auto device_name = ctx.GetStream().GetDeviceName(); + if(!(StartsWith(device_name, "gfx900") || StartsWith(device_name, "gfx906"))) + return false; + + if(!ctx.direction.IsBackwardData()) + return false; + + if(!ctx.Is2d()) + return false; + + if(!ctx.IsFp32()) + return false; + + if(!ctx.rmv.IsV3()) + return false; + + if(ctx.group_counts != 1) + return false; + + std::string kernel_name; + int block_size; + int grid_size; + return FindImplicitGemmDynamicKernelBwd(ctx, kernel_name, block_size, grid_size); +} + +ConvSolution ConvAsmImplicitGemmV4R1DynamicBwd::GetSolution(const ConvolutionContext& ctx) const +{ + ConvSolution result; + + std::string kernel_name; + int block_size; + int grid_size; + bool ret = FindImplicitGemmDynamicKernelBwd(ctx, kernel_name, block_size, grid_size); + if(!ret) + MIOPEN_THROW("should not happen!"); + + KernelInfo kernel; + std::ostringstream options; + + kernel.kernel_file = "igemm_bwd_gtc_dynamic.s"; + kernel.kernel_name = kernel_name; + kernel.g_wk.clear(); + /* Note here, for API like hipHccModuleLaunchKernel(), hipExtModuleLaunchKernel() + * grid dims is in unit of work item. + * But for api like hipModuleLaunchKernel(), grid dim is in unit of block. + */ + kernel.g_wk.push_back(grid_size * block_size); + kernel.g_wk.push_back(1); + kernel.g_wk.push_back(1); + kernel.l_wk.clear(); + kernel.l_wk.push_back(block_size); + kernel.l_wk.push_back(1); + kernel.l_wk.push_back(1); + + GenerateClangDefsym(options, "ROCM_METADATA_VERSION", ctx.rmv.UseV3() ? 5 : 4); + + kernel.comp_options = options.str(); + + result.invoker_factory = conv::MakeImplGemmDynamicBackwardDataInvokerFactory(ctx); + result.construction_params.push_back(kernel); + return result; +} + +} // namespace solver +} // namespace miopen diff --git a/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp index 6aaa73bc2e..e305e02f45 100644 --- a/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp @@ -129,6 +129,7 @@ static inline int RunAndMeasureSolutionDynamicBase(miopen::Handle& profile_h, { elapsed_time = float(0); std::vector kernels; + const auto& conv_problem = ctx.conv_problem; for(auto& k_info : solution.construction_params) { @@ -141,8 +142,14 @@ static inline int RunAndMeasureSolutionDynamicBase(miopen::Handle& profile_h, k_info.comp_options); kernels.push_back(kernel); } - float time = - conv::CallImplicitGemmDynamic(profile_h, ctx, bot_buf, top_buf, wei_buf, kernels); + bool kernel_is_1x1 = (kernels[0].GetName().find("igemm_v4r1_1x1_dynamic") == 0); + float time; + if(kernel_is_1x1) + time = conv::CallImplGemmDynamicForward1x1( + profile_h, conv_problem, bot_buf, top_buf, wei_buf, kernels); + else + time = conv::CallImplGemmDynamicForward( + profile_h, conv_problem, bot_buf, top_buf, wei_buf, kernels); elapsed_time += time; } #ifdef NDEBUG @@ -530,8 +537,9 @@ static inline ConvSolution GetSolutionBase(const ConvolutionContext& ctx, std::string kernel_name = GetKernelNameImplicitGemmV4R1Dynamic(config, kernel_type); - int block_size = GetImplicitGemmV4R1DynamicBlockSize(config); - int grid_size = GetImplicitGemmV4R1DynamicGridSize(ctx, config); + int block_size = GetImplicitGemmV4R1DynamicBlockSize(config); + int grid_size = GetImplicitGemmV4R1DynamicGridSize(ctx, config); + bool kernel_is_1x1 = (kernel_name.find("igemm_v4r1_1x1_dynamic") == 0); KernelInfo kernel; std::ostringstream options; @@ -557,7 +565,10 @@ static inline ConvSolution GetSolutionBase(const ConvolutionContext& ctx, MIOPEN_LOG_I2(kernel.kernel_file + ":" + kernel.kernel_name); - result.invoker_factory = conv::MakeImplGemmDynamicDataInvokerFactory(ctx); + if(kernel_is_1x1) + result.invoker_factory = conv::MakeImplGemmDynamicForward1x1InvokerFactory(ctx); + else + result.invoker_factory = conv::MakeImplGemmDynamicForwardInvokerFactory(ctx); result.construction_params.push_back(kernel); return result; } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 9c665a112b..0948ed8e74 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -534,20 +534,27 @@ COMMAND $ --verbose --conv_dim_type conv3d --input 8 COMMAND $ --verbose --conv_dim_type conv3d --input 16 64 3 4 4 --weights 64 32 1 3 3 --pads_strides_dilations 0 0 0 1 2 3 1 2 3 --trans_output_pads 0 0 0 --group-count 4 --cmode trans --pmode default ) -set(DYNAMIC_IMPLICITGEMM_ENVS - MIOPEN_DEBUG_FIND_ONLY_SOLVER=ConvAsmImplicitGemmV4R1DynamicFwd +set(DYNAMIC_IMPLICITGEMM_COMMON MIOPEN_DEBUG_CONV_FFT=0 - MIOPEN_DEBUG_CONV_GEMM=0) + MIOPEN_DEBUG_CONV_GEMM=0 + MIOPEN_DEBUG_CONV_WINOGRAD=0) +set(DYNAMIC_IMPLICITGEMM_ENVS + ${DYNAMIC_IMPLICITGEMM_COMMON} + MIOPEN_DEBUG_FIND_ONLY_SOLVER=ConvAsmImplicitGemmV4R1DynamicFwd) set(DYNAMIC_IMPLICITGEMM_1X1_ENVS - MIOPEN_DEBUG_FIND_ONLY_SOLVER=ConvAsmImplicitGemmV4R1DynamicFwd_1x1 - MIOPEN_DEBUG_CONV_FFT=0 - MIOPEN_DEBUG_CONV_GEMM=0) + ${DYNAMIC_IMPLICITGEMM_COMMON} + MIOPEN_DEBUG_FIND_ONLY_SOLVER=ConvAsmImplicitGemmV4R1DynamicFwd_1x1) +set(DYNAMIC_IMPLICITGEMM_BWD_ENVS + ${DYNAMIC_IMPLICITGEMM_COMMON} + MIOPEN_DEBUG_FIND_ONLY_SOLVER=ConvAsmImplicitGemmV4R1DynamicBwd) add_custom_test(test_conv_igemm_dynamic_small COMMAND ${DYNAMIC_IMPLICITGEMM_ENVS} $ --verbose --input 16 16 56 56 --weights 64 16 1 1 --pads_strides_dilations 0 0 1 1 1 1 --disable-backward-data --disable-backward-weights COMMAND ${DYNAMIC_IMPLICITGEMM_ENVS} $ --verbose --input 16 64 34 34 --weights 64 64 3 3 --pads_strides_dilations 0 0 1 1 1 1 --disable-backward-data --disable-backward-weights COMMAND ${DYNAMIC_IMPLICITGEMM_ENVS} $ --verbose --input 32 32 17 17 --weights 32 32 1 7 --pads_strides_dilations 0 3 1 1 1 1 --disable-backward-data --disable-backward-weights COMMAND ${DYNAMIC_IMPLICITGEMM_1X1_ENVS} $ --verbose --input 16 384 8 8 --weights 64 384 1 1 --pads_strides_dilations 0 0 1 1 1 1 --disable-backward-data --disable-backward-weights +COMMAND ${DYNAMIC_IMPLICITGEMM_BWD_ENVS} $ --verbose --input 64 64 28 28 --weights 16 64 1 1 --pads_strides_dilations 0 0 1 1 1 1 --disable-forward --disable-backward-weights +COMMAND ${DYNAMIC_IMPLICITGEMM_BWD_ENVS} $ --verbose --input 16 128 36 36 --weights 32 128 1 1 --pads_strides_dilations 0 0 1 1 1 1 --disable-forward --disable-backward-weights ) add_custom_test(test_conv_igemm_dynamic SKIP_UNLESS_ALL @@ -560,7 +567,10 @@ COMMAND ${DYNAMIC_IMPLICITGEMM_ENVS} $ --verbose -- COMMAND ${DYNAMIC_IMPLICITGEMM_1X1_ENVS} $ --verbose --input 128 256 28 28 --weights 128 256 1 1 --pads_strides_dilations 0 0 1 1 1 1 --disable-backward-data --disable-backward-weights COMMAND ${DYNAMIC_IMPLICITGEMM_1X1_ENVS} $ --verbose --input 64 1536 8 8 --weights 256 1536 1 1 --pads_strides_dilations 0 0 1 1 1 1 --disable-backward-data --disable-backward-weights COMMAND ${DYNAMIC_IMPLICITGEMM_1X1_ENVS} $ --verbose --input 128 768 17 17 --weights 128 768 1 1 --pads_strides_dilations 0 0 1 1 1 1 --disable-backward-data --disable-backward-weights -) +COMMAND ${DYNAMIC_IMPLICITGEMM_BWD_ENVS} $ --verbose --input 64 64 56 56 --weights 256 64 1 1 --pads_strides_dilations 0 0 1 1 1 1 --disable-forward --disable-backward-weights +COMMAND ${DYNAMIC_IMPLICITGEMM_BWD_ENVS} $ --verbose --input 32 128 34 34 --weights 64 128 3 3 --pads_strides_dilations 0 0 1 1 1 1 --disable-forward --disable-backward-weights +COMMAND ${DYNAMIC_IMPLICITGEMM_BWD_ENVS} $ --verbose --input 128 128 35 35 --weights 128 128 3 3 --pads_strides_dilations 1 1 1 1 1 1 --disable-forward --disable-backward-weights +COMMAND ${DYNAMIC_IMPLICITGEMM_BWD_ENVS} $ --verbose --input 128 256 56 56 --weights 64 256 1 1 --pads_strides_dilations 0 0 1 1 1 1 --disable-forward --disable-backward-weights) if(MIOPEN_TEST_DEEPBENCH) add_custom_test(test_deepbench_conv