Skip to content

Commit

Permalink
Explicitly disable asm solvers if XNACK is enabled (#1142)
Browse files Browse the repository at this point in the history
* disable all asm kernels if xnack enabled

* remove test_find_db, test_main, test_immed_conv2d from skip tests for gfx90a

* fix clang-tidy

Co-authored-by: Jun Liu <[email protected]>
  • Loading branch information
Slimakanzer and junliume committed Sep 30, 2021
1 parent f00f73b commit 5576c06
Show file tree
Hide file tree
Showing 25 changed files with 103 additions and 25 deletions.
4 changes: 3 additions & 1 deletion src/ocl/batchnormocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,9 @@ void BatchNormBackward(Handle& handle,
ctx.use_asm_kernels && ctx.rmv.IsV2orV3() &&
(StartsWith(handle.GetDeviceName(), "gfx8") ||
(StartsWith(handle.GetDeviceName(), "gfx9") &&
(handle.GetDeviceName() != "gfx90a"))))
(handle.GetDeviceName() != "gfx90a"))) &&
(!handle.GetTargetProperties().Xnack() ||
!*handle.GetTargetProperties().Xnack()))
{
kernel_name = "miopenGcnAsmBNBwdTrainSpatial";
program_name = "gcnAsmBNBwdTrainSpatial.s";
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_MP_bidirectional_winograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ inline bool IsApplicableTransform(const ConvolutionContext& params)
if(!(params.IsFp32() || params.IsFp16()))
return false;

const auto target = params.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

const std::string name = params.GetStream().GetDeviceName();
if(!StartsWith(name, "gfx9") || name == "gfx90a")
return false;
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_1x1u.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,10 @@ bool ConvAsm1x1U::IsApplicable(const ConvolutionContext& params) const
if(!(params.IsFp32() || params.IsFp16()))
return false;

const auto target = params.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

const std::string name = params.GetStream().GetDeviceName();
if(name.find("gfx9") == std::string::npos)
{
Expand Down
5 changes: 5 additions & 0 deletions src/solver/conv_asm_1x1u_stride2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,11 @@ bool ConvAsm1x1UV2::IsApplicable(const ConvolutionContext& params) const
return false;
if(!params.IsFp32())
return false;

const auto target = params.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

const std::string name = params.GetStream().GetDeviceName();
if(name.find("gfx8") == std::string::npos && name.find("gfx9") == std::string::npos)
{
Expand Down
5 changes: 5 additions & 0 deletions src/solver/conv_asm_3x3u.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ bool ConvAsm3x3U::IsApplicable(const ConvolutionContext& params) const
return false;
if(!params.rmv.IsV2orV3())
return false;

const auto target = params.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

const std::string name = params.GetStream().GetDeviceName();
if(!(StartsWith(name, "gfx8") || StartsWith(name, "gfx9")) || name == "gfx90a")
return false;
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_5x10u2v2b1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ bool ConvAsm5x10u2v2b1::IsApplicable(const ConvolutionContext& params) const
if(!params.rmv.IsV2orV3())
return false;

const auto target = params.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

const std::string name = params.GetStream().GetDeviceName();
const bool device_is_gfx8_9_no_xnack =
(name == "gfx800" || name == "gfx802" || name == "gfx803" || name == "gfx804" ||
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_5x10u2v2f1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ConvolutionContext& params) const
if(!params.rmv.IsV2orV3())
return false;

const auto target = params.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

const std::string name = params.GetStream().GetDeviceName();
const bool device_is_gfx8_9_no_xnack =
(name == "gfx800" || name == "gfx802" || name == "gfx803" || name == "gfx804" ||
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ConvolutionContext& p
if(!params.rmv.IsV2orV3())
return false;

const auto target = params.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

const std::string name = params.GetStream().GetDeviceName();
if(!(name == "gfx800" || name == "gfx802" || name == "gfx803" || name == "gfx804" ||
name == "gfx900" || name == "gfx904" || name == "gfx906" || name == "gfx908"))
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_dir_BwdWrW1x1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,10 @@ bool ConvAsmBwdWrW1x1::IsApplicable(const ConvolutionContext& params) const
if(!params.rmv.IsV2orV3())
return false;

const auto target = params.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

const std::string name = params.GetStream().GetDeviceName();
if(name.find("gfx8") == std::string::npos && name.find("gfx9") == std::string::npos)
{
Expand Down
5 changes: 5 additions & 0 deletions src/solver/conv_asm_dir_BwdWrW3x3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,11 @@ bool ConvAsmBwdWrW3x3::IsApplicable(const ConvolutionContext& params) const
return false;
if(!params.rmv.IsV2orV3())
return false;

const auto target = params.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

const std::string name = params.GetStream().GetDeviceName();
if(!(StartsWith(name, "gfx8") || StartsWith(name, "gfx9")) || name == "gfx90a")
return false;
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ bool ConvAsmImplicitGemmV4R1DynamicBwd::IsApplicable(const ConvolutionContext& c
return false;
}

const auto target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

std::string kernel_name;
int block_size;
int grid_size;
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,10 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlops::IsApplicable(const ConvolutionConte
return false;
}

const auto target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

bool isValid;
std::tie(isValid, std::ignore, std::ignore, std::ignore, std::ignore) =
FindImplicitGemmGtcDynamicBwdKernel(ctx);
Expand Down
17 changes: 3 additions & 14 deletions src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,21 +632,10 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::IsApplicable(const ConvolutionC
if(!ctx.IsLayoutNHWC())
return false;

const auto k = ctx.n_inputs;
const auto c = ctx.n_outputs;
const auto group = ctx.group_counts;

if(ctx.IsFp32() && (k / group) % 4 != 0)
return false; // gemm_k limitation for fp32

if(ctx.IsFp16() && (k / group) % 16 != 0)
return false; // gemm_k limitation for fp16
const auto target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false; // NOLINT (readability-simplify-boolean-expr)

if(ctx.IsFp16())
{
if((c / group) % 2 != 0)
return false; // vector store limitation
}
return true;
}
ConvSolution ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::GetSolution(
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1529,6 +1529,10 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlops::IsApplicable(const ConvolutionConte
return false;
}

const auto target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

bool isValid;
std::tie(isValid, std::ignore, std::ignore, std::ignore, std::ignore) =
FindImplicitGemmGtcDynamicFwdKernel(ctx);
Expand Down
5 changes: 5 additions & 0 deletions src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,11 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::IsApplicable(const ConvolutionC

if(!ctx.IsLayoutNHWC())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false; // NOLINT (readability-simplify-boolean-expr)

return true;
}
ConvSolution ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::GetSolution(
Expand Down
5 changes: 5 additions & 0 deletions src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,11 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsApplicable(const ConvolutionC

if(!ctx.IsLayoutNHWC())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false; // NOLINT (readability-simplify-boolean-expr)

return true;
}

Expand Down
8 changes: 8 additions & 0 deletions src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd::IsApplicable(const ConvolutionContext& c
{
return false;
}

const auto target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;
auto tunables = GetImplicitGemmV4R1DynamicTunables();
return !std::none_of(
tunables.begin(), tunables.end(), [&](auto tunable) { return tunable.IsValid(ctx); });
Expand Down Expand Up @@ -342,6 +346,10 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd_1x1::IsApplicable(const ConvolutionContex
{
return false;
}

const auto target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;
auto tunables = GetImplicitGemmV4R1DynamicTunables();
return !std::none_of(
tunables.begin(), tunables.end(), [&](auto tunable) { return tunable.IsValid(ctx); });
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,10 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlops::IsApplicable(const ConvolutionConte
{
return false;
}

const auto target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;
bool is_valid;
std::tie(is_valid, std::ignore, std::ignore, std::ignore, std::ignore) =
FindImplicitGemmWrwGTCDynamicXdlopsKernel(ctx);
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,10 @@ bool ConvAsmImplicitGemmV4R1DynamicWrw::IsApplicable(const ConvolutionContext& c
{
return false;
}

const auto target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;
std::string kernel_name;
int block_size;
int grid_size;
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_bin_wino3x3U.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ bool ConvBinWinograd3x3U::IsApplicable(const ConvolutionContext& params) const
if(!(params.rmv.IsV2orV3() && params.use_asm_kernels))
return false;

const auto target = params.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

const auto name = params.GetStream().GetDeviceName();
if(!(name == "gfx803" || name == "gfx900" || name == "gfx906" || name == "gfx908"))
return false;
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_bin_winoRxS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ bool ConvBinWinogradRxS::IsApplicable(const ConvolutionContext& params) const
if(!params.rmv.IsV2orV3())
return false;

const auto target = params.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

const auto name = params.GetStream().GetDeviceName();
const bool fp16 = params.IsFp16();
if(fp16)
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_multipass_wino3x3WrW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,10 @@ bool ConvWinograd3x3MultipassWrW<WinoDataH, WinoFilterH, WinoDataW, WinoFilterW>
return false;
}

const auto target = params.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

if(!(InTransform<WinoDataH, WinoFilterH, WinoDataW, WinoFilterW>::IsApplicable(params) &&
OutTransform<WinoDataH, WinoFilterH, WinoDataW, WinoFilterW>::IsApplicable(params) &&
FilterTransform<WinoDataH, WinoFilterH, WinoDataW, WinoFilterW>::IsApplicable(params)))
Expand Down
10 changes: 4 additions & 6 deletions src/solver/conv_winoRxS_f2x3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3)
MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_PERF_VALS)
MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_G1)

#define WORKAROUND_ISSUE_1093 1

#define WINODATA 2
#define WINOFILTER 3
#define MAX_CU_LIMIT 512
Expand Down Expand Up @@ -453,12 +451,12 @@ static bool IsApplicableBase(const ConvolutionContext& params)
if(!params.rmv.IsV3())
return false;

const auto target = params.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

const auto name = params.GetStream().GetDeviceName();
#if WORKAROUND_ISSUE_1093
if(!(StartsWith(name, "gfx9") || StartsWith(name, "gfx10")) || name == "gfx90a")
#else
if(!(StartsWith(name, "gfx9") || StartsWith(name, "gfx10")))
#endif
return false;
if(params.IsFp16() &&
!(StartsWith(name, "gfx906") || StartsWith(name, "gfx908") || StartsWith(name, "gfx1011") ||
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_winoRxS_f3x2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ bool ConvBinWinogradRxSf3x2::IsApplicable(const ConvolutionContext& params) cons
if(!params.IsLayoutDefault())
return false;

const auto target = params.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

const auto max_cu = params.GetStream().GetMaxHardwareComputeUnits();
if(max_cu > MAX_CU_LIMIT)
return false;
Expand Down
4 changes: 0 additions & 4 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ option( WORKAROUND_ISSUE_898 "" ON)
option( WORKAROUND_ISSUE_936 "" ON)
option( WORKAROUND_ISSUE_1053 "" ON)
option( WORKAROUND_ISSUE_1064 "" ON)
option( WORKAROUND_ISSUE_1093 "" ON)
option( WORKAROUND_ISSUE_1095 "" ON)

# Run the test suite to a depth limit
Expand Down Expand Up @@ -198,9 +197,6 @@ if(MIOPEN_TEST_GFX908)
endif()

if (MIOPEN_TEST_GFX90a)
if(WORKAROUND_ISSUE_1093)
list(APPEND SKIP_TESTS test_find_db test_main test_immed_conv2d)
endif()
if(WORKAROUND_ISSUE_1095)
list(APPEND SKIP_TESTS test_dropout)
endif()
Expand Down

0 comments on commit 5576c06

Please sign in to comment.