-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow dict input for audio classification pipeline #23445
Allow dict input for audio classification pipeline #23445
Conversation
The documentation is not available anymore as the PR was closed or merged. |
3eb2035
to
ae50bbb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your PR. Since you didn't put a description, I don't really understand what this tries to solve while making the API more complex.
|
||
_inputs = inputs.pop("raw", None) | ||
if _inputs is None: | ||
# Remove path which will not be used from `datasets`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have no idea what this means.
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this | ||
pipeline do the resampling. The dict must be either be in the format `{"sampling_rate": int, | ||
"raw": np.array}`, or `{"sampling_rate": int, "array": np.array}`, where the key `"raw"` or | ||
`"array"` is used to denote the raw audio waveform. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why accept both raw and array? Seems very brittle as an API. Also why add a new argument type instead of just accepting a new sampling_rate
argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied one-for-one from:
inputs (`np.ndarray` or `bytes` or `str` or `dict`): |
Originally, the ASR pipeline only accepted they raw
for the input waveform, but this was updated to accept both raw
and array
to bring the pipeline into alignment with datasets
, where the 1-d audio arrays go under the dict key array
(see comment below and motivations for this consistency in #20414 (comment))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Saw the raw key on the ASR pipeline was kept for backward compatibility. Do we really need to introduce it there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed! I'm agnostic - I kept it in to have consistency across the pipeline classes (e.g. if a user typically passes the raw
key in the ASR pipeline, then they would expect it to work for the audio classification pipeline), but can simplify it to just accept the array
key if we don't mind about this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's see what @Narsil thinks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gently pinging @Narsil - would be nice to have this ready in transformers
for the next release (unblocks huggingface/audio-transformers-course#25)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, sorry for this.
Usually I'm kind of against accepting widely different types, this case is different since it's about our ecosystem and making datasets
+ pipeline
work together nicer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're all happy I'll keep it as is then for now and we can explore a joint refactor of the ASR + audio class pipelines in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup
inputs = _inputs | ||
if in_sampling_rate != self.feature_extractor.sampling_rate: | ||
import torch | ||
from torchaudio import functional as F |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This adds a soft dep on torchaudio which is not necessary otherwise, no? Might be worth detecting if it's available and throwing a helpful error message?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also copied from
from torchaudio import functional as F |
Will update with an error message and propagate the changes here 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies @sgugger! To clarify, the changes in this PR are one-for-one copied from the input arguments in https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/automatic_speech_recognition.py Essentially, the PR allows users to input a dictionary of inputs to the pipeline. This aligns the pipeline with from datasets import load_dataset
librispeech = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
librispeech[0]["audio"] Output:
(the This PR enables the dict to be passed directly to the pipeline, in the same way that we do for the ASR pipeline and the pred_labels = pipe(librispeech[0]["audio"]) If there are any API decisions you feel require changing, I'd be happy to update these in the original code before propagating to this file. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanations!
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this | ||
pipeline do the resampling. The dict must be either be in the format `{"sampling_rate": int, | ||
"raw": np.array}`, or `{"sampling_rate": int, "array": np.array}`, where the key `"raw"` or | ||
`"array"` is used to denote the raw audio waveform. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Saw the raw key on the ASR pipeline was kept for backward compatibility. Do we really need to introduce it there?
I think what you're trying to do is already supported, but the sampling rate needs to be in the same dict as the array (both are needed to represent a single audio). That being said, the errors raised when misusing this feature could probably be largely improved (to guide users towards the correct form). |
_inputs = inputs.pop("array", None) | ||
in_sampling_rate = inputs.pop("sampling_rate") | ||
inputs = _inputs | ||
if in_sampling_rate != self.feature_extractor.sampling_rate: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure about that @Narsil? It is indeed the case that the ASR pipeline respects the sampling_rate
argument, but not in audio classification. Note that the resampling operation is new, there is currently no sampling rate check or operation performed. This PR adds it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh thanks for the ping, not sure how I missed notifications several times here. You're indeed correct, audio-classification
didn't have support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (just some fix on error message).
Sorry for the late review, I'm not sure how I missed those notifications.
Overall there might be room to refactor and abstract this for both pipelines so that we can easily reuse later, but it's good enough for now.
_inputs = inputs.pop("array", None) | ||
in_sampling_rate = inputs.pop("sampling_rate") | ||
inputs = _inputs | ||
if in_sampling_rate != self.feature_extractor.sampling_rate: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh thanks for the ping, not sure how I missed notifications several times here. You're indeed correct, audio-classification
didn't have support.
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this | ||
pipeline do the resampling. The dict must be either be in the format `{"sampling_rate": int, | ||
"raw": np.array}`, or `{"sampling_rate": int, "array": np.array}`, where the key `"raw"` or | ||
`"array"` is used to denote the raw audio waveform. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, sorry for this.
Usually I'm kind of against accepting widely different types, this case is different since it's about our ecosystem and making datasets
+ pipeline
work together nicer.
Co-authored-by: Sylvain <[email protected]>
Co-authored-by: Nicolas Patry <[email protected]>
4b8c7cd
to
fc5fa01
Compare
What does this PR do?
Allow dictionary inputs for the audio classification pipeline
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.