Skip to content

Commit

Permalink
Fix type casting issue in mask length calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
Pingchuan Ma committed Sep 7, 2023
1 parent ede4309 commit 0c0f3ea
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchaudio/models/wav2vec2/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ def _compute_mask_indices(
if mask_type == "static":
lengths = torch.full((num_mask,), mask_length)
elif mask_type == "uniform":
lengths = torch.randint(mask_other, mask_length * 2 + 1, size=(num_mask,))
lengths = torch.randint(int(mask_other), mask_length * 2 + 1, size=(num_mask,))
elif mask_type == "normal":
lengths = torch.normal(mask_length, mask_other, size=(num_mask,))
lengths = torch.maximum(torch.ones(1), torch.round(lengths)).int()
Expand Down

0 comments on commit 0c0f3ea

Please sign in to comment.