Skip to content

Commit

Permalink
Allow dict input for audio classification pipeline (#23445)
Browse files Browse the repository at this point in the history
* Allow dict input for audio classification pipeline

* make style

* Empty commit to trigger CI

* Empty commit to trigger CI

* check for torchaudio

* add pip instructions

Co-authored-by: Sylvain <[email protected]>

* Update src/transformers/pipelines/audio_classification.py

Co-authored-by: Nicolas Patry <[email protected]>

* asr -> audio class

* asr -> audio class

---------

Co-authored-by: Sylvain <[email protected]>
Co-authored-by: Nicolas Patry <[email protected]>
  • Loading branch information
3 people authored Jun 23, 2023
1 parent a6f37f8 commit 8767958
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 8 deletions.
54 changes: 46 additions & 8 deletions src/transformers/pipelines/audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import requests

from ..utils import add_end_docstrings, is_torch_available, logging
from ..utils import add_end_docstrings, is_torch_available, is_torchaudio_available, logging
from .base import PIPELINE_INIT_ARGS, Pipeline


Expand Down Expand Up @@ -110,12 +110,18 @@ def __call__(
information.
Args:
inputs (`np.ndarray` or `bytes` or `str`):
The inputs is either a raw waveform (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
at the correct sampling rate (no further check will be done) or a `str` that is the filename of the
audio file, the file will be read at the correct sampling rate to get the waveform using *ffmpeg*. This
requires *ffmpeg* to be installed on the system. If *inputs* is `bytes` it is supposed to be the
content of an audio file and is interpreted by *ffmpeg* in the same way.
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
The inputs is either :
- `str` that is the filename of the audio file, the file will be read at the correct sampling rate
to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
- `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
same way.
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
Raw audio at the correct sampling rate (no further check will be done)
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
pipeline do the resampling. The dict must be either be in the format `{"sampling_rate": int,
"raw": np.array}`, or `{"sampling_rate": int, "array": np.array}`, where the key `"raw"` or
`"array"` is used to denote the raw audio waveform.
top_k (`int`, *optional*, defaults to None):
The number of top labels that will be returned by the pipeline. If the provided number is `None` or
higher than the number of labels available in the model configuration, it will default to the number of
Expand Down Expand Up @@ -151,10 +157,42 @@ def preprocess(self, inputs):
if isinstance(inputs, bytes):
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)

if isinstance(inputs, dict):
# Accepting `"array"` which is the key defined in `datasets` for
# better integration
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
raise ValueError(
"When passing a dictionary to AudioClassificationPipeline, the dict needs to contain a "
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
"containing the sampling_rate associated with that array"
)

_inputs = inputs.pop("raw", None)
if _inputs is None:
# Remove path which will not be used from `datasets`.
inputs.pop("path", None)
_inputs = inputs.pop("array", None)
in_sampling_rate = inputs.pop("sampling_rate")
inputs = _inputs
if in_sampling_rate != self.feature_extractor.sampling_rate:
import torch

if is_torchaudio_available():
from torchaudio import functional as F
else:
raise ImportError(
"torchaudio is required to resample audio samples in AudioClassificationPipeline. "
"The torchaudio package can be installed through: `pip install torchaudio`."
)

inputs = F.resample(
torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
).numpy()

if not isinstance(inputs, np.ndarray):
raise ValueError("We expect a numpy ndarray as input")
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
raise ValueError("We expect a single channel audio input for AudioClassificationPipeline")

processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
Expand Down
4 changes: 4 additions & 0 deletions tests/pipelines/test_pipelines_audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def test_small_model_pt(self):
]
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])

audio_dict = {"array": np.ones((8000,)), "sampling_rate": audio_classifier.feature_extractor.sampling_rate}
output = audio_classifier(audio_dict, top_k=4)
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])

@require_torch
@slow
def test_large_model_pt(self):
Expand Down

0 comments on commit 8767958

Please sign in to comment.