Skip to content

Commit

Permalink
[MI100][FP16][ASM iGemm] Fix wrw's very small ho and wo error (#1000)
Browse files Browse the repository at this point in the history
- Fix asm igemm wrw hoxwo less than b_padding bug
- [TESTS] Revert W/A for issue #996
  • Loading branch information
shaojiewang authored and atamazov committed Jul 22, 2021
1 parent d710b26 commit 81be7d0
Show file tree
Hide file tree
Showing 216 changed files with 5,828 additions and 5,608 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -238,32 +238,33 @@
.set s_x, 30
.set s_gemmk_split, 31
.set s_group, 32
.set s_out_stride_k, 33
.set s_ho_padded, 33
.set s_out_stride_k, 34
.set s_hoxwo, 31
.set s_out_stride_n, 34
.set s_in_stride_c0, 35
.set s_in_stride_c, 36
.set s_in_stride_n, 37
.set s_wei_stride_c, 38
.set s_wei_stride_k, 39
.set s_out_stride_n_n1, 40
.set s_in_stride_n_n1, 41
.set s_move_slice_n_n1, 42
.set s_move_slice_n_dsho, 43
.set s_move_slice_n_dswo, 44
.set s_dim_b, 45
.set s_block_gtc_ik, 46
.set s_block_gtc_ic0, 47
.set s_block_gtc_ic1e, 48
.set s_block_gtc_in, 49
.set s_block_gtc_ig, 50
.set s_out_stride_n, 35
.set s_in_stride_c0, 36
.set s_in_stride_c, 37
.set s_in_stride_n, 38
.set s_wei_stride_c, 39
.set s_wei_stride_k, 40
.set s_out_stride_n_n1, 41
.set s_in_stride_n_n1, 42
.set s_move_slice_n_n1, 43
.set s_move_slice_n_dsho, 44
.set s_move_slice_n_dswo, 45
.set s_dim_b, 46
.set s_block_gtc_ik, 47
.set s_block_gtc_ic0, 48
.set s_block_gtc_ic1e, 49
.set s_block_gtc_in, 50
.set s_block_gtc_ig, 51
.set s_knum, 1
.set s_gemm_k_num_n1, 0
.set s_kitr, 3
.set s_in_offset, 51
.set s_out_offset, 57
.set s_sub_n, 63
.set s_k_padded, 64
.set s_in_offset, 52
.set s_out_offset, 58
.set s_sub_n, 64
.set s_k_padded, 65
.set s_tmp, 66
.set s_end, 72

Expand Down Expand Up @@ -334,7 +335,7 @@ igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x128x16_wt32x32x8_ws1x1_wr2x2_ta1x1x1x8_1
s_load_dwordx2 s[s_p_wei+0:s_p_wei+1], s[s_ka+0:s_ka+1], 0+k_p_wei
s_load_dwordx2 s[s_p_out+0:s_p_out+1], s[s_ka+0:s_ka+1], 0+k_p_out
s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi
s_load_dword s[s_group], s[s_ka+0:s_ka+1], 0+k_group
s_load_dwordx2 s[s_group+0:s_group+1], s[s_ka+0:s_ka+1], 0+k_group

; input, thread(n0,n1b,c0,c1e): 1x1x8x1, cluster(n0,n1b,c0,c1e): 1x16x1x16
v_mov_b32 v[v_tmp], v0
Expand Down Expand Up @@ -683,8 +684,8 @@ igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x128x16_wt32x32x8_ws1x1_wr2x2_ta1x1x1x8_1
s_mov_b64 exec, -1

v_add_u32 v[v_move_slice_n_idsho], s[s_move_slice_n_dsho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho_padded], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho_padded], v[v_move_slice_n_idsho]
v_add_u32 v[v_move_slice_n_in1], 1, v[v_move_slice_n_in1]
v_add_u32 v[v_in_os_base], s[s_in_stride_n], v[v_in_os_base]
v_add_u32 v[v_out_os_base], s[s_out_stride_n], v[v_out_os_base]
Expand Down Expand Up @@ -759,8 +760,8 @@ L_igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x128x16_wt32x32x8_ws1x1_wr2x2_ta1x1x1x8
v_add_u32 v[v_move_slice_n_idsho], 1, v[v_move_slice_n_idsho]
s_mov_b64 exec, -1
v_add_u32 v[v_move_slice_n_idsho], s[s_move_slice_n_dsho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho_padded], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho_padded], v[v_move_slice_n_idsho]
v_add_u32 v[v_move_slice_n_in1], 1, v[v_move_slice_n_in1]
v_add_u32 v[v_in_os_base], s[s_in_stride_n], v[v_in_os_base]
v_add_u32 v[v_out_os_base], s[s_out_stride_n], v[v_out_os_base]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,32 +239,33 @@
.set s_x, 30
.set s_gemmk_split, 31
.set s_group, 32
.set s_out_stride_k, 33
.set s_ho_padded, 33
.set s_out_stride_k, 34
.set s_hoxwo, 31
.set s_out_stride_n, 34
.set s_in_stride_c0, 35
.set s_in_stride_c, 36
.set s_in_stride_n, 37
.set s_wei_stride_c, 38
.set s_wei_stride_k, 39
.set s_out_stride_n_n1, 40
.set s_in_stride_n_n1, 41
.set s_move_slice_n_n1, 42
.set s_move_slice_n_dsho, 43
.set s_move_slice_n_dswo, 44
.set s_dim_b, 45
.set s_block_gtc_ik, 46
.set s_block_gtc_ic0, 47
.set s_block_gtc_ic1e, 48
.set s_block_gtc_in, 49
.set s_block_gtc_ig, 50
.set s_out_stride_n, 35
.set s_in_stride_c0, 36
.set s_in_stride_c, 37
.set s_in_stride_n, 38
.set s_wei_stride_c, 39
.set s_wei_stride_k, 40
.set s_out_stride_n_n1, 41
.set s_in_stride_n_n1, 42
.set s_move_slice_n_n1, 43
.set s_move_slice_n_dsho, 44
.set s_move_slice_n_dswo, 45
.set s_dim_b, 46
.set s_block_gtc_ik, 47
.set s_block_gtc_ic0, 48
.set s_block_gtc_ic1e, 49
.set s_block_gtc_in, 50
.set s_block_gtc_ig, 51
.set s_knum, 1
.set s_gemm_k_num_n1, 0
.set s_kitr, 3
.set s_in_offset, 51
.set s_out_offset, 57
.set s_sub_n, 63
.set s_k_padded, 64
.set s_in_offset, 52
.set s_out_offset, 58
.set s_sub_n, 64
.set s_k_padded, 65
.set s_tmp, 66
.set s_end, 72

Expand Down Expand Up @@ -335,7 +336,7 @@ igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x128x16_wt32x32x8_ws1x1_wr2x2_ta1x1x1x8_1
s_load_dwordx2 s[s_p_wei+0:s_p_wei+1], s[s_ka+0:s_ka+1], 0+k_p_wei
s_load_dwordx2 s[s_p_out+0:s_p_out+1], s[s_ka+0:s_ka+1], 0+k_p_out
s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi
s_load_dword s[s_group], s[s_ka+0:s_ka+1], 0+k_group
s_load_dwordx2 s[s_group+0:s_group+1], s[s_ka+0:s_ka+1], 0+k_group

; input, thread(n0,n1b,c0,c1e): 1x1x8x1, cluster(n0,n1b,c0,c1e): 1x16x1x16
v_mov_b32 v[v_tmp], v0
Expand Down Expand Up @@ -685,8 +686,8 @@ igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x128x16_wt32x32x8_ws1x1_wr2x2_ta1x1x1x8_1
s_mov_b64 exec, -1

v_add_u32 v[v_move_slice_n_idsho], s[s_move_slice_n_dsho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho_padded], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho_padded], v[v_move_slice_n_idsho]
v_add_u32 v[v_move_slice_n_in1], 1, v[v_move_slice_n_in1]
v_add_u32 v[v_in_os_base], s[s_in_stride_n], v[v_in_os_base]
v_add_u32 v[v_out_os_base], s[s_out_stride_n], v[v_out_os_base]
Expand Down Expand Up @@ -761,8 +762,8 @@ L_igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x128x16_wt32x32x8_ws1x1_wr2x2_ta1x1x1x8
v_add_u32 v[v_move_slice_n_idsho], 1, v[v_move_slice_n_idsho]
s_mov_b64 exec, -1
v_add_u32 v[v_move_slice_n_idsho], s[s_move_slice_n_dsho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho_padded], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho_padded], v[v_move_slice_n_idsho]
v_add_u32 v[v_move_slice_n_in1], 1, v[v_move_slice_n_in1]
v_add_u32 v[v_in_os_base], s[s_in_stride_n], v[v_in_os_base]
v_add_u32 v[v_out_os_base], s[s_out_stride_n], v[v_out_os_base]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,33 +238,34 @@
.set s_x, 30
.set s_gemmk_split, 31
.set s_group, 32
.set s_out_stride_k, 33
.set s_ho_padded, 33
.set s_out_stride_k, 34
.set s_hoxwo, 31
.set s_out_stride_n, 34
.set s_in_stride_c, 35
.set s_in_stride_n, 36
.set s_wei_stride_c, 37
.set s_wei_stride_k, 38
.set s_out_stride_n_n1, 39
.set s_in_stride_n_n1, 40
.set s_move_slice_n_n1, 41
.set s_move_slice_n_dsho, 42
.set s_move_slice_n_dswo, 43
.set s_dim_b, 44
.set s_block_gtc_ik, 45
.set s_block_gtc_ic0, 46
.set s_block_gtc_ic1e, 47
.set s_block_gtc_in, 48
.set s_block_gtc_ig, 49
.set s_out_stride_n, 35
.set s_in_stride_c, 36
.set s_in_stride_n, 37
.set s_wei_stride_c, 38
.set s_wei_stride_k, 39
.set s_out_stride_n_n1, 40
.set s_in_stride_n_n1, 41
.set s_move_slice_n_n1, 42
.set s_move_slice_n_dsho, 43
.set s_move_slice_n_dswo, 44
.set s_dim_b, 45
.set s_block_gtc_ik, 46
.set s_block_gtc_ic0, 47
.set s_block_gtc_ic1e, 48
.set s_block_gtc_in, 49
.set s_block_gtc_ig, 50
.set s_knum, 1
.set s_gemm_k_num_n1, 0
.set s_kitr, 3
.set s_in_offset, 50
.set s_out_offset, 50
.set s_sub_n, 56
.set s_k_padded, 57
.set s_tmp, 58
.set s_end, 64
.set s_in_offset, 51
.set s_out_offset, 51
.set s_sub_n, 57
.set s_k_padded, 58
.set s_tmp, 60
.set s_end, 66

.set v_c, 0 ; coalescing:8, needed:0, resuable:33
.set v_a, 0
Expand Down Expand Up @@ -334,7 +335,7 @@ igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1x1
s_load_dwordx2 s[s_p_wei+0:s_p_wei+1], s[s_ka+0:s_ka+1], 0+k_p_wei
s_load_dwordx2 s[s_p_out+0:s_p_out+1], s[s_ka+0:s_ka+1], 0+k_p_out
s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi
s_load_dword s[s_group], s[s_ka+0:s_ka+1], 0+k_group
s_load_dwordx2 s[s_group+0:s_group+1], s[s_ka+0:s_ka+1], 0+k_group

; input, thread(n0,n1b,c0,c1e): 1x1x1x1, cluster(n0,n1b,c0,c1e): 1x16x1x16
v_mov_b32 v[v_tmp], v0
Expand Down Expand Up @@ -680,8 +681,8 @@ igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1x1
s_mov_b64 exec, -1

v_add_u32 v[v_move_slice_n_idsho], s[s_move_slice_n_dsho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho_padded], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho_padded], v[v_move_slice_n_idsho]
v_add_u32 v[v_move_slice_n_in1], 1, v[v_move_slice_n_in1]
v_add_u32 v[v_in_os_base], s[s_in_stride_n], v[v_in_os_base]
v_add_u32 v[v_out_os_base], s[s_out_stride_n], v[v_out_os_base]
Expand Down Expand Up @@ -750,8 +751,8 @@ L_igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1
v_add_u32 v[v_move_slice_n_idsho], 1, v[v_move_slice_n_idsho]
s_mov_b64 exec, -1
v_add_u32 v[v_move_slice_n_idsho], s[s_move_slice_n_dsho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho_padded], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho_padded], v[v_move_slice_n_idsho]
v_add_u32 v[v_move_slice_n_in1], 1, v[v_move_slice_n_in1]
v_add_u32 v[v_in_os_base], s[s_in_stride_n], v[v_in_os_base]
v_add_u32 v[v_out_os_base], s[s_out_stride_n], v[v_out_os_base]
Expand Down Expand Up @@ -911,7 +912,7 @@ L_igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1
.amdhsa_system_sgpr_workgroup_id_x 1
.amdhsa_system_vgpr_workitem_id 0
.amdhsa_next_free_vgpr 72
.amdhsa_next_free_sgpr 64
.amdhsa_next_free_sgpr 66
.amdhsa_ieee_mode 0
.amdhsa_dx10_clamp 0
.end_amdhsa_kernel
Expand All @@ -922,7 +923,7 @@ amdhsa.version: [ 1, 0 ]
amdhsa.kernels:
- .name: igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1x16x1x16_tb1x1x1x1_1x16x1x16
.symbol: igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1x16x1x16_tb1x1x1x1_1x16x1x16.kd
.sgpr_count: 70
.sgpr_count: 72
.vgpr_count: 72
.kernarg_segment_align: 8
.kernarg_segment_size: 96
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,33 +239,34 @@
.set s_x, 30
.set s_gemmk_split, 31
.set s_group, 32
.set s_out_stride_k, 33
.set s_ho_padded, 33
.set s_out_stride_k, 34
.set s_hoxwo, 31
.set s_out_stride_n, 34
.set s_in_stride_c, 35
.set s_in_stride_n, 36
.set s_wei_stride_c, 37
.set s_wei_stride_k, 38
.set s_out_stride_n_n1, 39
.set s_in_stride_n_n1, 40
.set s_move_slice_n_n1, 41
.set s_move_slice_n_dsho, 42
.set s_move_slice_n_dswo, 43
.set s_dim_b, 44
.set s_block_gtc_ik, 45
.set s_block_gtc_ic0, 46
.set s_block_gtc_ic1e, 47
.set s_block_gtc_in, 48
.set s_block_gtc_ig, 49
.set s_out_stride_n, 35
.set s_in_stride_c, 36
.set s_in_stride_n, 37
.set s_wei_stride_c, 38
.set s_wei_stride_k, 39
.set s_out_stride_n_n1, 40
.set s_in_stride_n_n1, 41
.set s_move_slice_n_n1, 42
.set s_move_slice_n_dsho, 43
.set s_move_slice_n_dswo, 44
.set s_dim_b, 45
.set s_block_gtc_ik, 46
.set s_block_gtc_ic0, 47
.set s_block_gtc_ic1e, 48
.set s_block_gtc_in, 49
.set s_block_gtc_ig, 50
.set s_knum, 1
.set s_gemm_k_num_n1, 0
.set s_kitr, 3
.set s_in_offset, 50
.set s_out_offset, 50
.set s_sub_n, 56
.set s_k_padded, 57
.set s_tmp, 58
.set s_end, 64
.set s_in_offset, 51
.set s_out_offset, 51
.set s_sub_n, 57
.set s_k_padded, 58
.set s_tmp, 60
.set s_end, 66

.set v_c, 0 ; coalescing:8, needed:0, resuable:33
.set v_a, 0
Expand Down Expand Up @@ -335,7 +336,7 @@ igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1x1
s_load_dwordx2 s[s_p_wei+0:s_p_wei+1], s[s_ka+0:s_ka+1], 0+k_p_wei
s_load_dwordx2 s[s_p_out+0:s_p_out+1], s[s_ka+0:s_ka+1], 0+k_p_out
s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi
s_load_dword s[s_group], s[s_ka+0:s_ka+1], 0+k_group
s_load_dwordx2 s[s_group+0:s_group+1], s[s_ka+0:s_ka+1], 0+k_group

; input, thread(n0,n1b,c0,c1e): 1x1x1x1, cluster(n0,n1b,c0,c1e): 1x16x1x16
v_mov_b32 v[v_tmp], v0
Expand Down Expand Up @@ -682,8 +683,8 @@ igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1x1
s_mov_b64 exec, -1

v_add_u32 v[v_move_slice_n_idsho], s[s_move_slice_n_dsho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho_padded], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho_padded], v[v_move_slice_n_idsho]
v_add_u32 v[v_move_slice_n_in1], 1, v[v_move_slice_n_in1]
v_add_u32 v[v_in_os_base], s[s_in_stride_n], v[v_in_os_base]
v_add_u32 v[v_out_os_base], s[s_out_stride_n], v[v_out_os_base]
Expand Down Expand Up @@ -752,8 +753,8 @@ L_igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1
v_add_u32 v[v_move_slice_n_idsho], 1, v[v_move_slice_n_idsho]
s_mov_b64 exec, -1
v_add_u32 v[v_move_slice_n_idsho], s[s_move_slice_n_dsho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho_padded], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho_padded], v[v_move_slice_n_idsho]
v_add_u32 v[v_move_slice_n_in1], 1, v[v_move_slice_n_in1]
v_add_u32 v[v_in_os_base], s[s_in_stride_n], v[v_in_os_base]
v_add_u32 v[v_out_os_base], s[s_out_stride_n], v[v_out_os_base]
Expand Down Expand Up @@ -901,7 +902,7 @@ L_igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1
.amdhsa_system_sgpr_workgroup_id_x 1
.amdhsa_system_vgpr_workitem_id 0
.amdhsa_next_free_vgpr 72
.amdhsa_next_free_sgpr 64
.amdhsa_next_free_sgpr 66
.amdhsa_ieee_mode 0
.amdhsa_dx10_clamp 0
.end_amdhsa_kernel
Expand All @@ -912,7 +913,7 @@ amdhsa.version: [ 1, 0 ]
amdhsa.kernels:
- .name: igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1x16x1x16_tb1x1x1x1_1x16x1x16_gkgs
.symbol: igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1x16x1x16_tb1x1x1x1_1x16x1x16_gkgs.kd
.sgpr_count: 70
.sgpr_count: 72
.vgpr_count: 72
.kernarg_segment_align: 8
.kernarg_segment_size: 96
Expand Down
Loading

0 comments on commit 81be7d0

Please sign in to comment.