Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow dict input for audio classification pipeline #23445

Merged
merged 9 commits into from
Jun 23, 2023

Conversation

sanchit-gandhi
Copy link
Contributor

What does this PR do?

Allow dictionary inputs for the audio classification pipeline

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 18, 2023

The documentation is not available anymore as the PR was closed or merged.

@sanchit-gandhi sanchit-gandhi force-pushed the audio-class-pipeline branch from 3eb2035 to ae50bbb Compare May 31, 2023 14:16
@sanchit-gandhi sanchit-gandhi requested review from Narsil and sgugger June 1, 2023 15:36
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR. Since you didn't put a description, I don't really understand what this tries to solve while making the API more complex.


_inputs = inputs.pop("raw", None)
if _inputs is None:
# Remove path which will not be used from `datasets`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea what this means.

Comment on lines +121 to +124
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
pipeline do the resampling. The dict must be either be in the format `{"sampling_rate": int,
"raw": np.array}`, or `{"sampling_rate": int, "array": np.array}`, where the key `"raw"` or
`"array"` is used to denote the raw audio waveform.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why accept both raw and array? Seems very brittle as an API. Also why add a new argument type instead of just accepting a new sampling_rate argument?

Copy link
Contributor Author

@sanchit-gandhi sanchit-gandhi Jun 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied one-for-one from:

inputs (`np.ndarray` or `bytes` or `str` or `dict`):

Originally, the ASR pipeline only accepted they raw for the input waveform, but this was updated to accept both raw and array to bring the pipeline into alignment with datasets, where the 1-d audio arrays go under the dict key array (see comment below and motivations for this consistency in #20414 (comment))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Saw the raw key on the ASR pipeline was kept for backward compatibility. Do we really need to introduce it there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed! I'm agnostic - I kept it in to have consistency across the pipeline classes (e.g. if a user typically passes the raw key in the ASR pipeline, then they would expect it to work for the audio classification pipeline), but can simplify it to just accept the array key if we don't mind about this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's see what @Narsil thinks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gently pinging @Narsil - would be nice to have this ready in transformers for the next release (unblocks huggingface/audio-transformers-course#25)

Copy link
Contributor

@Narsil Narsil Jun 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, sorry for this.

Usually I'm kind of against accepting widely different types, this case is different since it's about our ecosystem and making datasets + pipeline work together nicer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're all happy I'll keep it as is then for now and we can explore a joint refactor of the ASR + audio class pipelines in the future?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup

inputs = _inputs
if in_sampling_rate != self.feature_extractor.sampling_rate:
import torch
from torchaudio import functional as F
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds a soft dep on torchaudio which is not necessary otherwise, no? Might be worth detecting if it's available and throwing a helpful error message?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also copied from

from torchaudio import functional as F

Will update with an error message and propagate the changes here 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated for the ASR pipeline in #23953 and this PR in 06751d4

@sanchit-gandhi
Copy link
Contributor Author

sanchit-gandhi commented Jun 2, 2023

Apologies @sgugger! To clarify, the changes in this PR are one-for-one copied from the input arguments in https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/automatic_speech_recognition.py

Essentially, the PR allows users to input a dictionary of inputs to the pipeline. This aligns the pipeline with datasets, where the audio column returns a dict with array (the 1-d audio array) and sampling_rate (the sampling rate of the audio):

from datasets import load_dataset

librispeech = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
librispeech[0]["audio"]

Output:

{'path': '/Users/sanchitgandhi/.cache/huggingface/datasets/downloads/extracted/aad76e6f21870761d7a8b9b34436f6f8db846546c68cb2d9388598d7a164fa4b/dev_clean/1272/128104/1272-128104-0000.flac',
 'array': array([0.00238037, 0.0020752 , 0.00198364, ..., 0.00042725, 0.00057983,
        0.0010376 ]),
 'sampling_rate': 16000}

(the path column is deprecated an no longer required, but retained for backwards compatibility. This is what removing path refers to in the PR)

This PR enables the dict to be passed directly to the pipeline, in the same way that we do for the ASR pipeline and the transformers feature extractors:

pred_labels = pipe(librispeech[0]["audio"])

If there are any API decisions you feel require changing, I'd be happy to update these in the original code before propagating to this file.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanations!

Comment on lines +121 to +124
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
pipeline do the resampling. The dict must be either be in the format `{"sampling_rate": int,
"raw": np.array}`, or `{"sampling_rate": int, "array": np.array}`, where the key `"raw"` or
`"array"` is used to denote the raw audio waveform.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Saw the raw key on the ASR pipeline was kept for backward compatibility. Do we really need to introduce it there?

@Narsil
Copy link
Contributor

Narsil commented Jun 7, 2023

I think what you're trying to do is already supported, but the sampling rate needs to be in the same dict as the array (both are needed to represent a single audio).

That being said, the errors raised when misusing this feature could probably be largely improved (to guide users towards the correct form).

_inputs = inputs.pop("array", None)
in_sampling_rate = inputs.pop("sampling_rate")
inputs = _inputs
if in_sampling_rate != self.feature_extractor.sampling_rate:
Copy link
Contributor Author

@sanchit-gandhi sanchit-gandhi Jun 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure about that @Narsil? It is indeed the case that the ASR pipeline respects the sampling_rate argument, but not in audio classification. Note that the resampling operation is new, there is currently no sampling rate check or operation performed. This PR adds it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh thanks for the ping, not sure how I missed notifications several times here. You're indeed correct, audio-classification didn't have support.

Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (just some fix on error message).

Sorry for the late review, I'm not sure how I missed those notifications.

Overall there might be room to refactor and abstract this for both pipelines so that we can easily reuse later, but it's good enough for now.

_inputs = inputs.pop("array", None)
in_sampling_rate = inputs.pop("sampling_rate")
inputs = _inputs
if in_sampling_rate != self.feature_extractor.sampling_rate:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh thanks for the ping, not sure how I missed notifications several times here. You're indeed correct, audio-classification didn't have support.

src/transformers/pipelines/audio_classification.py Outdated Show resolved Hide resolved
Comment on lines +121 to +124
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
pipeline do the resampling. The dict must be either be in the format `{"sampling_rate": int,
"raw": np.array}`, or `{"sampling_rate": int, "array": np.array}`, where the key `"raw"` or
`"array"` is used to denote the raw audio waveform.
Copy link
Contributor

@Narsil Narsil Jun 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, sorry for this.

Usually I'm kind of against accepting widely different types, this case is different since it's about our ecosystem and making datasets + pipeline work together nicer.

@sanchit-gandhi sanchit-gandhi merged commit 8767958 into huggingface:main Jun 23, 2023
@sanchit-gandhi sanchit-gandhi deleted the audio-class-pipeline branch June 23, 2023 12:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants