Skip to content

Commit

Permalink
Fix type casting issue in mask length calculation (#3599)
Browse files Browse the repository at this point in the history
  • Loading branch information
Pingchuan Ma authored Sep 7, 2023
1 parent ede4309 commit e756b23
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchaudio/models/wav2vec2/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ def _compute_mask_indices(
if mask_type == "static":
lengths = torch.full((num_mask,), mask_length)
elif mask_type == "uniform":
lengths = torch.randint(mask_other, mask_length * 2 + 1, size=(num_mask,))
lengths = torch.randint(int(mask_other), mask_length * 2 + 1, size=(num_mask,))
elif mask_type == "normal":
lengths = torch.normal(mask_length, mask_other, size=(num_mask,))
lengths = torch.maximum(torch.ones(1), torch.round(lengths)).int()
Expand Down

0 comments on commit e756b23

Please sign in to comment.