diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index edb591921782..dcebb9ab2a6c 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -14,13 +14,14 @@ import os import warnings +from collections.abc import Mapping, Sequence from dataclasses import dataclass, field from math import ceil from typing import Any, Dict, List, Optional, Union import numpy as np import torch -from omegaconf import DictConfig, OmegaConf, open_dict +from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict from pytorch_lightning import Trainer from torch.utils.data import DataLoader @@ -387,6 +388,59 @@ def change_vocabulary( logging.info(f"Changed decoder to output to {vocabulary} vocabulary.") + def change_prompt( + self, prompt_format: Optional[str] = None, prompt_defaults: Optional[List[Dict[str, Any]]] = None + ): + """ + Changes the prompt format used during Multi Task decoding process. + + Args: + prompt_format: A string alias of the object that represents the prompt structure. + If not None, it will be used to update the prompt format. + prompt_defaults: A dictionary of default values for the prompt format. + """ + if prompt_format is not None: + self.prompt_format = prompt_format + + if prompt_defaults is not None: + # Perform some assertions on the prompt defaults contents + # Must be a list-like object + if not isinstance(prompt_defaults, Sequence): + raise ValueError("`prompt_defaults` must be a list of dictionaries") + + # Must contain dict-like objects + for item in prompt_defaults: + if not isinstance(item, Mapping): + raise ValueError("`prompt_defaults` must be a list of dictionaries") + + # Each dict item must have a `role` key + if 'role' not in item: + raise ValueError( + "`prompt_defaults` must have a `role` key for each item in the list of dictionaries" + ) + + if 'slots' not in item: + raise ValueError( + "`prompt_defaults` must have a `slots` key for each item in the list of dictionaries" + ) + + # Cast to OmegaConf if not already + if not isinstance(prompt_defaults, ListConfig): + prompt_defaults = OmegaConf.create(prompt_defaults) + + prompt_cls = PromptFormatter.resolve(self.prompt_format) + self.prompt = prompt_cls( + tokenizer=self.tokenizer, + defaults=OmegaConf.to_container(pd) if (pd := self.cfg.prompt_defaults) is not None else None, + ) + + # Update config + with open_dict(self.cfg): + self.cfg.prompt_format = self.prompt_format + self.cfg.prompt_defaults = prompt_defaults + + logging.info(f"Changed prompt format to `{self.prompt_format}`") + @torch.no_grad() def transcribe( self, diff --git a/nemo/collections/common/prompts/canary.py b/nemo/collections/common/prompts/canary.py index aadc976ba474..e511368a1edf 100644 --- a/nemo/collections/common/prompts/canary.py +++ b/nemo/collections/common/prompts/canary.py @@ -16,9 +16,9 @@ class CanaryPromptFormatter(PromptFormatter): "template": f"{CANARY_BOS}|source_lang||task||target_lang||pnc|", "slots": { "source_lang": Modality.Text, - "task": Modality.Text, + "task": Modality.TextLiteral("asr", "ast", "s2t_translation", "<|transcribe|>", "<|translate|>"), "target_lang": Modality.Text, - "pnc": Modality.Text, + "pnc": Modality.TextLiteral("yes", "no", "<|pnc|>", "<|nopnc|>"), }, }, OUTPUT_ROLE: { diff --git a/nemo/collections/common/prompts/formatter.py b/nemo/collections/common/prompts/formatter.py index 524b2e62c5a3..8a82563ebbaa 100644 --- a/nemo/collections/common/prompts/formatter.py +++ b/nemo/collections/common/prompts/formatter.py @@ -20,22 +20,38 @@ EOS_SLOT = "|eos|" -class Modality(Enum): +class BaseModalityType: + @staticmethod + def matches(value: Any) -> bool: + raise NotImplementedError + + +class Text(BaseModalityType): + """Modality for text values.""" + + @staticmethod + def matches(value: str) -> bool: + return isinstance(value, str) + + +class TextLiteral(BaseModalityType): + def __init__(self, *items): + self.allowed_values = items + + def matches(self, value: str) -> bool: + return isinstance(value, str) and value in self.allowed_values + + def __repr__(self): + return f"{self.__class__.__name__}({self.allowed_values})" + + +class Modality: """ Modalities supported as PromptFormatter slot values. """ - Text = "text" - - def matches(self, value: Any) -> bool: - """ - Checks if the provided value is compatible with an instance of Modality. - """ - match self: - case Modality.Text: - return isinstance(value, str) - case _: - return False + Text = Text + TextLiteral = TextLiteral class PromptFormatter(ABC): diff --git a/tests/collections/asr/test_asr_multitask_model_bpe.py b/tests/collections/asr/test_asr_multitask_model_bpe.py index 986df09deacb..4e805c8f34de 100644 --- a/tests/collections/asr/test_asr_multitask_model_bpe.py +++ b/tests/collections/asr/test_asr_multitask_model_bpe.py @@ -22,6 +22,7 @@ from nemo.collections.asr.models.aed_multitask_models import EncDecMultiTaskModel from nemo.collections.asr.parts.submodules import multitask_beam_decoding as beam_decode from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis +from nemo.collections.common.prompts.canary import CanaryPromptFormatter from nemo.collections.common.tokenizers import CanaryTokenizer @@ -275,6 +276,51 @@ def test_decoding_change(self, asr_model): assert isinstance(asr_model.decoding.decoding, beam_decode.TransformerAEDBeamInfer) assert asr_model.decoding.decoding.search_type == "default" + @pytest.mark.unit + def test_prompt_change(self, asr_model): + assert asr_model.prompt_format == 'canary' + assert isinstance(asr_model.prompt, CanaryPromptFormatter) + + # Default change prompt + asr_model.change_prompt() + assert asr_model.cfg.prompt_defaults is None + + prompt_defaults = asr_model.prompt.get_default_dialog_slots() + prompt_defaults[0]['slots']['pnc'] = 'no' + asr_model.change_prompt(prompt_defaults=prompt_defaults) + + assert asr_model.cfg.prompt_defaults[0]['slots']['pnc'] == 'no' + + @pytest.mark.unit + def test_prompt_change_subclass(self, asr_model): + assert asr_model.prompt_format == 'canary' + assert isinstance(asr_model.prompt, CanaryPromptFormatter) + + class CanaryPromptFormatterSubclass(CanaryPromptFormatter): + NAME = "canary2" + + # Default change prompt + asr_model.change_prompt() + assert asr_model.cfg.prompt_defaults is None + + prompt_defaults = asr_model.prompt.get_default_dialog_slots() + prompt_defaults[0]['slots']['pnc'] = 'no' + asr_model.change_prompt(prompt_format='canary2', prompt_defaults=prompt_defaults) + + assert asr_model.cfg.prompt_format == 'canary2' + assert asr_model.cfg.prompt_defaults[0]['slots']['pnc'] == 'no' + assert isinstance(asr_model.prompt, CanaryPromptFormatterSubclass) + + user_prompt = asr_model.prompt.get_default_dialog_slots()[0] + slots = user_prompt['slots'] + slots['source_lang'] = 'en' + slots['target_lang'] = 'en' + slots['task'] = 'asr' + slots['pnc'] = 'no' + ans = asr_model.prompt.encode_dialog([user_prompt]) + recovered = asr_model.tokenizer.ids_to_text(ans["input_ids"]) + assert recovered == "<|startoftranscript|><|en|><|transcribe|><|en|><|nopnc|>" + @pytest.mark.unit def test_transcribe_single_file(self, asr_model, test_data_dir): audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav")