Skip to content

Commit

Permalink
[NHWC][asm igemm][GFX90a] add support for several NHWC bwd ssd config…
Browse files Browse the repository at this point in the history
…, when k=4x, 2x (#1136)
  • Loading branch information
carlushuang authored Sep 6, 2021
1 parent 142fe49 commit 4e7b1ab
Show file tree
Hide file tree
Showing 186 changed files with 66,554 additions and 9,184 deletions.
50 changes: 6 additions & 44 deletions ...c_fp16_bx0_ex0_bt128x128x32_wt32x32x8_ws1x1_wr2x2_ta1x8x2x1_1x4x1x64_tb1x8x1x2_1x4x1x64.s
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -23,47 +23,9 @@
* SOFTWARE.
*
*******************************************************************************/
; generated by igemm_codegen.py (2461eab400b8d4378cb16e464421a920037d1b0f)
; generated by igemm_codegen.py (32a41f791dcf0139e95f217f3905939fbbae794c)
;
.macro .mdiv_u32_ss s_quot s_numer s_magic s_shift s_tmp
s_mul_hi_u32 s[\s_tmp], s[\s_magic], s[\s_numer]
s_add_u32 s[\s_tmp], s[\s_tmp], s[\s_numer]
s_lshr_b32 s[\s_quot], s[\s_tmp], s[\s_shift]
.endm

.macro .mdiv_u32_rem_ss s_rem s_quot s_numer s_magic s_shift s_denom s_tmp
.mdiv_u32_ss \s_quot,\s_numer,\s_magic,\s_shift,\s_tmp
s_mul_i32 s[\s_tmp], s[\s_denom], s[\s_quot]
s_sub_u32 s[\s_rem], s[\s_numer], s[\s_tmp]
.endm

.macro .mdiv_u32_vs v_quot v_numer s_magic s_shift v_tmp
v_mul_hi_u32 v[\v_tmp], s[\s_magic], v[\v_numer]
v_add_u32 v[\v_tmp], v[\v_tmp], v[\v_numer]
v_lshrrev_b32 v[\v_quot], s[\s_shift], v[\v_tmp]
.endm

.macro .mdiv_u32_rem_vs v_rem v_quot v_numer s_magic s_shift s_denom v_tmp
.mdiv_u32_vs \v_quot,\v_numer,\s_magic,\s_shift,\v_tmp
v_mul_lo_u32 v[\v_tmp], s[\s_denom], v[\v_quot]
v_sub_u32 v[\v_rem], v[\v_numer], v[\v_tmp]
.endm

.macro .v_clear_acc_c a, num
_a = \a
.rept \num
v_accvgpr_write_b32 a[_a], 0
_a = _a + 1
.endr
.endm

.macro .v_clear_nc vid, num
_v = \vid
.rept \num
v_mov_b32 v[_v], 0
_v = _v + 1
.endr
.endm
.include "igemm_bwd_gtcx2_nhwc_fp16_utils.inc"

;----------------------------------------------------------
; starting of kernel igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x128x32_wt32x32x8_ws1x1_wr2x2_ta1x8x2x1_1x4x1x64_tb1x8x1x2_1x4x1x64
Expand Down Expand Up @@ -230,12 +192,12 @@ igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x128x32_wt32x32x8_ws1x1_wr2x2_ta1x8x2x1_1
s_load_dwordx2 s[s_magic_2+0:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2
s_load_dword s[s_shift_pack_0], s[s_ka+0:s_ka+1], 0+k_shift_pack_0
; out(e, k, nb0, nb1) thread_lengths: 1x8x2x1, cluster_length: 1x4x1x64, k_pack:8
; wei(e, k, c0, c1) thread_length: 1x8x1x2, cluster_length: 1x4x1x64, k_pack:8
v_mov_b32 v[v_tmp], v0
v_and_b32 v[v_out_ik], 3, v[v_tmp]
v_lshlrev_b32 v[v_out_ik], 3, v[v_out_ik]
v_lshrrev_b32 v[v_tmp], 2, v[v_tmp]
v_and_b32 v[v_out_inb], 63, v[v_tmp]
; wei(e, k, c0, c1) thread_length: 1x8x1x2, cluster_length: 1x4x1x64, k_pack:8
v_mov_b32 v[v_tmp], v0
v_and_b32 v[v_wei_ic], 63, v[v_tmp]
v_lshlrev_b32 v[v_wei_ic], 1, v[v_wei_ic]
Expand Down Expand Up @@ -293,9 +255,6 @@ igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x128x32_wt32x32x8_ws1x1_wr2x2_ta1x8x2x1_1
.mdiv_u32_rem_vs v_tmp+4,v_out_in,v_tmp+5,s_magic_3,s_shift_m3,s_dim_br,v_tmp
s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8
.mdiv_u32_rem_vs v_out_iwo_list,v_out_iho_list,v_tmp+4,s_magic_2,s_shift_m2,s_wi,v_tmp
v_cmp_gt_u32 vcc, s[s_n], v[v_out_in]
v_cndmask_b32 v[v_tmp], 0, 1, vcc
v_lshlrev_b32 v[v_out_flag_n], 0, v[v_tmp]
s_lshl_b32 s[s_block_gtc_ig], s[s_block_gtc_ig], 1
; calculate wei offset
s_mul_i32 s[s_tmp+2], s[s_k], s[s_wei_stride_k]
Expand Down Expand Up @@ -332,6 +291,9 @@ igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x128x32_wt32x32x8_ws1x1_wr2x2_ta1x8x2x1_1
buffer_load_dword v[v_gld_b+7], v[v_wei_os], s[s_p_wei:s_p_wei+3], s[s_wei_offset+5] offen offset:0
s_mov_b64 exec, -1

v_cmp_gt_u32 vcc, s[s_n], v[v_out_in]
v_cndmask_b32 v[v_tmp], 0, 1, vcc
v_lshlrev_b32 v[v_out_flag_n], 0, v[v_tmp]
; calculate output offset
s_mov_b32 s[s_out_offset], 0
s_mul_i32 s[s_tmp], s[s_block_gtc_ig], s[s_k]
Expand Down
52 changes: 8 additions & 44 deletions ...6_bx0_ex0_bt128x128x32_wt32x32x8_ws1x1_wr2x2_ta1x8x2x1_1x4x1x64_tb1x8x1x2_1x4x1x64_gkgs.s
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -23,47 +23,9 @@
* SOFTWARE.
*
*******************************************************************************/
; generated by igemm_codegen.py (2461eab400b8d4378cb16e464421a920037d1b0f)
; generated by igemm_codegen.py (32a41f791dcf0139e95f217f3905939fbbae794c)
;
.macro .mdiv_u32_ss s_quot s_numer s_magic s_shift s_tmp
s_mul_hi_u32 s[\s_tmp], s[\s_magic], s[\s_numer]
s_add_u32 s[\s_tmp], s[\s_tmp], s[\s_numer]
s_lshr_b32 s[\s_quot], s[\s_tmp], s[\s_shift]
.endm

.macro .mdiv_u32_rem_ss s_rem s_quot s_numer s_magic s_shift s_denom s_tmp
.mdiv_u32_ss \s_quot,\s_numer,\s_magic,\s_shift,\s_tmp
s_mul_i32 s[\s_tmp], s[\s_denom], s[\s_quot]
s_sub_u32 s[\s_rem], s[\s_numer], s[\s_tmp]
.endm

.macro .mdiv_u32_vs v_quot v_numer s_magic s_shift v_tmp
v_mul_hi_u32 v[\v_tmp], s[\s_magic], v[\v_numer]
v_add_u32 v[\v_tmp], v[\v_tmp], v[\v_numer]
v_lshrrev_b32 v[\v_quot], s[\s_shift], v[\v_tmp]
.endm

.macro .mdiv_u32_rem_vs v_rem v_quot v_numer s_magic s_shift s_denom v_tmp
.mdiv_u32_vs \v_quot,\v_numer,\s_magic,\s_shift,\v_tmp
v_mul_lo_u32 v[\v_tmp], s[\s_denom], v[\v_quot]
v_sub_u32 v[\v_rem], v[\v_numer], v[\v_tmp]
.endm

.macro .v_clear_acc_c a, num
_a = \a
.rept \num
v_accvgpr_write_b32 a[_a], 0
_a = _a + 1
.endr
.endm

.macro .v_clear_nc vid, num
_v = \vid
.rept \num
v_mov_b32 v[_v], 0
_v = _v + 1
.endr
.endm
.include "igemm_bwd_gtcx2_nhwc_fp16_utils.inc"

;----------------------------------------------------------
; starting of kernel igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x128x32_wt32x32x8_ws1x1_wr2x2_ta1x8x2x1_1x4x1x64_tb1x8x1x2_1x4x1x64_gkgs
Expand Down Expand Up @@ -235,12 +197,12 @@ igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x128x32_wt32x32x8_ws1x1_wr2x2_ta1x8x2x1_1
s_load_dword s[s_shift_pack_0], s[s_ka+0:s_ka+1], 0+k_shift_pack_0
s_load_dword s[s_gemmk_split], s[s_ka+0:s_ka+1], 0+k_gemm_k_global_split
; out(e, k, nb0, nb1) thread_lengths: 1x8x2x1, cluster_length: 1x4x1x64, k_pack:8
; wei(e, k, c0, c1) thread_length: 1x8x1x2, cluster_length: 1x4x1x64, k_pack:8
v_mov_b32 v[v_tmp], v0
v_and_b32 v[v_out_ik], 3, v[v_tmp]
v_lshlrev_b32 v[v_out_ik], 3, v[v_out_ik]
v_lshrrev_b32 v[v_tmp], 2, v[v_tmp]
v_and_b32 v[v_out_inb], 63, v[v_tmp]
; wei(e, k, c0, c1) thread_length: 1x8x1x2, cluster_length: 1x4x1x64, k_pack:8
v_mov_b32 v[v_tmp], v0
v_and_b32 v[v_wei_ic], 63, v[v_tmp]
v_lshlrev_b32 v[v_wei_ic], 1, v[v_wei_ic]
Expand Down Expand Up @@ -286,6 +248,8 @@ igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x128x32_wt32x32x8_ws1x1_wr2x2_ta1x8x2x1_1
s_and_b32 s[s_block_gtc_ik], s[s_bx], s[s_tmp+3]
s_lshr_b32 s[s_bx], s[s_bx], s[s_gemmk_split]
s_mul_i32 s[s_block_gtc_ik], s[s_block_gtc_ik], s[s_sub_k]
s_cmp_lt_u32 s[s_block_gtc_ik], s[s_k]
s_cbranch_scc0 L_igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x128x32_wt32x32x8_ws1x1_wr2x2_ta1x8x2x1_1x4x1x64_tb1x8x1x2_1x4x1x64_gkgs_out
s_lshr_b32 s[s_tmp], s[s_dim_mp], 7
s_lshr_b32 s[s_tmp+1], s[s_dim_np], 7
s_mul_i32 s[0], s[s_tmp+1], s[s_tmp]
Expand All @@ -305,9 +269,6 @@ igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x128x32_wt32x32x8_ws1x1_wr2x2_ta1x8x2x1_1
.mdiv_u32_rem_vs v_tmp+4,v_out_in,v_tmp+5,s_magic_3,s_shift_m3,s_dim_br,v_tmp
s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8
.mdiv_u32_rem_vs v_out_iwo_list,v_out_iho_list,v_tmp+4,s_magic_2,s_shift_m2,s_wi,v_tmp
v_cmp_gt_u32 vcc, s[s_n], v[v_out_in]
v_cndmask_b32 v[v_tmp], 0, 1, vcc
v_lshlrev_b32 v[v_out_flag_n], 0, v[v_tmp]
s_lshl_b32 s[s_block_gtc_ig], s[s_block_gtc_ig], 1
; calculate wei offset
s_mul_i32 s[s_tmp+2], s[s_k], s[s_wei_stride_k]
Expand Down Expand Up @@ -345,6 +306,9 @@ igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x128x32_wt32x32x8_ws1x1_wr2x2_ta1x8x2x1_1
buffer_load_dword v[v_gld_b+7], v[v_wei_os], s[s_p_wei:s_p_wei+3], s[s_wei_offset+5] offen offset:0
s_mov_b64 exec, -1

v_cmp_gt_u32 vcc, s[s_n], v[v_out_in]
v_cndmask_b32 v[v_tmp], 0, 1, vcc
v_lshlrev_b32 v[v_out_flag_n], 0, v[v_tmp]
; calculate output offset
s_mov_b32 s[s_out_offset], 0
s_mul_i32 s[s_tmp], s[s_block_gtc_ig], s[s_k]
Expand Down
50 changes: 6 additions & 44 deletions ...c_fp16_bx0_ex0_bt128x256x32_wt32x32x8_ws1x2_wr2x2_ta1x8x2x1_1x4x1x64_tb1x8x1x4_1x4x1x64.s
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -23,47 +23,9 @@
* SOFTWARE.
*
*******************************************************************************/
; generated by igemm_codegen.py (2461eab400b8d4378cb16e464421a920037d1b0f)
; generated by igemm_codegen.py (32a41f791dcf0139e95f217f3905939fbbae794c)
;
.macro .mdiv_u32_ss s_quot s_numer s_magic s_shift s_tmp
s_mul_hi_u32 s[\s_tmp], s[\s_magic], s[\s_numer]
s_add_u32 s[\s_tmp], s[\s_tmp], s[\s_numer]
s_lshr_b32 s[\s_quot], s[\s_tmp], s[\s_shift]
.endm

.macro .mdiv_u32_rem_ss s_rem s_quot s_numer s_magic s_shift s_denom s_tmp
.mdiv_u32_ss \s_quot,\s_numer,\s_magic,\s_shift,\s_tmp
s_mul_i32 s[\s_tmp], s[\s_denom], s[\s_quot]
s_sub_u32 s[\s_rem], s[\s_numer], s[\s_tmp]
.endm

.macro .mdiv_u32_vs v_quot v_numer s_magic s_shift v_tmp
v_mul_hi_u32 v[\v_tmp], s[\s_magic], v[\v_numer]
v_add_u32 v[\v_tmp], v[\v_tmp], v[\v_numer]
v_lshrrev_b32 v[\v_quot], s[\s_shift], v[\v_tmp]
.endm

.macro .mdiv_u32_rem_vs v_rem v_quot v_numer s_magic s_shift s_denom v_tmp
.mdiv_u32_vs \v_quot,\v_numer,\s_magic,\s_shift,\v_tmp
v_mul_lo_u32 v[\v_tmp], s[\s_denom], v[\v_quot]
v_sub_u32 v[\v_rem], v[\v_numer], v[\v_tmp]
.endm

.macro .v_clear_acc_c a, num
_a = \a
.rept \num
v_accvgpr_write_b32 a[_a], 0
_a = _a + 1
.endr
.endm

.macro .v_clear_nc vid, num
_v = \vid
.rept \num
v_mov_b32 v[_v], 0
_v = _v + 1
.endr
.endm
.include "igemm_bwd_gtcx2_nhwc_fp16_utils.inc"

;----------------------------------------------------------
; starting of kernel igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x256x32_wt32x32x8_ws1x2_wr2x2_ta1x8x2x1_1x4x1x64_tb1x8x1x4_1x4x1x64
Expand Down Expand Up @@ -230,12 +192,12 @@ igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x256x32_wt32x32x8_ws1x2_wr2x2_ta1x8x2x1_1
s_load_dwordx2 s[s_magic_2+0:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2
s_load_dword s[s_shift_pack_0], s[s_ka+0:s_ka+1], 0+k_shift_pack_0
; out(e, k, nb0, nb1) thread_lengths: 1x8x2x1, cluster_length: 1x4x1x64, k_pack:8
; wei(e, k, c0, c1) thread_length: 1x8x1x4, cluster_length: 1x4x1x64, k_pack:8
v_mov_b32 v[v_tmp], v0
v_and_b32 v[v_out_ik], 3, v[v_tmp]
v_lshlrev_b32 v[v_out_ik], 3, v[v_out_ik]
v_lshrrev_b32 v[v_tmp], 2, v[v_tmp]
v_and_b32 v[v_out_inb], 63, v[v_tmp]
; wei(e, k, c0, c1) thread_length: 1x8x1x4, cluster_length: 1x4x1x64, k_pack:8
v_mov_b32 v[v_tmp], v0
v_and_b32 v[v_wei_ic], 63, v[v_tmp]
v_lshlrev_b32 v[v_wei_ic], 2, v[v_wei_ic]
Expand Down Expand Up @@ -293,9 +255,6 @@ igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x256x32_wt32x32x8_ws1x2_wr2x2_ta1x8x2x1_1
.mdiv_u32_rem_vs v_tmp+4,v_out_in,v_tmp+5,s_magic_3,s_shift_m3,s_dim_br,v_tmp
s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8
.mdiv_u32_rem_vs v_out_iwo_list,v_out_iho_list,v_tmp+4,s_magic_2,s_shift_m2,s_wi,v_tmp
v_cmp_gt_u32 vcc, s[s_n], v[v_out_in]
v_cndmask_b32 v[v_tmp], 0, 1, vcc
v_lshlrev_b32 v[v_out_flag_n], 0, v[v_tmp]
s_lshl_b32 s[s_block_gtc_ig], s[s_block_gtc_ig], 1
; calculate wei offset
s_mul_i32 s[s_tmp+2], s[s_k], s[s_wei_stride_k]
Expand Down Expand Up @@ -332,6 +291,9 @@ igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x256x32_wt32x32x8_ws1x2_wr2x2_ta1x8x2x1_1
buffer_load_dwordx2 v[v_gld_b+14:v_gld_b+14+1], v[v_wei_os], s[s_p_wei:s_p_wei+3], s[s_wei_offset+5] offen offset:0
s_mov_b64 exec, -1

v_cmp_gt_u32 vcc, s[s_n], v[v_out_in]
v_cndmask_b32 v[v_tmp], 0, 1, vcc
v_lshlrev_b32 v[v_out_flag_n], 0, v[v_tmp]
; calculate output offset
s_mov_b32 s[s_out_offset], 0
s_mul_i32 s[s_tmp], s[s_block_gtc_ig], s[s_k]
Expand Down
52 changes: 8 additions & 44 deletions ...6_bx0_ex0_bt128x256x32_wt32x32x8_ws1x2_wr2x2_ta1x8x2x1_1x4x1x64_tb1x8x1x4_1x4x1x64_gkgs.s
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -23,47 +23,9 @@
* SOFTWARE.
*
*******************************************************************************/
; generated by igemm_codegen.py (2461eab400b8d4378cb16e464421a920037d1b0f)
; generated by igemm_codegen.py (32a41f791dcf0139e95f217f3905939fbbae794c)
;
.macro .mdiv_u32_ss s_quot s_numer s_magic s_shift s_tmp
s_mul_hi_u32 s[\s_tmp], s[\s_magic], s[\s_numer]
s_add_u32 s[\s_tmp], s[\s_tmp], s[\s_numer]
s_lshr_b32 s[\s_quot], s[\s_tmp], s[\s_shift]
.endm

.macro .mdiv_u32_rem_ss s_rem s_quot s_numer s_magic s_shift s_denom s_tmp
.mdiv_u32_ss \s_quot,\s_numer,\s_magic,\s_shift,\s_tmp
s_mul_i32 s[\s_tmp], s[\s_denom], s[\s_quot]
s_sub_u32 s[\s_rem], s[\s_numer], s[\s_tmp]
.endm

.macro .mdiv_u32_vs v_quot v_numer s_magic s_shift v_tmp
v_mul_hi_u32 v[\v_tmp], s[\s_magic], v[\v_numer]
v_add_u32 v[\v_tmp], v[\v_tmp], v[\v_numer]
v_lshrrev_b32 v[\v_quot], s[\s_shift], v[\v_tmp]
.endm

.macro .mdiv_u32_rem_vs v_rem v_quot v_numer s_magic s_shift s_denom v_tmp
.mdiv_u32_vs \v_quot,\v_numer,\s_magic,\s_shift,\v_tmp
v_mul_lo_u32 v[\v_tmp], s[\s_denom], v[\v_quot]
v_sub_u32 v[\v_rem], v[\v_numer], v[\v_tmp]
.endm

.macro .v_clear_acc_c a, num
_a = \a
.rept \num
v_accvgpr_write_b32 a[_a], 0
_a = _a + 1
.endr
.endm

.macro .v_clear_nc vid, num
_v = \vid
.rept \num
v_mov_b32 v[_v], 0
_v = _v + 1
.endr
.endm
.include "igemm_bwd_gtcx2_nhwc_fp16_utils.inc"

;----------------------------------------------------------
; starting of kernel igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x256x32_wt32x32x8_ws1x2_wr2x2_ta1x8x2x1_1x4x1x64_tb1x8x1x4_1x4x1x64_gkgs
Expand Down Expand Up @@ -235,12 +197,12 @@ igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x256x32_wt32x32x8_ws1x2_wr2x2_ta1x8x2x1_1
s_load_dword s[s_shift_pack_0], s[s_ka+0:s_ka+1], 0+k_shift_pack_0
s_load_dword s[s_gemmk_split], s[s_ka+0:s_ka+1], 0+k_gemm_k_global_split
; out(e, k, nb0, nb1) thread_lengths: 1x8x2x1, cluster_length: 1x4x1x64, k_pack:8
; wei(e, k, c0, c1) thread_length: 1x8x1x4, cluster_length: 1x4x1x64, k_pack:8
v_mov_b32 v[v_tmp], v0
v_and_b32 v[v_out_ik], 3, v[v_tmp]
v_lshlrev_b32 v[v_out_ik], 3, v[v_out_ik]
v_lshrrev_b32 v[v_tmp], 2, v[v_tmp]
v_and_b32 v[v_out_inb], 63, v[v_tmp]
; wei(e, k, c0, c1) thread_length: 1x8x1x4, cluster_length: 1x4x1x64, k_pack:8
v_mov_b32 v[v_tmp], v0
v_and_b32 v[v_wei_ic], 63, v[v_tmp]
v_lshlrev_b32 v[v_wei_ic], 2, v[v_wei_ic]
Expand Down Expand Up @@ -286,6 +248,8 @@ igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x256x32_wt32x32x8_ws1x2_wr2x2_ta1x8x2x1_1
s_and_b32 s[s_block_gtc_ik], s[s_bx], s[s_tmp+3]
s_lshr_b32 s[s_bx], s[s_bx], s[s_gemmk_split]
s_mul_i32 s[s_block_gtc_ik], s[s_block_gtc_ik], s[s_sub_k]
s_cmp_lt_u32 s[s_block_gtc_ik], s[s_k]
s_cbranch_scc0 L_igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x256x32_wt32x32x8_ws1x2_wr2x2_ta1x8x2x1_1x4x1x64_tb1x8x1x4_1x4x1x64_gkgs_out
s_lshr_b32 s[s_tmp], s[s_dim_mp], 7
s_lshr_b32 s[s_tmp+1], s[s_dim_np], 8
s_mul_i32 s[0], s[s_tmp+1], s[s_tmp]
Expand All @@ -305,9 +269,6 @@ igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x256x32_wt32x32x8_ws1x2_wr2x2_ta1x8x2x1_1
.mdiv_u32_rem_vs v_tmp+4,v_out_in,v_tmp+5,s_magic_3,s_shift_m3,s_dim_br,v_tmp
s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8
.mdiv_u32_rem_vs v_out_iwo_list,v_out_iho_list,v_tmp+4,s_magic_2,s_shift_m2,s_wi,v_tmp
v_cmp_gt_u32 vcc, s[s_n], v[v_out_in]
v_cndmask_b32 v[v_tmp], 0, 1, vcc
v_lshlrev_b32 v[v_out_flag_n], 0, v[v_tmp]
s_lshl_b32 s[s_block_gtc_ig], s[s_block_gtc_ig], 1
; calculate wei offset
s_mul_i32 s[s_tmp+2], s[s_k], s[s_wei_stride_k]
Expand Down Expand Up @@ -345,6 +306,9 @@ igemm_bwd_gtcx2_nhwc_fp16_bx0_ex0_bt128x256x32_wt32x32x8_ws1x2_wr2x2_ta1x8x2x1_1
buffer_load_dwordx2 v[v_gld_b+14:v_gld_b+14+1], v[v_wei_os], s[s_p_wei:s_p_wei+3], s[s_wei_offset+5] offen offset:0
s_mov_b64 exec, -1

v_cmp_gt_u32 vcc, s[s_n], v[v_out_in]
v_cndmask_b32 v[v_tmp], 0, 1, vcc
v_lshlrev_b32 v[v_out_flag_n], 0, v[v_tmp]
; calculate output offset
s_mov_b32 s[s_out_offset], 0
s_mul_i32 s[s_tmp], s[s_block_gtc_ig], s[s_k]
Expand Down
Loading

0 comments on commit 4e7b1ab

Please sign in to comment.