From 0c0f3eaf44250d5493c1e48391a13828cf7ab892 Mon Sep 17 00:00:00 2001 From: Pingchuan Ma Date: Thu, 7 Sep 2023 11:16:11 +0000 Subject: [PATCH] Fix type casting issue in mask length calculation --- torchaudio/models/wav2vec2/components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/models/wav2vec2/components.py b/torchaudio/models/wav2vec2/components.py index 2717ead9c9..480a6ae509 100644 --- a/torchaudio/models/wav2vec2/components.py +++ b/torchaudio/models/wav2vec2/components.py @@ -894,7 +894,7 @@ def _compute_mask_indices( if mask_type == "static": lengths = torch.full((num_mask,), mask_length) elif mask_type == "uniform": - lengths = torch.randint(mask_other, mask_length * 2 + 1, size=(num_mask,)) + lengths = torch.randint(int(mask_other), mask_length * 2 + 1, size=(num_mask,)) elif mask_type == "normal": lengths = torch.normal(mask_length, mask_other, size=(num_mask,)) lengths = torch.maximum(torch.ones(1), torch.round(lengths)).int()