diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp index 07d5a8df10..2a4d7b67b1 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp @@ -35,6 +35,7 @@ #define WORKAROUND_ISSUE_1206 1 MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1_XDLOPS) +MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1_XDLOPS_PERF_VALS) namespace miopen { namespace solver { @@ -881,7 +882,7 @@ ConvHipImplicitGemmBwdDataV4R1Xdlops::Search(const ConvolutionContext& ctx, ConvSolution ConvHipImplicitGemmBwdDataV4R1Xdlops::GetSolution( const ConvolutionContext& ctx, const PerformanceImplicitGemmBwdDataV4R1Xdlops& config, - bool) const + const bool disableConfigOverrideFromEnv) const { ConvSolution result; @@ -891,6 +892,33 @@ ConvSolution ConvHipImplicitGemmBwdDataV4R1Xdlops::GetSolution( assert(false); } + const PerformanceImplicitGemmBwdDataV4R1Xdlops* pcfg = &config; + PerformanceImplicitGemmBwdDataV4R1Xdlops fromEnv; + if(!disableConfigOverrideFromEnv) + { + std::string s; + const auto p_asciz = + miopen::GetStringEnv(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1_XDLOPS_PERF_VALS{}); + if(p_asciz != nullptr) + { + s = std::string(p_asciz); + if(!s.empty()) // else nothing to parse. + { + if(!fromEnv.Deserialize(s) || !fromEnv.IsReallyValid(ctx)) + { + MIOPEN_LOG_E("MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1_XDLOPS_PERF_VALS: " + "Bad format or invalid for the problem config: " + << s); + } + else + { + MIOPEN_LOG_I("Overridden from env: " << fromEnv.ToString()); + pcfg = &fromEnv; + } + } + } + } + // a series of kernels for(std::size_t gemm_id = 0; gemm_id < CalculateNumberOfGemm(ctx); ++gemm_id) { @@ -908,16 +936,16 @@ ConvSolution ConvHipImplicitGemmBwdDataV4R1Xdlops::GetSolution( { int grid_size = 0; - const std::size_t GemmMPerBlock = config.GemmMPerBlock; - const std::size_t GemmNPerBlock = config.GemmNPerBlock; - const std::size_t GemmKPerBlock = config.GemmKPerBlock; - const std::size_t GemmMPerWave = config.GemmMPerWave; - const std::size_t GemmNPerWave = config.GemmNPerWave; + const std::size_t GemmMPerBlock = pcfg->GemmMPerBlock; + const std::size_t GemmNPerBlock = pcfg->GemmNPerBlock; + const std::size_t GemmKPerBlock = pcfg->GemmKPerBlock; + const std::size_t GemmMPerWave = pcfg->GemmMPerWave; + const std::size_t GemmNPerWave = pcfg->GemmNPerWave; const std::size_t block_size = GemmNPerBlock * GemmMPerBlock / (GemmMPerWave * GemmNPerWave) * wave_size; - std::tie(grid_size, std::ignore) = config.CalculateGridSize(ctx); + std::tie(grid_size, std::ignore) = pcfg->CalculateGridSize(ctx); construction_parameters.l_wk.push_back(block_size); construction_parameters.l_wk.push_back(1); @@ -953,7 +981,7 @@ ConvSolution ConvHipImplicitGemmBwdDataV4R1Xdlops::GetSolution( GemmABlockCopyClusterLengths_GemmKPack, GemmABlockCopySrcDataPerRead_GemmM, GemmABlockCopyDstDataPerWrite_GemmKPack, - std::ignore) = config.CalculateGemmABlockCopyPerformanceParameters(ctx); + std::ignore) = pcfg->CalculateGemmABlockCopyPerformanceParameters(ctx); int GemmBBlockCopyClusterLengths_GemmKPack = 1; int GemmBBlockCopyDstDataPerWrite_GemmKPack = 1; @@ -963,7 +991,7 @@ ConvSolution ConvHipImplicitGemmBwdDataV4R1Xdlops::GetSolution( GemmBBlockCopyClusterLengths_GemmKPack, GemmBBlockCopySrcDataPerRead_GemmN, GemmBBlockCopyDstDataPerWrite_GemmKPack, - std::ignore) = config.CalculateGemmBBlockCopyPerformanceParameters(ctx); + std::ignore) = pcfg->CalculateGemmBBlockCopyPerformanceParameters(ctx); const auto GemmABlockCopyDstDataPerWrite_GemmKPACK = GemmABlockCopyDstDataPerWrite_GemmKPack; @@ -1015,7 +1043,7 @@ ConvSolution ConvHipImplicitGemmBwdDataV4R1Xdlops::GetSolution( ctx.general_compile_options; construction_parameters.comp_options += - std::string(" -DCK_PARAM_KPACK_LENGTH=") + std::to_string(config.GemmKPACKSize) + + std::string(" -DCK_PARAM_KPACK_LENGTH=") + std::to_string(pcfg->GemmKPACKSize) + std::string(" -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK=") + std::to_string(GemmABlockCopyDstDataPerWrite_GemmKPACK) + std::string(" -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK=") + std::to_string(GemmBBlockCopyDstDataPerWrite_GemmKPACK);