Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Canary not stripping prompt from reference + more test coverage #9987

Merged
merged 6 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions nemo/collections/asr/data/audio_to_text_lhotse_prompted.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, Sequence

import torch.utils.data
Expand All @@ -25,6 +26,18 @@
from nemo.collections.common.tokenizers.canary_tokenizer import CANARY_SPECIAL_TOKENIZER


@dataclass
class PromptedAudioToTextMiniBatch:
tbartley94 marked this conversation as resolved.
Show resolved Hide resolved
audio: torch.Tensor
audio_lens: torch.Tensor
transcript: torch.Tensor
transcript_lens: torch.Tensor
prompt: torch.Tensor
prompt_lens: torch.Tensor
prompted_transcript: torch.Tensor
prompted_transcript_lens: torch.Tensor


class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):
"""
This dataset is based on :class:`~nemo.collections.asr.data.audio_to_text_lhotse.LhotseSpeechToTextBpeDataset`.
Expand All @@ -45,41 +58,47 @@ class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):
def __init__(
self,
tokenizer: TokenizerSpec,
prompt_format_fn: Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]],
inference: bool = False,
prompt_format_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]],
):
super().__init__()
self.tokenizer = TokenizerWrapper(tokenizer)
self.load_audio = AudioSamples(fault_tolerant=True)
self.padding_value = self.tokenizer._tokenizer.pad_id
self.prompt_format_fn = prompt_format_fn
self.inference = inference

def __getitem__(self, cuts: CutSet) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch:
audio, audio_lens, cuts = self.load_audio(cuts)

prompts_with_answers, prompts = self.prompt_format_fn(cuts, self.tokenizer, inference=self.inference)

prompts_with_answers = [torch.as_tensor(t) for t in prompts_with_answers]
prompts_with_answers_lens = torch.tensor([t.size(0) for t in prompts_with_answers], dtype=torch.long)
prompts_with_answers = collate_vectors(prompts_with_answers, padding_value=self.padding_value)

if self.inference:
prompts = [torch.as_tensor(t) for t in prompts]
prompts_lens = torch.tensor([t.size(0) for t in prompts], dtype=torch.long)
prompts = collate_vectors(prompts, padding_value=self.padding_value)
else:
prompts = None
prompts_lens = None
prompts_with_answers, prompts, answers = self.prompt_format_fn(cuts, self.tokenizer)

transcript, transcript_lens = self._collate_tokens(answers)
prompts_with_answers, prompts_with_answers_lens = self._collate_tokens(prompts_with_answers)
prompts, prompt_lens = self._collate_tokens(prompts)

# return audio, audio_lens, prompts_with_answers, prompts_with_answers_lens, prompts, prompts_lens
return PromptedAudioToTextMiniBatch(
audio=audio,
audio_lens=audio_lens,
transcript=transcript,
transcript_lens=transcript_lens,
prompt=prompts,
prompt_lens=prompt_lens,
prompted_transcript=prompts_with_answers,
prompted_transcript_lens=prompts_with_answers_lens,
)

return audio, audio_lens, prompts_with_answers, prompts_with_answers_lens, prompts, prompts_lens
def _collate_tokens(self, tokens: list[list[int]]) -> tuple[torch.Tensor, torch.Tensor]:
tokens = [torch.as_tensor(t) for t in tokens]
token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long)
tokens = collate_vectors(tokens, padding_value=self.padding_value)
return tokens, token_lens


# Mapping from a string name to a known prompt formatter function.
PROMPT_FORMAT_FNS = {}


def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]]):
def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]):
"""
Decorator for registering prompt functions under a name.

Expand All @@ -97,7 +116,7 @@ def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper, b
return prompt_fn


def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]]:
def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]:
if name not in PROMPT_FORMAT_FNS:
raise ValueError(
f"Unknown prompt format function name: {name} " f"(must be one of: {list(PROMPT_FORMAT_FNS.keys())}"
Expand All @@ -107,8 +126,8 @@ def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper, bool]

@registered_prompt_format_fn
def canary(
cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
cuts: CutSet, tokenizer: TokenizerWrapper
) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
"""
Prepend and append control tokens to the token sequence as per Canary format.

Expand Down Expand Up @@ -137,7 +156,7 @@ def canary(
), "To use 'canary' prompt format, you must use the CanaryTokenizer."
formatter = CanaryPromptFormatter(tokenizer._tokenizer)

prompts_with_answers, prompts = [], []
prompts_with_answers, prompts, answers = [], [], []
for cut in cuts:
if isinstance(cut, MixedCut):
cut = cut._first_non_padding_cut
Expand Down Expand Up @@ -180,8 +199,9 @@ def canary(
)
prompts_with_answers.append(encoded["input_ids"])
prompts.append(encoded["context_ids"])
answers.append(encoded["answer_ids"][:-1]) # Strip Canary's EOS
tbartley94 marked this conversation as resolved.
Show resolved Hide resolved

return prompts_with_answers, prompts
return prompts_with_answers, prompts, answers


class ProbablyIncorrectLanguageKeyError(RuntimeError):
Expand Down
103 changes: 58 additions & 45 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from nemo.collections.asr.data.audio_to_text_lhotse_prompted import (
PromptedAudioToTextLhotseDataset,
PromptedAudioToTextMiniBatch,
get_prompt_format_fn,
)
from nemo.collections.asr.metrics import BLEU, WER
Expand Down Expand Up @@ -498,7 +499,7 @@ def transcribe(

return super().transcribe(audio=audio, override_config=trcfg)

def _setup_dataloader_from_config(self, config: Optional[Dict], inference: bool = False):
def _setup_dataloader_from_config(self, config: Optional[Dict]):
assert config.get("use_lhotse", False), (
"Multi-task model only supports dataloading with Lhotse. "
"Please set config.{train,validation,test}_ds.use_lhotse=True"
Expand All @@ -510,7 +511,6 @@ def _setup_dataloader_from_config(self, config: Optional[Dict], inference: bool
dataset=PromptedAudioToTextLhotseDataset(
tokenizer=self.tokenizer,
prompt_format_fn=get_prompt_format_fn(self.prompt_format),
inference=inference,
),
)

Expand Down Expand Up @@ -554,7 +554,7 @@ def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict

# preserve config
self._update_dataset_config(dataset_name='validation', config=val_data_config)
self._validation_dl = self._setup_dataloader_from_config(config=val_data_config, inference=True)
self._validation_dl = self._setup_dataloader_from_config(config=val_data_config)

def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]):
"""
Expand All @@ -570,7 +570,7 @@ def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]):

# preserve config
self._update_dataset_config(dataset_name='test', config=test_data_config)
self._test_dl = self._setup_dataloader_from_config(config=test_data_config, inference=True)
self._test_dl = self._setup_dataloader_from_config(config=test_data_config)

@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
Expand Down Expand Up @@ -664,20 +664,18 @@ def forward(
return transf_log_probs, encoded_len, enc_states, enc_mask

# PTL-specific methods
def training_step(self, batch, batch_nb):
def training_step(self, batch: PromptedAudioToTextMiniBatch, batch_nb):

if batch is None:
return torch.tensor([0.0])

# During training prompt and prompt_len are null, ignore.
signal, signal_len, transcript, transcript_len, prompt, prompt_len = batch
input_ids, labels = transcript[:, :-1], transcript[:, 1:]
input_ids, labels = batch.prompted_transcript[:, :-1], batch.prompted_transcript[:, 1:]
pzelasko marked this conversation as resolved.
Show resolved Hide resolved

transf_log_probs, encoded_len, enc_states, enc_mask = self.forward(
input_signal=signal,
input_signal_length=signal_len,
input_signal=batch.audio,
input_signal_length=batch.audio_lens,
transcript=input_ids,
transcript_length=transcript_len,
transcript_length=batch.prompted_transcript_lens,
)

audio_loss = self.loss(log_probs=transf_log_probs, labels=labels)
Expand All @@ -689,16 +687,15 @@ def training_step(self, batch, batch_nb):

return {'loss': audio_loss, 'log': tensorboard_logs}

def validation_pass(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"):
def validation_pass(self, batch: PromptedAudioToTextMiniBatch, batch_idx, dataloader_idx=0, eval_mode="val"):
# During inference, dataloader passes pure prompt without transcript text.
signal, signal_len, transcript, transcript_len, prompt, prompt_len = batch
input_ids, labels = transcript[:, :-1], transcript[:, 1:]
input_ids, labels = batch.prompted_transcript[:, :-1], batch.prompted_transcript[:, 1:]

transf_log_probs, encoded_len, enc_states, enc_mask = self.forward(
input_signal=signal,
input_signal_length=signal_len,
input_signal=batch.audio,
input_signal_length=batch.audio_lens,
transcript=input_ids,
transcript_length=transcript_len,
transcript_length=batch.prompted_transcript_lens,
)

transf_loss = self.loss(log_probs=transf_log_probs, labels=labels)
Expand All @@ -710,10 +707,10 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"):
self.wer.update(
predictions=enc_states,
predictions_lengths=encoded_len,
targets=transcript,
targets_lengths=transcript_len,
targets=batch.transcript,
targets_lengths=batch.transcript_lens,
predictions_mask=enc_mask,
input_ids=prompt,
input_ids=batch.prompt,
)
wer, wer_num, wer_denom = self.wer.compute()
output_dict.update({"val_wer": wer, "val_wer_num": wer_num, "val_wer_denom": wer_denom})
Expand All @@ -722,10 +719,10 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"):
self.bleu.update(
predictions=enc_states,
predictions_lengths=encoded_len,
targets=transcript,
targets_lengths=transcript_len,
targets=batch.transcript,
targets_lengths=batch.transcript_lens,
predictions_mask=enc_mask,
input_ids=prompt,
input_ids=batch.prompt,
)
bleu_metrics = self.bleu.compute(prefix=f"{eval_mode}_")
output_dict.update(bleu_metrics)
Expand Down Expand Up @@ -823,7 +820,9 @@ def _transcribe_input_manifest_processing(

return super()._transcribe_input_manifest_processing(audio_files, temp_dir, trcfg)

def _transcribe_forward(self, batch: Any, trcfg: MultiTaskTranscriptionConfig):
def _transcribe_forward(
self, batch: PromptedAudioToTextMiniBatch | tuple[torch.Tensor, ...], trcfg: MultiTaskTranscriptionConfig
) -> dict:
"""
Internal function to perform the model's custom forward pass to return outputs that are processed by
`_transcribe_output_processing()`.
Expand All @@ -836,13 +835,25 @@ def _transcribe_forward(self, batch: Any, trcfg: MultiTaskTranscriptionConfig):
Returns:
The model's outputs that are processed by `_transcribe_output_processing()`.
"""
log_probs, encoded_len, enc_states, enc_mask = self.forward(
input_signal=batch[0], input_signal_length=batch[1]
)
if len(batch) == 6:
# Prompt provided by the dataloader.
decoder_input_ids = batch[4]
if isinstance(batch, PromptedAudioToTextMiniBatch):
# Handling regular Canary DataLoader
audio = batch.audio
audio_lens = batch.audio_lens
decoder_input_ids = batch.prompted_transcript
else:
# Handling TensorDataset / external DataLoader
audio, audio_lens = batch[0], batch[1]
if len(batch) == 6:
# Prompt provided by the user.
decoder_input_ids = batch[4]
else:
# Prompt to be built dynamically.
decoder_input_ids = None
batch_size = audio.shape[0]

log_probs, encoded_len, enc_states, enc_mask = self.forward(input_signal=audio, input_signal_length=audio_lens)

if decoder_input_ids is None:
# The dataloader provided only audio + audio_lens, so we
# are constructing the prompt dynamically using TranscribeConfig.

Expand Down Expand Up @@ -877,17 +888,17 @@ def _transcribe_forward(self, batch: Any, trcfg: MultiTaskTranscriptionConfig):
decoder_input_ids = (
self.prompt.encode_dialog(turns=turns)["context_ids"]
.unsqueeze(0)
.repeat(batch[0].shape[0], 1)
.repeat(batch_size, 1)
.to(trcfg._internal.device)
)
output = dict(

return dict(
log_probs=log_probs,
encoded_lengths=encoded_len,
encoder_states=enc_states,
encoder_mask=enc_mask,
decoder_input_ids=decoder_input_ids,
)
return output

def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionConfig) -> GenericTranscriptionType:
"""
Expand Down Expand Up @@ -954,7 +965,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'channel_selector': config.get('channel_selector', None),
}

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), inference=True)
temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
return temporary_datalayer

def _transcribe_on_end(self, trcfg: MultiTaskTranscriptionConfig):
Expand Down Expand Up @@ -1017,34 +1028,36 @@ def get_transcribe_config(cls) -> MultiTaskTranscriptionConfig:
"""
return MultiTaskTranscriptionConfig()

def predict_step(self, batch, batch_idx=0, dataloader_idx=0, has_processed_signal=False):
signal, signal_len, _, _, prompt, prompt_len = batch

processed_signal = None
processed_signal_length = None
def predict_step(
self, batch: PromptedAudioToTextMiniBatch, batch_idx=0, dataloader_idx=0, has_processed_signal=False
):
if has_processed_signal:
processed_signal = signal
processed_signal_length = signal_len
processed_signal = batch.audio
processed_signal_length = batch.audio_lens
signal = None
signal_len = None
else:
processed_signal = None
processed_signal_length = None
signal = batch.audio
signal_len = batch.audio_lens

transf_log_probs, encoded_len, enc_states, enc_mask = self.forward(
input_signal=signal,
input_signal_length=signal_len,
processed_signal=processed_signal,
processed_signal_length=processed_signal_length,
transcript=prompt,
transcript_length=prompt_len,
transcript=batch.prompt,
transcript_length=batch.prompt_lens,
)

text = self.decoding.decode_predictions_tensor(
encoder_hidden_states=enc_states,
encoder_input_mask=enc_mask,
decoder_input_ids=prompt,
decoder_input_ids=batch.prompt,
return_hypotheses=False,
)[0]

text = [self.decoding.strip_special_tokens(t) for t in text]
return text

@property
Expand Down
Loading
Loading