Skip to content

Commit

Permalink
[Feature Extractors] Fix kwargs to pre-trained (#30260)
Browse files Browse the repository at this point in the history
fixes
  • Loading branch information
sanchit-gandhi authored and ydshieh committed Apr 23, 2024
1 parent 9900b6b commit 9b5041d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,17 +566,17 @@ def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrain
"""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)

feature_extractor = cls(**feature_extractor_dict)

# Update feature_extractor with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(feature_extractor, key):
setattr(feature_extractor, key, value)
if key in feature_extractor_dict:
feature_extractor_dict[key] = value
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)

feature_extractor = cls(**feature_extractor_dict)

logger.info(f"Feature extractor {feature_extractor}")
if return_unused_kwargs:
return feature_extractor, kwargs
Expand Down
14 changes: 14 additions & 0 deletions tests/models/whisper/test_feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,20 @@ def test_feat_extract_to_json_file(self):
self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second)

def test_feat_extract_from_pretrained_kwargs(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)

with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(
tmpdirname, feature_size=2 * self.feat_extract_dict["feature_size"]
)

mel_1 = feat_extract_first.mel_filters
mel_2 = feat_extract_second.mel_filters
self.assertTrue(2 * mel_1.shape[1] == mel_2.shape[1])

def test_call(self):
# Tests that all call wrap to encode_plus and batch_encode_plus
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
Expand Down

0 comments on commit 9b5041d

Please sign in to comment.