-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flax Examples] Seq2Seq ASR Fine-Tuning Script #21764
[Flax Examples] Seq2Seq ASR Fine-Tuning Script #21764
Conversation
The documentation is not available anymore as the PR was closed or merged. |
f3a2c44
to
13b6487
Compare
@sanchit-gandhi @andyehrenberg We have made a version of this script will support streaming and training on the TPU pods. The current version of the script is available here: We are however struggling with a bug at the moment. The script seems to work for training the Tiny models on multiple pod sizes. Both for scaling for speed and for increasing the batch size. All the other model sizes (small, base, medium, large) also works on the single TPU v4-8. However, training on the non-Tiny-model sizes runs for a few steps then freezes. If anyone have any idea about this could be happening, I really appreciate it. |
6470651
to
575c7fd
Compare
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
f6fc0fb
to
bca0ad6
Compare
Given the popularity of the PyTorch fine-tuning script and Whisper JAX, it's a pretty easy addition adding a Whisper fine-tuning script in JAX/Flax. Note: this is largely based off the distil-whisper training script, but simplified to run offline, with just 1 training dataset and the cross-entropy objective https://github.com/huggingface/distil-whisper#training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! Looks great 😉
* from seq2seq speech * [Flax] Example script for speech seq2seq * tests and fixes * make style * fix: label padding tokens * fix: label padding tokens over list * update ln names for Whisper * try datasets iter loader * create readme and append results * style * make style * adjust lr * use pt dataloader * make fast * pin gen max len * finish * add pt to requirements for test * fix pt -> torch * add accelerate
* from seq2seq speech * [Flax] Example script for speech seq2seq * tests and fixes * make style * fix: label padding tokens * fix: label padding tokens over list * update ln names for Whisper * try datasets iter loader * create readme and append results * style * make style * adjust lr * use pt dataloader * make fast * pin gen max len * finish * add pt to requirements for test * fix pt -> torch * add accelerate
What does this PR do?
Can be used to fine-tune Flax Whisper for speech recognition.
Tested and verified as working with the following (dummy) config:
Will add a README with preliminary training configs / results later this week after doing a full fine-tuning run.
cc @peregilk @andyehrenberg for interest