Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AutoFeatureExtractor support to Wav2Vec2ProcessorWithLM #28706

Merged

Conversation

ylacombe
Copy link
Contributor

What does this PR do?

Using a N-GRAM based language-model on top of Wav2Vec2-based models is an easy way to get a performance boost. At the moment, Wav2Vec2ProcessorWithLM was only compatible with Wav2Vec2FeatureExtractor.

W2V2-Bert could also benefit from this boost, but need its feature extractor to also be compatible with Wav2Vec2ProcessorWithLM.

The easiest way to do it is to add AutoFeatureExtractor instead of Wav2Vec2FeatureExtractor in the code, since the processor only changes the tokenizer behaviour.

cc @sanchit-gandhi @amyeroberts

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle the changes LGTM - just a question of whether we should use the AutoFeatureExtractor class, or specify the feature extractors explicitly. Perhaps @amyeroberts can also lend her opinion here!

tokenizer ([`Wav2Vec2CTCTokenizer`]):
An instance of [`Wav2Vec2CTCTokenizer`]. The tokenizer is a required input.
decoder (`pyctcdecode.BeamSearchDecoderCTC`):
An instance of [`pyctcdecode.BeamSearchDecoderCTC`]. The decoder is a required input.
"""

feature_extractor_class = "Wav2Vec2FeatureExtractor"
feature_extractor_class = "AutoFeatureExtractor"
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi Jan 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps explicitly defining the valid feature extractor classes is less vulnerable to unexpected behaviour? Or do you think we should accept all feature extractors as valid attributes of the Wav2Vec2 processor? I would be in favour of the former

Suggested change
feature_extractor_class = "AutoFeatureExtractor"
feature_extractor_class = ("Wav2Vec2FeatureExtractor", "SeamlessM4TFeatureExtractor")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the tokenizer behaviour, when passing a tuple to ..._class, the second class is chosen by default if use_fast=True (and the other way around). So in that case, it won't work.

This is the first time that we had to face a processor with two possible classes. I'd be happy to modify the ProcessorMixin behaviour but I believe passing AutoFeatureExtractor is an easier workaround, especially since I don't really see a case in which an user would pass another FE yet!

cc @amyeroberts and @ArthurZucker, what do you think of this ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - good with me to use AutoFeatureExtractor in this case since changing the ProcessorMixin would add unnecessary complexity to handle this unique example

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't the first processor to handle the case with two object -- that's Instruct BLIP -- although it's slightly different as that has two tokenizers, rather than accepts two tokenizers.

I'm not a fan of using AutoFeatureExtractor here as the processor doesn't accept any feature extractor, and so the type is misleading. Ideally the processor should be adapted so for e.g. tokenizers it can accept a list of lists e.g. [[ToknizerA, TokenizerAFast], [TokenizerB, TokenizerBFast]], and to accept lists of objects for the other types e.g. [FeatureExtractorA, FeatureExtractorB, FeatureExtractorC].

At the very least, the processor doc should indicate that only these two type are accepted and type verification happens in the init.

Copy link
Contributor

@NielsRogge NielsRogge Mar 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am hitting the same issue at #29414, however I don't really see the problem of using autoclasses in the processors, I just would like it to be consistent (use either both auto classes for feature extractor and tokenizer, or not at all). In case of the use of auto classes, it would indeed be great to explicitly check for the supported classes.

Cause this PR currently would use AutoFeatureExtractor for the feature extractor, but a specific tokenizer class for the tokenizer. Hence I've opened #29414 to make the discussion more general, feel free to comment there!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let' s make sure we indeed check for the supported feature extractors.

  • either we check later that the class is in the list of supported class: preferred for a quick fix
  • we update the logic of feature_extractor_class to support multiple classes. Preferred solution if we don't want to support everything

Copy link
Contributor Author

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @amyeroberts, @ArthurZucker, @sanchit-gandhi, thanks for your review!
I've decided to go with the quickest way for now, WDYT ?

Comment on lines +96 to +100
if feature_extractor.__class__.__name__ not in ["Wav2Vec2FeatureExtractor", "SeamlessM4TFeatureExtractor"]:
raise ValueError(
f"`feature_extractor` has to be of type `Wav2Vec2FeatureExtractor` or `SeamlessM4TFeatureExtractor`, but is {type(feature_extractor)}"
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is the most transformers-like way of checking this, WDYT ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Contributor

@NielsRogge NielsRogge Mar 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it would be more general if we add this verification to the ProcessorMixin class rather than manually defining this ValueError (as we have this case for various models). Could be called supported_feature_extractors for instance, and the mixin class would then automatically raise a ValueError in case the class is not supported.

Alternatively this could be automatically checked based on the feature_extraction_auto.py class, where wav2vec2_with_lm would map to both Wav2Vec2FeatureExtractor and SeamlessM4TFeatureExtractor

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Auto classes are only used for two processors at the moment (neither merged)? If we enable using autoclasses and then add authentication into the ProcessorMixin, then we're just doing a more round-about alternative to just being able to specify more than one class for feature_extractor_class, which I think would be preferable.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this!

@@ -157,6 +157,27 @@ def test_feature_extractor(self):
for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)

def test_another_feature_extractor(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also test that a feature extractor this isn't SeamlessMT4 or Wav2vec2 raises an error on instantiation

Comment on lines +96 to +100
if feature_extractor.__class__.__name__ not in ["Wav2Vec2FeatureExtractor", "SeamlessM4TFeatureExtractor"]:
raise ValueError(
f"`feature_extractor` has to be of type `Wav2Vec2FeatureExtractor` or `SeamlessM4TFeatureExtractor`, but is {type(feature_extractor)}"
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@amyeroberts amyeroberts mentioned this pull request Feb 27, 2024
3 tasks
Copy link

github-actions bot commented Mar 1, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ylacombe ylacombe reopened this May 20, 2024
@ylacombe
Copy link
Contributor Author

Sorry for the long wait, I've added the final test you asked @amyeroberts ! Still good to merge (once all's green) ?

@ylacombe ylacombe merged commit e670870 into huggingface:main May 20, 2024
22 of 24 checks passed
itazap pushed a commit that referenced this pull request May 24, 2024
* Add AutoFeatureExtractor support to Wav2Vec2ProcessorWithLM

* update with a type filter

* add raises error test

* fix added test
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jun 11, 2024
…ace#28706)

* Add AutoFeatureExtractor support to Wav2Vec2ProcessorWithLM

* update with a type filter

* add raises error test

* fix added test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants