diff --git a/src/conv/invokers/impl_gemm_dynamic.cpp b/src/conv/invokers/impl_gemm_dynamic.cpp index c993eadf40..bc7baa39bc 100644 --- a/src/conv/invokers/impl_gemm_dynamic.cpp +++ b/src/conv/invokers/impl_gemm_dynamic.cpp @@ -21,7 +21,6 @@ float CallImplGemmDynamicForward(const miopen::Handle& handle, auto kernel = kernels[0]; MIOPEN_LOG_I(kernel.GetName()); - bool kernel_is_1x1 = (kernel.GetName().find("igemm_v4r1_1x1_dynamic") == 0); // clang-format off int hi = conv_problem.GetInHeight(); int wi = conv_problem.GetInWidth(); @@ -40,6 +39,8 @@ float CallImplGemmDynamicForward(const miopen::Handle& handle, int x = conv_problem.GetWeightsWidth(); int __pack0 = 0; // clang-format on + bool kernel_is_1x1 = ((x == 1) && (y == 1)); + std::vector opArgs; opArgs.emplace_back(src); opArgs.emplace_back(wei); @@ -208,28 +209,19 @@ InvokerFactory MakeImplGemmDynamicForwardInvokerFactory(const ConvolutionContext const auto data_ctx = boost::any_cast(primitive_parameters); const auto& tensors = data_ctx.tensors; auto kernel = handle.Run(kernels[0]); - if(kernel.GetName().find("igemm_v4r1_dynamic") == 0 || - kernel.GetName().find("igemm_v4r1_1x1_dynamic") == 0) - { - std::vector ks; - std::transform(kernels.begin(), - kernels.end(), - std::back_inserter(ks), - [&](const Kernel& k) { return handle.Run(k); }); - float elapsed = 0; - elapsed = CallImplGemmDynamicForward( - handle, conv_problem, tensors.in, tensors.out, tensors.w, ks); - if(handle.IsProfilingEnabled()) - { - handle.ResetKernelTime(); - handle.AccumKernelTime(elapsed); - } - } - else + + std::vector ks; + std::transform(kernels.begin(), + kernels.end(), + std::back_inserter(ks), + [&](const Kernel& k) { return handle.Run(k); }); + float elapsed = 0; + elapsed = CallImplGemmDynamicForward( + handle, conv_problem, tensors.in, tensors.out, tensors.w, ks); + if(handle.IsProfilingEnabled()) { - MIOPEN_THROW( - "Error running dynamic implicit GEMM convolution (invalid kernel name " + - kernel.GetName() + ")"); + handle.ResetKernelTime(); + handle.AccumKernelTime(elapsed); } }; };