diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index e7afe0ca21ac99..f434efbe3e5af5 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -241,18 +241,18 @@ def random_masking(self, sequence, noise=None): noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1] # sort noise for each sample - ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove - ids_restore = torch.argsort(ids_shuffle, dim=1) + ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device) # keep the first subset - ids_keep = ids_shuffle[:, :len_keep].to(sequence.device) + ids_keep = ids_shuffle[:, :len_keep] sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)) # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([batch_size, seq_length], device=sequence.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask - mask = torch.gather(mask, dim=1, index=ids_restore.to(sequence.device)) + mask = torch.gather(mask, dim=1, index=ids_restore) return sequence_unmasked, mask, ids_restore