Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax Examples] Seq2Seq ASR Fine-Tuning Script #21764

Merged
merged 19 commits into from
Sep 29, 2023

Conversation

sanchit-gandhi
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi commented Feb 23, 2023

What does this PR do?

Can be used to fine-tune Flax Whisper for speech recognition.

Tested and verified as working with the following (dummy) config:

run_flax_speech_recognition_seq2seq.py \
            --model_name_or_path openai/whisper-tiny.en \
            --dataset_name hf-internal-testing/librispeech_asr_dummy \
            --dataset_config clean \
            --train_split_name validation \
            --eval_split_name validation \
            --output_dir whisper-tiny-ft-dummy \
            --overwrite_output_dir \
            --num_train_epochs=2 \
            --max_train_samples 10 \
            --max_eval_samples 10 \
            --warmup_steps=8 \
            --do_train \
            --do_eval \
            --learning_rate=2e-4 \
            --per_device_train_batch_size=2 \
            --per_device_eval_batch_size=1 \
            --predict_with_generate

Will add a README with preliminary training configs / results later this week after doing a full fine-tuning run.

cc @peregilk @andyehrenberg for interest

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 23, 2023

The documentation is not available anymore as the PR was closed or merged.

@peregilk
Copy link
Contributor

peregilk commented Apr 3, 2023

@sanchit-gandhi @andyehrenberg

We have made a version of this script will support streaming and training on the TPU pods.

The current version of the script is available here:
https://github.com/NbAiLab/nb-whisper/blob/main/run_flax_speech_recognition_seq2seq_streaming.py

We are however struggling with a bug at the moment. The script seems to work for training the Tiny models on multiple pod sizes. Both for scaling for speed and for increasing the batch size. All the other model sizes (small, base, medium, large) also works on the single TPU v4-8. However, training on the non-Tiny-model sizes runs for a few steps then freezes.

If anyone have any idea about this could be happening, I really appreciate it.

@huggingface huggingface deleted a comment from github-actions bot May 15, 2023
@sanchit-gandhi sanchit-gandhi mentioned this pull request Jun 12, 2023
4 tasks
@huggingface huggingface deleted a comment from github-actions bot Jun 12, 2023
@huggingface huggingface deleted a comment from github-actions bot Jun 12, 2023
@github-actions github-actions bot closed this Jul 15, 2023
@huggingface huggingface deleted a comment from github-actions bot Jul 28, 2023
@sanchit-gandhi sanchit-gandhi marked this pull request as ready for review August 11, 2023 11:08
@github-actions
Copy link

github-actions bot commented Sep 5, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Sep 13, 2023
@sanchit-gandhi
Copy link
Contributor Author

Given the popularity of the PyTorch fine-tuning script and Whisper JAX, it's a pretty easy addition adding a Whisper fine-tuning script in JAX/Flax.

Note: this is largely based off the distil-whisper training script, but simplified to run offline, with just 1 training dataset and the cross-entropy objective https://github.com/huggingface/distil-whisper#training

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! Looks great 😉

@sanchit-gandhi sanchit-gandhi merged commit 68e85fc into huggingface:main Sep 29, 2023
3 checks passed
blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
* from seq2seq speech

* [Flax] Example script for speech seq2seq

* tests and fixes

* make style

* fix: label padding tokens

* fix: label padding tokens over list

* update ln names for Whisper

* try datasets iter loader

* create readme and append results

* style

* make style

* adjust lr

* use pt dataloader

* make fast

* pin gen max len

* finish

* add pt to requirements for test

* fix pt -> torch

* add accelerate
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 18, 2023
* from seq2seq speech

* [Flax] Example script for speech seq2seq

* tests and fixes

* make style

* fix: label padding tokens

* fix: label padding tokens over list

* update ln names for Whisper

* try datasets iter loader

* create readme and append results

* style

* make style

* adjust lr

* use pt dataloader

* make fast

* pin gen max len

* finish

* add pt to requirements for test

* fix pt -> torch

* add accelerate
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants