Skip to content

Commit

Permalink
fix a bug when re-factoring fwd invoker
Browse files Browse the repository at this point in the history
  • Loading branch information
carlushuang committed Jul 22, 2020
1 parent 5096c06 commit fa44e70
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 23 deletions.
106 changes: 89 additions & 17 deletions src/conv/invokers/impl_gemm_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ float CallImplGemmDynamicForward(const miopen::Handle& handle,
int x = conv_problem.GetWeightsWidth();
int __pack0 = 0;
// clang-format on
bool kernel_is_1x1 = ((x == 1) && (y == 1));

std::vector<OpKernelArg> opArgs;
opArgs.emplace_back(src);
Expand All @@ -58,31 +57,78 @@ float CallImplGemmDynamicForward(const miopen::Handle& handle,
opArgs.emplace_back(dilation_w);
opArgs.emplace_back(pad_h);
opArgs.emplace_back(pad_w);
// clang-format off
if(kernel_is_1x1)
{
opArgs.emplace_back(__pack0);
}
else
{
opArgs.emplace_back(y);
opArgs.emplace_back(x);
opArgs.emplace_back(__pack0);
}
opArgs.emplace_back(y);
opArgs.emplace_back(x);
opArgs.emplace_back(__pack0);

kernel(opArgs);

if(handle.IsProfilingEnabled())
elapsed += handle.GetKernelTime();
return elapsed;
}

float CallImplGemmDynamicForward1x1(const miopen::Handle& handle,
const ProblemDescription& conv_problem,
ConstData_t src,
Data_t dst,
ConstData_t wei,
const std::vector<KernelInvoke>& kernels)
{
float elapsed = 0.0f;

auto kernel = kernels[0];
MIOPEN_LOG_I(kernel.GetName());

// clang-format off
int hi = conv_problem.GetInHeight();
int wi = conv_problem.GetInWidth();
int n = conv_problem.GetInBatchSize();
int k = conv_problem.GetOutChannels();
int c = conv_problem.GetInChannels();
int ho = conv_problem.GetOutHeight();
int wo = conv_problem.GetOutWidth();
int stride_h = conv_problem.GetKernelStrideH();
int stride_w = conv_problem.GetKernelStrideW();
int dilation_h = conv_problem.GetDilationH();
int dilation_w = conv_problem.GetDilationW();
int pad_h = conv_problem.GetPadH();
int pad_w = conv_problem.GetPadW();
int __pack0 = 0;
// clang-format on

std::vector<OpKernelArg> opArgs;
opArgs.emplace_back(src);
opArgs.emplace_back(wei);
opArgs.emplace_back(dst);
opArgs.emplace_back(hi);
opArgs.emplace_back(wi);
opArgs.emplace_back(n);
opArgs.emplace_back(k);
opArgs.emplace_back(c);
opArgs.emplace_back(ho);
opArgs.emplace_back(wo);
opArgs.emplace_back(stride_h);
opArgs.emplace_back(stride_w);
opArgs.emplace_back(dilation_h);
opArgs.emplace_back(dilation_w);
opArgs.emplace_back(pad_h);
opArgs.emplace_back(pad_w);
opArgs.emplace_back(__pack0);

kernel(opArgs);

if(handle.IsProfilingEnabled())
elapsed += handle.GetKernelTime();
return elapsed;
}

float CallImplGemmDynamicBackwardData(const miopen::Handle& handle,
const ProblemDescription& conv_problem,
ConstData_t src,
Data_t dst,
ConstData_t wei,
const std::vector<KernelInvoke>& kernels)
const ProblemDescription& conv_problem,
ConstData_t src,
Data_t dst,
ConstData_t wei,
const std::vector<KernelInvoke>& kernels)
{
float elapsed = 0.0f;

Expand Down Expand Up @@ -227,6 +273,32 @@ InvokerFactory MakeImplGemmDynamicForwardInvokerFactory(const ConvolutionContext
};
}

InvokerFactory MakeImplGemmDynamicForward1x1InvokerFactory(const ConvolutionContext& ctx)
{
const auto& conv_problem = ctx.conv_problem;
return [conv_problem](const std::vector<Kernel>& kernels) {
return [=](const Handle& handle, const boost::any& primitive_parameters) {
const auto data_ctx = boost::any_cast<conv::DataInvokeParams>(primitive_parameters);
const auto& tensors = data_ctx.tensors;
auto kernel = handle.Run(kernels[0]);

std::vector<KernelInvoke> ks;
std::transform(kernels.begin(),
kernels.end(),
std::back_inserter(ks),
[&](const Kernel& k) { return handle.Run(k); });
float elapsed = 0;
elapsed = CallImplGemmDynamicForward1x1(
handle, conv_problem, tensors.in, tensors.out, tensors.w, ks);
if(handle.IsProfilingEnabled())
{
handle.ResetKernelTime();
handle.AccumKernelTime(elapsed);
}
};
};
}

InvokerFactory MakeImplGemmDynamicBackwardDataInvokerFactory(const ConvolutionContext& ctx)
{
const auto& conv_problem = ctx.conv_problem;
Expand Down
8 changes: 7 additions & 1 deletion src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ float CallImplGemmDynamicForward(const miopen::Handle& handle,
Data_t dst,
ConstData_t wei,
const std::vector<KernelInvoke>& kernels);

float CallImplGemmDynamicForward1x1(const miopen::Handle& handle,
const ProblemDescription& conv_problem,
ConstData_t src,
Data_t dst,
ConstData_t wei,
const std::vector<KernelInvoke>& kernels);
float CallImplGemmDynamicBackwardData(const miopen::Handle& handle,
const ProblemDescription& conv_problem,
ConstData_t src,
Expand All @@ -50,6 +55,7 @@ float CallImplGemmDynamicBackwardData(const miopen::Handle& handle,
const std::vector<KernelInvoke>& kernels);

InvokerFactory MakeImplGemmDynamicForwardInvokerFactory(const ConvolutionContext& ctx);
InvokerFactory MakeImplGemmDynamicForward1x1InvokerFactory(const ConvolutionContext& ctx);
InvokerFactory MakeImplGemmDynamicBackwardDataInvokerFactory(const ConvolutionContext& ctx);

} // namespace conv
Expand Down
20 changes: 15 additions & 5 deletions src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,14 @@ static inline int RunAndMeasureSolutionDynamicBase(miopen::Handle& profile_h,
k_info.comp_options);
kernels.push_back(kernel);
}
float time = conv::CallImplGemmDynamicForward(
profile_h, conv_problem, bot_buf, top_buf, wei_buf, kernels);
bool kernel_is_1x1 = (kernels[0].GetName().find("igemm_v4r1_1x1_dynamic") == 0);
float time;
if(kernel_is_1x1)
time = conv::CallImplGemmDynamicForward1x1(
profile_h, conv_problem, bot_buf, top_buf, wei_buf, kernels);
else
time = conv::CallImplGemmDynamicForward(
profile_h, conv_problem, bot_buf, top_buf, wei_buf, kernels);
elapsed_time += time;
}
#ifdef NDEBUG
Expand Down Expand Up @@ -531,8 +537,9 @@ static inline ConvSolution GetSolutionBase(const ConvolutionContext& ctx,

std::string kernel_name = GetKernelNameImplicitGemmV4R1Dynamic(config, kernel_type);

int block_size = GetImplicitGemmV4R1DynamicBlockSize(config);
int grid_size = GetImplicitGemmV4R1DynamicGridSize(ctx, config);
int block_size = GetImplicitGemmV4R1DynamicBlockSize(config);
int grid_size = GetImplicitGemmV4R1DynamicGridSize(ctx, config);
bool kernel_is_1x1 = (kernel_name.find("igemm_v4r1_1x1_dynamic") == 0);

KernelInfo kernel;
std::ostringstream options;
Expand All @@ -558,7 +565,10 @@ static inline ConvSolution GetSolutionBase(const ConvolutionContext& ctx,

MIOPEN_LOG_I2(kernel.kernel_file + ":" + kernel.kernel_name);

result.invoker_factory = conv::MakeImplGemmDynamicForwardInvokerFactory(ctx);
if(kernel_is_1x1)
result.invoker_factory = conv::MakeImplGemmDynamicForward1x1InvokerFactory(ctx);
else
result.invoker_factory = conv::MakeImplGemmDynamicForwardInvokerFactory(ctx);
result.construction_params.push_back(kernel);
return result;
}
Expand Down

0 comments on commit fa44e70

Please sign in to comment.