Skip to content

Commit

Permalink
Fix NCHW bwd asm igemm solver. Remove WORKAROUND_SWDEV_286774. (#952)
Browse files Browse the repository at this point in the history
  • Loading branch information
carlushuang authored May 27, 2021
1 parent 7c34cde commit fc8a05b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
21 changes: 14 additions & 7 deletions src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
#include <miopen/solver/implicitgemm_util.hpp>
#include <miopen/conv/asm_implicit_gemm.hpp>

#define WORKAROUND_SWDEV_286774 1

MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_GTC_XDLOPS)

namespace miopen {
Expand Down Expand Up @@ -860,6 +858,13 @@ static std::tuple<bool, // is suitable kernel found
if(b % cfg.nxb != 0)
continue;

// if thread length have multiple value in K direction, then must make sure K is multiply of
// gemm_k_per_block
if((cfg.tensor_a_thread_lengths[0] != 1 || cfg.tensor_a_thread_lengths[1] != 1 ||
cfg.tensor_b_thread_lengths[0] != 1 || cfg.tensor_b_thread_lengths[1] != 1) &&
(k % cfg.gemm_k_per_block != 0))
continue;

bool gemm_k_valid = true;
for(int gemm_id = 0; gemm_id < num_of_gemm; gemm_id++)
{
Expand Down Expand Up @@ -931,6 +936,13 @@ static std::tuple<bool, // is suitable kernel found
if(b % cfg.nxb != 0)
continue;

// if thread length have multiple value in K direction, then must make sure K is
// multiply of gemm_k_per_block
if((cfg.tensor_a_thread_lengths[0] != 1 || cfg.tensor_a_thread_lengths[1] != 1 ||
cfg.tensor_b_thread_lengths[0] != 1 || cfg.tensor_b_thread_lengths[1] != 1) &&
(k % cfg.gemm_k_per_block != 0))
continue;

bool gemm_k_valid = true;
for(int gemm_id = 0; gemm_id < num_of_gemm; gemm_id++)
{
Expand Down Expand Up @@ -982,11 +994,6 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlops::IsApplicable(const ConvolutionConte
if(!ctx.IsFp32() && !ctx.IsFp16())
return false;

#if WORKAROUND_SWDEV_286774
if(ctx.IsFp16() && !miopen::IsEnabled(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_GTC_XDLOPS{}))
return false;
#endif

if(!ctx.rmv.IsV3())
return false;

Expand Down
4 changes: 3 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ COMMAND ${DYNAMIC_IMPLICITGEMM_BWD_ENVS} $<TARGET_FILE:test_conv2d> --verbose --
COMMAND ${DYNAMIC_IMPLICITGEMM_BWD_ENVS} $<TARGET_FILE:test_conv2d> --verbose --input 128 256 56 56 --weights 64 256 1 1 --pads_strides_dilations 0 0 1 1 1 1 --disable-forward --disable-backward-weights
)

add_custom_test(test_conv_igemm_dynamic_xdlops_bwd SKIP_UNLESS_ALL
add_custom_test(test_conv_igemm_dynamic_xdlops_bwd SKIP_UNLESS_ALL ALLOW_HALF
COMMAND ${DYNAMIC_IMPLICITGEMM_BWD_ENVS_XDLOPS} $<TARGET_FILE:test_conv2d> --verbose --input 64 64 28 28 --weights 16 64 1 1 --pads_strides_dilations 0 0 1 1 1 1 --disable-forward --disable-backward-weights
COMMAND ${DYNAMIC_IMPLICITGEMM_BWD_ENVS_XDLOPS} $<TARGET_FILE:test_conv2d> --verbose --input 16 128 36 36 --weights 32 128 1 1 --pads_strides_dilations 0 0 1 1 1 1 --disable-forward --disable-backward-weights
COMMAND ${DYNAMIC_IMPLICITGEMM_BWD_ENVS_XDLOPS} $<TARGET_FILE:test_conv2d> --verbose --input 64 64 56 56 --weights 256 64 1 1 --pads_strides_dilations 0 0 1 1 1 1 --disable-forward --disable-backward-weights
Expand All @@ -783,7 +783,9 @@ COMMAND ${DYNAMIC_IMPLICITGEMM_BWD_ENVS_XDLOPS} $<TARGET_FILE:test_conv2d> --ver
COMMAND ${DYNAMIC_IMPLICITGEMM_BWD_ENVS_XDLOPS} $<TARGET_FILE:test_conv2d> --verbose --input 4 512 128 128 --weights 12 512 1 1 --pads_strides_dilations 0 0 1 1 1 1 --disable-forward --disable-backward-weights
COMMAND ${DYNAMIC_IMPLICITGEMM_BWD_ENVS_XDLOPS} $<TARGET_FILE:test_conv2d> --verbose --input 400 256 7 7 --weights 1024 256 7 7 --pads_strides_dilations 0 0 1 1 1 1 --disable-forward --disable-backward-weights
COMMAND ${DYNAMIC_IMPLICITGEMM_BWD_ENVS_XDLOPS} $<TARGET_FILE:test_conv2d> --verbose --input 400 256 1 1 --weights 1024 256 1 1 --pads_strides_dilations 0 0 1 1 1 1 --disable-forward --disable-backward-weights
COMMAND ${DYNAMIC_IMPLICITGEMM_BWD_ENVS_XDLOPS} $<TARGET_FILE:test_conv2d> --verbose --input 8 16 5 5 --weights 8 16 2 2 --pads_strides_dilations 0 0 1 1 1 1 --disable-forward --disable-backward-weights
)

add_custom_test(test_conv_igemm_dynamic_xdlops_fwd SKIP_UNLESS_ALL ALLOW_HALF
COMMAND ${DYNAMIC_IMPLICITGEMM_FWD_GTC_DYNAMIC_XDLOPS_ENVS} $<TARGET_FILE:test_conv2d> --verbose --input 64 1024 14 14 --weights 1024 1024 1 1 --pads_strides_dilations 0 0 1 1 1 1 --disable-backward-data --disable-backward-weights
COMMAND ${DYNAMIC_IMPLICITGEMM_FWD_GTC_DYNAMIC_XDLOPS_ENVS} $<TARGET_FILE:test_conv2d> --verbose --input 64 256 56 56 --weights 512 256 1 1 --pads_strides_dilations 0 0 2 2 1 1 --disable-backward-data --disable-backward-weights
Expand Down

0 comments on commit fc8a05b

Please sign in to comment.