Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to change Multi task model prompt #9542

Merged
merged 7 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

import os
import warnings
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from math import ceil
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -387,6 +388,59 @@ def change_vocabulary(

logging.info(f"Changed decoder to output to {vocabulary} vocabulary.")

def change_prompt(
self, prompt_format: Optional[str] = None, prompt_defaults: Optional[List[Dict[str, Any]]] = None
):
"""
Changes the prompt format used during Multi Task decoding process.
Args:
prompt_format: A string alias of the object that represents the prompt structure.
If not None, it will be used to update the prompt format.
prompt_defaults: A dictionary of default values for the prompt format.
"""
if prompt_format is not None:
self.prompt_format = prompt_format

if prompt_defaults is not None:
# Perform some assertions on the prompt defaults contents
# Must be a list-like object
if not isinstance(prompt_defaults, Sequence):
raise ValueError("`prompt_defaults` must be a list of dictionaries")

# Must contain dict-like objects
for item in prompt_defaults:
if not isinstance(item, Mapping):
raise ValueError("`prompt_defaults` must be a list of dictionaries")

# Each dict item must have a `role` key
if 'role' not in item:
raise ValueError(
"`prompt_defaults` must have a `role` key for each item in the list of dictionaries"
)

if 'slots' not in item:
raise ValueError(
"`prompt_defaults` must have a `slots` key for each item in the list of dictionaries"
)

# Cast to OmegaConf if not already
if not isinstance(prompt_defaults, ListConfig):
prompt_defaults = OmegaConf.create(prompt_defaults)

prompt_cls = PromptFormatter.resolve(self.prompt_format)
self.prompt = prompt_cls(
tokenizer=self.tokenizer,
defaults=OmegaConf.to_container(pd) if (pd := self.cfg.prompt_defaults) is not None else None,
)

# Update config
with open_dict(self.cfg):
self.cfg.prompt_format = self.prompt_format
self.cfg.prompt_defaults = prompt_defaults

logging.info(f"Changed prompt format to `{self.prompt_format}`")

@torch.no_grad()
def transcribe(
self,
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/common/prompts/canary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ class CanaryPromptFormatter(PromptFormatter):
"template": f"{CANARY_BOS}|source_lang||task||target_lang||pnc|",
"slots": {
"source_lang": Modality.Text,
"task": Modality.Text,
"task": Modality.TextLiteral("asr", "ast", "s2t_translation", "<|transcribe|>", "<|translate|>"),
"target_lang": Modality.Text,
"pnc": Modality.Text,
"pnc": Modality.TextLiteral("yes", "no", "<|pnc|>", "<|nopnc|>"),
},
},
OUTPUT_ROLE: {
Expand Down
40 changes: 28 additions & 12 deletions nemo/collections/common/prompts/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,38 @@
EOS_SLOT = "|eos|"


class Modality(Enum):
class BaseModalityType:
@staticmethod
def matches(value: Any) -> bool:
raise NotImplementedError


class Text(BaseModalityType):
"""Modality for text values."""

@staticmethod
def matches(value: str) -> bool:
return isinstance(value, str)


class TextLiteral(BaseModalityType):
def __init__(self, *items):
self.allowed_values = items

def matches(self, value: str) -> bool:
return isinstance(value, str) and value in self.allowed_values

def __repr__(self):
return f"{self.__class__.__name__}({self.allowed_values})"


class Modality:
"""
Modalities supported as PromptFormatter slot values.
"""

Text = "text"

def matches(self, value: Any) -> bool:
"""
Checks if the provided value is compatible with an instance of Modality.
"""
match self:
case Modality.Text:
return isinstance(value, str)
case _:
return False
Text = Text
TextLiteral = TextLiteral


class PromptFormatter(ABC):
Expand Down
46 changes: 46 additions & 0 deletions tests/collections/asr/test_asr_multitask_model_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nemo.collections.asr.models.aed_multitask_models import EncDecMultiTaskModel
from nemo.collections.asr.parts.submodules import multitask_beam_decoding as beam_decode
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.collections.common.prompts.canary import CanaryPromptFormatter
from nemo.collections.common.tokenizers import CanaryTokenizer


Expand Down Expand Up @@ -275,6 +276,51 @@ def test_decoding_change(self, asr_model):
assert isinstance(asr_model.decoding.decoding, beam_decode.TransformerAEDBeamInfer)
assert asr_model.decoding.decoding.search_type == "default"

@pytest.mark.unit
def test_prompt_change(self, asr_model):
assert asr_model.prompt_format == 'canary'
assert isinstance(asr_model.prompt, CanaryPromptFormatter)

# Default change prompt
asr_model.change_prompt()
assert asr_model.cfg.prompt_defaults is None

prompt_defaults = asr_model.prompt.get_default_dialog_slots()
prompt_defaults[0]['slots']['pnc'] = 'no'
asr_model.change_prompt(prompt_defaults=prompt_defaults)

assert asr_model.cfg.prompt_defaults[0]['slots']['pnc'] == 'no'

@pytest.mark.unit
def test_prompt_change_subclass(self, asr_model):
assert asr_model.prompt_format == 'canary'
assert isinstance(asr_model.prompt, CanaryPromptFormatter)

class CanaryPromptFormatterSubclass(CanaryPromptFormatter):
NAME = "canary2"

# Default change prompt
asr_model.change_prompt()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're allowing no arguments, I wonder whether reset_prompt is not a more appropriate name. As is this is slightly confusing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have similar semantics for change decoding strategy. Calling it with none sets up the original config with all defaults

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out i cant do that cause all the default values are None and then it messes up if user doesn't explicitly override. I really wish we had default values (from config or somewhere). For now ill just leave None prompt_defaults as a No-op

Copy link
Collaborator

@pzelasko pzelasko Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have an idea. How about adding this to CanaryPromptFormatter:

def __init__(self, tokenizer, defaults) -> None:
    if defaults is None:
        defaults = ...  # appropriate values for the released canary-1b checkpoint
    super().__init__(tokenizer, defaults)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes might be a good idea but lets do it in another PR cause i don't actually know all the defaults accepted by canary

assert asr_model.cfg.prompt_defaults is None

prompt_defaults = asr_model.prompt.get_default_dialog_slots()
prompt_defaults[0]['slots']['pnc'] = 'no'
asr_model.change_prompt(prompt_format='canary2', prompt_defaults=prompt_defaults)

assert asr_model.cfg.prompt_format == 'canary2'
assert asr_model.cfg.prompt_defaults[0]['slots']['pnc'] == 'no'
assert isinstance(asr_model.prompt, CanaryPromptFormatterSubclass)

user_prompt = asr_model.prompt.get_default_dialog_slots()[0]
slots = user_prompt['slots']
slots['source_lang'] = 'en'
slots['target_lang'] = 'en'
slots['task'] = 'asr'
slots['pnc'] = 'no'
ans = asr_model.prompt.encode_dialog([user_prompt])
recovered = asr_model.tokenizer.ids_to_text(ans["input_ids"])
assert recovered == "<|startoftranscript|><|en|><|transcribe|><|en|><|nopnc|>"

@pytest.mark.unit
def test_transcribe_single_file(self, asr_model, test_data_dir):
audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav")
Expand Down
Loading