Skip to content

Latest commit

 

History

History
13 lines (8 loc) · 325 Bytes

nozero2fixedshape.md

File metadata and controls

13 lines (8 loc) · 325 Bytes

尝试用torch.where()以及toch.stack()代替torch.nonzero()

# patch_q_cnt shape [b, k]
patch_nonzero_idxs = (patch_q_cnt > 0).nozero(as_tuple=False) 

# ---------------等效做法----------------  
patch_q_idxs_tmp = torch.where(patch_q_cnt > 0)
patch_nonzero_idxs = torch.stack(patch_q_idxs_tmp, dim=1)