diff --git a/src/conv/invokers/impl_gemm_dynamic.cpp b/src/conv/invokers/impl_gemm_dynamic.cpp index bc7baa39bc..d7d2deed25 100644 --- a/src/conv/invokers/impl_gemm_dynamic.cpp +++ b/src/conv/invokers/impl_gemm_dynamic.cpp @@ -39,7 +39,6 @@ float CallImplGemmDynamicForward(const miopen::Handle& handle, int x = conv_problem.GetWeightsWidth(); int __pack0 = 0; // clang-format on - bool kernel_is_1x1 = ((x == 1) && (y == 1)); std::vector opArgs; opArgs.emplace_back(src); @@ -58,17 +57,10 @@ float CallImplGemmDynamicForward(const miopen::Handle& handle, opArgs.emplace_back(dilation_w); opArgs.emplace_back(pad_h); opArgs.emplace_back(pad_w); - // clang-format off - if(kernel_is_1x1) - { - opArgs.emplace_back(__pack0); - } - else - { - opArgs.emplace_back(y); - opArgs.emplace_back(x); - opArgs.emplace_back(__pack0); - } + opArgs.emplace_back(y); + opArgs.emplace_back(x); + opArgs.emplace_back(__pack0); + kernel(opArgs); if(handle.IsProfilingEnabled()) @@ -76,13 +68,67 @@ float CallImplGemmDynamicForward(const miopen::Handle& handle, return elapsed; } +float CallImplGemmDynamicForward1x1(const miopen::Handle& handle, + const ProblemDescription& conv_problem, + ConstData_t src, + Data_t dst, + ConstData_t wei, + const std::vector& kernels) +{ + float elapsed = 0.0f; + + auto kernel = kernels[0]; + MIOPEN_LOG_I(kernel.GetName()); + + // clang-format off + int hi = conv_problem.GetInHeight(); + int wi = conv_problem.GetInWidth(); + int n = conv_problem.GetInBatchSize(); + int k = conv_problem.GetOutChannels(); + int c = conv_problem.GetInChannels(); + int ho = conv_problem.GetOutHeight(); + int wo = conv_problem.GetOutWidth(); + int stride_h = conv_problem.GetKernelStrideH(); + int stride_w = conv_problem.GetKernelStrideW(); + int dilation_h = conv_problem.GetDilationH(); + int dilation_w = conv_problem.GetDilationW(); + int pad_h = conv_problem.GetPadH(); + int pad_w = conv_problem.GetPadW(); + int __pack0 = 0; + // clang-format on + + std::vector opArgs; + opArgs.emplace_back(src); + opArgs.emplace_back(wei); + opArgs.emplace_back(dst); + opArgs.emplace_back(hi); + opArgs.emplace_back(wi); + opArgs.emplace_back(n); + opArgs.emplace_back(k); + opArgs.emplace_back(c); + opArgs.emplace_back(ho); + opArgs.emplace_back(wo); + opArgs.emplace_back(stride_h); + opArgs.emplace_back(stride_w); + opArgs.emplace_back(dilation_h); + opArgs.emplace_back(dilation_w); + opArgs.emplace_back(pad_h); + opArgs.emplace_back(pad_w); + opArgs.emplace_back(__pack0); + + kernel(opArgs); + + if(handle.IsProfilingEnabled()) + elapsed += handle.GetKernelTime(); + return elapsed; +} float CallImplGemmDynamicBackwardData(const miopen::Handle& handle, - const ProblemDescription& conv_problem, - ConstData_t src, - Data_t dst, - ConstData_t wei, - const std::vector& kernels) + const ProblemDescription& conv_problem, + ConstData_t src, + Data_t dst, + ConstData_t wei, + const std::vector& kernels) { float elapsed = 0.0f; @@ -227,6 +273,32 @@ InvokerFactory MakeImplGemmDynamicForwardInvokerFactory(const ConvolutionContext }; } +InvokerFactory MakeImplGemmDynamicForward1x1InvokerFactory(const ConvolutionContext& ctx) +{ + const auto& conv_problem = ctx.conv_problem; + return [conv_problem](const std::vector& kernels) { + return [=](const Handle& handle, const boost::any& primitive_parameters) { + const auto data_ctx = boost::any_cast(primitive_parameters); + const auto& tensors = data_ctx.tensors; + auto kernel = handle.Run(kernels[0]); + + std::vector ks; + std::transform(kernels.begin(), + kernels.end(), + std::back_inserter(ks), + [&](const Kernel& k) { return handle.Run(k); }); + float elapsed = 0; + elapsed = CallImplGemmDynamicForward1x1( + handle, conv_problem, tensors.in, tensors.out, tensors.w, ks); + if(handle.IsProfilingEnabled()) + { + handle.ResetKernelTime(); + handle.AccumKernelTime(elapsed); + } + }; + }; +} + InvokerFactory MakeImplGemmDynamicBackwardDataInvokerFactory(const ConvolutionContext& ctx) { const auto& conv_problem = ctx.conv_problem; diff --git a/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp b/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp index 4897acad00..619be17a59 100644 --- a/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp +++ b/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp @@ -41,7 +41,12 @@ float CallImplGemmDynamicForward(const miopen::Handle& handle, Data_t dst, ConstData_t wei, const std::vector& kernels); - +float CallImplGemmDynamicForward1x1(const miopen::Handle& handle, + const ProblemDescription& conv_problem, + ConstData_t src, + Data_t dst, + ConstData_t wei, + const std::vector& kernels); float CallImplGemmDynamicBackwardData(const miopen::Handle& handle, const ProblemDescription& conv_problem, ConstData_t src, @@ -50,6 +55,7 @@ float CallImplGemmDynamicBackwardData(const miopen::Handle& handle, const std::vector& kernels); InvokerFactory MakeImplGemmDynamicForwardInvokerFactory(const ConvolutionContext& ctx); +InvokerFactory MakeImplGemmDynamicForward1x1InvokerFactory(const ConvolutionContext& ctx); InvokerFactory MakeImplGemmDynamicBackwardDataInvokerFactory(const ConvolutionContext& ctx); } // namespace conv diff --git a/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp index c1f01d2057..e305e02f45 100644 --- a/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp @@ -142,8 +142,14 @@ static inline int RunAndMeasureSolutionDynamicBase(miopen::Handle& profile_h, k_info.comp_options); kernels.push_back(kernel); } - float time = conv::CallImplGemmDynamicForward( - profile_h, conv_problem, bot_buf, top_buf, wei_buf, kernels); + bool kernel_is_1x1 = (kernels[0].GetName().find("igemm_v4r1_1x1_dynamic") == 0); + float time; + if(kernel_is_1x1) + time = conv::CallImplGemmDynamicForward1x1( + profile_h, conv_problem, bot_buf, top_buf, wei_buf, kernels); + else + time = conv::CallImplGemmDynamicForward( + profile_h, conv_problem, bot_buf, top_buf, wei_buf, kernels); elapsed_time += time; } #ifdef NDEBUG @@ -531,8 +537,9 @@ static inline ConvSolution GetSolutionBase(const ConvolutionContext& ctx, std::string kernel_name = GetKernelNameImplicitGemmV4R1Dynamic(config, kernel_type); - int block_size = GetImplicitGemmV4R1DynamicBlockSize(config); - int grid_size = GetImplicitGemmV4R1DynamicGridSize(ctx, config); + int block_size = GetImplicitGemmV4R1DynamicBlockSize(config); + int grid_size = GetImplicitGemmV4R1DynamicGridSize(ctx, config); + bool kernel_is_1x1 = (kernel_name.find("igemm_v4r1_1x1_dynamic") == 0); KernelInfo kernel; std::ostringstream options; @@ -558,7 +565,10 @@ static inline ConvSolution GetSolutionBase(const ConvolutionContext& ctx, MIOPEN_LOG_I2(kernel.kernel_file + ":" + kernel.kernel_name); - result.invoker_factory = conv::MakeImplGemmDynamicForwardInvokerFactory(ctx); + if(kernel_is_1x1) + result.invoker_factory = conv::MakeImplGemmDynamicForward1x1InvokerFactory(ctx); + else + result.invoker_factory = conv::MakeImplGemmDynamicForwardInvokerFactory(ctx); result.construction_params.push_back(kernel); return result; }