Skip to content

Commit

Permalink
[MLIR] Implement tuning - step 2: fwd, xdlops (#1139)
Browse files Browse the repository at this point in the history
* Implement tuning support for ConvMlirIgemmFwdXdlops

* Narrowing KPACK size tuning range to 4/8
  • Loading branch information
jerryyin authored Sep 8, 2021
1 parent d70ece3 commit c636bf2
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 3 deletions.
53 changes: 53 additions & 0 deletions src/include/miopen/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,10 +849,63 @@ struct ConvMlirIgemmFwd : SolverBase<ConvolutionContext>
bool disableConfigOverrideFromEnv = false) const;
};

struct PerformanceConvMlirIgemmXdlops : Serializable<PerformanceConvMlirIgemmXdlops>
{
int GemmMPerBlock; // 2^n[32..128]
int GemmNPerBlock; // 2^n[8..16]
int GemmKPerBlock; // 2^n[4..16]

int GemmMPerWave;
int GemmNPerWave;

int GemmKPACKSize; // 2^[1..4]

// GemmAThreadCopyMoreGemmK is currently a fix value, is untunable
bool GemmAThreadCopyMoreGemmK;
bool GemmBThreadCopyMoreGemmKPack;

bool use_spare_set;
PerformanceConvMlirIgemmXdlops(int, int, int, int, int, int, bool, bool, bool);

PerformanceConvMlirIgemmXdlops();
PerformanceConvMlirIgemmXdlops(bool spare);
PerformanceConvMlirIgemmXdlops(int a, int b, int c, int d, int e, int f, bool g, bool h)
: PerformanceConvMlirIgemmXdlops(a, b, c, d, e, f, g, h, false)
{
}

bool operator==(const PerformanceConvMlirIgemmXdlops& other) const;

template <class Self, class F>
static void Visit(Self&& self, F f)
{
f(self.GemmNPerBlock, "GemmNPerBlock");
f(self.GemmMPerBlock, "GemmMPerBlock");
f(self.GemmKPerBlock, "GemmKPerBlock");
f(self.GemmMPerWave, "GemmMPerWave");
f(self.GemmNPerWave, "GemmNPerWave");
f(self.GemmKPACKSize, "GemmKPACKSize");
f(self.GemmAThreadCopyMoreGemmK, "GemmAThreadCopyMoreGemmK");
f(self.GemmBThreadCopyMoreGemmKPack, "GemmBThreadCopyMoreGemmKPack");
}

bool IsValid(const ConvolutionContext& ctx) const;
bool SetNextValue(const ConvolutionContext& config);
std::string ToString() const;
};

struct ConvMlirIgemmFwdXdlops : SolverBase<ConvolutionContext>
{
bool IsApplicable(const ConvolutionContext& ctx) const;
ConvSolution GetSolution(const ConvolutionContext& ctx) const;
PerformanceConvMlirIgemmXdlops GetPerformanceConfig(const ConvolutionContext& ctx) const;
bool IsValidPerformanceConfig(const ConvolutionContext& ctx,
const PerformanceConvMlirIgemmXdlops& config) const;
PerformanceConvMlirIgemmXdlops Search(const ConvolutionContext&,
const AnyInvokeParams& invoke_ctx) const;
ConvSolution GetSolution(const ConvolutionContext& ctx,
const PerformanceConvMlirIgemmXdlops& config,
bool disableConfigOverrideFromEnv = false) const;
};

struct PerformanceImplicitGemmV4R4GenXdlopsFwdFp32
Expand Down
131 changes: 128 additions & 3 deletions src/solver/conv_mlir_igemm_fwd_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <miopen/conv/invokers/mlir_impl_gemm.hpp>
#include <miopen/config.h>
#include <miopen/env.hpp>
#include <miopen/generic_search.hpp>
#include <miopen/mlir_build.hpp>
#include <miopen/solver.hpp>
#include <miopen/solver/implicitgemm_util.hpp>
Expand Down Expand Up @@ -69,16 +70,139 @@ bool ConvMlirIgemmFwdXdlops::IsApplicable(const ConvolutionContext& ctx) const
#endif
}

ConvSolution ConvMlirIgemmFwdXdlops::GetSolution(const ConvolutionContext& ctx) const
PerformanceConvMlirIgemmXdlops::PerformanceConvMlirIgemmXdlops(int GemmMPerBlock_,
int GemmNPerBlock_,
int GemmKPerBlock_,
int GemmMPerWave_,
int GemmNPerWave_,
int GemmKPACKSize_,
bool GemmAThreadCopyMoreGemmK_,
bool GemmBThreadCopyMoreGemmKPack_,
bool use_spare_set_)
: GemmMPerBlock(GemmMPerBlock_),
GemmNPerBlock(GemmNPerBlock_),
GemmKPerBlock(GemmKPerBlock_),
GemmMPerWave(GemmMPerWave_),
GemmNPerWave(GemmNPerWave_),
GemmKPACKSize(GemmKPACKSize_),
GemmAThreadCopyMoreGemmK(GemmAThreadCopyMoreGemmK_),
GemmBThreadCopyMoreGemmKPack(GemmBThreadCopyMoreGemmKPack_),
use_spare_set(use_spare_set_)
{
}

PerformanceConvMlirIgemmXdlops::PerformanceConvMlirIgemmXdlops(bool spare)
: PerformanceConvMlirIgemmXdlops::PerformanceConvMlirIgemmXdlops(
4, 16, 1, 4, 16, 4, false, false, spare)
{
}

PerformanceConvMlirIgemmXdlops::PerformanceConvMlirIgemmXdlops()
: PerformanceConvMlirIgemmXdlops::PerformanceConvMlirIgemmXdlops(
-1, -1, -1, -1, -1, -1, false, false)
{
}

bool PerformanceConvMlirIgemmXdlops::operator==(const PerformanceConvMlirIgemmXdlops& other) const
{
// clang-format off
return GemmMPerBlock == other.GemmMPerBlock
&& GemmNPerBlock == other.GemmNPerBlock
&& GemmKPerBlock == other.GemmKPerBlock
&& GemmMPerWave == other.GemmMPerWave
&& GemmNPerWave == other.GemmNPerWave
&& GemmKPACKSize == other.GemmKPACKSize
&& GemmAThreadCopyMoreGemmK == other.GemmAThreadCopyMoreGemmK
&& GemmBThreadCopyMoreGemmKPack == other.GemmBThreadCopyMoreGemmKPack
&& use_spare_set == other.use_spare_set;
// clang-format on
}

bool PerformanceConvMlirIgemmXdlops::IsValid(const ConvolutionContext& ctx) const
{
#if MIOPEN_USE_MLIR
bool isValid = MiirIsConfigApplicable(
mlir::ConstructBuildOptions(ctx, GetOperation(), GetKernelName(), ToString(), true));
return isValid;
#else
std::ignore = ctx;
return false;
#endif
}

bool PerformanceConvMlirIgemmXdlops::SetNextValue(const ConvolutionContext& /*config*/)
{
GemmBThreadCopyMoreGemmKPack = true;
GemmAThreadCopyMoreGemmK = true;
do
{
if(!NextTwoPower<4, 256>(GemmMPerBlock))
break;
if(!NextTwoPower<16, 256>(GemmNPerBlock))
break;
if(!NextTwoPower<1, 8>(GemmKPerBlock))
break;
if(!NextTwoPower<4, 128>(GemmMPerWave))
break;
if(!NextTwoPower<16, 128>(GemmNPerWave))
break;
if(!NextTwoPower<4, 8>(GemmKPACKSize))
break;

return false;
} while(false);

return true;
}

std::string PerformanceConvMlirIgemmXdlops::ToString() const
{
std::ostringstream ss;
Serialize(ss);
return ss.str();
}

PerformanceConvMlirIgemmXdlops
ConvMlirIgemmFwdXdlops::GetPerformanceConfig(const ConvolutionContext& ctx) const
{
std::ignore = ctx;
return {};
}

bool ConvMlirIgemmFwdXdlops::IsValidPerformanceConfig(
const ConvolutionContext& ctx, const PerformanceConvMlirIgemmXdlops& config) const
{
MIOPEN_LOG_I("");
return config.IsValid(ctx);
}

PerformanceConvMlirIgemmXdlops
ConvMlirIgemmFwdXdlops::Search(const ConvolutionContext& ctx,
const AnyInvokeParams& invoke_ctx) const
{
return GenericSearch(*this, ctx, invoke_ctx);
}

ConvSolution ConvMlirIgemmFwdXdlops::GetSolution(const ConvolutionContext& ctx,
const PerformanceConvMlirIgemmXdlops& config,
bool) const
{
#if MIOPEN_USE_MLIR
ConvSolution result;
KernelInfo construction_parameters;

construction_parameters.kernel_name = GetKernelName();
construction_parameters.kernel_file = construction_parameters.kernel_name + ".mlir";
construction_parameters.comp_options =
mlir::ConstructBuildOptions(ctx, GetOperation(), GetKernelName(), true);

if(config == PerformanceConvMlirIgemmXdlops())
// At this case, do not pass in the invalid perf config and instead make Miir library to do
// heuristic initialization
construction_parameters.comp_options =
mlir::ConstructBuildOptions(ctx, GetOperation(), GetKernelName(), true);
else
// At this case, Make Miir library to use the valid perf config
construction_parameters.comp_options = mlir::ConstructBuildOptions(
ctx, GetOperation(), GetKernelName(), config.ToString(), true);

size_t local_size = 0;
size_t global_size = 0;
Expand All @@ -97,6 +221,7 @@ ConvSolution ConvMlirIgemmFwdXdlops::GetSolution(const ConvolutionContext& ctx)
return result;
#else
std::ignore = ctx;
std::ignore = config;
return {};
#endif
}
Expand Down

0 comments on commit c636bf2

Please sign in to comment.