Skip to content

Commit

Permalink
remove kernel name check in invoker
Browse files Browse the repository at this point in the history
  • Loading branch information
carlushuang committed Jul 21, 2020
1 parent 5368a56 commit 5096c06
Showing 1 changed file with 14 additions and 22 deletions.
36 changes: 14 additions & 22 deletions src/conv/invokers/impl_gemm_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ float CallImplGemmDynamicForward(const miopen::Handle& handle,
auto kernel = kernels[0];
MIOPEN_LOG_I(kernel.GetName());

bool kernel_is_1x1 = (kernel.GetName().find("igemm_v4r1_1x1_dynamic") == 0);
// clang-format off
int hi = conv_problem.GetInHeight();
int wi = conv_problem.GetInWidth();
Expand All @@ -40,6 +39,8 @@ float CallImplGemmDynamicForward(const miopen::Handle& handle,
int x = conv_problem.GetWeightsWidth();
int __pack0 = 0;
// clang-format on
bool kernel_is_1x1 = ((x == 1) && (y == 1));

std::vector<OpKernelArg> opArgs;
opArgs.emplace_back(src);
opArgs.emplace_back(wei);
Expand Down Expand Up @@ -208,28 +209,19 @@ InvokerFactory MakeImplGemmDynamicForwardInvokerFactory(const ConvolutionContext
const auto data_ctx = boost::any_cast<conv::DataInvokeParams>(primitive_parameters);
const auto& tensors = data_ctx.tensors;
auto kernel = handle.Run(kernels[0]);
if(kernel.GetName().find("igemm_v4r1_dynamic") == 0 ||
kernel.GetName().find("igemm_v4r1_1x1_dynamic") == 0)
{
std::vector<KernelInvoke> ks;
std::transform(kernels.begin(),
kernels.end(),
std::back_inserter(ks),
[&](const Kernel& k) { return handle.Run(k); });
float elapsed = 0;
elapsed = CallImplGemmDynamicForward(
handle, conv_problem, tensors.in, tensors.out, tensors.w, ks);
if(handle.IsProfilingEnabled())
{
handle.ResetKernelTime();
handle.AccumKernelTime(elapsed);
}
}
else

std::vector<KernelInvoke> ks;
std::transform(kernels.begin(),
kernels.end(),
std::back_inserter(ks),
[&](const Kernel& k) { return handle.Run(k); });
float elapsed = 0;
elapsed = CallImplGemmDynamicForward(
handle, conv_problem, tensors.in, tensors.out, tensors.w, ks);
if(handle.IsProfilingEnabled())
{
MIOPEN_THROW(
"Error running dynamic implicit GEMM convolution (invalid kernel name " +
kernel.GetName() + ")");
handle.ResetKernelTime();
handle.AccumKernelTime(elapsed);
}
};
};
Expand Down

0 comments on commit 5096c06

Please sign in to comment.