Skip to content

Commit

Permalink
Fix Canary not stripping prompt from reference + more test coverage (N…
Browse files Browse the repository at this point in the history
…VIDIA#9987)

* Fix not stripping Canary prompt from the reference and add extra test coverage

Signed-off-by: Piotr Żelasko <[email protected]>

* Fix transcripts for Canary containing EOS

Signed-off-by: Piotr Żelasko <[email protected]>

* Fix validation_pass

Signed-off-by: Piotr Żelasko <[email protected]>

* Review

Signed-off-by: Piotr Żelasko <[email protected]>

* Revie 2

Signed-off-by: Piotr Żelasko <[email protected]>

* Fix non-deterministic unit test assertions

Signed-off-by: Piotr Żelasko <[email protected]>

---------

Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko authored and XuesongYang committed Jan 18, 2025
1 parent 3647e8e commit 4912d0e
Show file tree
Hide file tree
Showing 8 changed files with 694 additions and 121 deletions.
78 changes: 54 additions & 24 deletions nemo/collections/asr/data/audio_to_text_lhotse_prompted.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, Sequence

import torch.utils.data
Expand All @@ -25,6 +26,26 @@
from nemo.collections.common.tokenizers.canary_tokenizer import CANARY_SPECIAL_TOKENIZER


@dataclass
class PromptedAudioToTextMiniBatch:
audio: torch.Tensor
audio_lens: torch.Tensor
transcript: torch.Tensor
transcript_lens: torch.Tensor
prompt: torch.Tensor
prompt_lens: torch.Tensor
prompted_transcript: torch.Tensor
prompted_transcript_lens: torch.Tensor

def get_decoder_inputs_outputs(self) -> tuple[torch.Tensor, torch.Tensor]:
"""
Returns the inputs and outputs of transformer decoder for training.
The input is ``prompted_transcript`` (minus last token),
and the output is ``prompted_transcript`` (minus first token).
"""
return self.prompted_transcript[:, :-1], self.prompted_transcript[:, 1:]


class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):
"""
This dataset is based on :class:`~nemo.collections.asr.data.audio_to_text_lhotse.LhotseSpeechToTextBpeDataset`.
Expand All @@ -45,41 +66,46 @@ class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):
def __init__(
self,
tokenizer: TokenizerSpec,
prompt_format_fn: Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]],
inference: bool = False,
prompt_format_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]],
):
super().__init__()
self.tokenizer = TokenizerWrapper(tokenizer)
self.load_audio = AudioSamples(fault_tolerant=True)
self.padding_value = self.tokenizer._tokenizer.pad_id
self.prompt_format_fn = prompt_format_fn
self.inference = inference

def __getitem__(self, cuts: CutSet) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch:
audio, audio_lens, cuts = self.load_audio(cuts)

prompts_with_answers, prompts = self.prompt_format_fn(cuts, self.tokenizer, inference=self.inference)

prompts_with_answers = [torch.as_tensor(t) for t in prompts_with_answers]
prompts_with_answers_lens = torch.tensor([t.size(0) for t in prompts_with_answers], dtype=torch.long)
prompts_with_answers = collate_vectors(prompts_with_answers, padding_value=self.padding_value)

if self.inference:
prompts = [torch.as_tensor(t) for t in prompts]
prompts_lens = torch.tensor([t.size(0) for t in prompts], dtype=torch.long)
prompts = collate_vectors(prompts, padding_value=self.padding_value)
else:
prompts = None
prompts_lens = None
prompts_with_answers, prompts, answers = self.prompt_format_fn(cuts, self.tokenizer)

transcript, transcript_lens = self._collate_tokens(answers)
prompts_with_answers, prompts_with_answers_lens = self._collate_tokens(prompts_with_answers)
prompts, prompt_lens = self._collate_tokens(prompts)

return PromptedAudioToTextMiniBatch(
audio=audio,
audio_lens=audio_lens,
transcript=transcript,
transcript_lens=transcript_lens,
prompt=prompts,
prompt_lens=prompt_lens,
prompted_transcript=prompts_with_answers,
prompted_transcript_lens=prompts_with_answers_lens,
)

return audio, audio_lens, prompts_with_answers, prompts_with_answers_lens, prompts, prompts_lens
def _collate_tokens(self, tokens: list[list[int]]) -> tuple[torch.Tensor, torch.Tensor]:
tokens = [torch.as_tensor(t) for t in tokens]
token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long)
tokens = collate_vectors(tokens, padding_value=self.padding_value)
return tokens, token_lens


# Mapping from a string name to a known prompt formatter function.
PROMPT_FORMAT_FNS = {}


def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]]):
def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]):
"""
Decorator for registering prompt functions under a name.
Expand All @@ -97,7 +123,7 @@ def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper, b
return prompt_fn


def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]]:
def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]:
if name not in PROMPT_FORMAT_FNS:
raise ValueError(
f"Unknown prompt format function name: {name} " f"(must be one of: {list(PROMPT_FORMAT_FNS.keys())}"
Expand All @@ -107,8 +133,8 @@ def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper, bool]

@registered_prompt_format_fn
def canary(
cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
cuts: CutSet, tokenizer: TokenizerWrapper
) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
"""
Prepend and append control tokens to the token sequence as per Canary format.
Expand Down Expand Up @@ -137,7 +163,7 @@ def canary(
), "To use 'canary' prompt format, you must use the CanaryTokenizer."
formatter = CanaryPromptFormatter(tokenizer._tokenizer)

prompts_with_answers, prompts = [], []
prompts_with_answers, prompts, answers = [], [], []
for cut in cuts:
if isinstance(cut, MixedCut):
cut = cut._first_non_padding_cut
Expand Down Expand Up @@ -180,8 +206,12 @@ def canary(
)
prompts_with_answers.append(encoded["input_ids"])
prompts.append(encoded["context_ids"])
assert (
encoded["answer_ids"][-1].item() == formatter.tokenizer.eos
), f"Expected the last token in answer_ids to be EOS, but we got {encoded['answer_ids']=}"
answers.append(encoded["answer_ids"][:-1]) # Strip Canary's EOS

return prompts_with_answers, prompts
return prompts_with_answers, prompts, answers


class ProbablyIncorrectLanguageKeyError(RuntimeError):
Expand Down
104 changes: 58 additions & 46 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from nemo.collections.asr.data.audio_to_text_lhotse_prompted import (
PromptedAudioToTextLhotseDataset,
PromptedAudioToTextMiniBatch,
get_prompt_format_fn,
)
from nemo.collections.asr.metrics import BLEU, WER
Expand Down Expand Up @@ -498,7 +499,7 @@ def transcribe(

return super().transcribe(audio=audio, override_config=trcfg)

def _setup_dataloader_from_config(self, config: Optional[Dict], inference: bool = False):
def _setup_dataloader_from_config(self, config: Optional[Dict]):
assert config.get("use_lhotse", False), (
"Multi-task model only supports dataloading with Lhotse. "
"Please set config.{train,validation,test}_ds.use_lhotse=True"
Expand All @@ -510,7 +511,6 @@ def _setup_dataloader_from_config(self, config: Optional[Dict], inference: bool
dataset=PromptedAudioToTextLhotseDataset(
tokenizer=self.tokenizer,
prompt_format_fn=get_prompt_format_fn(self.prompt_format),
inference=inference,
),
)

Expand Down Expand Up @@ -554,7 +554,7 @@ def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict

# preserve config
self._update_dataset_config(dataset_name='validation', config=val_data_config)
self._validation_dl = self._setup_dataloader_from_config(config=val_data_config, inference=True)
self._validation_dl = self._setup_dataloader_from_config(config=val_data_config)

def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]):
"""
Expand All @@ -570,7 +570,7 @@ def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]):

# preserve config
self._update_dataset_config(dataset_name='test', config=test_data_config)
self._test_dl = self._setup_dataloader_from_config(config=test_data_config, inference=True)
self._test_dl = self._setup_dataloader_from_config(config=test_data_config)

@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
Expand Down Expand Up @@ -664,20 +664,18 @@ def forward(
return transf_log_probs, encoded_len, enc_states, enc_mask

# PTL-specific methods
def training_step(self, batch, batch_nb):
def training_step(self, batch: PromptedAudioToTextMiniBatch, batch_nb):

if batch is None:
return torch.tensor([0.0])

# During training prompt and prompt_len are null, ignore.
signal, signal_len, transcript, transcript_len, prompt, prompt_len = batch
input_ids, labels = transcript[:, :-1], transcript[:, 1:]
input_ids, labels = batch.get_decoder_inputs_outputs()

transf_log_probs, encoded_len, enc_states, enc_mask = self.forward(
input_signal=signal,
input_signal_length=signal_len,
input_signal=batch.audio,
input_signal_length=batch.audio_lens,
transcript=input_ids,
transcript_length=transcript_len,
transcript_length=batch.prompted_transcript_lens,
)

audio_loss = self.loss(log_probs=transf_log_probs, labels=labels)
Expand All @@ -689,16 +687,14 @@ def training_step(self, batch, batch_nb):

return {'loss': audio_loss, 'log': tensorboard_logs}

def validation_pass(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"):
# During inference, dataloader passes pure prompt without transcript text.
signal, signal_len, transcript, transcript_len, prompt, prompt_len = batch
input_ids, labels = transcript[:, :-1], transcript[:, 1:]
def validation_pass(self, batch: PromptedAudioToTextMiniBatch, batch_idx, dataloader_idx=0, eval_mode="val"):
input_ids, labels = batch.get_decoder_inputs_outputs()

transf_log_probs, encoded_len, enc_states, enc_mask = self.forward(
input_signal=signal,
input_signal_length=signal_len,
input_signal=batch.audio,
input_signal_length=batch.audio_lens,
transcript=input_ids,
transcript_length=transcript_len,
transcript_length=batch.prompted_transcript_lens,
)

transf_loss = self.loss(log_probs=transf_log_probs, labels=labels)
Expand All @@ -710,10 +706,10 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"):
self.wer.update(
predictions=enc_states,
predictions_lengths=encoded_len,
targets=transcript,
targets_lengths=transcript_len,
targets=batch.transcript,
targets_lengths=batch.transcript_lens,
predictions_mask=enc_mask,
input_ids=prompt,
input_ids=batch.prompt,
)
wer, wer_num, wer_denom = self.wer.compute()
output_dict.update({"val_wer": wer, "val_wer_num": wer_num, "val_wer_denom": wer_denom})
Expand All @@ -722,10 +718,10 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"):
self.bleu.update(
predictions=enc_states,
predictions_lengths=encoded_len,
targets=transcript,
targets_lengths=transcript_len,
targets=batch.transcript,
targets_lengths=batch.transcript_lens,
predictions_mask=enc_mask,
input_ids=prompt,
input_ids=batch.prompt,
)
bleu_metrics = self.bleu.compute(prefix=f"{eval_mode}_")
output_dict.update(bleu_metrics)
Expand Down Expand Up @@ -823,7 +819,9 @@ def _transcribe_input_manifest_processing(

return super()._transcribe_input_manifest_processing(audio_files, temp_dir, trcfg)

def _transcribe_forward(self, batch: Any, trcfg: MultiTaskTranscriptionConfig):
def _transcribe_forward(
self, batch: PromptedAudioToTextMiniBatch | tuple[torch.Tensor, ...], trcfg: MultiTaskTranscriptionConfig
) -> dict:
"""
Internal function to perform the model's custom forward pass to return outputs that are processed by
`_transcribe_output_processing()`.
Expand All @@ -836,13 +834,25 @@ def _transcribe_forward(self, batch: Any, trcfg: MultiTaskTranscriptionConfig):
Returns:
The model's outputs that are processed by `_transcribe_output_processing()`.
"""
log_probs, encoded_len, enc_states, enc_mask = self.forward(
input_signal=batch[0], input_signal_length=batch[1]
)
if len(batch) == 6:
# Prompt provided by the dataloader.
decoder_input_ids = batch[4]
if isinstance(batch, PromptedAudioToTextMiniBatch):
# Handling regular Canary DataLoader
audio = batch.audio
audio_lens = batch.audio_lens
decoder_input_ids = batch.prompted_transcript
else:
# Handling TensorDataset / external DataLoader
audio, audio_lens = batch[0], batch[1]
if len(batch) == 6:
# Prompt provided by the user.
decoder_input_ids = batch[4]
else:
# Prompt to be built dynamically.
decoder_input_ids = None
batch_size = audio.shape[0]

log_probs, encoded_len, enc_states, enc_mask = self.forward(input_signal=audio, input_signal_length=audio_lens)

if decoder_input_ids is None:
# The dataloader provided only audio + audio_lens, so we
# are constructing the prompt dynamically using TranscribeConfig.

Expand Down Expand Up @@ -877,17 +887,17 @@ def _transcribe_forward(self, batch: Any, trcfg: MultiTaskTranscriptionConfig):
decoder_input_ids = (
self.prompt.encode_dialog(turns=turns)["context_ids"]
.unsqueeze(0)
.repeat(batch[0].shape[0], 1)
.repeat(batch_size, 1)
.to(trcfg._internal.device)
)
output = dict(

return dict(
log_probs=log_probs,
encoded_lengths=encoded_len,
encoder_states=enc_states,
encoder_mask=enc_mask,
decoder_input_ids=decoder_input_ids,
)
return output

def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionConfig) -> GenericTranscriptionType:
"""
Expand Down Expand Up @@ -954,7 +964,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'channel_selector': config.get('channel_selector', None),
}

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), inference=True)
temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
return temporary_datalayer

def _transcribe_on_end(self, trcfg: MultiTaskTranscriptionConfig):
Expand Down Expand Up @@ -1017,34 +1027,36 @@ def get_transcribe_config(cls) -> MultiTaskTranscriptionConfig:
"""
return MultiTaskTranscriptionConfig()

def predict_step(self, batch, batch_idx=0, dataloader_idx=0, has_processed_signal=False):
signal, signal_len, _, _, prompt, prompt_len = batch

processed_signal = None
processed_signal_length = None
def predict_step(
self, batch: PromptedAudioToTextMiniBatch, batch_idx=0, dataloader_idx=0, has_processed_signal=False
):
if has_processed_signal:
processed_signal = signal
processed_signal_length = signal_len
processed_signal = batch.audio
processed_signal_length = batch.audio_lens
signal = None
signal_len = None
else:
processed_signal = None
processed_signal_length = None
signal = batch.audio
signal_len = batch.audio_lens

transf_log_probs, encoded_len, enc_states, enc_mask = self.forward(
input_signal=signal,
input_signal_length=signal_len,
processed_signal=processed_signal,
processed_signal_length=processed_signal_length,
transcript=prompt,
transcript_length=prompt_len,
transcript=batch.prompt,
transcript_length=batch.prompt_lens,
)

text = self.decoding.decode_predictions_tensor(
encoder_hidden_states=enc_states,
encoder_input_mask=enc_mask,
decoder_input_ids=prompt,
decoder_input_ids=batch.prompt,
return_hypotheses=False,
)[0]

text = [self.decoding.strip_special_tokens(t) for t in text]
return text

@property
Expand Down
Loading

0 comments on commit 4912d0e

Please sign in to comment.