Skip to content

Commit

Permalink
[NFC] WrW gemm solvers (#852)
Browse files Browse the repository at this point in the history
* Implemented WrW gemm solvers
* Updated system-find-databases
  • Loading branch information
DrizztDoUrden authored May 7, 2021
1 parent f3921e7 commit 46fbbbf
Show file tree
Hide file tree
Showing 24 changed files with 147,826 additions and 147,838 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ set( MIOpen_Source
conv/problem_description.cpp
solver/gemm.cpp
solver/gemm_bwd.cpp
solver/gemm_wrw.cpp
dropout.cpp
dropout_api.cpp
readonlyramdb.cpp
Expand Down
120 changes: 32 additions & 88 deletions src/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_WINOGRAD)
MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_GEMM)
MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_FFT)

/// \todo Workaround for issue 1430.
/// Vega20 fails to access GPU memory larger than of GetMaxMemoryAllocSize() of Vega10.
#define MAX_MEM_ALLOC_SZ(handle) (std::min((handle).GetMaxMemoryAllocSize(), size_t(7287183769)))

namespace miopen {

ConvolutionDescriptor::ConvolutionDescriptor(std::size_t spatial_dim,
Expand Down Expand Up @@ -301,65 +297,32 @@ bool ConvolutionDescriptor::IsWinograd3x3SupportedAndFast(miopen::ConvolutionCon
return solver::ConvBinWinograd3x3U{}.IsApplicable(ctx);
}

std::size_t
ConvolutionDescriptor::BackwardGetValidWorkSpaceSizeGemm(const TensorDescriptor& dyDesc,
const TensorDescriptor& wDesc,
const TensorDescriptor& dxDesc) const
{
#if MIOPEN_USE_GEMM
if(!miopen::IsDisabled(MIOPEN_DEBUG_CONV_GEMM{}))
{
const auto ctx =
ConvolutionContext{dxDesc, wDesc, dyDesc, *this, conv::Direction::BackwardData};
decltype(auto) gemm_ws_sz_pairs = AllGemmWorkspaceSize(ctx);

if(!gemm_ws_sz_pairs.empty())
{
decltype(auto) gemm_ws_szs =
gemm_ws_sz_pairs |
boost::adaptors::transformed([](const auto& p) { return p.second; });
return *std::max_element(gemm_ws_szs.begin(), gemm_ws_szs.end());
}
}
return 0;
#else
std::ignore = dyDesc;
std::ignore = wDesc;
std::ignore = dxDesc;
return 0;
#endif
}

std::size_t
ConvolutionDescriptor::WrwGetValidWorkSpaceSizeGemm(const TensorDescriptor& dyDesc,
const TensorDescriptor& /*xDesc*/,
const TensorDescriptor& xDesc,
const TensorDescriptor& dwDesc) const
{
#if MIOPEN_USE_GEMM
if(!miopen::IsDisabled(MIOPEN_DEBUG_CONV_GEMM{}))
{
const std::size_t spatial_dim = GetSpatialDimension();
const auto wei_spatial = boost::adaptors::slice(dwDesc.GetLengths(), 2, 2 + spatial_dim);

// if not 1x1
if((miopen::any_of(wei_spatial, [](auto v) { return v != 1; }) ||
miopen::any_of(GetConvPads(), [](auto v) { return v != 0; }) ||
miopen::any_of(GetConvStrides(), [](auto v) { return v != 1; })))
return BackwardWeightsGetWorkSpaceSizeGEMM(dyDesc, dwDesc);
if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_GEMM{}))
return 0;

if(miopen::any_of(wei_spatial, [](auto v) { return v == 1; }) &&
miopen::any_of(GetConvPads(), [](auto v) { return v == 0; }) &&
miopen::any_of(GetConvStrides(), [](auto v) { return v == 1; }))
return 0;
const auto ctx =
ConvolutionContext{xDesc, dwDesc, dyDesc, *this, conv::Direction::BackwardWeights};
decltype(auto) gemm_ws_sz_pairs = AllGemmWorkspaceSize(ctx);

MIOPEN_THROW(miopenStatusNotImplemented);
if(!gemm_ws_sz_pairs.empty())
{
decltype(auto) gemm_ws_szs =
gemm_ws_sz_pairs | boost::adaptors::transformed([](const auto& p) { return p.second; });
return *std::max_element(gemm_ws_szs.begin(), gemm_ws_szs.end());
}
return 0;
#else
std::ignore = dwDesc;
std::ignore = dyDesc;
return 0;
std::ignore = xDesc;
std::ignore = dwDesc;
#endif

return 0;
}

std::size_t ConvolutionDescriptor::ForwardGetWorkSpaceSize(Handle& handle,
Expand Down Expand Up @@ -531,37 +494,28 @@ ConvolutionDescriptor::BackwardDataGetWorkSpaceSize(Handle& handle,
return workspace_size;
}

std::size_t
ConvolutionDescriptor::BackwardWeightsGetWorkSpaceSizeGEMM(const TensorDescriptor& dyDesc,
const TensorDescriptor& dwDesc) const
std::size_t ConvolutionDescriptor::BackwardWeightsGetWorkSpaceSizeGEMM(
const miopen::ConvolutionContext& ctx) const
{
const std::size_t spatial_dim = GetSpatialDimension();

auto out_spatial = boost::adaptors::slice(dyDesc.GetLengths(), 2, 2 + spatial_dim);
auto wei_spatial = boost::adaptors::slice(dwDesc.GetLengths(), 2, 2 + spatial_dim);

const std::size_t wei_c = dwDesc.GetLengths()[1];
#if MIOPEN_USE_GEMM
if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_GEMM{}))
return 0;

const std::size_t gemm_size = GetTypeSize(dyDesc.GetType()) * wei_c *
std::accumulate(out_spatial.begin(),
out_spatial.end(),
std::size_t(1),
std::multiplies<std::size_t>()) *
std::accumulate(wei_spatial.begin(),
wei_spatial.end(),
std::size_t(1),
std::multiplies<std::size_t>()) *
group_count;
decltype(auto) gemm_ws_sz_pairs = AllGemmWorkspaceSize(ctx);

// No workspace is needed for 1x1_stride=1 convolutions
if(miopen::all_of(wei_spatial, [](auto v) { return v == 1; }) &&
miopen::all_of(GetConvStrides(), [](auto v) { return v == 1; }) &&
miopen::all_of(GetConvPads(), [](auto v) { return v == 0; }))
if(!gemm_ws_sz_pairs.empty())
{
return 0;
decltype(auto) gemm_ws_szs =
gemm_ws_sz_pairs | boost::adaptors::transformed([](const auto& p) { return p.second; });
return *std::max_element(gemm_ws_szs.begin(), gemm_ws_szs.end());
}
#else
std::ignore = dyDesc;
std::ignore = xDesc;
std::ignore = dwDesc;
#endif

return gemm_size;
return 0;
}

std::size_t ConvolutionDescriptor::ForwardBackwardGetWorkSpaceSizeImplicitGemm(
Expand Down Expand Up @@ -793,20 +747,10 @@ ConvolutionDescriptor::BackwardWeightsGetWorkSpaceSize(Handle& handle,
ctx.do_search = false;
ctx.disable_perfdb_access = true;

std::size_t workspace_size_gemm = 0;
#if MIOPEN_USE_GEMM
if(!miopen::IsDisabled(MIOPEN_DEBUG_CONV_GEMM{}))
{
workspace_size_gemm = BackwardWeightsGetWorkSpaceSizeGEMM(dyDesc, dwDesc);
if(workspace_size_gemm > MAX_MEM_ALLOC_SZ(handle))
workspace_size_gemm = 0;
}
#endif

const size_t workspace_size = std::max({BackwardWeightsGetWorkSpaceSizeImplicitGemm(ctx),
BackwardWeightsGetWorkSpaceSizeWinograd(ctx),
BackwardWeightsGetWorkSpaceSizeDirect(ctx),
workspace_size_gemm});
BackwardWeightsGetWorkSpaceSizeGEMM(ctx)});
MIOPEN_LOG_I2(workspace_size);
return workspace_size;
}
Expand Down
3 changes: 2 additions & 1 deletion src/find_db.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ bool CheckInvokerSupport(const std::string& algo)
algo == "miopenConvolutionBwdDataAlgoImplicitGEMM" ||
algo == "miopenConvolutionBwdWeightsAlgoImplicitGEMM" ||
algo == "miopenConvolutionFwdAlgoFFT" || algo == "miopenConvolutionBwdDataAlgoFFT" ||
algo == "miopenConvolutionFwdAlgoGEMM" || algo == "miopenConvolutionBwdDataAlgoGEMM";
algo == "miopenConvolutionFwdAlgoGEMM" || algo == "miopenConvolutionBwdDataAlgoGEMM" ||
algo == "miopenConvolutionBwdWeightsAlgoGEMM";
}

template <class TDb>
Expand Down
30 changes: 2 additions & 28 deletions src/include/miopen/convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,6 @@ struct ConvolutionDescriptor : miopenConvolutionDescriptor

bool IsWinograd3x3SupportedAndFast(miopen::ConvolutionContext& ctx) const;

std::size_t BackwardGetValidWorkSpaceSizeGemm(const TensorDescriptor& dyDesc,
const TensorDescriptor& wDesc,
const TensorDescriptor& dxDesc) const;

std::size_t WrwGetValidWorkSpaceSizeGemm(const TensorDescriptor& dyDesc,
const TensorDescriptor& xDesc,
const TensorDescriptor& dwDesc) const;
Expand Down Expand Up @@ -338,8 +334,7 @@ struct ConvolutionDescriptor : miopenConvolutionDescriptor
const TensorDescriptor& xDesc,
const TensorDescriptor& dwDesc) const;

std::size_t BackwardWeightsGetWorkSpaceSizeGEMM(const TensorDescriptor& dyDesc,
const TensorDescriptor& dwDesc) const;
std::size_t BackwardWeightsGetWorkSpaceSizeGEMM(const miopen::ConvolutionContext& ctx) const;

std::size_t BackwardWeightsGetWorkSpaceSizeDirect(const miopen::ConvolutionContext& ctx) const;
std::size_t
Expand All @@ -361,7 +356,7 @@ struct ConvolutionDescriptor : miopenConvolutionDescriptor
std::size_t workSpaceSize,
bool exhaustiveSearch) const;

void ConvolutionBackwardWeights(Handle& handle,
void ConvolutionBackwardWeights(const Handle& handle,
const void* alpha,
const TensorDescriptor& dyDesc,
ConstData_t dy,
Expand Down Expand Up @@ -394,11 +389,6 @@ struct ConvolutionDescriptor : miopenConvolutionDescriptor
const TensorDescriptor& xDesc,
const TensorDescriptor& dwDesc) const;

void BackwardWeightsGemm(Handle& handle,
const ConvWrwTensors& tensors,
Data_t workSpace,
std::size_t workSpaceSize) const;

template <class TKernels>
void BackwardWeightsDirect(Handle& handle,
const ConvolutionContext& ctx,
Expand All @@ -412,22 +402,6 @@ struct ConvolutionDescriptor : miopenConvolutionDescriptor
size_t* solutionCount,
miopenConvSolution_t* solutions) const;

bool IsGemmApplicableBwd(const TensorDescriptor& dyDesc,
const TensorDescriptor& wDesc,
const TensorDescriptor& dxDesc) const;

bool IsGemmApplicableWrw(const TensorDescriptor& dyDesc,
const TensorDescriptor& xDesc,
const TensorDescriptor& dwDesc) const;

float ComputeGemmWtiBwd(const TensorDescriptor& dyDesc,
const TensorDescriptor& wDesc,
const TensorDescriptor& dxDesc) const;

float ComputeGemmWtiWrw(const TensorDescriptor& dyDesc,
const TensorDescriptor& xDesc,
const TensorDescriptor& dwDesc) const;

std::size_t GetSolutionCountFallback(Handle& handle, const ProblemDescription& problem) const;
};

Expand Down
1 change: 1 addition & 0 deletions src/include/miopen/find_solution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ struct SolverContainer
++count;
auto sz = solver.GetWorkspaceSize(search_params);
res.push_back(std::make_pair(SolverDbId(solver), sz));
MIOPEN_LOG_I2(SolverDbId(solver) << ": " << sz);
}
},
Solvers{}...);
Expand Down
62 changes: 52 additions & 10 deletions src/include/miopen/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1950,16 +1950,6 @@ struct fft : SolverBase<ConvolutionContext>
ConvSolution GetSolution(const ConvolutionContext& ctx) const;
};

/// Partial implementation.
struct gemm : SolverBase<ConvolutionContext>
{
bool IsApplicable(const ConvolutionContext& /*params*/) const { return false; };
ConvSolution GetSolution(const ConvolutionContext&) const
{
return ConvSolution{miopenStatusNotInitialized};
}
};

struct PerformanceImplicitGemmWrwV4R4Xdlops : Serializable<PerformanceImplicitGemmWrwV4R4Xdlops>
{
int GemmMPerBlock;
Expand Down Expand Up @@ -2305,6 +2295,58 @@ struct GemmBwdRest : GemmBwdBase
ConvSolution GetSolution(const ExecutionContext&, const conv::ProblemDescription&) const;
};

struct GemmWrwBase : SolverBase<ConvolutionContext>
{
bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const;
bool IsDynamic() const { return true; }
float GetWti(const ConvolutionContext& ctx) const { return GetWti(ctx, ctx.conv_problem); }
float GetWti(const ExecutionContext& context, const conv::ProblemDescription& problem) const;
};

struct GemmWrw1x1_stride1 : GemmWrwBase
{
size_t GetWorkspaceSize(const ConvolutionContext& ctx) const
{
return GetWorkspaceSize(ctx, ctx.conv_problem);
}

bool IsApplicable(const ConvolutionContext& ctx) const
{
return IsApplicable(ctx, ctx.conv_problem);
}

ConvSolution GetSolution(const ConvolutionContext& ctx) const
{
return GetSolution(ctx, ctx.conv_problem);
}

size_t GetWorkspaceSize(const ExecutionContext&, const conv::ProblemDescription&) const;
bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const;
ConvSolution GetSolution(const ExecutionContext&, const conv::ProblemDescription&) const;
};

struct GemmWrwUniversal : GemmWrwBase
{
size_t GetWorkspaceSize(const ConvolutionContext& ctx) const
{
return GetWorkspaceSize(ctx, ctx.conv_problem);
}

bool IsApplicable(const ConvolutionContext& ctx) const
{
return IsApplicable(ctx, ctx.conv_problem);
}

ConvSolution GetSolution(const ConvolutionContext& ctx) const
{
return GetSolution(ctx, ctx.conv_problem);
}

size_t GetWorkspaceSize(const ExecutionContext&, const conv::ProblemDescription&) const;
bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const;
ConvSolution GetSolution(const ExecutionContext&, const conv::ProblemDescription&) const;
};

struct AnySolver;

} // namespace solver
Expand Down
6 changes: 0 additions & 6 deletions src/include/miopen/solver_id.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,6 @@ struct Id
}
bool operator!=(const Id& other) const { return !(*this == other); }

static solver::Id gemm()
{
static const auto value = solver::Id{"gemm"};
return value;
}

private:
uint64_t value = invalid_value;
bool is_valid = false;
Expand Down
Loading

0 comments on commit 46fbbbf

Please sign in to comment.