From 013ee8c49b8e0f56e762e7be729d6ac32110c3ca Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Tue, 25 Jun 2024 15:41:44 -0700 Subject: [PATCH 1/7] Add support to change Multi task model prompt Signed-off-by: smajumdar --- .../asr/models/aed_multitask_models.py | 51 ++++++++++++++++++- nemo/collections/common/prompts/canary.py | 2 +- nemo/collections/common/prompts/formatter.py | 39 +++++++++----- .../asr/test_asr_multitask_model_bpe.py | 34 +++++++++++++ 4 files changed, 112 insertions(+), 14 deletions(-) diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index edb591921782..2b0755d19a6e 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -20,7 +20,7 @@ import numpy as np import torch -from omegaconf import DictConfig, OmegaConf, open_dict +from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict from pytorch_lightning import Trainer from torch.utils.data import DataLoader @@ -387,6 +387,55 @@ def change_vocabulary( logging.info(f"Changed decoder to output to {vocabulary} vocabulary.") + def change_prompt(self, prompt_format: Optional[str] = None, prompt_defaults: Optional[List[Dict[str, Any]]] = None): + """ + Changes the prompt format used during Multi Task decoding process. + + Args: + prompt_format: A string alias of the object that represents the prompt structure. + If not None, it will be used to update the prompt format. + prompt_defaults: A dictionary of default values for the prompt format. + """ + if prompt_format is not None: + self.prompt_format = prompt_format + + if prompt_defaults is not None: + # Perform some assertions on the prompt defaults contents + # Must be a list-like object + if not isinstance(prompt_defaults, (list, tuple, ListConfig)): + raise ValueError("`prompt_defaults` must be a list of dictionaries") + + # Must contain dict-like objects + for item in prompt_defaults: + if not isinstance(item, (dict, DictConfig)): + raise ValueError("`prompt_defaults` must be a list of dictionaries") + + # Each dict item must have a `role` key + if 'role' not in item: + raise ValueError("`prompt_defaults` must have a `role` key for each item " + "in the list of dictionaries") + + if 'slots' not in item: + raise ValueError("`prompt_defaults` must have a `slots` key for each item " + "in the list of dictionaries") + + # Cast to OmegaConf if not already + if not isinstance(prompt_defaults, ListConfig): + prompt_defaults = OmegaConf.create(prompt_defaults) + + prompt_cls = PromptFormatter.resolve(self.prompt_format) + self.prompt = prompt_cls( + tokenizer=self.tokenizer, + defaults=OmegaConf.to_container(pd) if (pd := self.cfg.prompt_defaults) is not None else None, + ) + + # Update config + with open_dict(self.cfg): + self.cfg.prompt_format = self.prompt_format + self.cfg.prompt_defaults = prompt_defaults + + logging.info(f"Changed prompt format to `{self.prompt_format}`") + @torch.no_grad() def transcribe( self, diff --git a/nemo/collections/common/prompts/canary.py b/nemo/collections/common/prompts/canary.py index aadc976ba474..bf8025d75502 100644 --- a/nemo/collections/common/prompts/canary.py +++ b/nemo/collections/common/prompts/canary.py @@ -18,7 +18,7 @@ class CanaryPromptFormatter(PromptFormatter): "source_lang": Modality.Text, "task": Modality.Text, "target_lang": Modality.Text, - "pnc": Modality.Text, + "pnc": Modality.TextLiteral("yes", "no"), }, }, OUTPUT_ROLE: { diff --git a/nemo/collections/common/prompts/formatter.py b/nemo/collections/common/prompts/formatter.py index 524b2e62c5a3..1eef6cb0d547 100644 --- a/nemo/collections/common/prompts/formatter.py +++ b/nemo/collections/common/prompts/formatter.py @@ -20,22 +20,37 @@ EOS_SLOT = "|eos|" -class Modality(Enum): +class BaseModalityType: + @staticmethod + def matches(value: Any) -> bool: + raise NotImplementedError + + +class Text(BaseModalityType): + """ Modality for text values. """ + @staticmethod + def matches(value: str) -> bool: + return isinstance(value, str) + +class TextLiteral(BaseModalityType): + def __init__(self, *items): + self.allowed_values = items + + def matches(self, value: str): + if "<|" in value and "|>" in value: + return True # Special token + return isinstance(value, str) and value in self.allowed_values + + def __repr__(self): + return f"{self.__class__.__name__}(allowed_values={self.allowed_values})" + +class Modality: """ Modalities supported as PromptFormatter slot values. """ - Text = "text" - - def matches(self, value: Any) -> bool: - """ - Checks if the provided value is compatible with an instance of Modality. - """ - match self: - case Modality.Text: - return isinstance(value, str) - case _: - return False + Text = Text + TextLiteral = TextLiteral class PromptFormatter(ABC): diff --git a/tests/collections/asr/test_asr_multitask_model_bpe.py b/tests/collections/asr/test_asr_multitask_model_bpe.py index 986df09deacb..ef085b1b809c 100644 --- a/tests/collections/asr/test_asr_multitask_model_bpe.py +++ b/tests/collections/asr/test_asr_multitask_model_bpe.py @@ -23,6 +23,7 @@ from nemo.collections.asr.parts.submodules import multitask_beam_decoding as beam_decode from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.collections.common.tokenizers import CanaryTokenizer +from nemo.collections.common.prompts.canary import CanaryPromptFormatter @pytest.fixture() @@ -275,6 +276,39 @@ def test_decoding_change(self, asr_model): assert isinstance(asr_model.decoding.decoding, beam_decode.TransformerAEDBeamInfer) assert asr_model.decoding.decoding.search_type == "default" + @pytest.mark.unit + def test_prompt_change(self, asr_model): + assert asr_model.prompt_format == 'canary' + assert isinstance(asr_model.prompt, CanaryPromptFormatter) + + # Default change prompt + asr_model.change_prompt() + assert asr_model.cfg.prompt_defaults is None + + prompt_defaults = [{'role': 'user', 'slots': {'pnc': 'no'}}] + asr_model.change_prompt(prompt_defaults=prompt_defaults) + + assert asr_model.cfg.prompt_defaults[0]['slots']['pnc'] == 'no' + + @pytest.mark.unit + def test_prompt_change_subclass(self, asr_model): + assert asr_model.prompt_format == 'canary' + assert isinstance(asr_model.prompt, CanaryPromptFormatter) + + class CanaryPromptFormatterSubclass(CanaryPromptFormatter): + NAME = "canary2" + + # Default change prompt + asr_model.change_prompt() + assert asr_model.cfg.prompt_defaults is None + + prompt_defaults = [{'role': 'user', 'slots': {'pnc': 'no'}}] + asr_model.change_prompt(prompt_format='canary2', prompt_defaults=prompt_defaults) + + assert asr_model.cfg.prompt_format == 'canary2' + assert asr_model.cfg.prompt_defaults[0]['slots']['pnc'] == 'no' + assert isinstance(asr_model.prompt, CanaryPromptFormatterSubclass) + @pytest.mark.unit def test_transcribe_single_file(self, asr_model, test_data_dir): audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav") From bd0a6e43ad55a6f3b302da307fcb12e528d58c6b Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Tue, 25 Jun 2024 16:02:06 -0700 Subject: [PATCH 2/7] Add support to change Multi task model prompt Signed-off-by: smajumdar --- .../asr/test_asr_multitask_model_bpe.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/collections/asr/test_asr_multitask_model_bpe.py b/tests/collections/asr/test_asr_multitask_model_bpe.py index ef085b1b809c..ab4e7c84385d 100644 --- a/tests/collections/asr/test_asr_multitask_model_bpe.py +++ b/tests/collections/asr/test_asr_multitask_model_bpe.py @@ -285,7 +285,8 @@ def test_prompt_change(self, asr_model): asr_model.change_prompt() assert asr_model.cfg.prompt_defaults is None - prompt_defaults = [{'role': 'user', 'slots': {'pnc': 'no'}}] + prompt_defaults = asr_model.prompt.get_default_dialog_slots() + prompt_defaults[0]['slots']['pnc'] = 'no' asr_model.change_prompt(prompt_defaults=prompt_defaults) assert asr_model.cfg.prompt_defaults[0]['slots']['pnc'] == 'no' @@ -302,13 +303,24 @@ class CanaryPromptFormatterSubclass(CanaryPromptFormatter): asr_model.change_prompt() assert asr_model.cfg.prompt_defaults is None - prompt_defaults = [{'role': 'user', 'slots': {'pnc': 'no'}}] + prompt_defaults = asr_model.prompt.get_default_dialog_slots() + prompt_defaults[0]['slots']['pnc'] = 'no' asr_model.change_prompt(prompt_format='canary2', prompt_defaults=prompt_defaults) assert asr_model.cfg.prompt_format == 'canary2' assert asr_model.cfg.prompt_defaults[0]['slots']['pnc'] == 'no' assert isinstance(asr_model.prompt, CanaryPromptFormatterSubclass) + user_prompt = asr_model.prompt.get_default_dialog_slots()[0] + slots = user_prompt['slots'] + slots['source_lang'] = 'en' + slots['target_lang'] = 'en' + slots['task'] = 'asr' + slots['pnc'] = 'no' + ans = asr_model.prompt.encode_dialog([user_prompt]) + recovered = asr_model.tokenizer.ids_to_text(ans["input_ids"]) + assert recovered == "<|startoftranscript|><|en|><|transcribe|><|en|><|nopnc|>" + @pytest.mark.unit def test_transcribe_single_file(self, asr_model, test_data_dir): audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav") From ffcaebf966548fe51de3774810bef9518d33cd41 Mon Sep 17 00:00:00 2001 From: titu1994 Date: Tue, 25 Jun 2024 23:05:59 +0000 Subject: [PATCH 3/7] Apply isort and black reformatting Signed-off-by: titu1994 --- .../collections/asr/models/aed_multitask_models.py | 14 +++++++++----- nemo/collections/common/prompts/formatter.py | 5 ++++- .../asr/test_asr_multitask_model_bpe.py | 2 +- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 2b0755d19a6e..496b3fc6a329 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -387,7 +387,9 @@ def change_vocabulary( logging.info(f"Changed decoder to output to {vocabulary} vocabulary.") - def change_prompt(self, prompt_format: Optional[str] = None, prompt_defaults: Optional[List[Dict[str, Any]]] = None): + def change_prompt( + self, prompt_format: Optional[str] = None, prompt_defaults: Optional[List[Dict[str, Any]]] = None + ): """ Changes the prompt format used during Multi Task decoding process. @@ -412,12 +414,14 @@ def change_prompt(self, prompt_format: Optional[str] = None, prompt_defaults: Op # Each dict item must have a `role` key if 'role' not in item: - raise ValueError("`prompt_defaults` must have a `role` key for each item " - "in the list of dictionaries") + raise ValueError( + "`prompt_defaults` must have a `role` key for each item " "in the list of dictionaries" + ) if 'slots' not in item: - raise ValueError("`prompt_defaults` must have a `slots` key for each item " - "in the list of dictionaries") + raise ValueError( + "`prompt_defaults` must have a `slots` key for each item " "in the list of dictionaries" + ) # Cast to OmegaConf if not already if not isinstance(prompt_defaults, ListConfig): diff --git a/nemo/collections/common/prompts/formatter.py b/nemo/collections/common/prompts/formatter.py index 1eef6cb0d547..bb93f3eaee00 100644 --- a/nemo/collections/common/prompts/formatter.py +++ b/nemo/collections/common/prompts/formatter.py @@ -27,11 +27,13 @@ def matches(value: Any) -> bool: class Text(BaseModalityType): - """ Modality for text values. """ + """Modality for text values.""" + @staticmethod def matches(value: str) -> bool: return isinstance(value, str) + class TextLiteral(BaseModalityType): def __init__(self, *items): self.allowed_values = items @@ -44,6 +46,7 @@ def matches(self, value: str): def __repr__(self): return f"{self.__class__.__name__}(allowed_values={self.allowed_values})" + class Modality: """ Modalities supported as PromptFormatter slot values. diff --git a/tests/collections/asr/test_asr_multitask_model_bpe.py b/tests/collections/asr/test_asr_multitask_model_bpe.py index ab4e7c84385d..4e805c8f34de 100644 --- a/tests/collections/asr/test_asr_multitask_model_bpe.py +++ b/tests/collections/asr/test_asr_multitask_model_bpe.py @@ -22,8 +22,8 @@ from nemo.collections.asr.models.aed_multitask_models import EncDecMultiTaskModel from nemo.collections.asr.parts.submodules import multitask_beam_decoding as beam_decode from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis -from nemo.collections.common.tokenizers import CanaryTokenizer from nemo.collections.common.prompts.canary import CanaryPromptFormatter +from nemo.collections.common.tokenizers import CanaryTokenizer @pytest.fixture() From 4ccf4c6bf8da226a4da8b9b9bfd3b5c35c722d2f Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Wed, 26 Jun 2024 10:53:48 -0700 Subject: [PATCH 4/7] Update nemo/collections/common/prompts/formatter.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Piotr Żelasko Signed-off-by: Somshubra Majumdar --- nemo/collections/common/prompts/formatter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/common/prompts/formatter.py b/nemo/collections/common/prompts/formatter.py index bb93f3eaee00..ea55aaa8faad 100644 --- a/nemo/collections/common/prompts/formatter.py +++ b/nemo/collections/common/prompts/formatter.py @@ -44,7 +44,7 @@ def matches(self, value: str): return isinstance(value, str) and value in self.allowed_values def __repr__(self): - return f"{self.__class__.__name__}(allowed_values={self.allowed_values})" + return f"{self.__class__.__name__}{self.allowed_values}" class Modality: From 4ba5408f03bde5afbfeaccbff1899c041f5ac154 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Thu, 27 Jun 2024 15:48:56 -0700 Subject: [PATCH 5/7] Address comments Signed-off-by: smajumdar --- nemo/collections/asr/models/aed_multitask_models.py | 9 +++++---- nemo/collections/common/prompts/canary.py | 4 ++-- nemo/collections/common/prompts/formatter.py | 6 ++---- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 496b3fc6a329..6fbe3fd395b8 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -17,6 +17,7 @@ from dataclasses import dataclass, field from math import ceil from typing import Any, Dict, List, Optional, Union +from collections.abc import Sequence, Mapping import numpy as np import torch @@ -404,23 +405,23 @@ def change_prompt( if prompt_defaults is not None: # Perform some assertions on the prompt defaults contents # Must be a list-like object - if not isinstance(prompt_defaults, (list, tuple, ListConfig)): + if not isinstance(prompt_defaults, Sequence): raise ValueError("`prompt_defaults` must be a list of dictionaries") # Must contain dict-like objects for item in prompt_defaults: - if not isinstance(item, (dict, DictConfig)): + if not isinstance(item, Mapping): raise ValueError("`prompt_defaults` must be a list of dictionaries") # Each dict item must have a `role` key if 'role' not in item: raise ValueError( - "`prompt_defaults` must have a `role` key for each item " "in the list of dictionaries" + "`prompt_defaults` must have a `role` key for each item in the list of dictionaries" ) if 'slots' not in item: raise ValueError( - "`prompt_defaults` must have a `slots` key for each item " "in the list of dictionaries" + "`prompt_defaults` must have a `slots` key for each item in the list of dictionaries" ) # Cast to OmegaConf if not already diff --git a/nemo/collections/common/prompts/canary.py b/nemo/collections/common/prompts/canary.py index bf8025d75502..77db277cadd4 100644 --- a/nemo/collections/common/prompts/canary.py +++ b/nemo/collections/common/prompts/canary.py @@ -16,9 +16,9 @@ class CanaryPromptFormatter(PromptFormatter): "template": f"{CANARY_BOS}|source_lang||task||target_lang||pnc|", "slots": { "source_lang": Modality.Text, - "task": Modality.Text, + "task": Modality.TextLiteral("asr", "asr", "s2t_translation", "<|transcribe|>", "<|translate|>"), "target_lang": Modality.Text, - "pnc": Modality.TextLiteral("yes", "no"), + "pnc": Modality.TextLiteral("yes", "no", "<|pnc|>", "<|nopnc|>"), }, }, OUTPUT_ROLE: { diff --git a/nemo/collections/common/prompts/formatter.py b/nemo/collections/common/prompts/formatter.py index ea55aaa8faad..8a82563ebbaa 100644 --- a/nemo/collections/common/prompts/formatter.py +++ b/nemo/collections/common/prompts/formatter.py @@ -38,13 +38,11 @@ class TextLiteral(BaseModalityType): def __init__(self, *items): self.allowed_values = items - def matches(self, value: str): - if "<|" in value and "|>" in value: - return True # Special token + def matches(self, value: str) -> bool: return isinstance(value, str) and value in self.allowed_values def __repr__(self): - return f"{self.__class__.__name__}{self.allowed_values}" + return f"{self.__class__.__name__}({self.allowed_values})" class Modality: From 5dc7d0a199d263e99b0a2182bc334f5db2d814f6 Mon Sep 17 00:00:00 2001 From: titu1994 Date: Thu, 27 Jun 2024 22:49:41 +0000 Subject: [PATCH 6/7] Apply isort and black reformatting Signed-off-by: titu1994 --- nemo/collections/asr/models/aed_multitask_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 6fbe3fd395b8..dcebb9ab2a6c 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -14,10 +14,10 @@ import os import warnings +from collections.abc import Mapping, Sequence from dataclasses import dataclass, field from math import ceil from typing import Any, Dict, List, Optional, Union -from collections.abc import Sequence, Mapping import numpy as np import torch From 729853ee0f2c21c527055200c79470900da4983c Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Fri, 28 Jun 2024 15:04:02 -0700 Subject: [PATCH 7/7] Address comments Signed-off-by: smajumdar --- nemo/collections/common/prompts/canary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/common/prompts/canary.py b/nemo/collections/common/prompts/canary.py index 77db277cadd4..e511368a1edf 100644 --- a/nemo/collections/common/prompts/canary.py +++ b/nemo/collections/common/prompts/canary.py @@ -16,7 +16,7 @@ class CanaryPromptFormatter(PromptFormatter): "template": f"{CANARY_BOS}|source_lang||task||target_lang||pnc|", "slots": { "source_lang": Modality.Text, - "task": Modality.TextLiteral("asr", "asr", "s2t_translation", "<|transcribe|>", "<|translate|>"), + "task": Modality.TextLiteral("asr", "ast", "s2t_translation", "<|transcribe|>", "<|translate|>"), "target_lang": Modality.Text, "pnc": Modality.TextLiteral("yes", "no", "<|pnc|>", "<|nopnc|>"), },