Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1_XDLOPS_PERF_VALS #1217

Merged
merged 3 commits into from
Oct 12, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#define WORKAROUND_ISSUE_1206 1

MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1_XDLOPS)
MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1_XDLOPS_PERF_VALS)

namespace miopen {
namespace solver {
Expand Down Expand Up @@ -881,7 +882,7 @@ ConvHipImplicitGemmBwdDataV4R1Xdlops::Search(const ConvolutionContext& ctx,
ConvSolution ConvHipImplicitGemmBwdDataV4R1Xdlops::GetSolution(
const ConvolutionContext& ctx,
const PerformanceImplicitGemmBwdDataV4R1Xdlops& config,
bool) const
const bool disableConfigOverrideFromEnv) const
{
ConvSolution result;

Expand All @@ -891,6 +892,33 @@ ConvSolution ConvHipImplicitGemmBwdDataV4R1Xdlops::GetSolution(
assert(false);
}

const PerformanceImplicitGemmBwdDataV4R1Xdlops* pcfg = &config;
PerformanceImplicitGemmBwdDataV4R1Xdlops fromEnv;
if(!disableConfigOverrideFromEnv)
{
std::string s;
const auto p_asciz =
miopen::GetStringEnv(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1_XDLOPS_PERF_VALS{});
if(p_asciz != nullptr)
{
s = std::string(p_asciz);
if(!s.empty()) // else nothing to parse.
{
if(!fromEnv.Deserialize(s) || !fromEnv.IsReallyValid(ctx))
{
MIOPEN_LOG_E("MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1_XDLOPS_PERF_VALS: "
"Bad format or invalid for the problem config: "
<< s);
}
else
{
MIOPEN_LOG_I("Overridden from env: " << fromEnv.ToString());
pcfg = &fromEnv;
}
}
}
}

// a series of kernels
for(std::size_t gemm_id = 0; gemm_id < CalculateNumberOfGemm(ctx); ++gemm_id)
{
Expand All @@ -908,16 +936,16 @@ ConvSolution ConvHipImplicitGemmBwdDataV4R1Xdlops::GetSolution(
{
int grid_size = 0;

const std::size_t GemmMPerBlock = config.GemmMPerBlock;
const std::size_t GemmNPerBlock = config.GemmNPerBlock;
const std::size_t GemmKPerBlock = config.GemmKPerBlock;
const std::size_t GemmMPerWave = config.GemmMPerWave;
const std::size_t GemmNPerWave = config.GemmNPerWave;
const std::size_t GemmMPerBlock = pcfg->GemmMPerBlock;
const std::size_t GemmNPerBlock = pcfg->GemmNPerBlock;
const std::size_t GemmKPerBlock = pcfg->GemmKPerBlock;
const std::size_t GemmMPerWave = pcfg->GemmMPerWave;
const std::size_t GemmNPerWave = pcfg->GemmNPerWave;

const std::size_t block_size =
GemmNPerBlock * GemmMPerBlock / (GemmMPerWave * GemmNPerWave) * wave_size;

std::tie(grid_size, std::ignore) = config.CalculateGridSize(ctx);
std::tie(grid_size, std::ignore) = pcfg->CalculateGridSize(ctx);

construction_parameters.l_wk.push_back(block_size);
construction_parameters.l_wk.push_back(1);
Expand Down Expand Up @@ -953,7 +981,7 @@ ConvSolution ConvHipImplicitGemmBwdDataV4R1Xdlops::GetSolution(
GemmABlockCopyClusterLengths_GemmKPack,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmKPack,
std::ignore) = config.CalculateGemmABlockCopyPerformanceParameters(ctx);
std::ignore) = pcfg->CalculateGemmABlockCopyPerformanceParameters(ctx);

int GemmBBlockCopyClusterLengths_GemmKPack = 1;
int GemmBBlockCopyDstDataPerWrite_GemmKPack = 1;
Expand All @@ -963,7 +991,7 @@ ConvSolution ConvHipImplicitGemmBwdDataV4R1Xdlops::GetSolution(
GemmBBlockCopyClusterLengths_GemmKPack,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmKPack,
std::ignore) = config.CalculateGemmBBlockCopyPerformanceParameters(ctx);
std::ignore) = pcfg->CalculateGemmBBlockCopyPerformanceParameters(ctx);

const auto GemmABlockCopyDstDataPerWrite_GemmKPACK =
GemmABlockCopyDstDataPerWrite_GemmKPack;
Expand Down Expand Up @@ -1015,7 +1043,7 @@ ConvSolution ConvHipImplicitGemmBwdDataV4R1Xdlops::GetSolution(
ctx.general_compile_options;

construction_parameters.comp_options +=
std::string(" -DCK_PARAM_KPACK_LENGTH=") + std::to_string(config.GemmKPACKSize) +
std::string(" -DCK_PARAM_KPACK_LENGTH=") + std::to_string(pcfg->GemmKPACKSize) +
std::string(" -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK=") + std::to_string(GemmABlockCopyDstDataPerWrite_GemmKPACK) +
std::string(" -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK=") + std::to_string(GemmBBlockCopyDstDataPerWrite_GemmKPACK);

Expand Down