diff --git a/src/transformers/pipelines/audio_classification.py b/src/transformers/pipelines/audio_classification.py index 7f104e74e12824..8ac2a4a554e363 100644 --- a/src/transformers/pipelines/audio_classification.py +++ b/src/transformers/pipelines/audio_classification.py @@ -17,7 +17,7 @@ import numpy as np import requests -from ..utils import add_end_docstrings, is_torch_available, logging +from ..utils import add_end_docstrings, is_torch_available, is_torchaudio_available, logging from .base import PIPELINE_INIT_ARGS, Pipeline @@ -110,12 +110,18 @@ def __call__( information. Args: - inputs (`np.ndarray` or `bytes` or `str`): - The inputs is either a raw waveform (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`) - at the correct sampling rate (no further check will be done) or a `str` that is the filename of the - audio file, the file will be read at the correct sampling rate to get the waveform using *ffmpeg*. This - requires *ffmpeg* to be installed on the system. If *inputs* is `bytes` it is supposed to be the - content of an audio file and is interpreted by *ffmpeg* in the same way. + inputs (`np.ndarray` or `bytes` or `str` or `dict`): + The inputs is either : + - `str` that is the filename of the audio file, the file will be read at the correct sampling rate + to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system. + - `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the + same way. + - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`) + Raw audio at the correct sampling rate (no further check will be done) + - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this + pipeline do the resampling. The dict must be either be in the format `{"sampling_rate": int, + "raw": np.array}`, or `{"sampling_rate": int, "array": np.array}`, where the key `"raw"` or + `"array"` is used to denote the raw audio waveform. top_k (`int`, *optional*, defaults to None): The number of top labels that will be returned by the pipeline. If the provided number is `None` or higher than the number of labels available in the model configuration, it will default to the number of @@ -151,10 +157,42 @@ def preprocess(self, inputs): if isinstance(inputs, bytes): inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate) + if isinstance(inputs, dict): + # Accepting `"array"` which is the key defined in `datasets` for + # better integration + if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)): + raise ValueError( + "When passing a dictionary to AudioClassificationPipeline, the dict needs to contain a " + '"raw" key containing the numpy array representing the audio and a "sampling_rate" key, ' + "containing the sampling_rate associated with that array" + ) + + _inputs = inputs.pop("raw", None) + if _inputs is None: + # Remove path which will not be used from `datasets`. + inputs.pop("path", None) + _inputs = inputs.pop("array", None) + in_sampling_rate = inputs.pop("sampling_rate") + inputs = _inputs + if in_sampling_rate != self.feature_extractor.sampling_rate: + import torch + + if is_torchaudio_available(): + from torchaudio import functional as F + else: + raise ImportError( + "torchaudio is required to resample audio samples in AudioClassificationPipeline. " + "The torchaudio package can be installed through: `pip install torchaudio`." + ) + + inputs = F.resample( + torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate + ).numpy() + if not isinstance(inputs, np.ndarray): raise ValueError("We expect a numpy ndarray as input") if len(inputs.shape) != 1: - raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline") + raise ValueError("We expect a single channel audio input for AudioClassificationPipeline") processed = self.feature_extractor( inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" diff --git a/tests/pipelines/test_pipelines_audio_classification.py b/tests/pipelines/test_pipelines_audio_classification.py index 8f2e46e0a50bb5..48c39ff663fbe8 100644 --- a/tests/pipelines/test_pipelines_audio_classification.py +++ b/tests/pipelines/test_pipelines_audio_classification.py @@ -103,6 +103,10 @@ def test_small_model_pt(self): ] self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2]) + audio_dict = {"array": np.ones((8000,)), "sampling_rate": audio_classifier.feature_extractor.sampling_rate} + output = audio_classifier(audio_dict, top_k=4) + self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2]) + @require_torch @slow def test_large_model_pt(self):