diff --git a/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py b/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py index 0d4879a35ea3..2a83e56fc0bd 100644 --- a/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py @@ -286,14 +286,14 @@ def __call__( remainder = num_frames % self.stride if remainder != 0: - input_features = input_features[:, :num_frames, :] - attention_mask = attention_mask[:, :num_frames] + input_features = input_features[:, : num_frames - remainder, :] + attention_mask = attention_mask[:, : num_frames - remainder] input_features = np.reshape( input_features, (batch_size, num_frames // self.stride, num_channels * self.stride) ) - indices = np.arange(0, num_frames) + indices = np.arange(0, num_frames - remainder) attention_mask = attention_mask[:, indices % self.stride == 1] padded_inputs["input_features"] = input_features diff --git a/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py b/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py index a8fca4b90ba9..8830660c097c 100644 --- a/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py +++ b/tests/models/seamless_m4t/test_feature_extraction_seamless_m4t.py @@ -171,6 +171,63 @@ def test_call_numpy(self): for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + def test_call_with_padded_input_not_multiple_of_stride(self): + # same as test_call_numpy but with stride=6 and pad_to_multiple_of=8 + # the input sizes 800, 1400 and 200 are a multiple of pad_to_multiple_of but not a multiple of stride + # therefore remainder = num_frames % self.stride will not be zero and must be subtracted from num_frames + stride = 6 + pad_to_multiple_of = 8 + + feature_extractor_args = self.feat_extract_tester.prepare_feat_extract_dict() + feature_extractor_args["stride"] = stride + feature_extractor = self.feature_extraction_class(**feature_extractor_args) + + speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] + np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs] + + # Test feature size and attention mask size + output = feature_extractor(np_speech_inputs, pad_to_multiple_of=pad_to_multiple_of, return_tensors="np") + input_features = output.input_features + self.assertTrue(input_features.ndim == 3) + self.assertTrue(input_features.shape[0] == 3) + self.assertTrue(input_features.shape[-1] == feature_extractor.feature_size * feature_extractor.stride) + # same as test_attention_mask + attention_mask = output.attention_mask + self.assertTrue(attention_mask.ndim == 2) + self.assertTrue(attention_mask.shape[0] == 3) + self.assertTrue(attention_mask.shape[-1] == input_features.shape[1]) + + # Test not batched input + encoded_sequences_1 = feature_extractor( + speech_inputs[0], pad_to_multiple_of=pad_to_multiple_of, return_tensors="np" + ).input_features + encoded_sequences_2 = feature_extractor( + np_speech_inputs[0], pad_to_multiple_of=pad_to_multiple_of, return_tensors="np" + ).input_features + self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3)) + + # Test batched + encoded_sequences_1 = feature_extractor( + speech_inputs, pad_to_multiple_of=pad_to_multiple_of, return_tensors="np" + ).input_features + encoded_sequences_2 = feature_extractor( + np_speech_inputs, pad_to_multiple_of=pad_to_multiple_of, return_tensors="np" + ).input_features + for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): + self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + + # Test 2-D numpy arrays are batched. + speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)] + np_speech_inputs = np.asarray(speech_inputs) + encoded_sequences_1 = feature_extractor( + speech_inputs, pad_to_multiple_of=pad_to_multiple_of, return_tensors="np" + ).input_features + encoded_sequences_2 = feature_extractor( + np_speech_inputs, pad_to_multiple_of=pad_to_multiple_of, return_tensors="np" + ).input_features + for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): + self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + def test_call_without_attention_mask(self): feature_extractor_args = self.feat_extract_tester.prepare_feat_extract_dict() feature_extractor = self.feature_extraction_class(**feature_extractor_args)