From c636bf257b39f288090735a9a4abfdfec08fe107 Mon Sep 17 00:00:00 2001 From: Zhuoran Yin Date: Tue, 7 Sep 2021 21:30:45 -0500 Subject: [PATCH] [MLIR] Implement tuning - step 2: fwd, xdlops (#1139) * Implement tuning support for ConvMlirIgemmFwdXdlops * Narrowing KPACK size tuning range to 4/8 --- src/include/miopen/solver.hpp | 53 +++++++++ src/solver/conv_mlir_igemm_fwd_xdlops.cpp | 131 +++++++++++++++++++++- 2 files changed, 181 insertions(+), 3 deletions(-) diff --git a/src/include/miopen/solver.hpp b/src/include/miopen/solver.hpp index 8e00805db9..c8114fe3b3 100644 --- a/src/include/miopen/solver.hpp +++ b/src/include/miopen/solver.hpp @@ -849,10 +849,63 @@ struct ConvMlirIgemmFwd : SolverBase bool disableConfigOverrideFromEnv = false) const; }; +struct PerformanceConvMlirIgemmXdlops : Serializable +{ + int GemmMPerBlock; // 2^n[32..128] + int GemmNPerBlock; // 2^n[8..16] + int GemmKPerBlock; // 2^n[4..16] + + int GemmMPerWave; + int GemmNPerWave; + + int GemmKPACKSize; // 2^[1..4] + + // GemmAThreadCopyMoreGemmK is currently a fix value, is untunable + bool GemmAThreadCopyMoreGemmK; + bool GemmBThreadCopyMoreGemmKPack; + + bool use_spare_set; + PerformanceConvMlirIgemmXdlops(int, int, int, int, int, int, bool, bool, bool); + + PerformanceConvMlirIgemmXdlops(); + PerformanceConvMlirIgemmXdlops(bool spare); + PerformanceConvMlirIgemmXdlops(int a, int b, int c, int d, int e, int f, bool g, bool h) + : PerformanceConvMlirIgemmXdlops(a, b, c, d, e, f, g, h, false) + { + } + + bool operator==(const PerformanceConvMlirIgemmXdlops& other) const; + + template + static void Visit(Self&& self, F f) + { + f(self.GemmNPerBlock, "GemmNPerBlock"); + f(self.GemmMPerBlock, "GemmMPerBlock"); + f(self.GemmKPerBlock, "GemmKPerBlock"); + f(self.GemmMPerWave, "GemmMPerWave"); + f(self.GemmNPerWave, "GemmNPerWave"); + f(self.GemmKPACKSize, "GemmKPACKSize"); + f(self.GemmAThreadCopyMoreGemmK, "GemmAThreadCopyMoreGemmK"); + f(self.GemmBThreadCopyMoreGemmKPack, "GemmBThreadCopyMoreGemmKPack"); + } + + bool IsValid(const ConvolutionContext& ctx) const; + bool SetNextValue(const ConvolutionContext& config); + std::string ToString() const; +}; + struct ConvMlirIgemmFwdXdlops : SolverBase { bool IsApplicable(const ConvolutionContext& ctx) const; ConvSolution GetSolution(const ConvolutionContext& ctx) const; + PerformanceConvMlirIgemmXdlops GetPerformanceConfig(const ConvolutionContext& ctx) const; + bool IsValidPerformanceConfig(const ConvolutionContext& ctx, + const PerformanceConvMlirIgemmXdlops& config) const; + PerformanceConvMlirIgemmXdlops Search(const ConvolutionContext&, + const AnyInvokeParams& invoke_ctx) const; + ConvSolution GetSolution(const ConvolutionContext& ctx, + const PerformanceConvMlirIgemmXdlops& config, + bool disableConfigOverrideFromEnv = false) const; }; struct PerformanceImplicitGemmV4R4GenXdlopsFwdFp32 diff --git a/src/solver/conv_mlir_igemm_fwd_xdlops.cpp b/src/solver/conv_mlir_igemm_fwd_xdlops.cpp index 0ff503834b..e00eca8047 100644 --- a/src/solver/conv_mlir_igemm_fwd_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_fwd_xdlops.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -69,7 +70,122 @@ bool ConvMlirIgemmFwdXdlops::IsApplicable(const ConvolutionContext& ctx) const #endif } -ConvSolution ConvMlirIgemmFwdXdlops::GetSolution(const ConvolutionContext& ctx) const +PerformanceConvMlirIgemmXdlops::PerformanceConvMlirIgemmXdlops(int GemmMPerBlock_, + int GemmNPerBlock_, + int GemmKPerBlock_, + int GemmMPerWave_, + int GemmNPerWave_, + int GemmKPACKSize_, + bool GemmAThreadCopyMoreGemmK_, + bool GemmBThreadCopyMoreGemmKPack_, + bool use_spare_set_) + : GemmMPerBlock(GemmMPerBlock_), + GemmNPerBlock(GemmNPerBlock_), + GemmKPerBlock(GemmKPerBlock_), + GemmMPerWave(GemmMPerWave_), + GemmNPerWave(GemmNPerWave_), + GemmKPACKSize(GemmKPACKSize_), + GemmAThreadCopyMoreGemmK(GemmAThreadCopyMoreGemmK_), + GemmBThreadCopyMoreGemmKPack(GemmBThreadCopyMoreGemmKPack_), + use_spare_set(use_spare_set_) +{ +} + +PerformanceConvMlirIgemmXdlops::PerformanceConvMlirIgemmXdlops(bool spare) + : PerformanceConvMlirIgemmXdlops::PerformanceConvMlirIgemmXdlops( + 4, 16, 1, 4, 16, 4, false, false, spare) +{ +} + +PerformanceConvMlirIgemmXdlops::PerformanceConvMlirIgemmXdlops() + : PerformanceConvMlirIgemmXdlops::PerformanceConvMlirIgemmXdlops( + -1, -1, -1, -1, -1, -1, false, false) +{ +} + +bool PerformanceConvMlirIgemmXdlops::operator==(const PerformanceConvMlirIgemmXdlops& other) const +{ + // clang-format off + return GemmMPerBlock == other.GemmMPerBlock + && GemmNPerBlock == other.GemmNPerBlock + && GemmKPerBlock == other.GemmKPerBlock + && GemmMPerWave == other.GemmMPerWave + && GemmNPerWave == other.GemmNPerWave + && GemmKPACKSize == other.GemmKPACKSize + && GemmAThreadCopyMoreGemmK == other.GemmAThreadCopyMoreGemmK + && GemmBThreadCopyMoreGemmKPack == other.GemmBThreadCopyMoreGemmKPack + && use_spare_set == other.use_spare_set; + // clang-format on +} + +bool PerformanceConvMlirIgemmXdlops::IsValid(const ConvolutionContext& ctx) const +{ +#if MIOPEN_USE_MLIR + bool isValid = MiirIsConfigApplicable( + mlir::ConstructBuildOptions(ctx, GetOperation(), GetKernelName(), ToString(), true)); + return isValid; +#else + std::ignore = ctx; + return false; +#endif +} + +bool PerformanceConvMlirIgemmXdlops::SetNextValue(const ConvolutionContext& /*config*/) +{ + GemmBThreadCopyMoreGemmKPack = true; + GemmAThreadCopyMoreGemmK = true; + do + { + if(!NextTwoPower<4, 256>(GemmMPerBlock)) + break; + if(!NextTwoPower<16, 256>(GemmNPerBlock)) + break; + if(!NextTwoPower<1, 8>(GemmKPerBlock)) + break; + if(!NextTwoPower<4, 128>(GemmMPerWave)) + break; + if(!NextTwoPower<16, 128>(GemmNPerWave)) + break; + if(!NextTwoPower<4, 8>(GemmKPACKSize)) + break; + + return false; + } while(false); + + return true; +} + +std::string PerformanceConvMlirIgemmXdlops::ToString() const +{ + std::ostringstream ss; + Serialize(ss); + return ss.str(); +} + +PerformanceConvMlirIgemmXdlops +ConvMlirIgemmFwdXdlops::GetPerformanceConfig(const ConvolutionContext& ctx) const +{ + std::ignore = ctx; + return {}; +} + +bool ConvMlirIgemmFwdXdlops::IsValidPerformanceConfig( + const ConvolutionContext& ctx, const PerformanceConvMlirIgemmXdlops& config) const +{ + MIOPEN_LOG_I(""); + return config.IsValid(ctx); +} + +PerformanceConvMlirIgemmXdlops +ConvMlirIgemmFwdXdlops::Search(const ConvolutionContext& ctx, + const AnyInvokeParams& invoke_ctx) const +{ + return GenericSearch(*this, ctx, invoke_ctx); +} + +ConvSolution ConvMlirIgemmFwdXdlops::GetSolution(const ConvolutionContext& ctx, + const PerformanceConvMlirIgemmXdlops& config, + bool) const { #if MIOPEN_USE_MLIR ConvSolution result; @@ -77,8 +193,16 @@ ConvSolution ConvMlirIgemmFwdXdlops::GetSolution(const ConvolutionContext& ctx) construction_parameters.kernel_name = GetKernelName(); construction_parameters.kernel_file = construction_parameters.kernel_name + ".mlir"; - construction_parameters.comp_options = - mlir::ConstructBuildOptions(ctx, GetOperation(), GetKernelName(), true); + + if(config == PerformanceConvMlirIgemmXdlops()) + // At this case, do not pass in the invalid perf config and instead make Miir library to do + // heuristic initialization + construction_parameters.comp_options = + mlir::ConstructBuildOptions(ctx, GetOperation(), GetKernelName(), true); + else + // At this case, Make Miir library to use the valid perf config + construction_parameters.comp_options = mlir::ConstructBuildOptions( + ctx, GetOperation(), GetKernelName(), config.ToString(), true); size_t local_size = 0; size_t global_size = 0; @@ -97,6 +221,7 @@ ConvSolution ConvMlirIgemmFwdXdlops::GetSolution(const ConvolutionContext& ctx) return result; #else std::ignore = ctx; + std::ignore = config; return {}; #endif }