diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index b0df39e1642b..12fef5103d85 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -566,17 +566,17 @@ def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrain """ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) - feature_extractor = cls(**feature_extractor_dict) - # Update feature_extractor with kwargs if needed to_remove = [] for key, value in kwargs.items(): - if hasattr(feature_extractor, key): - setattr(feature_extractor, key, value) + if key in feature_extractor_dict: + feature_extractor_dict[key] = value to_remove.append(key) for key in to_remove: kwargs.pop(key, None) + feature_extractor = cls(**feature_extractor_dict) + logger.info(f"Feature extractor {feature_extractor}") if return_unused_kwargs: return feature_extractor, kwargs diff --git a/tests/models/whisper/test_feature_extraction_whisper.py b/tests/models/whisper/test_feature_extraction_whisper.py index 8b1e25927e33..a8295542f4e3 100644 --- a/tests/models/whisper/test_feature_extraction_whisper.py +++ b/tests/models/whisper/test_feature_extraction_whisper.py @@ -142,6 +142,20 @@ def test_feat_extract_to_json_file(self): self.assertTrue(np.allclose(mel_1, mel_2)) self.assertEqual(dict_first, dict_second) + def test_feat_extract_from_pretrained_kwargs(self): + feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + saved_file = feat_extract_first.save_pretrained(tmpdirname)[0] + check_json_file_has_correct_format(saved_file) + feat_extract_second = self.feature_extraction_class.from_pretrained( + tmpdirname, feature_size=2 * self.feat_extract_dict["feature_size"] + ) + + mel_1 = feat_extract_first.mel_filters + mel_2 = feat_extract_second.mel_filters + self.assertTrue(2 * mel_1.shape[1] == mel_2.shape[1]) + def test_call(self): # Tests that all call wrap to encode_plus and batch_encode_plus feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())