From 7cae0ae46a4b3418b8b21074efcd744bf0bb40be Mon Sep 17 00:00:00 2001 From: Alex McKinney Date: Wed, 21 Jun 2023 16:51:23 +0100 Subject: [PATCH] Rename `FlaxWhisperPipline` -> `FlaxWhisperPipeline` --- README.md | 28 ++++++++++++++-------------- app/app.py | 4 ++-- whisper-jax-tpu.ipynb | 2 +- whisper_jax/__init__.py | 2 +- whisper_jax/pipeline.py | 2 +- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 0077d99..810d1a2 100644 --- a/README.md +++ b/README.md @@ -32,17 +32,17 @@ pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit ## Pipeline Usage -The recommended way of running Whisper JAX is through the [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) abstraction class. This class handles all +The recommended way of running Whisper JAX is through the [`FlaxWhisperPipeline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) abstraction class. This class handles all the necessary pre- and post-processing, as well as wrapping the generate method for data parallelism across accelerator devices. Whisper JAX makes use of JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) function for data parallelism across GPU/TPU devices. This function is _Just In Time (JIT)_ compiled the first time it is called. Thereafter, the function will be _cached_, enabling it to be run in super-fast time: ```python -from whisper_jax import FlaxWhisperPipline +from whisper_jax import FlaxWhisperPipeline # instantiate pipeline -pipeline = FlaxWhisperPipline("openai/whisper-large-v2") +pipeline = FlaxWhisperPipeline("openai/whisper-large-v2") # JIT compile the forward call - slow, but we only do once text = pipeline("audio.mp3") @@ -59,11 +59,11 @@ of the model weights. For most GPUs, the dtype should be set to `jnp.float16`. For A100 GPUs or TPUs, the dtype should be set to `jnp.bfloat16`: ```python -from whisper_jax import FlaxWhisperPipline +from whisper_jax import FlaxWhisperPipeline import jax.numpy as jnp # instantiate pipeline in bfloat16 -pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16) +pipeline = FlaxWhisperPipeline("openai/whisper-large-v2", dtype=jnp.bfloat16) ``` ### Batching @@ -75,10 +75,10 @@ provides a 10x speed-up compared to transcribing the audio samples sequentially, To enable batching, pass the `batch_size` parameter when you instantiate the pipeline: ```python -from whisper_jax import FlaxWhisperPipline +from whisper_jax import FlaxWhisperPipeline # instantiate pipeline with batching -pipeline = FlaxWhisperPipline("openai/whisper-large-v2", batch_size=16) +pipeline = FlaxWhisperPipeline("openai/whisper-large-v2", batch_size=16) ``` ### Task @@ -93,7 +93,7 @@ text = pipeline("audio.mp3", task="translate") ### Timestamps -The [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) also supports timestamp prediction. Note that enabling timestamps will require a second JIT compilation of the +The [`FlaxWhisperPipeline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) also supports timestamp prediction. Note that enabling timestamps will require a second JIT compilation of the forward call, this time including the timestamp outputs: ```python @@ -108,11 +108,11 @@ In the following code snippet, we instantiate the model in bfloat16 precision wi returning timestamps tokens: ```python -from whisper_jax import FlaxWhisperPipline +from whisper_jax import FlaxWhisperPipeline import jax.numpy as jnp # instantiate pipeline with bfloat16 and enable batching -pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16) +pipeline = FlaxWhisperPipeline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16) # transcribe and return timestamps outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True) @@ -188,7 +188,7 @@ the next time they are required. Note that converting weights from PyTorch to Fl For example, to convert the fine-tuned checkpoint [`sanchit-gandhi/whisper-small-hi`](https://huggingface.co/sanchit-gandhi/whisper-small-hi) from the blog post [Fine-Tuning Whisper](https://huggingface.co/blog/fine-tune-whisper): ```python -from whisper_jax import FlaxWhisperForConditionalGeneration, FlaxWhisperPipline +from whisper_jax import FlaxWhisperForConditionalGeneration, FlaxWhisperPipeline import jax.numpy as jnp checkpoint_id = "sanchit-gandhi/whisper-small-hi" @@ -198,7 +198,7 @@ model = FlaxWhisperForConditionalGeneration.from_pretrained(checkpoint_id, from_ model.push_to_hub(checkpoint_id) # now we can load the Flax weights directly as required -pipeline = FlaxWhisperPipline(checkpoint_id, dtype=jnp.bfloat16, batch_size=16) +pipeline = FlaxWhisperPipeline(checkpoint_id, dtype=jnp.bfloat16, batch_size=16) ``` ## Advanced Usage @@ -212,7 +212,7 @@ The following code snippet demonstrates how data parallelism can be achieved usi an entirely equivalent way to `pmap`: ```python -from whisper_jax import FlaxWhisperPipline +from whisper_jax import FlaxWhisperPipeline import jax.numpy as jnp # 2D parameter and activation partitioning for DP @@ -230,7 +230,7 @@ logical_axis_rules_dp = ( ("channels", None), ) -pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16) +pipeline = FlaxWhisperPipeline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16) pipeline.shard_params(num_mp_partitions=1, logical_axis_rules=logical_axis_rules_dp) ``` diff --git a/app/app.py b/app/app.py index dc762d5..d2ffde1 100644 --- a/app/app.py +++ b/app/app.py @@ -13,7 +13,7 @@ from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE from transformers.pipelines.audio_utils import ffmpeg_read -from whisper_jax import FlaxWhisperPipline +from whisper_jax import FlaxWhisperPipeline cc.initialize_cache("./jax_cache") @@ -73,7 +73,7 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal if __name__ == "__main__": - pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE) + pipeline = FlaxWhisperPipeline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE) stride_length_s = CHUNK_LENGTH_S / 6 chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate) stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate) diff --git a/whisper-jax-tpu.ipynb b/whisper-jax-tpu.ipynb index 54db980..cba4339 100644 --- a/whisper-jax-tpu.ipynb +++ b/whisper-jax-tpu.ipynb @@ -1 +1 @@ -{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.8.16","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"}},"nbformat_minor":4,"nbformat":4,"cells":[{"source":"\"Kaggle\"","metadata":{},"cell_type":"markdown"},{"cell_type":"markdown","source":"## Whisper JAX ⚡️\n\nThis Kaggle notebook demonstratese how to run Whisper JAX on a TPU v3-8. Whisper JAX is a highly optimised JAX implementation of the Whisper model by OpenAI, largely built on the 🤗 Hugging Face Transformers Whisper implementation. Compared to OpenAI's PyTorch code, Whisper JAX runs over **70x faster**, making it the fastest Whisper implementation available.\n\nThe Whisper JAX model is also running as a [demo](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax) on the Hugging Face Hub. You can find the code [here](https://github.com/sanchit-gandhi/whisper-jax).","metadata":{}},{"cell_type":"markdown","source":"## Let's get started!\n\nThe first thing we need to do is connect to a TPU. Kaggle offers 20 hours of TPU v3-8 usage per month for free, which we'll make use of for this notebook. Refer to the guide [Introducing TPUs to Kaggle](https://www.kaggle.com/product-feedback/129828) for more information on TPU quotas in Kaggle.\n\nYou will need to register a Kaggle account and verify your phone number if you haven't done so already. Once verified, open up the settings menu in the Notebook editor (the small arrow in the bottom right). Then under _Notebook options_, select ‘TPU VM v3-8’ from the _Accelerator_ menu. You will also need to toggle the internet switch so that it is set to \"on\".\n\nOnce we've got a TPU allocated (there might be a queue to get one!), we can run the following to see the TPU devices we have available:","metadata":{}},{"cell_type":"code","source":"import jax\njax.devices()","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:05:34.178195Z","iopub.execute_input":"2023-04-25T15:05:34.178923Z","iopub.status.idle":"2023-04-25T15:05:41.356736Z","shell.execute_reply.started":"2023-04-25T15:05:34.178889Z","shell.execute_reply":"2023-04-25T15:05:41.355848Z"},"trusted":true},"execution_count":1,"outputs":[{"execution_count":1,"output_type":"execute_result","data":{"text/plain":"[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"},"metadata":{}}]},{"cell_type":"markdown","source":"Cool! We've got 8 TPU devices packaged into one accelerator.\n\nKaggle TPUs come with JAX pre-installed, so we can directly install the remaining Python packages. If you're running the notebook on a Cloud TPU, ensure you have installed JAX according to the official [installation guide](https://github.com/google/jax#pip-installation-google-cloud-tpu). \n\nWe'll install [Whisper JAX](https://github.com/sanchit-gandhi/whisper-jax) from main, as well as `datasets`, `soundfile` and `librosa` for loading audio files:","metadata":{}},{"cell_type":"code","source":"!pip install --quiet --upgrade pip\n!pip install --quiet git+https://github.com/sanchit-gandhi/whisper-jax.git datasets soundfile librosa","metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","tags":[],"execution":{"iopub.status.busy":"2023-04-25T15:05:41.358251Z","iopub.execute_input":"2023-04-25T15:05:41.358603Z","iopub.status.idle":"2023-04-25T15:06:26.680298Z","shell.execute_reply.started":"2023-04-25T15:05:41.358577Z","shell.execute_reply":"2023-04-25T15:06:26.678988Z"},"trusted":true},"execution_count":2,"outputs":[{"name":"stdout","text":"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n\u001b[0m","output_type":"stream"}]},{"cell_type":"markdown","source":"## Loading the Pipeline\n\nThe recommended way of running Whisper JAX is through the [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) class. This class handles all the necessary pre- and post-processing for the model, as well as wrapping the generate method for data parallelism across all available accelerator devices.\n\nWhisper JAX makes use of JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) function for data parallelism across GPU/TPU devices. This function is Just In Time (JIT) compiled the first time it is called. Thereafter, the function will be cached, enabling it to be run in super-fast time.\n\n\nLet's load the large-v2 model in bfloat16 (half-precision). Using half-precision will speed-up the computation quite considerably by storing intermediate tensors in half-precision. There is no change to the precision of the model weights.\n\nWe'll also make use of _batching_ for single audio inputs: the audio is first chunked into 30 second segments, and then chunks dispatched to the model to be transcribed in parallel. By batching an audio input and transcribing it in parallel, we get a ~10x speed-up compared to transcribing the audio samples sequentially.","metadata":{}},{"cell_type":"code","source":"from whisper_jax import FlaxWhisperPipline\nimport jax.numpy as jnp\n\npipeline = FlaxWhisperPipline(\"openai/whisper-large-v2\", dtype=jnp.bfloat16, batch_size=16)","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:06:26.681714Z","iopub.execute_input":"2023-04-25T15:06:26.682018Z","iopub.status.idle":"2023-04-25T15:07:53.135957Z","shell.execute_reply.started":"2023-04-25T15:06:26.681991Z","shell.execute_reply":"2023-04-25T15:07:53.134715Z"},"trusted":true},"execution_count":3,"outputs":[{"name":"stderr","text":"/usr/local/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\nDownloading (…)rocessor_config.json: 100%|██████████| 185k/185k [00:00<00:00, 5.29MB/s]\nDownloading (…)okenizer_config.json: 100%|██████████| 800/800 [00:00<00:00, 120kB/s]\nDownloading (…)olve/main/vocab.json: 100%|██████████| 836k/836k [00:00<00:00, 15.9MB/s]\nDownloading (…)/main/tokenizer.json: 100%|██████████| 2.20M/2.20M [00:00<00:00, 39.1MB/s]\nDownloading (…)olve/main/merges.txt: 100%|██████████| 494k/494k [00:00<00:00, 28.1MB/s]\nDownloading (…)main/normalizer.json: 100%|██████████| 52.7k/52.7k [00:00<00:00, 24.5MB/s]\nDownloading (…)in/added_tokens.json: 100%|██████████| 2.08k/2.08k [00:00<00:00, 1.11MB/s]\nDownloading (…)cial_tokens_map.json: 100%|██████████| 2.08k/2.08k [00:00<00:00, 1.22MB/s]\nDownloading (…)lve/main/config.json: 100%|██████████| 1.99k/1.99k [00:00<00:00, 320kB/s]\nDownloading flax_model.msgpack: 100%|██████████| 6.17G/6.17G [00:30<00:00, 203MB/s] \nDownloading (…)neration_config.json: 100%|██████████| 3.51k/3.51k [00:00<00:00, 1.99MB/s]\n","output_type":"stream"}]},{"cell_type":"markdown","source":"We'll then initialise a compilation cache, which will speed-up the compilation time if we close our kernel and want to compile the model again:","metadata":{}},{"cell_type":"code","source":"from jax.experimental.compilation_cache import compilation_cache as cc\n\ncc.initialize_cache(\"./jax_cache\")","metadata":{"tags":[],"execution":{"iopub.status.busy":"2023-04-25T15:07:53.139631Z","iopub.execute_input":"2023-04-25T15:07:53.139958Z","iopub.status.idle":"2023-04-25T15:07:53.146144Z","shell.execute_reply.started":"2023-04-25T15:07:53.139931Z","shell.execute_reply":"2023-04-25T15:07:53.145279Z"},"trusted":true},"execution_count":4,"outputs":[{"name":"stderr","text":"WARNING:jax.experimental.compilation_cache.compilation_cache:Initialized persistent compilation cache at ./jax_cache\n","output_type":"stream"}]},{"cell_type":"markdown","source":"## 🎶 Load an audio file\n\nLet's load up a long audio file for our tests. We provide 5 and 30 mins audio files created by contatenating consecutive sample of the [LibriSpeech ASR](https://huggingface.co/datasets/librispeech_asr) corpus, which we can load in one line through Hugging Face Datastes' [`load_dataset`](https://huggingface.co/docs/datasets/loading#load) function. Note that you can also pass in any `.mp3`, `.wav` or `.flac` audio file directly to the Whisper JAX pipeline, and it will take care of loading the audio file for you.","metadata":{}},{"cell_type":"code","source":"from datasets import load_dataset\n\ntest_dataset = load_dataset(\"sanchit-gandhi/whisper-jax-test-files\", split=\"train\")\naudio = test_dataset[0][\"audio\"] # load the first sample (5 mins) and get the audio array","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:07:53.147081Z","iopub.execute_input":"2023-04-25T15:07:53.147354Z","iopub.status.idle":"2023-04-25T15:08:17.685945Z","shell.execute_reply.started":"2023-04-25T15:07:53.147331Z","shell.execute_reply":"2023-04-25T15:08:17.684706Z"},"trusted":true},"execution_count":5,"outputs":[{"name":"stderr","text":"Downloading readme: 100%|██████████| 371/371 [00:00<00:00, 188kB/s]\n","output_type":"stream"},{"name":"stdout","text":"Downloading and preparing dataset None/None to /root/.cache/huggingface/datasets/sanchit-gandhi___parquet/sanchit-gandhi--whisper-jax-test-files-95479fe55e88baac/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...\n","output_type":"stream"},{"name":"stderr","text":"Downloading data files: 0%| | 0/1 [00:00","text/html":"\n \n "},"metadata":{}}]},{"cell_type":"markdown","source":"## Run the model\n\nNow we're ready to transcribe! We'll need to compile the `pmap` function the first time we use it. You can expect compilation to take ~2 minutes on a TPU v3-8 with a batch size of 16. Enough time to grab a coffee ☕️\n\nThereafter, we can use our cached `pmap` function, which you'll see is amazingly fast.","metadata":{}},{"cell_type":"code","source":"# JIT compile the forward call - slow, but we only do once\n%time text = pipeline(audio)","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:08:17.900179Z","iopub.execute_input":"2023-04-25T15:08:17.900475Z","iopub.status.idle":"2023-04-25T15:10:19.031506Z","shell.execute_reply.started":"2023-04-25T15:08:17.90045Z","shell.execute_reply":"2023-04-25T15:10:19.030345Z"},"trusted":true},"execution_count":7,"outputs":[{"name":"stdout","text":"CPU times: user 3min 21s, sys: 1min 15s, total: 4min 37s\nWall time: 2min 1s\n","output_type":"stream"}]},{"cell_type":"code","source":"# used cached function thereafter - super fast!\n%time text = pipeline(audio)","metadata":{"tags":[],"execution":{"iopub.status.busy":"2023-04-25T15:10:19.033137Z","iopub.execute_input":"2023-04-25T15:10:19.033486Z","iopub.status.idle":"2023-04-25T15:10:24.923769Z","shell.execute_reply.started":"2023-04-25T15:10:19.033437Z","shell.execute_reply":"2023-04-25T15:10:24.922709Z"},"trusted":true},"execution_count":8,"outputs":[{"name":"stdout","text":"CPU times: user 28.5 s, sys: 52.5 s, total: 1min 20s\nWall time: 5.88 s\n","output_type":"stream"}]},{"cell_type":"code","source":"# let's check our transcription - looks spot on!\nprint(text)","metadata":{"tags":[],"execution":{"iopub.status.busy":"2023-04-25T15:10:24.925055Z","iopub.execute_input":"2023-04-25T15:10:24.925367Z","iopub.status.idle":"2023-04-25T15:10:24.93081Z","shell.execute_reply.started":"2023-04-25T15:10:24.925339Z","shell.execute_reply":"2023-04-25T15:10:24.929746Z"},"trusted":true},"execution_count":9,"outputs":[{"name":"stdout","text":"{'text': \" Chapter 16. I might have told you of the beginning of this liaison in a few lines, but I wanted you to see every step by which we came, I to agree to whatever Marguerite wished, Marguerite to be unable to live apart from me. It was the day after the evening when she came to see me that I sent her Manon Lescate. From that time, seeing that I could not change my mistress's life, I changed my own. I wished above all not to leave myself time to think over the position I had accepted, for, in spite of myself, it was a great distress to me. Thus my life, generally so calm, assumed all at once an appearance of noise and disorder. Never believe, however disinterested the love of a kept woman may be, that it will cost one nothing. Nothing is so expensive as their caprices, flowers, boxes at the theatre, suppers, days in the country, which one can never refuse to one's mistress. As I have told you, I had little money. My father was, and still is, Receiver General at sea. He has a great reputation there for loyalty, thanks to which he was able to find the security which he needed in order to attain this position. I came to Paris, studied law, was called to the bar, and, like many other young men, put my diploma in my pocket, and let myself drift, as one so easily does in Paris. My expenses were very moderate, only I used up my year's income in eight months, and spent the four summer months with my father, which practically gave me twelve thousand francs a year, and, in addition, the reputation of a good son. For the rest, not a penny of debt. This, then, was my position when I made the acquaintance of Marguerite. You can well understand that, in spite of myself, my expenses soon increased. Marguerite's nature was very capricious, and, like so many women, she never regarded as a serious expense those thousand and one distractions which made up her life. So, wishing to spend as much time with me as possible, she would write to me in the morning that she would dine with me, not at home, but at some restaurant in Paris, or in the country. I would call for her, and we would dine and go on to the theatre, often having supper as well. Forgive me if I give you all these details, but you will see that they were the cause of what was to follow. What I tell you is a true and simple story, and I leave to it all the naivete of its details, and all the simplicity of its developments. I realized then that as nothing in the world would make me forget my mistress, it was needful for me to find some way of meeting the expenses into which she drew me. Then, too, my love for her had so disturbing an influence upon me, that every moment I spent away from Marguerite was like a year, and that I felt the need of consuming these moments in the fire of some sort of passion, as not to know that I was living them. I began by borrowing five or six thousand francs on my little capital, and with this I took to gambling. Since gambling-houses were destroyed, gambling goes on everywhere. Formerly, when one went to Frascati, one had the chance of making a fortune, one played against money, and if one lost, there was always the consolation of saying that one might have gained. Whereas now, except in the clubs, where there is still a certain rigour in regard to payments, one is almost certain, the moment one gains a considerable sum, not to receive it. You will readily understand why. Gambling is only likely to be carried on by young people very much in need of money, and not possessing the fortune necessary for supporting the life they lead. They gamble, then, and with this result, or else they gain, and then those who lose serve to pay for their horses and mistresses, which is very disagreeable. Debts are contracted, acquaintances begun about a green table, and by quarrels in which life or honour comes to grief, and though one may be an honest man, one finds oneself ruined by very honest men, whose only defect is that they have not two hundred thousand francs a year. I need not tell you of those who cheat at play. I flung myself into this rapid, noisy, and volcanic life, which had formerly terrified me when I thought of it, and which had become for me the necessary complement of my love for Marguerite. What else could I have done? The nights that I did not spend in the Rue d'Antin, if I had spent them alone in my own room, I could not have slept. Jealousy would have kept me awake, and inflamed my blood and my thoughts.\"}\n","output_type":"stream"}]},{"cell_type":"markdown","source":"## Run it again!","metadata":{}},{"cell_type":"markdown","source":"Now let's step it up a notch. Let's try transcribing 30 minutes of audio from the LibriSpeech dataset. We'll first load up the second sample from our dataset, which corresponds to the 30 min audio file. We'll then pass the audio to the model for transcription, again timing how long the foward pass takes:","metadata":{}},{"cell_type":"code","source":"audio = test_dataset[1][\"audio\"] # load the second sample (30 mins) and get the audio array\n\naudio_length_in_mins = len(audio[\"array\"]) / audio[\"sampling_rate\"] / 60\nprint(f\"Audio is {audio_length_in_mins} mins.\")","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:10:24.933638Z","iopub.execute_input":"2023-04-25T15:10:24.933953Z","iopub.status.idle":"2023-04-25T15:10:25.123771Z","shell.execute_reply.started":"2023-04-25T15:10:24.933927Z","shell.execute_reply":"2023-04-25T15:10:25.122674Z"},"trusted":true},"execution_count":10,"outputs":[{"name":"stdout","text":"Audio is 30.252252083333335 mins.\n","output_type":"stream"}]},{"cell_type":"code","source":"# transcribe using cached function\n%time text = pipeline(audio)","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:10:25.125014Z","iopub.execute_input":"2023-04-25T15:10:25.125314Z","iopub.status.idle":"2023-04-25T15:11:00.199377Z","shell.execute_reply.started":"2023-04-25T15:10:25.125287Z","shell.execute_reply":"2023-04-25T15:11:00.198092Z"},"trusted":true},"execution_count":11,"outputs":[{"name":"stdout","text":"CPU times: user 3min 3s, sys: 5min 14s, total: 8min 18s\nWall time: 35.1 s\n","output_type":"stream"}]},{"cell_type":"markdown","source":"Just 35s to transcribe for 30 mins of audio! That means you could transcribe an entire 2 hour movie in under 2.5 minutes 🤯 By increasing the batch size, we could also reduce the transcription time for long audio files further: increasing the batch size by 2x roughly decreases the transcription time by 2x, provided the overall batch size is less than the total audio time.\n\nIf you're fortunate enough to have access to a TPU v4, you'll find that the transcription times a factor of 2 faster than on a v3 - you can quickly see how we can get super fast transcription times using Whisper JAX on TPU!","metadata":{}},{"cell_type":"markdown","source":"## ⏰ Timestamps and more\n\nWe can also get timestamps from the model by passing `return_timestamps=True`, but this will require a recompilation since we change the signature of the forward pass. \n\nThe timestamps compilation takes longer than the non-timestamps one. Luckily, because we initialised our compilation cache above, we're not starting from scratch in compiling this time. This is the last compilation we need to do!","metadata":{}},{"cell_type":"code","source":"# compile the forward call with timestamps - slow but we only do once\n%time outputs = pipeline(audio, return_timestamps=True)\ntext = outputs[\"text\"] # transcription\nchunks = outputs[\"chunks\"] # transcription + timestamps","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:11:00.200803Z","iopub.execute_input":"2023-04-25T15:11:00.201152Z","iopub.status.idle":"2023-04-25T15:13:18.16395Z","shell.execute_reply.started":"2023-04-25T15:11:00.201111Z","shell.execute_reply":"2023-04-25T15:13:18.16294Z"},"trusted":true},"execution_count":12,"outputs":[{"name":"stdout","text":"CPU times: user 7min 8s, sys: 7min 31s, total: 14min 39s\nWall time: 2min 17s\n","output_type":"stream"}]},{"cell_type":"code","source":"# use cached timestamps function - super fast!\n%time outputs = pipeline(audio, return_timestamps=True)\ntext = outputs[\"text\"] \nchunks = outputs[\"chunks\"]","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:13:18.165286Z","iopub.execute_input":"2023-04-25T15:13:18.165597Z","iopub.status.idle":"2023-04-25T15:13:36.482166Z","shell.execute_reply.started":"2023-04-25T15:13:18.165567Z","shell.execute_reply":"2023-04-25T15:13:36.48114Z"},"trusted":true},"execution_count":13,"outputs":[{"name":"stdout","text":"CPU times: user 3min 16s, sys: 5min 35s, total: 8min 52s\nWall time: 18.3 s\n","output_type":"stream"}]},{"cell_type":"markdown","source":"We've shown how you can transcibe an audio file in English. The pipeline is also compatible with two further arguments that you can use to control the generation process. It's perfectly fine to omit these if you want speech transcription and the Whisper model to automatically detect which language the audio is in. Otherwise, you can change them depending on your task/language:\n\n\n* `task`: task to use for generation, either `\"transcribe\"` or `\"translate\"`. Defaults to `\"transcribe\"`.\n* `language`: language token to use for generation, can be either in the form of `\"<|en|>\"`, `\"en\"` or `\"english\"`. Defaults to `None`, meaning the language is automatically inferred from the audio input. Optional, and only relevant if the source audio language is known a-priori.","metadata":{}}]} \ No newline at end of file +{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.8.16","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"}},"nbformat_minor":4,"nbformat":4,"cells":[{"source":"\"Kaggle\"","metadata":{},"cell_type":"markdown"},{"cell_type":"markdown","source":"## Whisper JAX ⚡️\n\nThis Kaggle notebook demonstratese how to run Whisper JAX on a TPU v3-8. Whisper JAX is a highly optimised JAX implementation of the Whisper model by OpenAI, largely built on the 🤗 Hugging Face Transformers Whisper implementation. Compared to OpenAI's PyTorch code, Whisper JAX runs over **70x faster**, making it the fastest Whisper implementation available.\n\nThe Whisper JAX model is also running as a [demo](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax) on the Hugging Face Hub. You can find the code [here](https://github.com/sanchit-gandhi/whisper-jax).","metadata":{}},{"cell_type":"markdown","source":"## Let's get started!\n\nThe first thing we need to do is connect to a TPU. Kaggle offers 20 hours of TPU v3-8 usage per month for free, which we'll make use of for this notebook. Refer to the guide [Introducing TPUs to Kaggle](https://www.kaggle.com/product-feedback/129828) for more information on TPU quotas in Kaggle.\n\nYou will need to register a Kaggle account and verify your phone number if you haven't done so already. Once verified, open up the settings menu in the Notebook editor (the small arrow in the bottom right). Then under _Notebook options_, select ‘TPU VM v3-8’ from the _Accelerator_ menu. You will also need to toggle the internet switch so that it is set to \"on\".\n\nOnce we've got a TPU allocated (there might be a queue to get one!), we can run the following to see the TPU devices we have available:","metadata":{}},{"cell_type":"code","source":"import jax\njax.devices()","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:05:34.178195Z","iopub.execute_input":"2023-04-25T15:05:34.178923Z","iopub.status.idle":"2023-04-25T15:05:41.356736Z","shell.execute_reply.started":"2023-04-25T15:05:34.178889Z","shell.execute_reply":"2023-04-25T15:05:41.355848Z"},"trusted":true},"execution_count":1,"outputs":[{"execution_count":1,"output_type":"execute_result","data":{"text/plain":"[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"},"metadata":{}}]},{"cell_type":"markdown","source":"Cool! We've got 8 TPU devices packaged into one accelerator.\n\nKaggle TPUs come with JAX pre-installed, so we can directly install the remaining Python packages. If you're running the notebook on a Cloud TPU, ensure you have installed JAX according to the official [installation guide](https://github.com/google/jax#pip-installation-google-cloud-tpu). \n\nWe'll install [Whisper JAX](https://github.com/sanchit-gandhi/whisper-jax) from main, as well as `datasets`, `soundfile` and `librosa` for loading audio files:","metadata":{}},{"cell_type":"code","source":"!pip install --quiet --upgrade pip\n!pip install --quiet git+https://github.com/sanchit-gandhi/whisper-jax.git datasets soundfile librosa","metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","tags":[],"execution":{"iopub.status.busy":"2023-04-25T15:05:41.358251Z","iopub.execute_input":"2023-04-25T15:05:41.358603Z","iopub.status.idle":"2023-04-25T15:06:26.680298Z","shell.execute_reply.started":"2023-04-25T15:05:41.358577Z","shell.execute_reply":"2023-04-25T15:06:26.678988Z"},"trusted":true},"execution_count":2,"outputs":[{"name":"stdout","text":"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n\u001b[0m","output_type":"stream"}]},{"cell_type":"markdown","source":"## Loading the Pipeline\n\nThe recommended way of running Whisper JAX is through the [`FlaxWhisperPipeline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) class. This class handles all the necessary pre- and post-processing for the model, as well as wrapping the generate method for data parallelism across all available accelerator devices.\n\nWhisper JAX makes use of JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) function for data parallelism across GPU/TPU devices. This function is Just In Time (JIT) compiled the first time it is called. Thereafter, the function will be cached, enabling it to be run in super-fast time.\n\n\nLet's load the large-v2 model in bfloat16 (half-precision). Using half-precision will speed-up the computation quite considerably by storing intermediate tensors in half-precision. There is no change to the precision of the model weights.\n\nWe'll also make use of _batching_ for single audio inputs: the audio is first chunked into 30 second segments, and then chunks dispatched to the model to be transcribed in parallel. By batching an audio input and transcribing it in parallel, we get a ~10x speed-up compared to transcribing the audio samples sequentially.","metadata":{}},{"cell_type":"code","source":"from whisper_jax import FlaxWhisperPipeline\nimport jax.numpy as jnp\n\npipeline = FlaxWhisperPipeline(\"openai/whisper-large-v2\", dtype=jnp.bfloat16, batch_size=16)","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:06:26.681714Z","iopub.execute_input":"2023-04-25T15:06:26.682018Z","iopub.status.idle":"2023-04-25T15:07:53.135957Z","shell.execute_reply.started":"2023-04-25T15:06:26.681991Z","shell.execute_reply":"2023-04-25T15:07:53.134715Z"},"trusted":true},"execution_count":3,"outputs":[{"name":"stderr","text":"/usr/local/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\nDownloading (…)rocessor_config.json: 100%|██████████| 185k/185k [00:00<00:00, 5.29MB/s]\nDownloading (…)okenizer_config.json: 100%|██████████| 800/800 [00:00<00:00, 120kB/s]\nDownloading (…)olve/main/vocab.json: 100%|██████████| 836k/836k [00:00<00:00, 15.9MB/s]\nDownloading (…)/main/tokenizer.json: 100%|██████████| 2.20M/2.20M [00:00<00:00, 39.1MB/s]\nDownloading (…)olve/main/merges.txt: 100%|██████████| 494k/494k [00:00<00:00, 28.1MB/s]\nDownloading (…)main/normalizer.json: 100%|██████████| 52.7k/52.7k [00:00<00:00, 24.5MB/s]\nDownloading (…)in/added_tokens.json: 100%|██████████| 2.08k/2.08k [00:00<00:00, 1.11MB/s]\nDownloading (…)cial_tokens_map.json: 100%|██████████| 2.08k/2.08k [00:00<00:00, 1.22MB/s]\nDownloading (…)lve/main/config.json: 100%|██████████| 1.99k/1.99k [00:00<00:00, 320kB/s]\nDownloading flax_model.msgpack: 100%|██████████| 6.17G/6.17G [00:30<00:00, 203MB/s] \nDownloading (…)neration_config.json: 100%|██████████| 3.51k/3.51k [00:00<00:00, 1.99MB/s]\n","output_type":"stream"}]},{"cell_type":"markdown","source":"We'll then initialise a compilation cache, which will speed-up the compilation time if we close our kernel and want to compile the model again:","metadata":{}},{"cell_type":"code","source":"from jax.experimental.compilation_cache import compilation_cache as cc\n\ncc.initialize_cache(\"./jax_cache\")","metadata":{"tags":[],"execution":{"iopub.status.busy":"2023-04-25T15:07:53.139631Z","iopub.execute_input":"2023-04-25T15:07:53.139958Z","iopub.status.idle":"2023-04-25T15:07:53.146144Z","shell.execute_reply.started":"2023-04-25T15:07:53.139931Z","shell.execute_reply":"2023-04-25T15:07:53.145279Z"},"trusted":true},"execution_count":4,"outputs":[{"name":"stderr","text":"WARNING:jax.experimental.compilation_cache.compilation_cache:Initialized persistent compilation cache at ./jax_cache\n","output_type":"stream"}]},{"cell_type":"markdown","source":"## 🎶 Load an audio file\n\nLet's load up a long audio file for our tests. We provide 5 and 30 mins audio files created by contatenating consecutive sample of the [LibriSpeech ASR](https://huggingface.co/datasets/librispeech_asr) corpus, which we can load in one line through Hugging Face Datastes' [`load_dataset`](https://huggingface.co/docs/datasets/loading#load) function. Note that you can also pass in any `.mp3`, `.wav` or `.flac` audio file directly to the Whisper JAX pipeline, and it will take care of loading the audio file for you.","metadata":{}},{"cell_type":"code","source":"from datasets import load_dataset\n\ntest_dataset = load_dataset(\"sanchit-gandhi/whisper-jax-test-files\", split=\"train\")\naudio = test_dataset[0][\"audio\"] # load the first sample (5 mins) and get the audio array","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:07:53.147081Z","iopub.execute_input":"2023-04-25T15:07:53.147354Z","iopub.status.idle":"2023-04-25T15:08:17.685945Z","shell.execute_reply.started":"2023-04-25T15:07:53.147331Z","shell.execute_reply":"2023-04-25T15:08:17.684706Z"},"trusted":true},"execution_count":5,"outputs":[{"name":"stderr","text":"Downloading readme: 100%|██████████| 371/371 [00:00<00:00, 188kB/s]\n","output_type":"stream"},{"name":"stdout","text":"Downloading and preparing dataset None/None to /root/.cache/huggingface/datasets/sanchit-gandhi___parquet/sanchit-gandhi--whisper-jax-test-files-95479fe55e88baac/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...\n","output_type":"stream"},{"name":"stderr","text":"Downloading data files: 0%| | 0/1 [00:00","text/html":"\n \n "},"metadata":{}}]},{"cell_type":"markdown","source":"## Run the model\n\nNow we're ready to transcribe! We'll need to compile the `pmap` function the first time we use it. You can expect compilation to take ~2 minutes on a TPU v3-8 with a batch size of 16. Enough time to grab a coffee ☕️\n\nThereafter, we can use our cached `pmap` function, which you'll see is amazingly fast.","metadata":{}},{"cell_type":"code","source":"# JIT compile the forward call - slow, but we only do once\n%time text = pipeline(audio)","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:08:17.900179Z","iopub.execute_input":"2023-04-25T15:08:17.900475Z","iopub.status.idle":"2023-04-25T15:10:19.031506Z","shell.execute_reply.started":"2023-04-25T15:08:17.90045Z","shell.execute_reply":"2023-04-25T15:10:19.030345Z"},"trusted":true},"execution_count":7,"outputs":[{"name":"stdout","text":"CPU times: user 3min 21s, sys: 1min 15s, total: 4min 37s\nWall time: 2min 1s\n","output_type":"stream"}]},{"cell_type":"code","source":"# used cached function thereafter - super fast!\n%time text = pipeline(audio)","metadata":{"tags":[],"execution":{"iopub.status.busy":"2023-04-25T15:10:19.033137Z","iopub.execute_input":"2023-04-25T15:10:19.033486Z","iopub.status.idle":"2023-04-25T15:10:24.923769Z","shell.execute_reply.started":"2023-04-25T15:10:19.033437Z","shell.execute_reply":"2023-04-25T15:10:24.922709Z"},"trusted":true},"execution_count":8,"outputs":[{"name":"stdout","text":"CPU times: user 28.5 s, sys: 52.5 s, total: 1min 20s\nWall time: 5.88 s\n","output_type":"stream"}]},{"cell_type":"code","source":"# let's check our transcription - looks spot on!\nprint(text)","metadata":{"tags":[],"execution":{"iopub.status.busy":"2023-04-25T15:10:24.925055Z","iopub.execute_input":"2023-04-25T15:10:24.925367Z","iopub.status.idle":"2023-04-25T15:10:24.93081Z","shell.execute_reply.started":"2023-04-25T15:10:24.925339Z","shell.execute_reply":"2023-04-25T15:10:24.929746Z"},"trusted":true},"execution_count":9,"outputs":[{"name":"stdout","text":"{'text': \" Chapter 16. I might have told you of the beginning of this liaison in a few lines, but I wanted you to see every step by which we came, I to agree to whatever Marguerite wished, Marguerite to be unable to live apart from me. It was the day after the evening when she came to see me that I sent her Manon Lescate. From that time, seeing that I could not change my mistress's life, I changed my own. I wished above all not to leave myself time to think over the position I had accepted, for, in spite of myself, it was a great distress to me. Thus my life, generally so calm, assumed all at once an appearance of noise and disorder. Never believe, however disinterested the love of a kept woman may be, that it will cost one nothing. Nothing is so expensive as their caprices, flowers, boxes at the theatre, suppers, days in the country, which one can never refuse to one's mistress. As I have told you, I had little money. My father was, and still is, Receiver General at sea. He has a great reputation there for loyalty, thanks to which he was able to find the security which he needed in order to attain this position. I came to Paris, studied law, was called to the bar, and, like many other young men, put my diploma in my pocket, and let myself drift, as one so easily does in Paris. My expenses were very moderate, only I used up my year's income in eight months, and spent the four summer months with my father, which practically gave me twelve thousand francs a year, and, in addition, the reputation of a good son. For the rest, not a penny of debt. This, then, was my position when I made the acquaintance of Marguerite. You can well understand that, in spite of myself, my expenses soon increased. Marguerite's nature was very capricious, and, like so many women, she never regarded as a serious expense those thousand and one distractions which made up her life. So, wishing to spend as much time with me as possible, she would write to me in the morning that she would dine with me, not at home, but at some restaurant in Paris, or in the country. I would call for her, and we would dine and go on to the theatre, often having supper as well. Forgive me if I give you all these details, but you will see that they were the cause of what was to follow. What I tell you is a true and simple story, and I leave to it all the naivete of its details, and all the simplicity of its developments. I realized then that as nothing in the world would make me forget my mistress, it was needful for me to find some way of meeting the expenses into which she drew me. Then, too, my love for her had so disturbing an influence upon me, that every moment I spent away from Marguerite was like a year, and that I felt the need of consuming these moments in the fire of some sort of passion, as not to know that I was living them. I began by borrowing five or six thousand francs on my little capital, and with this I took to gambling. Since gambling-houses were destroyed, gambling goes on everywhere. Formerly, when one went to Frascati, one had the chance of making a fortune, one played against money, and if one lost, there was always the consolation of saying that one might have gained. Whereas now, except in the clubs, where there is still a certain rigour in regard to payments, one is almost certain, the moment one gains a considerable sum, not to receive it. You will readily understand why. Gambling is only likely to be carried on by young people very much in need of money, and not possessing the fortune necessary for supporting the life they lead. They gamble, then, and with this result, or else they gain, and then those who lose serve to pay for their horses and mistresses, which is very disagreeable. Debts are contracted, acquaintances begun about a green table, and by quarrels in which life or honour comes to grief, and though one may be an honest man, one finds oneself ruined by very honest men, whose only defect is that they have not two hundred thousand francs a year. I need not tell you of those who cheat at play. I flung myself into this rapid, noisy, and volcanic life, which had formerly terrified me when I thought of it, and which had become for me the necessary complement of my love for Marguerite. What else could I have done? The nights that I did not spend in the Rue d'Antin, if I had spent them alone in my own room, I could not have slept. Jealousy would have kept me awake, and inflamed my blood and my thoughts.\"}\n","output_type":"stream"}]},{"cell_type":"markdown","source":"## Run it again!","metadata":{}},{"cell_type":"markdown","source":"Now let's step it up a notch. Let's try transcribing 30 minutes of audio from the LibriSpeech dataset. We'll first load up the second sample from our dataset, which corresponds to the 30 min audio file. We'll then pass the audio to the model for transcription, again timing how long the foward pass takes:","metadata":{}},{"cell_type":"code","source":"audio = test_dataset[1][\"audio\"] # load the second sample (30 mins) and get the audio array\n\naudio_length_in_mins = len(audio[\"array\"]) / audio[\"sampling_rate\"] / 60\nprint(f\"Audio is {audio_length_in_mins} mins.\")","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:10:24.933638Z","iopub.execute_input":"2023-04-25T15:10:24.933953Z","iopub.status.idle":"2023-04-25T15:10:25.123771Z","shell.execute_reply.started":"2023-04-25T15:10:24.933927Z","shell.execute_reply":"2023-04-25T15:10:25.122674Z"},"trusted":true},"execution_count":10,"outputs":[{"name":"stdout","text":"Audio is 30.252252083333335 mins.\n","output_type":"stream"}]},{"cell_type":"code","source":"# transcribe using cached function\n%time text = pipeline(audio)","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:10:25.125014Z","iopub.execute_input":"2023-04-25T15:10:25.125314Z","iopub.status.idle":"2023-04-25T15:11:00.199377Z","shell.execute_reply.started":"2023-04-25T15:10:25.125287Z","shell.execute_reply":"2023-04-25T15:11:00.198092Z"},"trusted":true},"execution_count":11,"outputs":[{"name":"stdout","text":"CPU times: user 3min 3s, sys: 5min 14s, total: 8min 18s\nWall time: 35.1 s\n","output_type":"stream"}]},{"cell_type":"markdown","source":"Just 35s to transcribe for 30 mins of audio! That means you could transcribe an entire 2 hour movie in under 2.5 minutes 🤯 By increasing the batch size, we could also reduce the transcription time for long audio files further: increasing the batch size by 2x roughly decreases the transcription time by 2x, provided the overall batch size is less than the total audio time.\n\nIf you're fortunate enough to have access to a TPU v4, you'll find that the transcription times a factor of 2 faster than on a v3 - you can quickly see how we can get super fast transcription times using Whisper JAX on TPU!","metadata":{}},{"cell_type":"markdown","source":"## ⏰ Timestamps and more\n\nWe can also get timestamps from the model by passing `return_timestamps=True`, but this will require a recompilation since we change the signature of the forward pass. \n\nThe timestamps compilation takes longer than the non-timestamps one. Luckily, because we initialised our compilation cache above, we're not starting from scratch in compiling this time. This is the last compilation we need to do!","metadata":{}},{"cell_type":"code","source":"# compile the forward call with timestamps - slow but we only do once\n%time outputs = pipeline(audio, return_timestamps=True)\ntext = outputs[\"text\"] # transcription\nchunks = outputs[\"chunks\"] # transcription + timestamps","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:11:00.200803Z","iopub.execute_input":"2023-04-25T15:11:00.201152Z","iopub.status.idle":"2023-04-25T15:13:18.16395Z","shell.execute_reply.started":"2023-04-25T15:11:00.201111Z","shell.execute_reply":"2023-04-25T15:13:18.16294Z"},"trusted":true},"execution_count":12,"outputs":[{"name":"stdout","text":"CPU times: user 7min 8s, sys: 7min 31s, total: 14min 39s\nWall time: 2min 17s\n","output_type":"stream"}]},{"cell_type":"code","source":"# use cached timestamps function - super fast!\n%time outputs = pipeline(audio, return_timestamps=True)\ntext = outputs[\"text\"] \nchunks = outputs[\"chunks\"]","metadata":{"execution":{"iopub.status.busy":"2023-04-25T15:13:18.165286Z","iopub.execute_input":"2023-04-25T15:13:18.165597Z","iopub.status.idle":"2023-04-25T15:13:36.482166Z","shell.execute_reply.started":"2023-04-25T15:13:18.165567Z","shell.execute_reply":"2023-04-25T15:13:36.48114Z"},"trusted":true},"execution_count":13,"outputs":[{"name":"stdout","text":"CPU times: user 3min 16s, sys: 5min 35s, total: 8min 52s\nWall time: 18.3 s\n","output_type":"stream"}]},{"cell_type":"markdown","source":"We've shown how you can transcibe an audio file in English. The pipeline is also compatible with two further arguments that you can use to control the generation process. It's perfectly fine to omit these if you want speech transcription and the Whisper model to automatically detect which language the audio is in. Otherwise, you can change them depending on your task/language:\n\n\n* `task`: task to use for generation, either `\"transcribe\"` or `\"translate\"`. Defaults to `\"transcribe\"`.\n* `language`: language token to use for generation, can be either in the form of `\"<|en|>\"`, `\"en\"` or `\"english\"`. Defaults to `None`, meaning the language is automatically inferred from the audio input. Optional, and only relevant if the source audio language is known a-priori.","metadata":{}}]} diff --git a/whisper_jax/__init__.py b/whisper_jax/__init__.py index c34ee69..642aba4 100644 --- a/whisper_jax/__init__.py +++ b/whisper_jax/__init__.py @@ -17,5 +17,5 @@ from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration from .partitioner import PjitPartitioner -from .pipeline import FlaxWhisperPipline +from .pipeline import FlaxWhisperPipeline from .train_state import InferenceState diff --git a/whisper_jax/pipeline.py b/whisper_jax/pipeline.py index 06dd60d..e6b84c5 100644 --- a/whisper_jax/pipeline.py +++ b/whisper_jax/pipeline.py @@ -54,7 +54,7 @@ ) -class FlaxWhisperPipline: +class FlaxWhisperPipeline: def __init__( self, checkpoint="openai/whisper-large-v2",