Skip to content

Commit

Permalink
Add support to change Multi task model prompt (NVIDIA#9542)
Browse files Browse the repository at this point in the history
* Add support to change Multi task model prompt

Signed-off-by: smajumdar <[email protected]>

* Add support to change Multi task model prompt

Signed-off-by: smajumdar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: titu1994 <[email protected]>

* Update nemo/collections/common/prompts/formatter.py

Co-authored-by: Piotr Żelasko <[email protected]>
Signed-off-by: Somshubra Majumdar <[email protected]>

* Address comments

Signed-off-by: smajumdar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: titu1994 <[email protected]>

* Address comments

Signed-off-by: smajumdar <[email protected]>

---------

Signed-off-by: smajumdar <[email protected]>
Signed-off-by: titu1994 <[email protected]>
Signed-off-by: Somshubra Majumdar <[email protected]>
Co-authored-by: Piotr Żelasko <[email protected]>
Signed-off-by: tonyjie <[email protected]>
  • Loading branch information
2 people authored and tonyjie committed Aug 6, 2024
1 parent efe528c commit d6c0a06
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 15 deletions.
56 changes: 55 additions & 1 deletion nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

import os
import warnings
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from math import ceil
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -387,6 +388,59 @@ def change_vocabulary(

logging.info(f"Changed decoder to output to {vocabulary} vocabulary.")

def change_prompt(
self, prompt_format: Optional[str] = None, prompt_defaults: Optional[List[Dict[str, Any]]] = None
):
"""
Changes the prompt format used during Multi Task decoding process.
Args:
prompt_format: A string alias of the object that represents the prompt structure.
If not None, it will be used to update the prompt format.
prompt_defaults: A dictionary of default values for the prompt format.
"""
if prompt_format is not None:
self.prompt_format = prompt_format

if prompt_defaults is not None:
# Perform some assertions on the prompt defaults contents
# Must be a list-like object
if not isinstance(prompt_defaults, Sequence):
raise ValueError("`prompt_defaults` must be a list of dictionaries")

# Must contain dict-like objects
for item in prompt_defaults:
if not isinstance(item, Mapping):
raise ValueError("`prompt_defaults` must be a list of dictionaries")

# Each dict item must have a `role` key
if 'role' not in item:
raise ValueError(
"`prompt_defaults` must have a `role` key for each item in the list of dictionaries"
)

if 'slots' not in item:
raise ValueError(
"`prompt_defaults` must have a `slots` key for each item in the list of dictionaries"
)

# Cast to OmegaConf if not already
if not isinstance(prompt_defaults, ListConfig):
prompt_defaults = OmegaConf.create(prompt_defaults)

prompt_cls = PromptFormatter.resolve(self.prompt_format)
self.prompt = prompt_cls(
tokenizer=self.tokenizer,
defaults=OmegaConf.to_container(pd) if (pd := self.cfg.prompt_defaults) is not None else None,
)

# Update config
with open_dict(self.cfg):
self.cfg.prompt_format = self.prompt_format
self.cfg.prompt_defaults = prompt_defaults

logging.info(f"Changed prompt format to `{self.prompt_format}`")

@torch.no_grad()
def transcribe(
self,
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/common/prompts/canary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ class CanaryPromptFormatter(PromptFormatter):
"template": f"{CANARY_BOS}|source_lang||task||target_lang||pnc|",
"slots": {
"source_lang": Modality.Text,
"task": Modality.Text,
"task": Modality.TextLiteral("asr", "ast", "s2t_translation", "<|transcribe|>", "<|translate|>"),
"target_lang": Modality.Text,
"pnc": Modality.Text,
"pnc": Modality.TextLiteral("yes", "no", "<|pnc|>", "<|nopnc|>"),
},
},
OUTPUT_ROLE: {
Expand Down
40 changes: 28 additions & 12 deletions nemo/collections/common/prompts/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,38 @@
EOS_SLOT = "|eos|"


class Modality(Enum):
class BaseModalityType:
@staticmethod
def matches(value: Any) -> bool:
raise NotImplementedError


class Text(BaseModalityType):
"""Modality for text values."""

@staticmethod
def matches(value: str) -> bool:
return isinstance(value, str)


class TextLiteral(BaseModalityType):
def __init__(self, *items):
self.allowed_values = items

def matches(self, value: str) -> bool:
return isinstance(value, str) and value in self.allowed_values

def __repr__(self):
return f"{self.__class__.__name__}({self.allowed_values})"


class Modality:
"""
Modalities supported as PromptFormatter slot values.
"""

Text = "text"

def matches(self, value: Any) -> bool:
"""
Checks if the provided value is compatible with an instance of Modality.
"""
match self:
case Modality.Text:
return isinstance(value, str)
case _:
return False
Text = Text
TextLiteral = TextLiteral


class PromptFormatter(ABC):
Expand Down
46 changes: 46 additions & 0 deletions tests/collections/asr/test_asr_multitask_model_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nemo.collections.asr.models.aed_multitask_models import EncDecMultiTaskModel
from nemo.collections.asr.parts.submodules import multitask_beam_decoding as beam_decode
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.collections.common.prompts.canary import CanaryPromptFormatter
from nemo.collections.common.tokenizers import CanaryTokenizer


Expand Down Expand Up @@ -275,6 +276,51 @@ def test_decoding_change(self, asr_model):
assert isinstance(asr_model.decoding.decoding, beam_decode.TransformerAEDBeamInfer)
assert asr_model.decoding.decoding.search_type == "default"

@pytest.mark.unit
def test_prompt_change(self, asr_model):
assert asr_model.prompt_format == 'canary'
assert isinstance(asr_model.prompt, CanaryPromptFormatter)

# Default change prompt
asr_model.change_prompt()
assert asr_model.cfg.prompt_defaults is None

prompt_defaults = asr_model.prompt.get_default_dialog_slots()
prompt_defaults[0]['slots']['pnc'] = 'no'
asr_model.change_prompt(prompt_defaults=prompt_defaults)

assert asr_model.cfg.prompt_defaults[0]['slots']['pnc'] == 'no'

@pytest.mark.unit
def test_prompt_change_subclass(self, asr_model):
assert asr_model.prompt_format == 'canary'
assert isinstance(asr_model.prompt, CanaryPromptFormatter)

class CanaryPromptFormatterSubclass(CanaryPromptFormatter):
NAME = "canary2"

# Default change prompt
asr_model.change_prompt()
assert asr_model.cfg.prompt_defaults is None

prompt_defaults = asr_model.prompt.get_default_dialog_slots()
prompt_defaults[0]['slots']['pnc'] = 'no'
asr_model.change_prompt(prompt_format='canary2', prompt_defaults=prompt_defaults)

assert asr_model.cfg.prompt_format == 'canary2'
assert asr_model.cfg.prompt_defaults[0]['slots']['pnc'] == 'no'
assert isinstance(asr_model.prompt, CanaryPromptFormatterSubclass)

user_prompt = asr_model.prompt.get_default_dialog_slots()[0]
slots = user_prompt['slots']
slots['source_lang'] = 'en'
slots['target_lang'] = 'en'
slots['task'] = 'asr'
slots['pnc'] = 'no'
ans = asr_model.prompt.encode_dialog([user_prompt])
recovered = asr_model.tokenizer.ids_to_text(ans["input_ids"])
assert recovered == "<|startoftranscript|><|en|><|transcribe|><|en|><|nopnc|>"

@pytest.mark.unit
def test_transcribe_single_file(self, asr_model, test_data_dir):
audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav")
Expand Down

0 comments on commit d6c0a06

Please sign in to comment.