- Current Pre-Print accepted at Interspeech 2024 available on arXiv
- Code for the old 2023 preprint is in the 2023-preprint branch
- As repo is w.i.p if you cannot figure out how to use anything please feel free to contact me by creating an issue!
- Requires Pytorch 2.0 or greater
- For best performance, install Flash Attention 2.0 https://github.com/Dao-AILab/flash-attention and fused_dense_lib for fused MLP layers from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/fused_dense_lib
- Apex is used for fused rms/layer norm (and fused Adam if not using madgrad) - if it's not installed we will default to pytorch !
- clone and run pip install ./
- Here is a colab that outlines basic installation and demonstrates how to load and transcriber with a pretrained model!
For training models you must request access to receive the spotify training data which can be done via the following:link (unfortunatly spotify are no longer maintaining this dataset)- For training models this code will work with any set of data where you have unsegmented (not segmented into utterances) precomputed spectrograms and corresponding transcriptions (with word level alignment). Word level alignement is needed to be able to chunk files into arbitrary sequence lengths.
- Earnings-22 and Tedlium datasets can be found via the following links: Earnings-22 Tedlium
Config files for all pretrained models are provided within the checkpoint file
All checkpoints from the paper are currently hosted on huggingface below is a table that provides links for each model type aswell as there configuration in performance. checkpoints for each sequence length and repeat are contained inside folders in each repository that is linked for example a model with 10s of context and repeat 1 (out of 3) would be in the folder: n_seq_sched_1024_rp_1. 1024 represents the spectrogram length, 1024/100 =10.24 seconds, rp_1 = repeat 1.
WERs in the table are given for 10s/2.7min/20min context lengths
All models from this table/paper use this model class. If anything is unclear let me know!
Download | D_model | Layers | Params (M) | Attn head dim | Epochs | Pos Enc | SpecAugment | Subsampling | Tedlium (WER) | Earnings-22 (WER) |
---|---|---|---|---|---|---|---|---|---|---|
here | 768 | 9 | ~120 | 128 | 1 | Rotary (\theta=1.5M) | No | 8X Depthwise 256D | 6.8/6.0/5.9 | 26.6/23.1/22.7 |
here | 768 | 6 | 90 | 128 | 1 | Rotary (\theta=1.5M) | No | 8X Depthwise 256D | 6.8/6.4/6.2 | 27.7/24.6/24.4 |
here | 768 | 6 | 90 | 64 | 1 | Rotary (\theta=1.5M) | No | 8X Depthwise 256D | ... | 27.5/24.8/24.4 |
here | 768 | 6 | 90 | 32 | 1 | Rotary (\theta=1.5M) | No | 8X Depthwise 256D | 6.8/6.4/6.4 | 26.7/24.6/24.8 |
see 2023-preprint branch | 768 | 6 | 90 | 128 | 1 | Rotary (\theta=10K) | No | 8X Depthwise 256D | ... | 27.2/24.9/25.0* |
here | 768 | 6 | 90 | 128 | 1 | Sine | No | 8X Depthwise 256D | 7.2/6.7/6.6 | 27.8/25.3/25.3 |
here | 768 | 6 | 90 | 128 | 1 | None | No | 8X Depthwise 256D | 7.7/6.8/6.6 | 27.5/25.3/25.2 |
here | 2048 | 3 | 315 | 128 | 1 | Rotary (\theta=1.5M) | No | 8X Depthwise 256D | ... | 28.7/26.1/26.1 |
here | 768 | 3 | ~50 | 128 | 1 | Rotary (\theta=1.5M) | No | 8X Depthwise 256D | 8.2/7.8/7.4 | 32.3/29.6/30.2 |
here | 256 | 12 | ~20 | 32 | 1 | Rotary (\theta=1.5M) | No | 8X Depthwise 256D | 7.6/6.9/6.9 | 28.6/26.3/26.4 |
here | 256 | 6 | ~10 | 32 | 1 | Rotary (\theta=1.5M) | No | 8X Depthwise 256D | 8.8/8.0/8.2 | 32.2/29.8/29.9 |
*@1hour