Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MI100][FP16][ASM iGemm] Fix wrw's very small ho and wo error #1000

Merged
merged 4 commits into from
Jun 24, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -238,32 +238,33 @@
.set s_x, 30
.set s_gemmk_split, 31
.set s_group, 32
.set s_out_stride_k, 33
.set s_ho_padded, 33
.set s_out_stride_k, 34
.set s_hoxwo, 31
.set s_out_stride_n, 34
.set s_in_stride_c0, 35
.set s_in_stride_c, 36
.set s_in_stride_n, 37
.set s_wei_stride_c, 38
.set s_wei_stride_k, 39
.set s_out_stride_n_n1, 40
.set s_in_stride_n_n1, 41
.set s_move_slice_n_n1, 42
.set s_move_slice_n_dsho, 43
.set s_move_slice_n_dswo, 44
.set s_dim_b, 45
.set s_block_gtc_ik, 46
.set s_block_gtc_ic0, 47
.set s_block_gtc_ic1e, 48
.set s_block_gtc_in, 49
.set s_block_gtc_ig, 50
.set s_out_stride_n, 35
.set s_in_stride_c0, 36
.set s_in_stride_c, 37
.set s_in_stride_n, 38
.set s_wei_stride_c, 39
.set s_wei_stride_k, 40
.set s_out_stride_n_n1, 41
.set s_in_stride_n_n1, 42
.set s_move_slice_n_n1, 43
.set s_move_slice_n_dsho, 44
.set s_move_slice_n_dswo, 45
.set s_dim_b, 46
.set s_block_gtc_ik, 47
.set s_block_gtc_ic0, 48
.set s_block_gtc_ic1e, 49
.set s_block_gtc_in, 50
.set s_block_gtc_ig, 51
.set s_knum, 1
.set s_gemm_k_num_n1, 0
.set s_kitr, 3
.set s_in_offset, 51
.set s_out_offset, 57
.set s_sub_n, 63
.set s_k_padded, 64
.set s_in_offset, 52
.set s_out_offset, 58
.set s_sub_n, 64
.set s_k_padded, 65
.set s_tmp, 66
.set s_end, 72

Expand Down Expand Up @@ -334,7 +335,7 @@ igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x128x16_wt32x32x8_ws1x1_wr2x2_ta1x1x1x8_1
s_load_dwordx2 s[s_p_wei+0:s_p_wei+1], s[s_ka+0:s_ka+1], 0+k_p_wei
s_load_dwordx2 s[s_p_out+0:s_p_out+1], s[s_ka+0:s_ka+1], 0+k_p_out
s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi
s_load_dword s[s_group], s[s_ka+0:s_ka+1], 0+k_group
s_load_dwordx2 s[s_group+0:s_group+1], s[s_ka+0:s_ka+1], 0+k_group

; input, thread(n0,n1b,c0,c1e): 1x1x8x1, cluster(n0,n1b,c0,c1e): 1x16x1x16
v_mov_b32 v[v_tmp], v0
Expand Down Expand Up @@ -683,8 +684,8 @@ igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x128x16_wt32x32x8_ws1x1_wr2x2_ta1x1x1x8_1
s_mov_b64 exec, -1

v_add_u32 v[v_move_slice_n_idsho], s[s_move_slice_n_dsho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho_padded], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho_padded], v[v_move_slice_n_idsho]
v_add_u32 v[v_move_slice_n_in1], 1, v[v_move_slice_n_in1]
v_add_u32 v[v_in_os_base], s[s_in_stride_n], v[v_in_os_base]
v_add_u32 v[v_out_os_base], s[s_out_stride_n], v[v_out_os_base]
Expand Down Expand Up @@ -759,8 +760,8 @@ L_igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x128x16_wt32x32x8_ws1x1_wr2x2_ta1x1x1x8
v_add_u32 v[v_move_slice_n_idsho], 1, v[v_move_slice_n_idsho]
s_mov_b64 exec, -1
v_add_u32 v[v_move_slice_n_idsho], s[s_move_slice_n_dsho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho_padded], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho_padded], v[v_move_slice_n_idsho]
v_add_u32 v[v_move_slice_n_in1], 1, v[v_move_slice_n_in1]
v_add_u32 v[v_in_os_base], s[s_in_stride_n], v[v_in_os_base]
v_add_u32 v[v_out_os_base], s[s_out_stride_n], v[v_out_os_base]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,32 +239,33 @@
.set s_x, 30
.set s_gemmk_split, 31
.set s_group, 32
.set s_out_stride_k, 33
.set s_ho_padded, 33
.set s_out_stride_k, 34
.set s_hoxwo, 31
.set s_out_stride_n, 34
.set s_in_stride_c0, 35
.set s_in_stride_c, 36
.set s_in_stride_n, 37
.set s_wei_stride_c, 38
.set s_wei_stride_k, 39
.set s_out_stride_n_n1, 40
.set s_in_stride_n_n1, 41
.set s_move_slice_n_n1, 42
.set s_move_slice_n_dsho, 43
.set s_move_slice_n_dswo, 44
.set s_dim_b, 45
.set s_block_gtc_ik, 46
.set s_block_gtc_ic0, 47
.set s_block_gtc_ic1e, 48
.set s_block_gtc_in, 49
.set s_block_gtc_ig, 50
.set s_out_stride_n, 35
.set s_in_stride_c0, 36
.set s_in_stride_c, 37
.set s_in_stride_n, 38
.set s_wei_stride_c, 39
.set s_wei_stride_k, 40
.set s_out_stride_n_n1, 41
.set s_in_stride_n_n1, 42
.set s_move_slice_n_n1, 43
.set s_move_slice_n_dsho, 44
.set s_move_slice_n_dswo, 45
.set s_dim_b, 46
.set s_block_gtc_ik, 47
.set s_block_gtc_ic0, 48
.set s_block_gtc_ic1e, 49
.set s_block_gtc_in, 50
.set s_block_gtc_ig, 51
.set s_knum, 1
.set s_gemm_k_num_n1, 0
.set s_kitr, 3
.set s_in_offset, 51
.set s_out_offset, 57
.set s_sub_n, 63
.set s_k_padded, 64
.set s_in_offset, 52
.set s_out_offset, 58
.set s_sub_n, 64
.set s_k_padded, 65
.set s_tmp, 66
.set s_end, 72

Expand Down Expand Up @@ -335,7 +336,7 @@ igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x128x16_wt32x32x8_ws1x1_wr2x2_ta1x1x1x8_1
s_load_dwordx2 s[s_p_wei+0:s_p_wei+1], s[s_ka+0:s_ka+1], 0+k_p_wei
s_load_dwordx2 s[s_p_out+0:s_p_out+1], s[s_ka+0:s_ka+1], 0+k_p_out
s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi
s_load_dword s[s_group], s[s_ka+0:s_ka+1], 0+k_group
s_load_dwordx2 s[s_group+0:s_group+1], s[s_ka+0:s_ka+1], 0+k_group

; input, thread(n0,n1b,c0,c1e): 1x1x8x1, cluster(n0,n1b,c0,c1e): 1x16x1x16
v_mov_b32 v[v_tmp], v0
Expand Down Expand Up @@ -685,8 +686,8 @@ igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x128x16_wt32x32x8_ws1x1_wr2x2_ta1x1x1x8_1
s_mov_b64 exec, -1

v_add_u32 v[v_move_slice_n_idsho], s[s_move_slice_n_dsho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho_padded], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho_padded], v[v_move_slice_n_idsho]
v_add_u32 v[v_move_slice_n_in1], 1, v[v_move_slice_n_in1]
v_add_u32 v[v_in_os_base], s[s_in_stride_n], v[v_in_os_base]
v_add_u32 v[v_out_os_base], s[s_out_stride_n], v[v_out_os_base]
Expand Down Expand Up @@ -761,8 +762,8 @@ L_igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x128x16_wt32x32x8_ws1x1_wr2x2_ta1x1x1x8
v_add_u32 v[v_move_slice_n_idsho], 1, v[v_move_slice_n_idsho]
s_mov_b64 exec, -1
v_add_u32 v[v_move_slice_n_idsho], s[s_move_slice_n_dsho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho_padded], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho_padded], v[v_move_slice_n_idsho]
v_add_u32 v[v_move_slice_n_in1], 1, v[v_move_slice_n_in1]
v_add_u32 v[v_in_os_base], s[s_in_stride_n], v[v_in_os_base]
v_add_u32 v[v_out_os_base], s[s_out_stride_n], v[v_out_os_base]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,33 +238,34 @@
.set s_x, 30
.set s_gemmk_split, 31
.set s_group, 32
.set s_out_stride_k, 33
.set s_ho_padded, 33
.set s_out_stride_k, 34
.set s_hoxwo, 31
.set s_out_stride_n, 34
.set s_in_stride_c, 35
.set s_in_stride_n, 36
.set s_wei_stride_c, 37
.set s_wei_stride_k, 38
.set s_out_stride_n_n1, 39
.set s_in_stride_n_n1, 40
.set s_move_slice_n_n1, 41
.set s_move_slice_n_dsho, 42
.set s_move_slice_n_dswo, 43
.set s_dim_b, 44
.set s_block_gtc_ik, 45
.set s_block_gtc_ic0, 46
.set s_block_gtc_ic1e, 47
.set s_block_gtc_in, 48
.set s_block_gtc_ig, 49
.set s_out_stride_n, 35
.set s_in_stride_c, 36
.set s_in_stride_n, 37
.set s_wei_stride_c, 38
.set s_wei_stride_k, 39
.set s_out_stride_n_n1, 40
.set s_in_stride_n_n1, 41
.set s_move_slice_n_n1, 42
.set s_move_slice_n_dsho, 43
.set s_move_slice_n_dswo, 44
.set s_dim_b, 45
.set s_block_gtc_ik, 46
.set s_block_gtc_ic0, 47
.set s_block_gtc_ic1e, 48
.set s_block_gtc_in, 49
.set s_block_gtc_ig, 50
.set s_knum, 1
.set s_gemm_k_num_n1, 0
.set s_kitr, 3
.set s_in_offset, 50
.set s_out_offset, 50
.set s_sub_n, 56
.set s_k_padded, 57
.set s_tmp, 58
.set s_end, 64
.set s_in_offset, 51
.set s_out_offset, 51
.set s_sub_n, 57
.set s_k_padded, 58
.set s_tmp, 60
.set s_end, 66

.set v_c, 0 ; coalescing:8, needed:0, resuable:33
.set v_a, 0
Expand Down Expand Up @@ -334,7 +335,7 @@ igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1x1
s_load_dwordx2 s[s_p_wei+0:s_p_wei+1], s[s_ka+0:s_ka+1], 0+k_p_wei
s_load_dwordx2 s[s_p_out+0:s_p_out+1], s[s_ka+0:s_ka+1], 0+k_p_out
s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi
s_load_dword s[s_group], s[s_ka+0:s_ka+1], 0+k_group
s_load_dwordx2 s[s_group+0:s_group+1], s[s_ka+0:s_ka+1], 0+k_group

; input, thread(n0,n1b,c0,c1e): 1x1x1x1, cluster(n0,n1b,c0,c1e): 1x16x1x16
v_mov_b32 v[v_tmp], v0
Expand Down Expand Up @@ -680,8 +681,8 @@ igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1x1
s_mov_b64 exec, -1

v_add_u32 v[v_move_slice_n_idsho], s[s_move_slice_n_dsho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho_padded], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho_padded], v[v_move_slice_n_idsho]
v_add_u32 v[v_move_slice_n_in1], 1, v[v_move_slice_n_in1]
v_add_u32 v[v_in_os_base], s[s_in_stride_n], v[v_in_os_base]
v_add_u32 v[v_out_os_base], s[s_out_stride_n], v[v_out_os_base]
Expand Down Expand Up @@ -750,8 +751,8 @@ L_igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1
v_add_u32 v[v_move_slice_n_idsho], 1, v[v_move_slice_n_idsho]
s_mov_b64 exec, -1
v_add_u32 v[v_move_slice_n_idsho], s[s_move_slice_n_dsho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho_padded], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho_padded], v[v_move_slice_n_idsho]
v_add_u32 v[v_move_slice_n_in1], 1, v[v_move_slice_n_in1]
v_add_u32 v[v_in_os_base], s[s_in_stride_n], v[v_in_os_base]
v_add_u32 v[v_out_os_base], s[s_out_stride_n], v[v_out_os_base]
Expand Down Expand Up @@ -911,7 +912,7 @@ L_igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1
.amdhsa_system_sgpr_workgroup_id_x 1
.amdhsa_system_vgpr_workitem_id 0
.amdhsa_next_free_vgpr 72
.amdhsa_next_free_sgpr 64
.amdhsa_next_free_sgpr 66
.amdhsa_ieee_mode 0
.amdhsa_dx10_clamp 0
.end_amdhsa_kernel
Expand All @@ -922,7 +923,7 @@ amdhsa.version: [ 1, 0 ]
amdhsa.kernels:
- .name: igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1x16x1x16_tb1x1x1x1_1x16x1x16
.symbol: igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1x16x1x16_tb1x1x1x1_1x16x1x16.kd
.sgpr_count: 70
.sgpr_count: 72
.vgpr_count: 72
.kernarg_segment_align: 8
.kernarg_segment_size: 96
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,33 +239,34 @@
.set s_x, 30
.set s_gemmk_split, 31
.set s_group, 32
.set s_out_stride_k, 33
.set s_ho_padded, 33
.set s_out_stride_k, 34
.set s_hoxwo, 31
.set s_out_stride_n, 34
.set s_in_stride_c, 35
.set s_in_stride_n, 36
.set s_wei_stride_c, 37
.set s_wei_stride_k, 38
.set s_out_stride_n_n1, 39
.set s_in_stride_n_n1, 40
.set s_move_slice_n_n1, 41
.set s_move_slice_n_dsho, 42
.set s_move_slice_n_dswo, 43
.set s_dim_b, 44
.set s_block_gtc_ik, 45
.set s_block_gtc_ic0, 46
.set s_block_gtc_ic1e, 47
.set s_block_gtc_in, 48
.set s_block_gtc_ig, 49
.set s_out_stride_n, 35
.set s_in_stride_c, 36
.set s_in_stride_n, 37
.set s_wei_stride_c, 38
.set s_wei_stride_k, 39
.set s_out_stride_n_n1, 40
.set s_in_stride_n_n1, 41
.set s_move_slice_n_n1, 42
.set s_move_slice_n_dsho, 43
.set s_move_slice_n_dswo, 44
.set s_dim_b, 45
.set s_block_gtc_ik, 46
.set s_block_gtc_ic0, 47
.set s_block_gtc_ic1e, 48
.set s_block_gtc_in, 49
.set s_block_gtc_ig, 50
.set s_knum, 1
.set s_gemm_k_num_n1, 0
.set s_kitr, 3
.set s_in_offset, 50
.set s_out_offset, 50
.set s_sub_n, 56
.set s_k_padded, 57
.set s_tmp, 58
.set s_end, 64
.set s_in_offset, 51
.set s_out_offset, 51
.set s_sub_n, 57
.set s_k_padded, 58
.set s_tmp, 60
.set s_end, 66

.set v_c, 0 ; coalescing:8, needed:0, resuable:33
.set v_a, 0
Expand Down Expand Up @@ -335,7 +336,7 @@ igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1x1
s_load_dwordx2 s[s_p_wei+0:s_p_wei+1], s[s_ka+0:s_ka+1], 0+k_p_wei
s_load_dwordx2 s[s_p_out+0:s_p_out+1], s[s_ka+0:s_ka+1], 0+k_p_out
s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi
s_load_dword s[s_group], s[s_ka+0:s_ka+1], 0+k_group
s_load_dwordx2 s[s_group+0:s_group+1], s[s_ka+0:s_ka+1], 0+k_group

; input, thread(n0,n1b,c0,c1e): 1x1x1x1, cluster(n0,n1b,c0,c1e): 1x16x1x16
v_mov_b32 v[v_tmp], v0
Expand Down Expand Up @@ -682,8 +683,8 @@ igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1x1
s_mov_b64 exec, -1

v_add_u32 v[v_move_slice_n_idsho], s[s_move_slice_n_dsho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho_padded], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho_padded], v[v_move_slice_n_idsho]
v_add_u32 v[v_move_slice_n_in1], 1, v[v_move_slice_n_in1]
v_add_u32 v[v_in_os_base], s[s_in_stride_n], v[v_in_os_base]
v_add_u32 v[v_out_os_base], s[s_out_stride_n], v[v_out_os_base]
Expand Down Expand Up @@ -752,8 +753,8 @@ L_igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1
v_add_u32 v[v_move_slice_n_idsho], 1, v[v_move_slice_n_idsho]
s_mov_b64 exec, -1
v_add_u32 v[v_move_slice_n_idsho], s[s_move_slice_n_dsho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho], v[v_move_slice_n_idsho]
v_cmpx_le_u32 vcc, s[s_ho_padded], v[v_move_slice_n_idsho]
v_subrev_u32 v[v_move_slice_n_idsho], s[s_ho_padded], v[v_move_slice_n_idsho]
v_add_u32 v[v_move_slice_n_in1], 1, v[v_move_slice_n_in1]
v_add_u32 v[v_in_os_base], s[s_in_stride_n], v[v_in_os_base]
v_add_u32 v[v_out_os_base], s[s_out_stride_n], v[v_out_os_base]
Expand Down Expand Up @@ -901,7 +902,7 @@ L_igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1
.amdhsa_system_sgpr_workgroup_id_x 1
.amdhsa_system_vgpr_workitem_id 0
.amdhsa_next_free_vgpr 72
.amdhsa_next_free_sgpr 64
.amdhsa_next_free_sgpr 66
.amdhsa_ieee_mode 0
.amdhsa_dx10_clamp 0
.end_amdhsa_kernel
Expand All @@ -912,7 +913,7 @@ amdhsa.version: [ 1, 0 ]
amdhsa.kernels:
- .name: igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1x16x1x16_tb1x1x1x1_1x16x1x16_gkgs
.symbol: igemm_wrw_gtcx_nchw_fp16_bx16_ex1_bt128x16x16_wt32x8x4_ws2x1_wr1x1_ta1x1x1x8_1x16x1x16_tb1x1x1x1_1x16x1x16_gkgs.kd
.sgpr_count: 70
.sgpr_count: 72
.vgpr_count: 72
.kernarg_segment_align: 8
.kernarg_segment_size: 96
Expand Down
Loading