From 4b8bc0d73725c2132943ade857d5a028a0b1bd49 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Tue, 16 Jul 2024 15:03:40 -0400 Subject: [PATCH] Canary Adapters tutorial (#9670) * Fix issue with prompt_defaults Signed-off-by: smajumdar * Add core level support for grad map tracking Signed-off-by: smajumdar * Add core level support for grad map tracking Signed-off-by: smajumdar * Apply isort and black reformatting Signed-off-by: titu1994 * Add tutorial and update repr of formatters Signed-off-by: smajumdar * Update docs Signed-off-by: smajumdar --------- Signed-off-by: smajumdar Signed-off-by: titu1994 --- .../asr/models/aed_multitask_models.py | 4 +- nemo/collections/asr/models/ctc_models.py | 14 - .../asr/models/hybrid_rnnt_ctc_models.py | 2 +- .../asr/models/transformer_bpe_models.py | 2 +- .../asr/parts/mixins/transcription.py | 6 +- nemo/collections/common/prompts/formatter.py | 5 +- nemo/core/classes/module.py | 99 +- tests/core/test_neural_module.py | 89 + .../asr_adapters/Multi_Task_Adapters.ipynb | 1660 +++++++++++++++++ tutorials/asr/asr_adapters/README.md | 2 + 10 files changed, 1848 insertions(+), 35 deletions(-) create mode 100644 tests/core/test_neural_module.py create mode 100644 tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 5ec7a8298beef..dbf8013af3316 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -434,7 +434,7 @@ def change_prompt( prompt_cls = PromptFormatter.resolve(self.prompt_format) self.prompt = prompt_cls( tokenizer=self.tokenizer, - defaults=OmegaConf.to_container(pd) if (pd := self.cfg.prompt_defaults) is not None else None, + defaults=OmegaConf.to_container(pd) if (pd := self.cfg.get('prompt_defaults')) is not None else None, ) # Update config @@ -979,7 +979,7 @@ def _transcribe_on_end(self, trcfg: MultiTaskTranscriptionConfig): """ super()._transcribe_on_end(trcfg) - self.transf_decoder.unfreeze() + self.transf_decoder.unfreeze(partial=True) def _may_be_make_dict_and_fix_paths(self, json_items, manifest_path, trcfg: MultiTaskTranscriptionConfig): """ diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index b6d8945b6c6b0..76233d57622b0 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -665,20 +665,6 @@ def test_dataloader(self): """ Transcription related methods """ - def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig): - super()._transcribe_on_begin(audio, trcfg) - - # Freeze the encoder and decoder modules - self.encoder.freeze() - self.decoder.freeze() - - def _transcribe_on_end(self, trcfg: TranscribeConfig): - super()._transcribe_on_end(trcfg) - - # Unfreeze the encoder and decoder modules - self.encoder.unfreeze() - self.decoder.unfreeze() - def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): logits, logits_len, greedy_predictions = self.forward(input_signal=batch[0], input_signal_length=batch[1]) output = dict(logits=logits, logits_len=logits_len) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index c7c09739be647..f161454c9bae1 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -157,7 +157,7 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig): super()._transcribe_on_end(trcfg) if hasattr(self, 'ctc_decoder'): - self.ctc_decoder.unfreeze() + self.ctc_decoder.unfreeze(partial=True) def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): if self.cur_decoder == "rnnt": diff --git a/nemo/collections/asr/models/transformer_bpe_models.py b/nemo/collections/asr/models/transformer_bpe_models.py index 79de83f1d4a19..9970b49702363 100644 --- a/nemo/collections/asr/models/transformer_bpe_models.py +++ b/nemo/collections/asr/models/transformer_bpe_models.py @@ -633,4 +633,4 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig): super()._transcribe_on_end(trcfg) # Unfreeze the encoder and decoder modules - self.transf_decoder.unfreeze() + self.transf_decoder.unfreeze(partial=True) diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index b6238cad4534a..2105097d0afff 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -770,13 +770,13 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig): # Unfreeze the encoder and decoder modules if hasattr(self, 'encoder'): - self.encoder.unfreeze() + self.encoder.unfreeze(partial=True) if hasattr(self, 'decoder'): - self.decoder.unfreeze() + self.decoder.unfreeze(partial=True) if hasattr(self, 'joint'): - self.joint.unfreeze() + self.joint.unfreeze(partial=True) @classmethod def get_transcribe_config(cls) -> TranscribeConfig: diff --git a/nemo/collections/common/prompts/formatter.py b/nemo/collections/common/prompts/formatter.py index 8a82563ebbaa6..6d2c67f5311d5 100644 --- a/nemo/collections/common/prompts/formatter.py +++ b/nemo/collections/common/prompts/formatter.py @@ -25,6 +25,9 @@ class BaseModalityType: def matches(value: Any) -> bool: raise NotImplementedError + def __repr__(self): + return f"Modality.{self.__class__.__name__}()" + class Text(BaseModalityType): """Modality for text values.""" @@ -42,7 +45,7 @@ def matches(self, value: str) -> bool: return isinstance(value, str) and value in self.allowed_values def __repr__(self): - return f"{self.__class__.__name__}({self.allowed_values})" + return f"Modality.{self.__class__.__name__}(allowed_values={self.allowed_values})" class Modality: diff --git a/nemo/core/classes/module.py b/nemo/core/classes/module.py index 2d7bd0179447f..ef80467c8c7a2 100644 --- a/nemo/core/classes/module.py +++ b/nemo/core/classes/module.py @@ -18,6 +18,7 @@ from torch.nn import Module from nemo.core.classes.common import FileIO, Serialization, Typing +from nemo.utils import logging __all__ = ['NeuralModule'] @@ -54,39 +55,111 @@ def input_example(self, max_batch=None, max_dim=None): def freeze(self) -> None: r""" Freeze all params for inference. + + This method sets `requires_grad` to False for all parameters of the module. + It also stores the original `requires_grad` state of each parameter in a dictionary, + so that `unfreeze()` can restore the original state if `partial=True` is set in `unfreeze()`. """ - for param in self.parameters(): + grad_map = {} + + for pname, param in self.named_parameters(): + # Store the original grad state + grad_map[pname] = param.requires_grad + # Freeze the parameter param.requires_grad = False + # Store the frozen grad map + if not hasattr(self, '_frozen_grad_map'): + self._frozen_grad_map = grad_map + else: + self._frozen_grad_map.update(grad_map) + self.eval() - def unfreeze(self) -> None: + def unfreeze(self, partial: bool = False) -> None: """ Unfreeze all parameters for training. + + Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`). + The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were + previously unfrozen prior `freeze()`. + + Example: + Consider a model that has an encoder and a decoder module. Assume we want the encoder to be frozen always. + + ```python + model.encoder.freeze() # Freezes all parameters in the encoder explicitly + ``` + + During inference, all parameters of the model should be frozen - we do this by calling the model's freeze method. + This step records that the encoder module parameters were already frozen, and so if partial unfreeze is called, + we should keep the encoder parameters frozen. + + ```python + model.freeze() # Freezes all parameters in the model; encoder remains frozen + ``` + + Now, during fine-tuning, we want to unfreeze the decoder but keep the encoder frozen. We can do this by calling + `unfreeze(partial=True)`. + + ```python + model.unfreeze(partial=True) # Unfreezes only the decoder; encoder remains frozen + ``` + + Args: + partial: If True, only unfreeze parameters that were previously frozen. If the parameter was already frozen + when calling `freeze()`, it will remain frozen after calling `unfreeze(partial=True)`. """ - for param in self.parameters(): - param.requires_grad = True + if partial and not hasattr(self, '_frozen_grad_map'): + raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`") + + for pname, param in self.named_parameters(): + if not partial: + # Unfreeze all parameters + param.requires_grad = True + else: + # Unfreeze only parameters that were previously frozen + + # Check if the parameter was frozen + if pname in self._frozen_grad_map: + param.requires_grad = self._frozen_grad_map[pname] + else: + # Log a warning if the parameter was not found in the frozen grad map + logging.warning( + f"Parameter {pname} not found in list of previously frozen parameters. " + f"Unfreezing this parameter." + ) + param.requires_grad = True + + # Clean up the frozen grad map + if hasattr(self, '_frozen_grad_map'): + delattr(self, '_frozen_grad_map') self.train() @contextmanager def as_frozen(self): """ - Context manager which temporarily freezes a module, yields control and finally unfreezes the module. + Context manager which temporarily freezes a module, yields control and finally unfreezes the module partially + to return to original state. + + Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`). + The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were + previously unfrozen prior `freeze()`. + + Example: + with model.as_frozen(): # by default, partial = True + # Do something with the model + pass + + # Model's parameters are now back to original state of requires_grad """ training_mode = self.training - grad_map = {} - for pname, param in self.named_parameters(): - grad_map[pname] = param.requires_grad - self.freeze() try: yield finally: - self.unfreeze() - - for pname, param in self.named_parameters(): - param.requires_grad = grad_map[pname] + self.unfreeze(partial=True) if training_mode: self.train() diff --git a/tests/core/test_neural_module.py b/tests/core/test_neural_module.py new file mode 100644 index 0000000000000..73617f55635cc --- /dev/null +++ b/tests/core/test_neural_module.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile + +import pytest +import torch + +from nemo.core.classes.module import NeuralModule + + +class TempModule(NeuralModule): + + def __init__(self): + super().__init__() + + self.layer1 = torch.nn.Linear(10, 10, bias=False) + self.layer2 = torch.nn.Linear(10, 10, bias=False) + + +class TestNeuralModule: + + @pytest.mark.unit + def test_num_weights(self): + module = TempModule() + assert module.num_weights == 200 + + @pytest.mark.unit + def test_freeze(self): + module = TempModule() + module.freeze() + for p in module.parameters(): + assert not p.requires_grad + + @pytest.mark.unit + def test_unfreeze(self): + module = TempModule() + module.freeze() + module.unfreeze() + for p in module.parameters(): + assert p.requires_grad + + @pytest.mark.unit + def test_as_frozen(self): + module = TempModule() + + for p in module.parameters(): + assert p.requires_grad + + with module.as_frozen(): + for p in module.parameters(): + assert not p.requires_grad + + for p in module.parameters(): + assert p.requires_grad + + @pytest.mark.unit + def test_partial_unfreeze(self): + module = TempModule() + + for param in module.layer1.parameters(): + param.requires_grad = False + + module.freeze() + + for param in module.layer1.parameters(): + assert not param.requires_grad + + assert module._frozen_grad_map is not None + assert len(module._frozen_grad_map) == 2 + assert module._frozen_grad_map['layer1.weight'] is False + + module.unfreeze(partial=True) + + # layer1 should still be frozen due to partial unfreeze + assert module.layer1.weight.requires_grad is False + assert not hasattr(module, '_frozen_grad_map') diff --git a/tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb b/tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb new file mode 100644 index 0000000000000..51877b53fb8ac --- /dev/null +++ b/tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb @@ -0,0 +1,1660 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "b0373c4a-e565-4e8f-a87f-aae932d3aeed", + "metadata": { + "id": "b0373c4a-e565-4e8f-a87f-aae932d3aeed" + }, + "outputs": [], + "source": [ + "\"\"\"\n", + "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", + "\n", + "Instructions for setting up Colab are as follows:\n", + "1. Open a new Python 3 notebook.\n", + "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GitHub\" tab -> copy/paste GitHub URL)\n", + "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", + "4. Run this cell to set up dependencies.\n", + "5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect\n", + "\n", + "\n", + "NOTE: User is responsible for checking the content of datasets and the applicable licenses and determining if suitable for the intended use.\n", + "\"\"\"\n", + "# If you're using Google Colab and not running locally, run this cell.\n", + "import os\n", + "\n", + "# Install dependencies\n", + "!pip install wget\n", + "!apt-get install sox libsndfile1 ffmpeg\n", + "!pip install text-unidecode\n", + "!pip install matplotlib>=3.3.2\n", + "\n", + "## Install NeMo\n", + "BRANCH = 'main'\n", + "!python -m pip install \"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@$BRANCH\"" + ] + }, + { + "cell_type": "markdown", + "id": "6c021f07-0576-491d-b73c-6c65c8501351", + "metadata": { + "id": "6c021f07-0576-491d-b73c-6c65c8501351" + }, + "source": [ + "# Multi Task Adaptation with Adapters\n", + "\n", + "\n", + "In earliier tutorials, we utilized a specific model for one task - for example, an ASR model (CTC, RNN-T etc) for the singular task of Speech Recognition. This is very useful if we want to specialize one task per model, but it can be expensive to deploy a fleet of models for each task, and learn routers to pass user tasks to correct models.\n", + "\n", + "We now support Multi Task models in NeMo, such that a single model can perform multiple tasks such as speech recognition, speech translation, voice activity detection, and more in the future. With one model supporting multiple tasks, we can simplify the task of deploying models and also hope to leverage individual tasks to improve each other (for example: you do need strong speech recognition first before you start doing translation).\n", + "\n", + "---\n", + "\n", + "Multi Task (Canary) models are highly capable large neural networks capable of things like speech recognition, X to English and English to X translation and able to select whether to transcribe speech with punctuation and capitalization. These huge models are trained on several thousand hours of speech and text data, making it challenging to adapt to new datasets.\n", + "\n", + "In the previous tutorial for [ASR Adapters](https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb), we used small adapter modules to tune a large ASR model on a small amount of data. In this tutorial, we will adapt a [Nvidia Canary](https://huggingface.co/nvidia/canary-1b) model onto a small amount of speech data for both Automatic Speech Recognition (ASR) and Automatic Speech Translation (AST).\n", + "\n", + "In this tutorial, we will also demonstrate a simple way of creating custom Data Modules from PyTorch Lightning to design custom datasets and data loaders for the highly flexible Multi Task Models in NeMo ASR. This offers users more flexibility in designing new tasks, and finetuning the models on small amounts of data." + ] + }, + { + "cell_type": "markdown", + "id": "cbe2f8eb-204f-4d90-bb0a-a49d994f1ed7", + "metadata": { + "id": "cbe2f8eb-204f-4d90-bb0a-a49d994f1ed7" + }, + "source": [ + "----\n", + "\n", + "First, lets instantiate the [Canary](https://huggingface.co/nvidia/canary-1b) model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46c3e5c1-b4f2-4f84-89d6-c77bbe7ebe4f", + "metadata": { + "id": "46c3e5c1-b4f2-4f84-89d6-c77bbe7ebe4f" + }, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "\n", + "import nemo.collections.asr as nemo_asr" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48b9677b-b1d9-4361-becf-ee84fe8d53ca", + "metadata": { + "id": "48b9677b-b1d9-4361-becf-ee84fe8d53ca" + }, + "outputs": [], + "source": [ + "model = nemo_asr.models.ASRModel.from_pretrained(\"nvidia/canary-1b\")" + ] + }, + { + "cell_type": "markdown", + "id": "6c0c87c9-5290-4634-9338-818f181c936a", + "metadata": { + "id": "6c0c87c9-5290-4634-9338-818f181c936a" + }, + "source": [ + "# Enable Adapter Suppport in Model\n", + "\n", + "New in NeMo 2.0, we now have a simple utility function to convert the model into one that supports adapters, called `replace_adapter_compatible_modules()`.\n", + "\n", + "This will go through the full model and check modules if they support adapters, and then enable that ability. Once used, you can freely use adapter methods." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfd72316-630b-43c3-9a02-65bb2dabe624", + "metadata": { + "scrolled": true, + "id": "bfd72316-630b-43c3-9a02-65bb2dabe624" + }, + "outputs": [], + "source": [ + "model.replace_adapter_compatible_modules()" + ] + }, + { + "cell_type": "markdown", + "id": "30505bd5-323f-4e90-a941-d0de3f6e55e3", + "metadata": { + "id": "30505bd5-323f-4e90-a941-d0de3f6e55e3" + }, + "source": [ + "## Check Which Targets Are Supported For This Model\n", + "\n", + "Now that the model has enabled adapter support, lets take a look at which of its modules support adapter modules to be attached to them.\n", + "\n", + "**Note**\n", + "Below, you might see an adapter module with no name `''` - this corresponds to the \"default\" model target if the target isn't specified. Users can chose to simply skip the module name when adding an adapter, and the model will by default add adapters to the encoder module." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13bcf42e-d33a-4364-8d0f-ab59a26ffa7c", + "metadata": { + "id": "13bcf42e-d33a-4364-8d0f-ab59a26ffa7c" + }, + "outputs": [], + "source": [ + "model.adapter_module_names" + ] + }, + { + "cell_type": "markdown", + "id": "67324f6a-ffff-47a7-9ee5-dc93819f6ffd", + "metadata": { + "id": "67324f6a-ffff-47a7-9ee5-dc93819f6ffd" + }, + "source": [ + "## Prepare the Adapter\n", + "\n", + "Now that we know which modules are supported, lets create a simple adapter module for the encoder and decoder modules." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65ec3b2b-3f84-43ed-8a90-085aee383ea6", + "metadata": { + "id": "65ec3b2b-3f84-43ed-8a90-085aee383ea6" + }, + "outputs": [], + "source": [ + "from nemo.collections.common.parts import LinearAdapterConfig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47aab832-bfec-4cca-b4ee-868ea1af9869", + "metadata": { + "id": "47aab832-bfec-4cca-b4ee-868ea1af9869" + }, + "outputs": [], + "source": [ + "input_dim = model.cfg.encoder.d_model\n", + "adapter_dim = 8" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd519281-ad45-4719-9ad6-561e6192717f", + "metadata": { + "id": "cd519281-ad45-4719-9ad6-561e6192717f" + }, + "outputs": [], + "source": [ + "enc_adapter_cfg = LinearAdapterConfig(in_features=input_dim, dim=adapter_dim)\n", + "dec_adapter_cfg = LinearAdapterConfig(in_features=input_dim, dim=adapter_dim)" + ] + }, + { + "cell_type": "markdown", + "id": "f147fc89-ab93-4454-ad6b-909288a452a2", + "metadata": { + "id": "f147fc89-ab93-4454-ad6b-909288a452a2" + }, + "source": [ + "## Add Adapter Modules\n", + "\n", + "Now that we have the adapter configs prepared, lets add them to the model !\n", + "\n", + "We provide the target module by using `target:adapter_name` when calling `add_adapter()` - this tells the model to setup an adapter called `adapter_name` to the module denoted by `target` with the config `cfg`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a23256ce-bc09-4fb0-8c3b-214519b8774b", + "metadata": { + "id": "a23256ce-bc09-4fb0-8c3b-214519b8774b" + }, + "outputs": [], + "source": [ + "model.add_adapter(name=\"encoder:enc\", cfg=enc_adapter_cfg)\n", + "model.add_adapter(name=\"transf_decoder:dec\", cfg=dec_adapter_cfg)\n", + "\n", + "print(\"Added adapters!\")" + ] + }, + { + "cell_type": "markdown", + "id": "2dbe9b7b-9a3d-4504-a652-1d90701cbbf8", + "metadata": { + "id": "2dbe9b7b-9a3d-4504-a652-1d90701cbbf8" + }, + "source": [ + "## Freeze Original Module Parameters and Unfreeze Adapter Weights Only\n", + "\n", + "When tuning adapters, we usually freeze the entire base model and only tune the adapters. This prevents the need for large amounts of data, preserves a lot of memory (since the full model doesnt need backward pass, only the adapters) and makes it easier to adapt huge models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f8162dd-0373-4e65-aa8a-f458a1633578", + "metadata": { + "scrolled": true, + "id": "2f8162dd-0373-4e65-aa8a-f458a1633578" + }, + "outputs": [], + "source": [ + "model.freeze()\n", + "model.unfreeze_enabled_adapters()" + ] + }, + { + "cell_type": "markdown", + "id": "0b3795a4-fcfe-49ee-a76f-1cb77d99ace1", + "metadata": { + "id": "0b3795a4-fcfe-49ee-a76f-1cb77d99ace1" + }, + "source": [ + "----\n", + "\n", + "Lets make sure that the number of trainable parameters is a lot smaller (< 1 M) than the total number of params (1 B)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58453f40-d72d-4f9b-a427-3fb63787f3d6", + "metadata": { + "id": "58453f40-d72d-4f9b-a427-3fb63787f3d6" + }, + "outputs": [], + "source": [ + "model.summarize()" + ] + }, + { + "cell_type": "markdown", + "id": "aa713f4a-ec16-4e2a-aeb3-ac7c4090f20f", + "metadata": { + "id": "aa713f4a-ec16-4e2a-aeb3-ac7c4090f20f" + }, + "source": [ + "## Check Enabled Adapters\n", + "\n", + "Here, we check that the adapters that we named above (`enc` and `dec`) are both setup and enabled." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d69f09d9-411e-420e-8f17-c86391e88fc3", + "metadata": { + "id": "d69f09d9-411e-420e-8f17-c86391e88fc3" + }, + "outputs": [], + "source": [ + "model.get_enabled_adapters()" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Customizing Multi Task Models\n", + "\n", + "In the following section, we will take a deeper look into what are the components that compose a Multi Task Model and how users can override each of these parts to create their own customizable multi task models.\n", + "\n", + "---\n", + "\n", + "In this tutorial, we will only see the internal components such as the prompt format and dataset construction, but not change them.\n", + "\n", + "In a following tutorial, we will show how to add an additional task to a pre-trained Multi Task Model using a pre-trained model as a starting point." + ], + "metadata": { + "id": "f_XpTJx9hQXy" + }, + "id": "f_XpTJx9hQXy" + }, + { + "cell_type": "markdown", + "id": "6f0beb8c-7b12-4169-a3f7-1639bdaf6160", + "metadata": { + "id": "6f0beb8c-7b12-4169-a3f7-1639bdaf6160" + }, + "source": [ + "# Prompt Handling for Multi Task Models\n", + "Nvidia Canary is our first model that is a Multi Task Model.\n", + "\n", + "Multi Task models utilize a prompt format, similar to those used in Large Language Models, in order to denote to the model which task is to be performed, which langauge is being spoken and what language should the output transcript be in, whether to provide punctuation and capitalization or not, and so much more in the future !\n", + "\n", + "Lets take a look at the model's `prompt` for the Canary model that we have created -" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56a78cd0-afaf-4272-898f-d9e13ba871d3", + "metadata": { + "id": "56a78cd0-afaf-4272-898f-d9e13ba871d3" + }, + "outputs": [], + "source": [ + "model.prompt_format" + ] + }, + { + "cell_type": "markdown", + "id": "9cbaf28a-1f10-4da3-a3ed-53b2239baa49", + "metadata": { + "id": "9cbaf28a-1f10-4da3-a3ed-53b2239baa49" + }, + "source": [ + "----\n", + "\n", + "This gives us the prompt format functions name, which we will see below points to a prompt format function that reads in manifest items and maps it to the template." + ] + }, + { + "cell_type": "markdown", + "id": "087d1f60-3679-4593-840f-8d0fbd8a0e3e", + "metadata": { + "id": "087d1f60-3679-4593-840f-8d0fbd8a0e3e" + }, + "source": [ + "## Reuse / Register a Prompt Format Function\n", + "\n", + "When we print `model.prompt_format` it writes `canary` which is one of the registered prompt templates available in NeMo ASR.\n", + "For simplicity's sake, we will continue to use the same prompt format for this tutorial. However, we enable users to define their own prompt formats and register them as needed.\n", + "\n", + "Let's see what the `canary` prompt format looks like:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c202abaf-63ca-4475-a2bb-3b487be8e375", + "metadata": { + "id": "c202abaf-63ca-4475-a2bb-3b487be8e375" + }, + "outputs": [], + "source": [ + "from nemo.collections.asr.data.audio_to_text_lhotse_prompted import get_prompt_format_fn, registered_prompt_format_fn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07c56dc3-fe42-49fc-936c-770ec17a29ac", + "metadata": { + "scrolled": true, + "id": "07c56dc3-fe42-49fc-936c-770ec17a29ac" + }, + "outputs": [], + "source": [ + "canary_prompt_format_fn = get_prompt_format_fn(\"canary\")\n", + "canary_prompt_format_fn?" + ] + }, + { + "cell_type": "markdown", + "id": "1170b57c-f4c7-432f-91bb-1dbf73063d60", + "metadata": { + "id": "1170b57c-f4c7-432f-91bb-1dbf73063d60" + }, + "source": [ + "### Registering a New Prompt Format Function" + ] + }, + { + "cell_type": "markdown", + "id": "d11a8a05-6ba7-41f3-97ab-43453a59c860", + "metadata": { + "id": "d11a8a05-6ba7-41f3-97ab-43453a59c860" + }, + "source": [ + "Just to show that this is user-configurable, we show how to register a dummy prompt format below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f77378ff-d5de-4b86-bfaf-e62b51c7f9ce", + "metadata": { + "id": "f77378ff-d5de-4b86-bfaf-e62b51c7f9ce" + }, + "outputs": [], + "source": [ + "@registered_prompt_format_fn\n", + "def canary2(cuts, tokenizer, inference: bool):\n", + " \"\"\" Users can implement this as needed \"\"\"\n", + " raise NotImplementedError()\n", + "\n", + "print(\"Registered prompt\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb02f068-8fee-46e1-8096-910062668173", + "metadata": { + "id": "cb02f068-8fee-46e1-8096-910062668173" + }, + "outputs": [], + "source": [ + "temp = get_prompt_format_fn('canary2')\n", + "temp.__name__" + ] + }, + { + "cell_type": "markdown", + "id": "f14aa85b-71cb-4813-837b-b28a384685dc", + "metadata": { + "id": "f14aa85b-71cb-4813-837b-b28a384685dc" + }, + "source": [ + "## Create / Reuse a Prompt Format\n", + "\n", + "Canary Multi Task Model comes with a pre-defined prompt template, so we need to provide it data in a format that can be handled by that prompt format class.\n", + "\n", + "A `PromptFormatter` is a special class that defines the dialog template of the order of turns that occur in a model's prompt. For example, in Language Models, we normally may begin with either a `System` or `User` turn, followed by an `Assistant` turn which produces an output from the model. Similarly in Multi Task models, we enable support for such a usage pattern.\n", + "\n", + "Do note: Current generation of Canary models are not trained to operate on multi turn conversations, however future variants of Multi Task models may support such usage." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35530cad-84d7-422b-82c5-1bda5c1a4497", + "metadata": { + "scrolled": true, + "id": "35530cad-84d7-422b-82c5-1bda5c1a4497" + }, + "outputs": [], + "source": [ + "# Let's review the actual prompt formatter clas docs\n", + "model.prompt?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cd0c0d1-da8a-4de6-9efc-86a7dd3ed660", + "metadata": { + "id": "0cd0c0d1-da8a-4de6-9efc-86a7dd3ed660" + }, + "outputs": [], + "source": [ + "# Let's see the actual template of this prompt formatter\n", + "model.prompt.TEMPLATE" + ] + }, + { + "cell_type": "markdown", + "id": "72956a2f-f051-42d2-9e08-47e954d88e5c", + "metadata": { + "id": "72956a2f-f051-42d2-9e08-47e954d88e5c" + }, + "source": [ + "---\n", + "\n", + "We see that the template contains two turns - `user` and `assistant`.\n", + "\n", + "User template looks as follows: `<|startoftranscript|>|source_lang||task||target_lang||pnc|`\n", + "During execution, we remove the `|` in order to fill in the actual value of the slots provided by the the data loader.\n", + "\n", + "User holds the following allowed slots -\n", + "* `source_lang`\n", + "* `target_lang`\n", + "* `task`\n", + "* `pnc`\n", + "\n", + "Similarly, for Assistant template : `|text|<|endoftext|>`\n", + "\n", + "Assistant holds the following allowed slots -\n", + "* `text`" + ] + }, + { + "cell_type": "markdown", + "id": "540c04af-34d1-4b46-b935-40b16f54ca03", + "metadata": { + "id": "540c04af-34d1-4b46-b935-40b16f54ca03" + }, + "source": [ + "### Creating and Using a Custom Prompt Formatter\n", + "\n", + "While we provide a pre-trained model with a pre-defined prompt format, we also enable users to create their own PromptFormatter subclass and change it as needed.\n", + "\n", + "Below, we show a simple modification to the model's PromptFormatter and show how to change it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0adb576c-df58-4b66-b8fa-8e653da6fead", + "metadata": { + "id": "0adb576c-df58-4b66-b8fa-8e653da6fead" + }, + "outputs": [], + "source": [ + "# Create a new prompt formatter using the original CanaryPromptFormatter class as baseclass\n", + "class CanaryPromptFormatterV2(model.prompt.__class__):\n", + "\n", + " # make sure to provide a new name\n", + " NAME: str = \"canary2\"\n", + "\n", + " # Make any changes as necessary.\n", + " # For this demonstration, we will not change anything other than the name" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7d85683-ddd0-40c5-956d-e14d09243424", + "metadata": { + "id": "f7d85683-ddd0-40c5-956d-e14d09243424" + }, + "outputs": [], + "source": [ + "# Next, lets update the model's prompt formatter\n", + "model.change_prompt(\"canary2\")" + ] + }, + { + "cell_type": "markdown", + "id": "6581f934-a55b-41df-864a-351d1fb0029e", + "metadata": { + "id": "6581f934-a55b-41df-864a-351d1fb0029e" + }, + "source": [ + "---\n", + "\n", + "We have now successfully changed the prompt format to `canary2`.\n", + "\n", + "**Note**: It is important to know that when changing the prompt format, the name of the new prompt format class (`canary2` in this case) **has to match** the name of the prompt function registered with `@registered_prompt_format_fn`!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1d84948-8f73-4c31-923f-eaf01d877835", + "metadata": { + "scrolled": true, + "id": "c1d84948-8f73-4c31-923f-eaf01d877835" + }, + "outputs": [], + "source": [ + "# Check if everything is ok -\n", + "model.prompt.__class__.__name__" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f617cda0-d16b-400a-b495-dac213d318e1", + "metadata": { + "id": "f617cda0-d16b-400a-b495-dac213d318e1" + }, + "outputs": [], + "source": [ + "model.prompt_format" + ] + }, + { + "cell_type": "markdown", + "id": "cb964964-e978-43e9-befa-9bb0904db82f", + "metadata": { + "id": "cb964964-e978-43e9-befa-9bb0904db82f" + }, + "source": [ + "---\n", + "For the rest of the tutorial, we will revert back to the original prompt formatter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "526093a8-86ba-48f0-a60b-55642720fc4e", + "metadata": { + "id": "526093a8-86ba-48f0-a60b-55642720fc4e" + }, + "outputs": [], + "source": [ + "model.change_prompt('canary')" + ] + }, + { + "cell_type": "markdown", + "id": "9c4d2986-89b4-4589-ab0e-69683084cfd4", + "metadata": { + "id": "9c4d2986-89b4-4589-ab0e-69683084cfd4" + }, + "source": [ + "## Creating / Using a Multi Task Dataset\n", + "\n", + "Now that we have learned how to modify the model's prompt formatter and the underlying format function that maps manifest items into slots to inject into the prompt template, next let's take a look at how to use and create custom datasets for training multi task models.\n", + "\n", + "---\n", + "\n", + "Unlike previous tutorials that showcase how to use pre-defined datasets and point them to your manifest files, we will take a slightly more hands-on approach for multi task modes. This is due to shear flexibility of multi task models - they can do almost any task that you can formulate into a \"speech in - text out\" problem.\n", + "\n", + "So it is not easy to have a pre-defined dataset class that can handle all new ideas and tasks that researchers can come up with.\n", + "\n", + "Instead, we showcase how to build a custom dataset for yourself and use it with the Multi Task model instead." + ] + }, + { + "cell_type": "markdown", + "id": "b35ca0c2-8ceb-423f-b9ef-7dd6ec5a6952", + "metadata": { + "id": "b35ca0c2-8ceb-423f-b9ef-7dd6ec5a6952" + }, + "source": [ + "---\n", + "\n", + "However, we also provide a base class that can be used as is by users if they dont want the hassle of writing their own datasets.\n", + "\n", + "This is handled by the `PromptedAudioToTextLhotseDataset` - it maps user defined manifest items to the items defined in the prompt template of the model, so as long as the manifest corresponds to the slots supported by the model, it will be managed by the Dataset automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d35d513-8538-4bcb-b892-898f16ad3f0f", + "metadata": { + "scrolled": true, + "id": "3d35d513-8538-4bcb-b892-898f16ad3f0f" + }, + "outputs": [], + "source": [ + "from nemo.collections.asr.data.audio_to_text_lhotse_prompted import PromptedAudioToTextLhotseDataset\n", + "\n", + "# Uncomment below line to see the class definition of PromptedAudioToTextLhotseDataset\n", + "# PromptedAudioToTextLhotseDataset??" + ] + }, + { + "cell_type": "markdown", + "id": "51e3a150-40b9-4599-8c6e-0f01698989b4", + "metadata": { + "id": "51e3a150-40b9-4599-8c6e-0f01698989b4" + }, + "source": [ + "### Creating a New Prompted Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56208452-ea18-44c8-8c71-0daef431dc31", + "metadata": { + "id": "56208452-ea18-44c8-8c71-0daef431dc31" + }, + "outputs": [], + "source": [ + "import torch.utils.data\n", + "from lhotse import CutSet\n", + "from lhotse.cut import MixedCut, MonoCut\n", + "from lhotse.dataset import AudioSamples\n", + "from lhotse.dataset.collation import collate_vectors\n", + "\n", + "from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper\n", + "from nemo.collections.asr.data.audio_to_text_lhotse_prompted import PromptedAudioToTextLhotseDataset, get_prompt_format_fn\n", + "\n", + "class MyCanaryPromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):\n", + " \"\"\"\n", + " This dataset is based on :class:`~nemo.collections.asr.data.audio_to_text_lhotse.LhotseSpeechToTextBpeDataset`.\n", + " It is a Lhotse-style dataset that converts a mini-batch of Cuts into tensors.\n", + " The main difference from ``LhotseSpeechToTextBpeDataset`` is that we introduce\n", + " a special prompt format for multitask encoder-decoder models.\n", + "\n", + " To perform the prompt formatting, we accept a ``prompt_format_fn``.\n", + " It's expected to accept:\n", + " * a ``CutSet`` which it will internally iterate over for utterances, and\n", + " * a ``TokenizerWrapper`` object that will be internally used to tokenize the utterances\n", + "\n", + " Tokenized utterances will be extended with special prompt tokens according to ``prompt_format_fn`` logic.\n", + " We support cuts with multiple supervision segments -- their tokenized texts will be concatenated before we add the prompt tokens.\n", + " This is useful, for example, in code-switched scenarios where each segment is spoken in a different language.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " tokenizer: 'TokenizerSpec',\n", + " inference: bool = False,\n", + " ):\n", + " super().__init__()\n", + " self.tokenizer = TokenizerWrapper(tokenizer)\n", + " self.load_audio = AudioSamples(fault_tolerant=True)\n", + " self.padding_value = self.tokenizer._tokenizer.pad_id\n", + " self.prompt_format_fn = get_prompt_format_fn('canary') # Use the default canary prompt function\n", + " self.inference = inference\n", + "\n", + " def __getitem__(self, cuts: CutSet) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n", + " audio, audio_lens, cuts = self.load_audio(cuts)\n", + "\n", + " prompts_with_answers, prompts = self.prompt_format_fn(cuts, self.tokenizer, inference=self.inference)\n", + "\n", + " prompts_with_answers = [torch.as_tensor(t) for t in prompts_with_answers]\n", + " prompts_with_answers_lens = torch.tensor([t.size(0) for t in prompts_with_answers], dtype=torch.long)\n", + " prompts_with_answers = collate_vectors(prompts_with_answers, padding_value=self.padding_value)\n", + "\n", + " if self.inference:\n", + " prompts = [torch.as_tensor(t) for t in prompts]\n", + " prompts_lens = torch.tensor([t.size(0) for t in prompts], dtype=torch.long)\n", + " prompts = collate_vectors(prompts, padding_value=self.padding_value)\n", + " else:\n", + " prompts = None\n", + " prompts_lens = None\n", + "\n", + " return audio, audio_lens, prompts_with_answers, prompts_with_answers_lens, prompts, prompts_lens" + ] + }, + { + "cell_type": "markdown", + "id": "5cb71ba1-ce2e-49c7-8126-be7e7851c812", + "metadata": { + "id": "5cb71ba1-ce2e-49c7-8126-be7e7851c812" + }, + "source": [ + "---\n", + "\n", + "The above class is mostly a demonstration, but it showcases how users might flexibly change the prompt formatter, prompt format function and even the data set that handles these two in a flexible way.\n", + "\n", + "The order of operations is usually this -\n", + "\n", + "1) Create a new Prompt Formatter class - this denotes the slots that each turn can have (including new task inputs or other values). This class is auto registered.\n", + "2) Create a new Prompt Format function - Using `@registered_prompt_format_fn` decorator, write a custom function that accepts args and processes the provided input data from a manifest.\n", + "3) Create a new Dataset class (usually based on the `PromptedAudioToTextLhotseDataset` dataset) that uses the Prompt Format function to convert manifest items into nicely formatted samples that can be passed to the Prompt Formatter." + ] + }, + { + "cell_type": "markdown", + "id": "a7bf8078-663e-43cb-b045-0c8b6ef08e30", + "metadata": { + "id": "a7bf8078-663e-43cb-b045-0c8b6ef08e30" + }, + "source": [ + "# Preparing a Canary Dataset\n", + "\n", + "Now that we have all the pieces together on the model side, let's take a look on the data side." + ] + }, + { + "cell_type": "markdown", + "id": "83c9eabc-0473-463e-be1f-ab6d5f519a79", + "metadata": { + "id": "83c9eabc-0473-463e-be1f-ab6d5f519a79" + }, + "source": [ + "## Required Roles Defined by Prompt Format\n", + "\n", + "These are the available 'roles' available in the prompt format - they denote at each turn, one role can be enabled and its input or output can be calculated." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11ff9641-53fd-4481-b414-0edc12bf4dc3", + "metadata": { + "id": "11ff9641-53fd-4481-b414-0edc12bf4dc3" + }, + "outputs": [], + "source": [ + "model.prompt.get_roles()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "203a67e2-74fd-440c-9658-451f41239f36", + "metadata": { + "id": "203a67e2-74fd-440c-9658-451f41239f36" + }, + "outputs": [], + "source": [ + "for role in model.prompt.get_roles():\n", + " print(role, model.prompt.get_slots(role))\n", + " print()" + ] + }, + { + "cell_type": "markdown", + "id": "8e887f9d-94e7-4843-9da8-f914e24651f3", + "metadata": { + "id": "8e887f9d-94e7-4843-9da8-f914e24651f3" + }, + "source": [ + "## Create a Data Module\n", + "\n", + "Data Modules are one way of organizing datasets in PyTorch Lightning. It provides a unified place where data loading and processing can be potentially handled.\n", + "\n", + "**Note**: This isnt strictly necessary - you can achieve the same using just Pytorch dataloaders directly and passing it to Trainer.fit() but we showcase a data module codebase that can be extended by the user." + ] + }, + { + "cell_type": "markdown", + "id": "51d58931-4166-4ab9-a755-4c5268001192", + "metadata": { + "id": "51d58931-4166-4ab9-a755-4c5268001192" + }, + "source": [ + "----\n", + "\n", + "In our CanaryAN4DataModule - we will perform two tasks. One is En ASR - transcribing the AN4 English dataset. Another is En to De AST - directly translating the english audio to German text.\n", + "\n", + "For simplicity's sake, we will use a small off-the-shelf model to perform the translation of English Transcripts to German." + ] + }, + { + "cell_type": "markdown", + "id": "91ed74ca-5d5e-412d-a813-0659014aa9a3", + "metadata": { + "id": "91ed74ca-5d5e-412d-a813-0659014aa9a3" + }, + "source": [ + "---\n", + "\n", + "In NeMo 2.0, we utilize [Lhotse](https://github.com/lhotse-speech/lhotse) as our data backbone for speech tasks, which simplifies using custom speech datasets.\n", + "\n", + "Most of the magic is handled by the following code\n", + "\n", + "```python\n", + "from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config\n", + "\n", + "get_lhotse_dataloader_from_config(\n", + " OmegaConf.create(config), # Pass in a config that points to the manifest files and other arguments\n", + " global_rank=self.trainer.global_rank,\n", + " world_size=self.trainer.world_size,\n", + " # Pass in the dataset class for Lhotse to handle. This class now receives CutSet as input.\n", + " dataset=MyCanaryPromptedAudioToTextLhotseDataset(tokenizer=self.tokenizer, inference=inference),\n", + ")\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a15ab9b-7603-4ac5-890c-92a541a0527c", + "metadata": { + "id": "4a15ab9b-7603-4ac5-890c-92a541a0527c" + }, + "outputs": [], + "source": [ + "import os\n", + "import glob\n", + "import json\n", + "import copy\n", + "import subprocess\n", + "import tarfile\n", + "import wget\n", + "import librosa\n", + "import tqdm\n", + "from omegaconf import OmegaConf\n", + "\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "import pytorch_lightning as L\n", + "\n", + "from transformers import T5Tokenizer, T5ForConditionalGeneration\n", + "\n", + "from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest\n", + "from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config\n", + "\n", + "\n", + "# Function to build a manifest\n", + "def build_manifest(transcripts_path, manifest_path, wav_path, data_dir):\n", + " with open(transcripts_path, 'r') as fin:\n", + " with open(manifest_path, 'w') as fout:\n", + " for line in fin:\n", + " # Lines look like this:\n", + " # transcript (fileID)\n", + " transcript = line[: line.find('(')-1].lower()\n", + " transcript = transcript.replace('', '').replace('', '')\n", + " transcript = transcript.strip()\n", + "\n", + " file_id = line[line.find('(')+1 : -2] # e.g. \"cen4-fash-b\"\n", + " audio_path = os.path.join(\n", + " data_dir, wav_path,\n", + " file_id[file_id.find('-')+1 : file_id.rfind('-')],\n", + " file_id + '.wav')\n", + "\n", + " duration = librosa.core.get_duration(path=audio_path)\n", + "\n", + " # Write the metadata to the manifest\n", + " metadata = {\n", + " \"audio_filepath\": audio_path,\n", + " \"duration\": duration,\n", + " \"text\": transcript,\n", + " \"pnc\": \"no\",\n", + " \"source_lang\": \"en\",\n", + " \"target_lang\": \"en\",\n", + " \"task\": \"asr\",\n", + " }\n", + " json.dump(metadata, fout)\n", + " fout.write('\\n')\n", + "\n", + " return manifest_path\n", + "\n", + "\n", + "class CanaryAN4DataModule(L.LightningDataModule):\n", + "\n", + " def __init__(self, tokenizer, data_dir: str = \"./an4/\", batch_size=8):\n", + " super().__init__()\n", + " self.tokenizer = tokenizer\n", + " self.data_dir = data_dir\n", + " self.batch_size = batch_size\n", + "\n", + " # ASR manifests\n", + " self.train_manifest = data_dir + '/an4/train_manifest.json'\n", + " self.test_manifest = data_dir + '/an4/test_manifest.json'\n", + "\n", + " # AST manifests\n", + " self.ast_train_manifest = data_dir + '/an4/ast_train_manifest.json'\n", + " self.ast_test_manifest = data_dir + '/an4/ast_test_manifest.json'\n", + "\n", + " # Combined manifests\n", + " self.combined_train_manifest = data_dir + '/an4/combined_train_manifest.json'\n", + " self.combined_test_manifest = data_dir + '/an4/combined_test_manifest.json'\n", + "\n", + " def setup(self, stage):\n", + " # make assignments here (val/train/test split)\n", + " # called on every process in DDP\n", + " # Assign train/val datasets for use in dataloaders\n", + " pass\n", + "\n", + " def train_dataloader(self):\n", + " config = {'manifest_filepath': self.combined_train_manifest, 'batch_size': self.batch_size,\n", + " 'num_workers': 4, 'shuffle': True, 'min_duration': 0.3, 'max_duration': 10.0}\n", + " return self._setup_dataloader(config)\n", + "\n", + " def val_dataloader(self):\n", + " config = {'manifest_filepath': self.combined_test_manifest, 'batch_size': self.batch_size,\n", + " 'num_workers': 4, 'shuffle': False, 'min_duration': 0.3, 'max_duration': 10.0}\n", + " return self._setup_dataloader(config, inference=True)\n", + "\n", + " def test_dataloader(self):\n", + " config = {'manifest_filepath': self.combined_test_manifest, 'batch_size': self.batch_size,\n", + " 'num_workers': 4, 'shuffle': False, 'min_duration': 0.3, 'max_duration': 10.0}\n", + " return self._setup_dataloader(config, inference=True)\n", + "\n", + " def teardown(self, stage):\n", + " # clean up after fit or test\n", + " # called on every process in DDP\n", + " pass\n", + "\n", + " def _setup_dataloader(self, config, inference: bool = False):\n", + " \"\"\"\n", + " The main function that creates the data loader using Lhotse's integration with NeMo.\n", + " \"\"\"\n", + " return get_lhotse_dataloader_from_config(\n", + " OmegaConf.create(config),\n", + " global_rank=self.trainer.global_rank,\n", + " world_size=self.trainer.world_size,\n", + " # Note the passing of our custom dataset\n", + " dataset=MyCanaryPromptedAudioToTextLhotseDataset(tokenizer=self.tokenizer, inference=inference),\n", + " )\n", + "\n", + " def prepare_data(self):\n", + " # download, split, etc...\n", + " # only called on 1 GPU/TPU in distributed\n", + " if not os.path.exists(self.data_dir):\n", + " os.makedirs(self.data_dir)\n", + "\n", + " data_dir = self.data_dir\n", + " if not os.path.exists(data_dir + '/an4_sphere.tar.gz'):\n", + " an4_url = 'https://dldata-public.s3.us-east-2.amazonaws.com/an4_sphere.tar.gz'\n", + " an4_path = wget.download(an4_url, data_dir)\n", + " print(f\"Dataset downloaded at: {an4_path}\")\n", + " else:\n", + " print(\"Tarfile already exists.\")\n", + " an4_path = data_dir + '/an4_sphere.tar.gz'\n", + "\n", + " if not os.path.exists(data_dir + '/an4/'):\n", + " # Untar and convert .sph to .wav (using sox)\n", + " tar = tarfile.open(an4_path)\n", + " tar.extractall(path=data_dir)\n", + "\n", + " print(\"Converting .sph to .wav...\")\n", + " sph_list = glob.glob(data_dir + '/an4/**/*.sph', recursive=True)\n", + " for sph_path in sph_list:\n", + " wav_path = sph_path[:-4] + '.wav'\n", + " cmd = [\"sox\", sph_path, wav_path]\n", + " subprocess.run(cmd)\n", + " print(\"Finished conversion.\\n******\")\n", + "\n", + " # Building Manifests\n", + " print(\"******\")\n", + " train_transcripts = data_dir + '/an4/etc/an4_train.transcription'\n", + " train_manifest = self.train_manifest\n", + " if not os.path.isfile(train_manifest):\n", + " build_manifest(train_transcripts, train_manifest, 'an4/wav/an4_clstk', data_dir)\n", + " print(\"Training manifest created.\")\n", + "\n", + " test_transcripts = data_dir + '/an4/etc/an4_test.transcription'\n", + " test_manifest = self.test_manifest\n", + " if not os.path.isfile(test_manifest):\n", + " build_manifest(test_transcripts, test_manifest, 'an4/wav/an4test_clstk', data_dir)\n", + " print(\"Test manifest created.\")\n", + " print(\"*** Wrote manifests for Eng ***\")\n", + "\n", + " train_manifest_data = read_manifest(self.train_manifest)\n", + " test_manifest_data = read_manifest(self.test_manifest)\n", + "\n", + " if not os.path.isfile(self.ast_train_manifest) or not os.path.isfile(self.ast_test_manifest) or not os.path.isfile(self.combined_train_manifest) or not os.path.isfile(self.combined_test_manifest):\n", + " tokenizer = T5Tokenizer.from_pretrained(\"google-t5/t5-small\")\n", + " t5_model = T5ForConditionalGeneration.from_pretrained(\"google-t5/t5-small\")\n", + "\n", + " if torch.cuda.is_available():\n", + " t5_model = t5_model.cuda()\n", + "\n", + " def pipe(text):\n", + " if isinstance(text, str):\n", + " text = [text]\n", + "\n", + " prefix = \"translate English to German\"\n", + " prompts = [prefix + \": \" + x for x in text]\n", + " input_ids = tokenizer(prompts, return_tensors=\"pt\", padding=True, truncation=True).input_ids\n", + " input_ids = input_ids.to(t5_model.device)\n", + " outputs = t5_model.generate(input_ids, max_new_tokens=64)\n", + " return [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]\n", + "\n", + " ast_train_manifest_data = copy.deepcopy(train_manifest_data)\n", + " ast_test_manifest_data = copy.deepcopy(test_manifest_data)\n", + "\n", + " print(\"Translating train set\")\n", + " train_texts = [x['text'] for x in train_manifest_data]\n", + " BATCH_SIZE = 32\n", + "\n", + " for i in tqdm.tqdm(range(0, len(train_texts), BATCH_SIZE), total=len(train_texts) // BATCH_SIZE):\n", + " batch_texts = train_texts[i:i+BATCH_SIZE]\n", + " batch_texts = pipe(batch_texts)\n", + " for j, text in enumerate(batch_texts):\n", + " ast_train_manifest_data[i+j]['text'] = text\n", + " ast_train_manifest_data[i+j]['task'] = 'ast'\n", + " ast_train_manifest_data[i+j]['target_lang'] = 'de'\n", + "\n", + " print(\"Translating test set\")\n", + " for data in tqdm.tqdm(ast_test_manifest_data, total=len(ast_test_manifest_data)):\n", + " data['text'] = pipe(data['text'])[0]\n", + " data['task'] = 'ast'\n", + " data['target_lang'] = 'de'\n", + "\n", + " write_manifest(self.ast_train_manifest, ast_train_manifest_data)\n", + " write_manifest(self.ast_test_manifest, ast_test_manifest_data)\n", + "\n", + " print(\"*** Wrote ast manifests ***\")\n", + "\n", + " combined_train, combined_test = [], []\n", + " combined_train.extend(train_manifest_data)\n", + " combined_train.extend(ast_train_manifest_data)\n", + "\n", + " combined_test.extend(test_manifest_data)\n", + " combined_test.extend(ast_test_manifest_data)\n", + "\n", + " write_manifest(self.combined_train_manifest, combined_train)\n", + " write_manifest(self.combined_test_manifest, combined_test)\n", + " print(\"*** Wrote combined manifests ***\")\n", + "\n", + " else:\n", + " print(\"*** Wrote ast and combined manifests ***\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "e06e697d-7dc2-489f-a52f-195946bfbf6e", + "metadata": { + "id": "e06e697d-7dc2-489f-a52f-195946bfbf6e" + }, + "source": [ + "---\n", + "\n", + "Each item in the prepared manifest has the following items by default.\n", + "\n", + "As you will recognize, these are the same keys provided by the `CanaryPromptFormatter` classes `slots` argument, so each of these values in the is mapped back to those slots.\n", + "\n", + "```python\n", + "metadata = {\n", + " \"audio_filepath\": audio_path,\n", + " \"duration\": duration,\n", + " \"text\": transcript,\n", + " \"pnc\": \"no\",\n", + " \"source_lang\": \"en\",\n", + " \"target_lang\": \"en\",\n", + " \"task\": \"asr\",\n", + "}\n", + "```\n", + "\n", + "The most important function in the Data Module above is `prepare_data()`:\n", + "\n", + "1) It first downloads and converts the AN4 audio files to wav files.\n", + "2) Then it writes a new manifest file with the above keys for ASR task\n", + "3) It then translates the En transcripts with a `t5-small` model to generate German transcripts\n", + "4) Finally it writes another manifest for the AST task with these translated texts.\n", + "5) Finally it builds a combined manifest item for both ASR (en) and AST (en to de) multi-task training\n", + "\n", + "**Note**: We are using prepare_data() only for demonstration. Normally, users should process before experimentation, and so they would only need to implement methods above prepare_data() in their Data Module." + ] + }, + { + "cell_type": "markdown", + "id": "739f0141-1e0e-4db7-b1f6-9d13589bf50c", + "metadata": { + "id": "739f0141-1e0e-4db7-b1f6-9d13589bf50c" + }, + "source": [ + "## Download and Prepare Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "323287f1-9a44-49ab-8438-dcbf34bf2ebe", + "metadata": { + "id": "323287f1-9a44-49ab-8438-dcbf34bf2ebe" + }, + "outputs": [], + "source": [ + "data_module = CanaryAN4DataModule(tokenizer=model.tokenizer, batch_size=16)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "123faf0d-05b2-4f12-850f-350a175ba7c1", + "metadata": { + "scrolled": true, + "id": "123faf0d-05b2-4f12-850f-350a175ba7c1" + }, + "outputs": [], + "source": [ + "data_module.prepare_data()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fbec085b-9600-49bd-8739-73e5e8e3773f", + "metadata": { + "id": "fbec085b-9600-49bd-8739-73e5e8e3773f" + }, + "outputs": [], + "source": [ + "!head -n 5 {data_module.train_manifest}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66bad9ac-3bad-4d84-8b30-830856c06804", + "metadata": { + "id": "66bad9ac-3bad-4d84-8b30-830856c06804" + }, + "outputs": [], + "source": [ + "!head -n 5 {data_module.ast_train_manifest}" + ] + }, + { + "cell_type": "markdown", + "id": "cde19c46-e78c-4d7c-adbf-f1559c9203e1", + "metadata": { + "id": "cde19c46-e78c-4d7c-adbf-f1559c9203e1" + }, + "source": [ + "# Evaluate Model before Training\n", + "\n", + "Canary Multi Task model is already very capable, achieving strong scores on multiple benchmarks. So we first evaluate the baseline numbers on the two tasks\n", + "\n", + "1) ASR: WER calculation on transcripts\n", + "\n", + "2) AST: SacreBLEU calculation on translations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb4588b4-7d52-4c4e-bb81-2bcb5a227afd", + "metadata": { + "id": "eb4588b4-7d52-4c4e-bb81-2bcb5a227afd" + }, + "outputs": [], + "source": [ + "from nemo.collections.asr.metrics.wer import word_error_rate\n", + "from torchmetrics.text import SacreBLEUScore" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1c71044-3cb3-453c-bfcd-ee551cecdddf", + "metadata": { + "id": "a1c71044-3cb3-453c-bfcd-ee551cecdddf" + }, + "outputs": [], + "source": [ + "asr_test = read_manifest(data_module.test_manifest)\n", + "ast_test = read_manifest(data_module.ast_test_manifest)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1d8acd2-aa08-4ba0-b0c6-c5d662243b00", + "metadata": { + "id": "f1d8acd2-aa08-4ba0-b0c6-c5d662243b00" + }, + "outputs": [], + "source": [ + "asr_filepaths = [x['audio_filepath'] for x in asr_test]\n", + "asr_gt = [x['text'] for x in asr_test]\n", + "\n", + "ast_filepaths = [x['audio_filepath'] for x in ast_test]\n", + "ast_gt = [x['text'] for x in ast_test]\n", + "\n", + "print(\"Num files:\", len(asr_filepaths))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85ace700-97bf-4697-8e1a-5793eb21e678", + "metadata": { + "id": "85ace700-97bf-4697-8e1a-5793eb21e678" + }, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " model = model.cuda() # move model to gpu\n", + " model = model.to(torch.bfloat16) # cast full model to bfloat16" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00f2607a-2f67-47fe-9903-0adae4d9adf5", + "metadata": { + "id": "00f2607a-2f67-47fe-9903-0adae4d9adf5" + }, + "outputs": [], + "source": [ + "asr_preds = model.transcribe(asr_filepaths, pnc='no', task='asr', source_lang='en', target_lang='en', batch_size=32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eea5ab20-60d4-4e19-87fb-71f6835941e8", + "metadata": { + "id": "eea5ab20-60d4-4e19-87fb-71f6835941e8" + }, + "outputs": [], + "source": [ + "ast_preds = model.transcribe(ast_filepaths, pnc='no', task='ast', source_lang='en', target_lang='de', batch_size=32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69e5bb54-5193-4268-98e1-dc6daae8f6eb", + "metadata": { + "id": "69e5bb54-5193-4268-98e1-dc6daae8f6eb" + }, + "outputs": [], + "source": [ + "wer = word_error_rate(asr_preds, asr_gt)\n", + "print(\"WER\", wer)\n", + "\n", + "sacrebleu = SacreBLEUScore(n_gram=4)\n", + "scores = []\n", + "preds = []\n", + "gts = []\n", + "for pred, gt in zip(ast_preds, ast_gt):\n", + " preds.append(pred)\n", + " gts.append([gt])\n", + "\n", + "# bleu = sum(scores) / len(scores)\n", + "sacrebleu.update(preds, gts)\n", + "bleu = sacrebleu.compute()\n", + "print(\"BLEU\", bleu.item() * 100)" + ] + }, + { + "cell_type": "markdown", + "id": "5ee530c9-36a3-47d2-83b9-b2a64080c0eb", + "metadata": { + "id": "5ee530c9-36a3-47d2-83b9-b2a64080c0eb" + }, + "source": [ + "# Train Model\n", + "\n", + "Finally, now that adapters have been prepared, model has been evaluated for a baseline and the dataset is prepared, it's time to train the adapter weights on the new datasets.\n", + "\n", + "---\n", + "\n", + "First, we update the optimizer and scheduler config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0a40461-d739-436c-967a-1a0f8a3ad197", + "metadata": { + "id": "d0a40461-d739-436c-967a-1a0f8a3ad197" + }, + "outputs": [], + "source": [ + "print(OmegaConf.to_yaml(model.cfg.optim))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ba5811a-fc42-4de5-add5-0d26d1c84219", + "metadata": { + "id": "4ba5811a-fc42-4de5-add5-0d26d1c84219" + }, + "outputs": [], + "source": [ + "# Setup optimization\n", + "model.cfg.optim.lr = 3e-4\n", + "model.cfg.optim.sched.warmup_steps = 25" + ] + }, + { + "cell_type": "markdown", + "id": "d1de270a-d1cb-4080-b571-7acf365d7b99", + "metadata": { + "id": "d1de270a-d1cb-4080-b571-7acf365d7b99" + }, + "source": [ + "---\n", + "\n", + "Next, we setup a Lightning Trainer and Experiment Manager" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9e34369-21ec-41bf-beae-30b60ab46c14", + "metadata": { + "id": "b9e34369-21ec-41bf-beae-30b60ab46c14" + }, + "outputs": [], + "source": [ + "from omegaconf import OmegaConf\n", + "from nemo.utils import exp_manager" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46f74863-a34d-4ad0-9d8e-3337ea5edd63", + "metadata": { + "id": "46f74863-a34d-4ad0-9d8e-3337ea5edd63" + }, + "outputs": [], + "source": [ + "trainer = L.Trainer(max_steps=200, accumulate_grad_batches=1, logger=False, enable_checkpointing=False, check_val_every_n_epoch=5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "414d7887-bed5-46a2-bfe1-8349db1e6b5b", + "metadata": { + "id": "414d7887-bed5-46a2-bfe1-8349db1e6b5b" + }, + "outputs": [], + "source": [ + "# # Environment variable generally used for multi-node multi-gpu training.\n", + "# # In notebook environments, this flag is unnecessary and can cause logs of multiple training runs to overwrite each other.\n", + "# os.environ.pop('NEMO_EXPM_VERSION', None)\n", + "\n", + "# config = exp_manager.ExpManagerConfig(\n", + "# exp_dir=f'experiments/canary/',\n", + "# name=f\"Canary-Model-Adapter-Training\",\n", + "# checkpoint_callback_params=exp_manager.CallbackParams(\n", + "# monitor=\"val_wer\",\n", + "# mode=\"min\",\n", + "# always_save_nemo=False,\n", + "# save_best_model=False,\n", + "# ),\n", + "# )\n", + "\n", + "# config = OmegaConf.structured(config)\n", + "\n", + "# logdir = exp_manager.exp_manager(trainer, config)" + ] + }, + { + "cell_type": "markdown", + "id": "60769859-8ed5-4f9c-b93a-a6875c7c1c73", + "metadata": { + "id": "60769859-8ed5-4f9c-b93a-a6875c7c1c73" + }, + "source": [ + "---\n", + "\n", + "Begin training !" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2adb8607-a011-440d-bfa8-976c2871e8ef", + "metadata": { + "scrolled": true, + "id": "2adb8607-a011-440d-bfa8-976c2871e8ef" + }, + "outputs": [], + "source": [ + "trainer.fit(model, data_module)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "---\n", + "\n", + "Save just the adapter parameters - which is less than 2 MB !" + ], + "metadata": { + "id": "MImbKiqQ6ng-" + }, + "id": "MImbKiqQ6ng-" + }, + { + "cell_type": "code", + "source": [ + "model.save_adapters(\"adapters.pt\")\n", + "!ls -l -- *.pt\n", + "!du -sh *.pt" + ], + "metadata": { + "id": "-akTdyGM6gum" + }, + "id": "-akTdyGM6gum", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "2525bec5-c42b-48c1-b03c-e8126c346238", + "metadata": { + "id": "2525bec5-c42b-48c1-b03c-e8126c346238" + }, + "source": [ + "# Evaluate after Adaptatation\n", + "\n", + "Now that the model is done training, lets evalaute its scores on the test set again.\n", + "We should see a markedly higher translastion BLEU and lower WER from above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6edb5528-b1b6-4505-8cdc-ee68c715415e", + "metadata": { + "id": "6edb5528-b1b6-4505-8cdc-ee68c715415e" + }, + "outputs": [], + "source": [ + "asr_test = read_manifest(data_module.test_manifest)\n", + "ast_test = read_manifest(data_module.ast_test_manifest)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "384aa5f2-89d5-4080-a717-4d65776fae6b", + "metadata": { + "id": "384aa5f2-89d5-4080-a717-4d65776fae6b" + }, + "outputs": [], + "source": [ + "asr_filepaths = [x['audio_filepath'] for x in asr_test]\n", + "asr_gt = [x['text'] for x in asr_test]\n", + "\n", + "ast_filepaths = [x['audio_filepath'] for x in ast_test]\n", + "ast_gt = [x['text'] for x in ast_test]\n", + "\n", + "print(\"Num files:\", len(asr_filepaths))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48ce5b4c-d349-4d86-ad3c-ee930bb569ee", + "metadata": { + "id": "48ce5b4c-d349-4d86-ad3c-ee930bb569ee" + }, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " model = model.cuda()\n", + " model = model.to(torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49a37806-286e-4954-8f27-3829cf61d755", + "metadata": { + "id": "49a37806-286e-4954-8f27-3829cf61d755" + }, + "outputs": [], + "source": [ + "asr_preds = model.transcribe(asr_filepaths, pnc='no', task='asr', source_lang='en', target_lang='en', batch_size=32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b701e014-2f71-487c-9300-a3ea89a43a45", + "metadata": { + "id": "b701e014-2f71-487c-9300-a3ea89a43a45" + }, + "outputs": [], + "source": [ + "ast_preds = model.transcribe(ast_filepaths, pnc='no', task='ast', source_lang='en', target_lang='de', batch_size=32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "087054e5-c511-4094-a115-faf4a3b49d51", + "metadata": { + "id": "087054e5-c511-4094-a115-faf4a3b49d51" + }, + "outputs": [], + "source": [ + "from nemo.collections.asr.metrics.wer import word_error_rate\n", + "from torchmetrics.text import SacreBLEUScore" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef938f8f-b2db-45f6-9b30-4b3bbce2423f", + "metadata": { + "id": "ef938f8f-b2db-45f6-9b30-4b3bbce2423f" + }, + "outputs": [], + "source": [ + "wer = word_error_rate(asr_preds, asr_gt)\n", + "print(\"WER\", wer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a7c2820-d394-4627-8438-0d810d89b72d", + "metadata": { + "id": "5a7c2820-d394-4627-8438-0d810d89b72d" + }, + "outputs": [], + "source": [ + "sacrebleu = SacreBLEUScore(n_gram=4)\n", + "scores = []\n", + "preds = []\n", + "gts = []\n", + "for pred, gt in zip(ast_preds, ast_gt):\n", + " preds.append(pred)\n", + " gts.append([gt])\n", + "\n", + "# bleu = sum(scores) / len(scores)\n", + "sacrebleu.update(preds, gts)\n", + "bleu = sacrebleu.compute()\n", + "print(\"BLEU\", bleu.item() * 100)" + ] + }, + { + "cell_type": "markdown", + "id": "521df0e6-1d3c-4709-a080-63638315c514", + "metadata": { + "id": "521df0e6-1d3c-4709-a080-63638315c514" + }, + "source": [ + "# Conclusion\n", + "\n", + "In this tutorial we added adapters to a Multi Task model (Nvidia Canary) and show how to create a custom dataset to finetune a canary model to a new dataset with previous tasks such as ASR and AST. The primary goal of this tutorial was to show how to flexibly adapt a Canary model to any of the pre-existing tasks.\n", + "\n", + "In a future tutorial, we will show how to add additional tasks to a pre-trained Canary, so that you can leverage the pre-trained encoder and decoder for your own custom tasks!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU" + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/tutorials/asr/asr_adapters/README.md b/tutorials/asr/asr_adapters/README.md index 8408be56a218a..393a119938e25 100644 --- a/tutorials/asr/asr_adapters/README.md +++ b/tutorials/asr/asr_adapters/README.md @@ -10,4 +10,6 @@ In this repository, you will find several tutorials discussing how to utilize Ad 1) `ASR_with_Adapters`: An introduction of adapters and their use case with ASR models. Dives into domain adaptation of a pre-trained model with adapter modules, general advantages and disadvantages of adapters and finally trains a model to adapt on a toy dataset. +2) `Multi_Task_Adapters`: An introduction of how to customize multi-task models with adapters. We will train a model on two tasks, one being ASR and the other being a downstream task. We will discuss how to use adapters to finetune a model for Speech Recognition and Speech Translation task on a toy dataset, and dive into construction of custom datasets and prompt formatters for Multi Task Models. + ------------ \ No newline at end of file