diff --git a/README.md b/README.md
index 0077d99..810d1a2 100644
--- a/README.md
+++ b/README.md
@@ -32,17 +32,17 @@ pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit
## Pipeline Usage
-The recommended way of running Whisper JAX is through the [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) abstraction class. This class handles all
+The recommended way of running Whisper JAX is through the [`FlaxWhisperPipeline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) abstraction class. This class handles all
the necessary pre- and post-processing, as well as wrapping the generate method for data parallelism across accelerator devices.
Whisper JAX makes use of JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) function for data parallelism across GPU/TPU devices. This function is _Just In Time (JIT)_
compiled the first time it is called. Thereafter, the function will be _cached_, enabling it to be run in super-fast time:
```python
-from whisper_jax import FlaxWhisperPipline
+from whisper_jax import FlaxWhisperPipeline
# instantiate pipeline
-pipeline = FlaxWhisperPipline("openai/whisper-large-v2")
+pipeline = FlaxWhisperPipeline("openai/whisper-large-v2")
# JIT compile the forward call - slow, but we only do once
text = pipeline("audio.mp3")
@@ -59,11 +59,11 @@ of the model weights.
For most GPUs, the dtype should be set to `jnp.float16`. For A100 GPUs or TPUs, the dtype should be set to `jnp.bfloat16`:
```python
-from whisper_jax import FlaxWhisperPipline
+from whisper_jax import FlaxWhisperPipeline
import jax.numpy as jnp
# instantiate pipeline in bfloat16
-pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16)
+pipeline = FlaxWhisperPipeline("openai/whisper-large-v2", dtype=jnp.bfloat16)
```
### Batching
@@ -75,10 +75,10 @@ provides a 10x speed-up compared to transcribing the audio samples sequentially,
To enable batching, pass the `batch_size` parameter when you instantiate the pipeline:
```python
-from whisper_jax import FlaxWhisperPipline
+from whisper_jax import FlaxWhisperPipeline
# instantiate pipeline with batching
-pipeline = FlaxWhisperPipline("openai/whisper-large-v2", batch_size=16)
+pipeline = FlaxWhisperPipeline("openai/whisper-large-v2", batch_size=16)
```
### Task
@@ -93,7 +93,7 @@ text = pipeline("audio.mp3", task="translate")
### Timestamps
-The [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) also supports timestamp prediction. Note that enabling timestamps will require a second JIT compilation of the
+The [`FlaxWhisperPipeline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) also supports timestamp prediction. Note that enabling timestamps will require a second JIT compilation of the
forward call, this time including the timestamp outputs:
```python
@@ -108,11 +108,11 @@ In the following code snippet, we instantiate the model in bfloat16 precision wi
returning timestamps tokens:
```python
-from whisper_jax import FlaxWhisperPipline
+from whisper_jax import FlaxWhisperPipeline
import jax.numpy as jnp
# instantiate pipeline with bfloat16 and enable batching
-pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16)
+pipeline = FlaxWhisperPipeline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16)
# transcribe and return timestamps
outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)
@@ -188,7 +188,7 @@ the next time they are required. Note that converting weights from PyTorch to Fl
For example, to convert the fine-tuned checkpoint [`sanchit-gandhi/whisper-small-hi`](https://huggingface.co/sanchit-gandhi/whisper-small-hi) from the blog post [Fine-Tuning Whisper](https://huggingface.co/blog/fine-tune-whisper):
```python
-from whisper_jax import FlaxWhisperForConditionalGeneration, FlaxWhisperPipline
+from whisper_jax import FlaxWhisperForConditionalGeneration, FlaxWhisperPipeline
import jax.numpy as jnp
checkpoint_id = "sanchit-gandhi/whisper-small-hi"
@@ -198,7 +198,7 @@ model = FlaxWhisperForConditionalGeneration.from_pretrained(checkpoint_id, from_
model.push_to_hub(checkpoint_id)
# now we can load the Flax weights directly as required
-pipeline = FlaxWhisperPipline(checkpoint_id, dtype=jnp.bfloat16, batch_size=16)
+pipeline = FlaxWhisperPipeline(checkpoint_id, dtype=jnp.bfloat16, batch_size=16)
```
## Advanced Usage
@@ -212,7 +212,7 @@ The following code snippet demonstrates how data parallelism can be achieved usi
an entirely equivalent way to `pmap`:
```python
-from whisper_jax import FlaxWhisperPipline
+from whisper_jax import FlaxWhisperPipeline
import jax.numpy as jnp
# 2D parameter and activation partitioning for DP
@@ -230,7 +230,7 @@ logical_axis_rules_dp = (
("channels", None),
)
-pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16)
+pipeline = FlaxWhisperPipeline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16)
pipeline.shard_params(num_mp_partitions=1, logical_axis_rules=logical_axis_rules_dp)
```
diff --git a/app/app.py b/app/app.py
index dc762d5..d2ffde1 100644
--- a/app/app.py
+++ b/app/app.py
@@ -13,7 +13,7 @@
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
from transformers.pipelines.audio_utils import ffmpeg_read
-from whisper_jax import FlaxWhisperPipline
+from whisper_jax import FlaxWhisperPipeline
cc.initialize_cache("./jax_cache")
@@ -73,7 +73,7 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
if __name__ == "__main__":
- pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
+ pipeline = FlaxWhisperPipeline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
stride_length_s = CHUNK_LENGTH_S / 6
chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate)
diff --git a/whisper-jax-tpu.ipynb b/whisper-jax-tpu.ipynb
index 54db980..cba4339 100644
--- a/whisper-jax-tpu.ipynb
+++ b/whisper-jax-tpu.ipynb
@@ -1 +1 @@
-{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.8.16","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"}},"nbformat_minor":4,"nbformat":4,"cells":[{"source":"","metadata":{},"cell_type":"markdown"},{"cell_type":"markdown","source":"## Whisper JAX ⚡️\n\nThis Kaggle notebook demonstratese how to run Whisper JAX on a TPU v3-8. Whisper JAX is a highly optimised JAX implementation of the Whisper model by OpenAI, largely built on the 🤗 Hugging Face Transformers Whisper implementation. Compared to OpenAI's PyTorch code, Whisper JAX runs over **70x faster**, making it the fastest Whisper implementation available.\n\nThe Whisper JAX model is also running as a [demo](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax) on the Hugging Face Hub. You can find the code [here](https://github.com/sanchit-gandhi/whisper-jax).","metadata":{}},{"cell_type":"markdown","source":"## Let's get started!\n\nThe first thing we need to do is connect to a TPU. Kaggle offers 20 hours of TPU v3-8 usage per month for free, which we'll make use of for this notebook. Refer to the guide [Introducing TPUs to Kaggle](https://www.kaggle.com/product-feedback/129828) for more information on TPU quotas in Kaggle.\n\nYou will need to register a Kaggle account and verify your phone number if you haven't done so already. Once verified, open up the settings menu in the Notebook editor (the small arrow in the bottom right). Then under _Notebook options_, select ‘TPU VM v3-8’ from the _Accelerator_ menu. You will also need to toggle the internet switch so that it is set to \"on\".\n\nOnce we've got a TPU allocated (there might be a queue to get one!), we can run the following to see the TPU devices we have available:","metadata":{}},{"cell_type":"code","source":"import jax\njax.devices()","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:05:34.178195Z","iopub.execute_input":"2023-04-25T15:05:34.178923Z","iopub.status.idle":"2023-04-25T15:05:41.356736Z","shell.execute_reply.started":"2023-04-25T15:05:34.178889Z","shell.execute_reply":"2023-04-25T15:05:41.355848Z"},"trusted":true},"execution_count":1,"outputs":[{"execution_count":1,"output_type":"execute_result","data":{"text/plain":"[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"},"metadata":{}}]},{"cell_type":"markdown","source":"Cool! We've got 8 TPU devices packaged into one accelerator.\n\nKaggle TPUs come with JAX pre-installed, so we can directly install the remaining Python packages. If you're running the notebook on a Cloud TPU, ensure you have installed JAX according to the official [installation guide](https://github.com/google/jax#pip-installation-google-cloud-tpu). \n\nWe'll install [Whisper JAX](https://github.com/sanchit-gandhi/whisper-jax) from main, as well as `datasets`, `soundfile` and `librosa` for loading audio files:","metadata":{}},{"cell_type":"code","source":"!pip install --quiet --upgrade pip\n!pip install --quiet git+https://github.com/sanchit-gandhi/whisper-jax.git datasets soundfile librosa","metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","tags":[],"execution":{"iopub.status.busy":"2023-04-25T15:05:41.358251Z","iopub.execute_input":"2023-04-25T15:05:41.358603Z","iopub.status.idle":"2023-04-25T15:06:26.680298Z","shell.execute_reply.started":"2023-04-25T15:05:41.358577Z","shell.execute_reply":"2023-04-25T15:06:26.678988Z"},"trusted":true},"execution_count":2,"outputs":[{"name":"stdout","text":"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n\u001b[0m","output_type":"stream"}]},{"cell_type":"markdown","source":"## Loading the Pipeline\n\nThe recommended way of running Whisper JAX is through the [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) class. This class handles all the necessary pre- and post-processing for the model, as well as wrapping the generate method for data parallelism across all available accelerator devices.\n\nWhisper JAX makes use of JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) function for data parallelism across GPU/TPU devices. This function is Just In Time (JIT) compiled the first time it is called. Thereafter, the function will be cached, enabling it to be run in super-fast time.\n\n\nLet's load the large-v2 model in bfloat16 (half-precision). Using half-precision will speed-up the computation quite considerably by storing intermediate tensors in half-precision. There is no change to the precision of the model weights.\n\nWe'll also make use of _batching_ for single audio inputs: the audio is first chunked into 30 second segments, and then chunks dispatched to the model to be transcribed in parallel. By batching an audio input and transcribing it in parallel, we get a ~10x speed-up compared to transcribing the audio samples sequentially.","metadata":{}},{"cell_type":"code","source":"from whisper_jax import FlaxWhisperPipline\nimport jax.numpy as jnp\n\npipeline = FlaxWhisperPipline(\"openai/whisper-large-v2\", dtype=jnp.bfloat16, batch_size=16)","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:06:26.681714Z","iopub.execute_input":"2023-04-25T15:06:26.682018Z","iopub.status.idle":"2023-04-25T15:07:53.135957Z","shell.execute_reply.started":"2023-04-25T15:06:26.681991Z","shell.execute_reply":"2023-04-25T15:07:53.134715Z"},"trusted":true},"execution_count":3,"outputs":[{"name":"stderr","text":"/usr/local/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\nDownloading (…)rocessor_config.json: 100%|██████████| 185k/185k [00:00<00:00, 5.29MB/s]\nDownloading (…)okenizer_config.json: 100%|██████████| 800/800 [00:00<00:00, 120kB/s]\nDownloading (…)olve/main/vocab.json: 100%|██████████| 836k/836k [00:00<00:00, 15.9MB/s]\nDownloading (…)/main/tokenizer.json: 100%|██████████| 2.20M/2.20M [00:00<00:00, 39.1MB/s]\nDownloading (…)olve/main/merges.txt: 100%|██████████| 494k/494k [00:00<00:00, 28.1MB/s]\nDownloading (…)main/normalizer.json: 100%|██████████| 52.7k/52.7k [00:00<00:00, 24.5MB/s]\nDownloading (…)in/added_tokens.json: 100%|██████████| 2.08k/2.08k [00:00<00:00, 1.11MB/s]\nDownloading (…)cial_tokens_map.json: 100%|██████████| 2.08k/2.08k [00:00<00:00, 1.22MB/s]\nDownloading (…)lve/main/config.json: 100%|██████████| 1.99k/1.99k [00:00<00:00, 320kB/s]\nDownloading flax_model.msgpack: 100%|██████████| 6.17G/6.17G [00:30<00:00, 203MB/s] \nDownloading (…)neration_config.json: 100%|██████████| 3.51k/3.51k [00:00<00:00, 1.99MB/s]\n","output_type":"stream"}]},{"cell_type":"markdown","source":"We'll then initialise a compilation cache, which will speed-up the compilation time if we close our kernel and want to compile the model again:","metadata":{}},{"cell_type":"code","source":"from jax.experimental.compilation_cache import compilation_cache as cc\n\ncc.initialize_cache(\"./jax_cache\")","metadata":{"tags":[],"execution":{"iopub.status.busy":"2023-04-25T15:07:53.139631Z","iopub.execute_input":"2023-04-25T15:07:53.139958Z","iopub.status.idle":"2023-04-25T15:07:53.146144Z","shell.execute_reply.started":"2023-04-25T15:07:53.139931Z","shell.execute_reply":"2023-04-25T15:07:53.145279Z"},"trusted":true},"execution_count":4,"outputs":[{"name":"stderr","text":"WARNING:jax.experimental.compilation_cache.compilation_cache:Initialized persistent compilation cache at ./jax_cache\n","output_type":"stream"}]},{"cell_type":"markdown","source":"## 🎶 Load an audio file\n\nLet's load up a long audio file for our tests. We provide 5 and 30 mins audio files created by contatenating consecutive sample of the [LibriSpeech ASR](https://huggingface.co/datasets/librispeech_asr) corpus, which we can load in one line through Hugging Face Datastes' [`load_dataset`](https://huggingface.co/docs/datasets/loading#load) function. Note that you can also pass in any `.mp3`, `.wav` or `.flac` audio file directly to the Whisper JAX pipeline, and it will take care of loading the audio file for you.","metadata":{}},{"cell_type":"code","source":"from datasets import load_dataset\n\ntest_dataset = load_dataset(\"sanchit-gandhi/whisper-jax-test-files\", split=\"train\")\naudio = test_dataset[0][\"audio\"] # load the first sample (5 mins) and get the audio array","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:07:53.147081Z","iopub.execute_input":"2023-04-25T15:07:53.147354Z","iopub.status.idle":"2023-04-25T15:08:17.685945Z","shell.execute_reply.started":"2023-04-25T15:07:53.147331Z","shell.execute_reply":"2023-04-25T15:08:17.684706Z"},"trusted":true},"execution_count":5,"outputs":[{"name":"stderr","text":"Downloading readme: 100%|██████████| 371/371 [00:00<00:00, 188kB/s]\n","output_type":"stream"},{"name":"stdout","text":"Downloading and preparing dataset None/None to /root/.cache/huggingface/datasets/sanchit-gandhi___parquet/sanchit-gandhi--whisper-jax-test-files-95479fe55e88baac/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...\n","output_type":"stream"},{"name":"stderr","text":"Downloading data files: 0%| | 0/1 [00:00, ?it/s]\nDownloading data: 0%| | 0.00/113M [00:00, ?B/s]\u001b[A\nDownloading data: 1%|▏ | 1.63M/113M [00:00<00:06, 16.3MB/s]\u001b[A\nDownloading data: 6%|▌ | 6.39M/113M [00:00<00:03, 34.7MB/s]\u001b[A\nDownloading data: 10%|▉ | 11.3M/113M [00:00<00:02, 41.2MB/s]\u001b[A\nDownloading data: 14%|█▍ | 16.1M/113M [00:00<00:02, 43.8MB/s]\u001b[A\nDownloading data: 18%|█▊ | 20.8M/113M [00:00<00:02, 45.0MB/s]\u001b[A\nDownloading data: 23%|██▎ | 25.8M/113M [00:00<00:01, 46.7MB/s]\u001b[A\nDownloading data: 27%|██▋ | 30.9M/113M [00:00<00:01, 48.1MB/s]\u001b[A\nDownloading data: 31%|███▏ | 35.7M/113M [00:00<00:02, 38.7MB/s]\u001b[A\nDownloading data: 36%|███▌ | 40.6M/113M [00:00<00:01, 41.5MB/s]\u001b[A\nDownloading data: 40%|████ | 45.6M/113M [00:01<00:01, 44.0MB/s]\u001b[A\nDownloading data: 45%|████▍ | 50.7M/113M [00:01<00:01, 45.8MB/s]\u001b[A\nDownloading data: 49%|████▉ | 55.6M/113M [00:01<00:01, 46.9MB/s]\u001b[A\nDownloading data: 53%|█████▎ | 60.4M/113M [00:01<00:01, 47.3MB/s]\u001b[A\nDownloading data: 58%|█████▊ | 65.4M/113M [00:01<00:00, 48.1MB/s]\u001b[A\nDownloading data: 62%|██████▏ | 70.5M/113M [00:01<00:00, 48.7MB/s]\u001b[A\nDownloading data: 67%|██████▋ | 75.5M/113M [00:01<00:00, 49.2MB/s]\u001b[A\nDownloading data: 71%|███████ | 80.5M/113M [00:01<00:00, 49.3MB/s]\u001b[A\nDownloading data: 75%|███████▌ | 85.5M/113M [00:01<00:00, 49.6MB/s]\u001b[A\nDownloading data: 80%|███████▉ | 90.6M/113M [00:01<00:00, 50.0MB/s]\u001b[A\nDownloading data: 84%|████████▍ | 95.7M/113M [00:02<00:00, 50.2MB/s]\u001b[A\nDownloading data: 89%|████████▉ | 101M/113M [00:02<00:00, 50.3MB/s] \u001b[A\nDownloading data: 93%|█████████▎| 106M/113M [00:02<00:00, 50.5MB/s]\u001b[A\nDownloading data: 100%|██████████| 113M/113M [00:02<00:00, 46.7MB/s]\u001b[A\nDownloading data files: 100%|██████████| 1/1 [00:03<00:00, 3.04s/it]\nExtracting data files: 100%|██████████| 1/1 [00:00<00:00, 880.42it/s]\n \r","output_type":"stream"},{"name":"stdout","text":"Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/sanchit-gandhi___parquet/sanchit-gandhi--whisper-jax-test-files-95479fe55e88baac/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.\n","output_type":"stream"}]},{"cell_type":"markdown","source":"We can take a listen to the audio file that we've loaded - we'll see that it's approximately 5 mins long:","metadata":{}},{"cell_type":"code","source":"from IPython.display import Audio\n\nAudio(audio[\"array\"], rate=audio[\"sampling_rate\"])","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:08:17.687425Z","iopub.execute_input":"2023-04-25T15:08:17.688181Z","iopub.status.idle":"2023-04-25T15:08:17.898857Z","shell.execute_reply.started":"2023-04-25T15:08:17.68815Z","shell.execute_reply":"2023-04-25T15:08:17.897795Z"},"trusted":true},"execution_count":6,"outputs":[{"execution_count":6,"output_type":"execute_result","data":{"text/plain":"","text/html":"\n