Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ASR Pipeline is not super user-friendly #20414

Closed
sanchit-gandhi opened this issue Nov 23, 2022 · 11 comments
Closed

ASR Pipeline is not super user-friendly #20414

sanchit-gandhi opened this issue Nov 23, 2022 · 11 comments

Comments

@sanchit-gandhi
Copy link
Contributor

Feature request

Firstly, thank you to @Narsil for developing a the speech recognition pipeline - it's incredibly helpful for running the full speech-to-text mapping in one call, pre and post-processing included.

There are a couple of things that currently mean the pipeline is not super compatible with 🤗 Datasets. I'll motivate them below with an example.

Motivation

Let's take the example of evaluating a (dummy) Wav2Vec2 checkpoint on the (dummy) LibriSpeech ASR dataset:

from transformers import pipeline
from datasets import load_dataset

pipe = pipeline("automatic-speech-recognition", model="hf-internal-testing/tiny-random-wav2vec2")
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:10]")

Printing the first audio sample of the dataset:

print(dataset[0]["audio"])

Print Output:

{'path': '/home/sanchit_huggingface_co/.cache/huggingface/datasets/downloads/extracted/0393f71a8093c6541f95c89f60982213cf086569876e1195926741f097ad47fc/dev_clean/1272/128104/1272-128104-0000.flac', 
'array': array([0.00238037, 0.0020752 , 0.00198364, ..., 0.00042725, 0.00057983,
       0.0010376 ], dtype=float32), 
'sampling_rate': 16000}

So the audio's are in the format: {"path": str, "array": np.array, "sampling_rate": int}. The np audio array values are stored under the key "array". This format is ubiquitous across audio datasets in 🤗 Datasets: all audio datasets take this format.

However, pipeline expects the audio samples in the format {"sampling_rate": int, "raw": np.array}:

- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "raw":
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to

This means we have to do some hacking around to get the audio samples into the right format for pipeline:

def predict(batch):
    audios = batch["audio"]
    # hacky renaming
    audios = [{"raw": sample["array"], "sampling_rate": sample["sampling_rate"]} for sample in audios]

    predictions = pipe(audios)

    # unpack and index predictions (List[Dict])
    batch["predictions"] = [pred["text"] for pred in predictions]
    return batch

And then apply the function to our dataset using the map method:

batch_size = 4

result_set = dataset.map(
    predict,
    batched=True,
    batch_size=batch_size,
    remove_columns=dataset.features.keys(),
)

If pipeline's __call__ method was matched to Datasets' audio features, we'd be able to use any audio dataset directly with pipeline (no hacky feature renaming):

def hypothetical_predict(batch):
    predictions = pipe(audios)
    batch["predictions"] = [pred["text"] for pred in predictions]
    return batch

This would be very nice for the user!

Furthermore, the outputs returned by pipeline are a list of dicts (List[Dict]):

return {"text": text, **optional, **extra}

This means we have to unpack and index them before we can use them for any downstream use (such as WER calculations).

It would be nice if pipeline returned a ModelOutput class. That way, we could index the text column directly from the returned object:

def hypothetical_predict(batch):
    batch["predictions"] = pipe(batch["audio"]).text
    return batch

IMO this is more intuitive to the user than renaming their audio column and then iterating over the returned Dict object to get the predicted text.

Your contribution

WDYT @Narsil @patrickvonplaten? Happy to add these changes to smooth out the user experience!

@sanchit-gandhi
Copy link
Contributor Author

sanchit-gandhi commented Nov 23, 2022

One additional point! We can't pass generation kwargs to the generate method:

tokens = self.model.generate(
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
attention_mask=attention_mask,
)

This means our stdout is bombarded with UserWarnings from the generate method:

/usr/local/lib/python3.7/dist-packages/transformers/generation_utils.py:1364: UserWarning: Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to 448 (`self.config.max_length`). Controlling `max_length` via the config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.

Would be nice to be able to override generation kwargs to prevent these messages and have flexibility over max length, beams, temperature, length penalty, etc

cc @Vaibhavs10

@sanchit-gandhi
Copy link
Contributor Author

sanchit-gandhi commented Nov 23, 2022

Just went through the code in more-detail and found that "array" is pop'd from the input dict!

_inputs = inputs.pop("raw", None)
if _inputs is None:
_inputs = inputs.pop("array", None)

Maybe we can add this to the docstring to highlight!

@Narsil
Copy link
Contributor

Narsil commented Nov 23, 2022

Multiple points:

However, pipeline expects the audio samples in the format

as far as I remember we can also accept array for that reason. (raw came before datasets had normalized iirc so that's the reason for the discrepancy, but since we don't break, neither is going to go away in pipeline I'm afraid.
The problem is not array it's audio. See more in the docs about KeyDataset (or the iterator which I think is more elegant, but it lacks the number of items) : https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.pipeline

It would be nice if pipeline returned a ModelOutput class. That way, we could index the text column directly from the returned object:

This is not going to happen for reasons I'll explain in following points

We can't pass generation kwargs to the generate method:

We can add it as a generate_kwargs but I think we wanted to change the configs instead of the affected model (which were not defined for whisper I think) @ArthurZucker . If max_length is the actual maximal capacity of the model, everything should be fine, no warnings no nothing.

We could also make the warning appear only once. @sgugger since reducing noise seems something desirable.

Being able to send generate_kwargs would still be nice. (Careful I'm meaning pipe(..., generate_kwargs={"max_new_tokens":20}) NOT pipe(...., max_new_tokens=20) the reason is because generate has clashed in the past with tokenizer kwargs for instance and it's impossible to distentangle after the fact. That's for passing generic kwargs (all of them through time and eternity), but we can definitely add some first class parameters (like max_new_tokens for instance).

Maybe we can add this to the docstring to highlight!

Totally !

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:10]")

I highly recommend NOT loading the entire array of the datasets in memory when working on datasets. That means NOT passing around lists, and not being able to batch with ModelOutputs.

That because objects are meant to be consumed one by one in an iterable fashion.
This is true for datasets, but also for webservers, you can have pretty much the same code, do dynamic batching and such crazy stuff and still keep the code the same for instance.
This is not relevant for dataset.map since it does the slicing and batching on its own, but it is relevant when pipe.preprocess can leverage the streaming mode to compute multiple things at once.

Using generator and streams is much more efficient (and the pipeline will actually do the batching too, passing around lists to the pipeline will NOT batch things. ( More on batching in pipelines : https://huggingface.co/docs/transformers/main_classes/pipelines#pipeline-batching)

Here is the recommendation from the docs: https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.pipeline (Still need to upgrade that part to make it to the tutorial).

Here is a gist of few several examples: https://gist.github.com/Narsil/4f5b088f4dd23200d16dd2cc575fdc16

Method 1 (pipe) 0:00:00.294485
Method 2 (dataset) 0:00:00.308238
Method 3 (raw file) 0:00:00.635527

The 5% speedup is pretty consistent on this smallish data.

Method 3 is slower, but because you don't need to decode the audio files within the dataset, this can save some disk space (at a compute cost). Keep in mind the num_workers=1 means the actual decompression of audio files happens in a different thread (and even process since we're relying on ffmpeg for it).

I tried actually batching inputs, but it seems it's detrimental in this case (just add , batch_size=2 during pipeline initialization).
Method 1 is 20% faster than method 2 with actual batching, but 50% slower than without :

https://huggingface.co/docs/transformers/main_classes/pipelines#pipeline-batching for more info on why batching can hurt.

I had to add a "warmup" to do fair comparisons, it seems dataset is decompressing the flies on first access (it's my best guess) and it seems to do it slower than the raw pipeline (it's because of the threading and because librosa is actually slower that raw ffmpeg, I think, at least I remember it was "slow" to decompress).

Happy to discuss further how to make the integration easier. I should mention that KeyDataset is probably the nicest to use as it should keep the length, it's just one weird import away

from transformers.pipelines.pt_utils import KeyDataset

...
for out in pipe(KeyDataset(dataset, "audio")):
    pass

It has the same performance as method1 but plays better with tqdm. It's just less flexible imo.

@sanchit-gandhi
Copy link
Contributor Author

sanchit-gandhi commented Nov 24, 2022

Thanks for the super in-depth explanation, @Narsil! Incredibly helpful and much appreciated 🤗

Maybe I'm missing the point a bit with why pipelines exist - are they geared more towards maximising performance for inference (or at least giving you the option to)? Rather than just being a nice wrapper around the feature extractor, model and tokenizer?

Sounds good regarding:

  1. Updating the doc string to reflect the fact that we can pass array as well as raw as the keys for audio input
  2. Passing the gen kwargs as a specified dict to the generate method

Thanks for explaining why ModelOutputs is not viable! It makes sense using a generator and streams, rather than throwing a list into pipe.

(Still need to upgrade that part to make it to the tutorial).

Is there a tutorial that's been published or a WIP? That'll be super handy!

Here is a gist of few several examples: https://gist.github.com/Narsil/4f5b088f4dd23200d16dd2cc575fdc16

Super comprehensive, thanks for these benchmarks! Interesting to see how .map compares to the generator method!

I should mention that KeyDataset is probably the nicest to use as it should keep the length, it's just one weird import away

Thanks for flagging this! I had a follow-up question - are there docs / examples for using pipe when loading a dataset in streaming mode? Here, we can't use KeyDataset (as we can't index a streamed dataset):

return self.dataset[i][self.key]

Is the best option just to go for a generator here?

def data():
    for i, sample in enumerate(dataset):
        yield sample["audio"]

output = []
for out in pipe(data(), batch_size=2):
    output.append(out["text"])

With this generator method, we currently yield the audio samples which we pass to the pipe. Is there a way of iterating over the streaming dataset to get the target transcriptions (sample["text"]) as well? Here, we would not need to pass the target text to the pipe, but simply return it in the generator. Ideally, want the target transcriptions sample["text"] so that we can assess our predictions.

(this is the actual example I'm working with: https://github.com/sanchit-gandhi/codesnippets/blob/main/benchmark_inference_whisper.ipynb)

@Narsil
Copy link
Contributor

Narsil commented Nov 24, 2022

Thanks for the super in-depth explanation, @Narsil! Incredibly helpful and much appreciated hugs

Well you initial issue was also pretty comprehensive, so thanks for creating it.

Maybe I'm missing the point a bit with why pipelines exist - are they geared more towards maximising performance for inference (or at least giving you the option to)? Rather than just being a nice wrapper around the feature extractor, model and tokenizer?

Pipeline started without any real guidelines into what they should or should not do.
Currently the first and foremost goal is to make ML accessible for users who have no clue what is a model or tensors, it's the primary target.
That being said, being efficient for inference goes along since we don't want to provide a 10x slowdown experience for those users.
It's not the primary focus though, otherwise it would not be written in python, and it would not be that convenient :).

Let's say there are 2 kinds of performance:

  • Don't do useless work (Remove crap code, or code which is not really useful, or work that's discarded, useless copies etc..)
  • Actual performance by making sure every inch of your hardware is properly used at the appropriate time. (Read understanding CPU instructions, looking a SIMD, optimizing threading layout, maximizing L1 cache hits, minimizing branching predictions, using custom GPU kernels, etc..)

We're only doing the first kind here. (Maybe a little of 2 for the GPU feeding that needs to be as fast as possible because CPU-GPU is a bottleneck really quick otherwise)

Is there a tutorial that's been published or a WIP? That'll be super handy!

There this tutorial https://huggingface.co/docs/transformers/pipeline_tutorial which I find less comprehensive than this https://huggingface.co/docs/transformers/main_classes/pipelines unfortunatly.

I'm in the process of rewriting it, as it seems most people read only that. And you're not the first person to not be aware of those cool features, so I'd say it's a doc problem.

Super comprehensive, thanks for these benchmarks! Interesting to see how .map compares to the generator method!

Can't tell you why there is a difference, but I can tell you I went to great length to optimize everything I could in the pipeline directly. (Only the first kind of optimization, and it's still written in Python so far from perfect but hey ... :) )

With this generator method, we currently yield the audio samples which we pass to the pipe. Is there a way of iterating over the streaming dataset to get the target transcriptions (sample["text"]) as well?

Actually if you pass along other keys in your data, they should be passed along all the way to the result with the asr pipeline.
I would like to be the case for all pipelines, but never got down to doing it.
But since it is streaming, yes you need to pass things around since otherwise it's tricky to start matching results with inputs at the end.

def data():
   for item in streaming_data:
       yield {**item["audio"], "expected": item["text"]}
       
for out in pipe(data()):
    generated = out["text"]
    expected = out["expected"]
    # Do you WER thing.

Would that work ? (I haven't tested this)

If it wasn't you could do
Something like that might be a useful hack though (Provided you're running in a single thread for the server looping).

GLOBAL_INDEX = {}

def data():
   for i, item in enumerate(streaming_data):
       GLOBAL_INDEX[i] = item["text"]
       yield item
       
for i, out in enumerate(pipe(data())):
    generated = out["text"]
    expected = GLOBAL_INDEX.pop(i) # Pop will remove it enabling releasing memory
    # Do you WER thing.

@sanchit-gandhi
Copy link
Contributor Author

sanchit-gandhi commented Nov 25, 2022

Thank you again for the super comprehensive reply, really appreciate the time given to answering this thread!

make ML accessible for users who have no clue what is a model or tensors

Awesome! Think it's fantastic in this regard. Having some easy examples that show you how to run pipeline in different scenarios / tasks like a little 'recipe' book would be great to further this.

otherwise it would not be written in python, and it would not be that convenient :)

Did someone say Rust 👀

Thanks for linking the tutorials - I learnt quite a lot from this thread + docs after knowing where to look. I guess you have two camps of people that will be using pipeline:

  1. Those migrating from the transformers approach (feature extractor + model + processor)
  2. Those who don't use transformers

For me, it was making the link between my transformers approach and pipeline that made the penny drop. There's a bit of a different mindset which you have to adopt vs the usual datasets .map method. I think some more examples showing how to make actual transformers tasks work in pipeline would go a long way! In this regard, your updated tutorial looks amazing (doing exactly this)! Happy to do a pass of the PR when it's in a review ready state!

Would that work ? (I haven't tested this)

It did indeed work, thanks 🙌

@patrickvonplaten
Copy link
Contributor

I think we should definitely try to avoid by default displaying warnings when running the ASRPipeline.
Also, since Whisper is a Encoder-Decoder model architecture the main use case for speech recognition might soon switch from Wav2Vec2CTC to Encoder-Decoder => thus we should also try to adapt the ASR pipeline into this direction.

Short term:
Let's try to not display any warnings by default & I agree with @sanchit-gandhi - it'd also be nice to pipelines to be directly used in combination with datasets. Could we maybe adapt the pipeline API from:

{"sampling_rate": int, "raw": np.array}

to

{"sampling_rate": int, "raw": Optional[np.array], "array": Optional[np.array]}

to just allow both use cases? What is the big drawback of this?

Mid/Long term
As discussed with @sgugger and @sanchit-gandhi already a bit, I think we should really think about creating a new generate method just for audio. The current generate method is a) too bloated and b) just not adapted for speech recognition. Chunking, long-range audio recognition, streamed audio recognition are much more of a required use case for speech recognition then for NLP. Also we could design the new generate method to be future compatible with models like the Transducer.

This would then also render the ASR pipeline much easier IMO.

@Narsil
Copy link
Contributor

Narsil commented Nov 30, 2022

What is the big drawback of this?

This is already done, it's a doc issue. And specifically for sanchit, datasets are using {"audio" : {"sampling_rate": .., "audio": ..}} instead of the inner dict.

The current generate method is a) too bloated and b) just not adapted for speech recognition.

True, I have potential suggestions for it, which mainly are going full on Processor/StoppingCriteria route. This is what was necessary to enable complex batching within bloom inference.
Splitting specifically for audio might be necessary but I am under the impression it's only a matter of defaults for those objects.

@patrickvonplaten
Copy link
Contributor

Maybe a bigger discussion, but could it make sense to move some more complicated tasks such as real-time speech recognition to something like: https://github.com/huggingface/speechbox ?

@flozi00
Copy link
Contributor

flozi00 commented Dec 30, 2022

For cases like realtime ASR more optimized methods, for example as rust modules, would be super cool.
Maybe with functionality for community pipelines as in diffusers, just for speech ?

@huggingface huggingface deleted a comment from github-actions bot Jan 16, 2023
@huggingface huggingface deleted a comment from github-actions bot Feb 10, 2023
@huggingface huggingface deleted a comment from github-actions bot Mar 17, 2023
@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants