From faf3730bbac4ed90f9f44455669a50bd4083b14d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 31 Jul 2024 12:46:25 -0400 Subject: [PATCH 1/6] Fix not stripping Canary prompt from the reference and add extra test coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../asr/data/audio_to_text_lhotse_prompted.py | 60 +++-- .../asr/models/aed_multitask_models.py | 102 ++++---- .../transformer/transformer_generators.py | 74 ++++-- .../submodules/multitask_beam_decoding.py | 3 + .../parts/submodules/multitask_decoding.py | 53 ++-- .../submodules/multitask_greedy_decoding.py | 242 ++++++++++++++++++ .../asr/parts/submodules/token_classifier.py | 37 ++- .../common/parts/multi_layer_perceptron.py | 6 +- .../asr/decoding/test_multi_task_decoding.py | 182 +++++++++++++ 9 files changed, 648 insertions(+), 111 deletions(-) create mode 100644 nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py create mode 100644 tests/collections/asr/decoding/test_multi_task_decoding.py diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index 4779e3677b05..0e8fdf22034f 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import Callable, Sequence import torch.utils.data @@ -25,6 +26,18 @@ from nemo.collections.common.tokenizers.canary_tokenizer import CANARY_SPECIAL_TOKENIZER +@dataclass +class PromptedAudioToTextMiniBatch: + audio: torch.Tensor + audio_lens: torch.Tensor + transcript: torch.Tensor + transcript_lens: torch.Tensor + prompt: torch.Tensor + prompt_lens: torch.Tensor + prompted_transcript: torch.Tensor + prompted_transcript_lens: torch.Tensor + + class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset): """ This dataset is based on :class:`~nemo.collections.asr.data.audio_to_text_lhotse.LhotseSpeechToTextBpeDataset`. @@ -46,33 +59,40 @@ def __init__( self, tokenizer: TokenizerSpec, prompt_format_fn: Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]], - inference: bool = False, ): super().__init__() self.tokenizer = TokenizerWrapper(tokenizer) self.load_audio = AudioSamples(fault_tolerant=True) self.padding_value = self.tokenizer._tokenizer.pad_id self.prompt_format_fn = prompt_format_fn - self.inference = inference - def __getitem__(self, cuts: CutSet) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch: audio, audio_lens, cuts = self.load_audio(cuts) - prompts_with_answers, prompts = self.prompt_format_fn(cuts, self.tokenizer, inference=self.inference) - - prompts_with_answers = [torch.as_tensor(t) for t in prompts_with_answers] - prompts_with_answers_lens = torch.tensor([t.size(0) for t in prompts_with_answers], dtype=torch.long) - prompts_with_answers = collate_vectors(prompts_with_answers, padding_value=self.padding_value) - - if self.inference: - prompts = [torch.as_tensor(t) for t in prompts] - prompts_lens = torch.tensor([t.size(0) for t in prompts], dtype=torch.long) - prompts = collate_vectors(prompts, padding_value=self.padding_value) - else: - prompts = None - prompts_lens = None + prompts_with_answers, prompts = self.prompt_format_fn(cuts, self.tokenizer) + + transcript = [pa[len(p) :] for pa, p in zip(prompts_with_answers, prompts)] + transcript, transcript_lens = self._collate_tokens(transcript) + prompts_with_answers, prompts_with_answers_lens = self._collate_tokens(prompts_with_answers) + prompts, prompt_lens = self._collate_tokens(prompts) + + # return audio, audio_lens, prompts_with_answers, prompts_with_answers_lens, prompts, prompts_lens + return PromptedAudioToTextMiniBatch( + audio=audio, + audio_lens=audio_lens, + transcript=transcript, + transcript_lens=transcript_lens, + prompt=prompts, + prompt_lens=prompt_lens, + prompted_transcript=prompts_with_answers, + prompted_transcript_lens=prompts_with_answers_lens, + ) - return audio, audio_lens, prompts_with_answers, prompts_with_answers_lens, prompts, prompts_lens + def _collate_tokens(self, tokens: list[list[int]]) -> tuple[torch.Tensor, torch.Tensor]: + tokens = [torch.as_tensor(t) for t in tokens] + token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) + tokens = collate_vectors(tokens, padding_value=self.padding_value) + return tokens, token_lens # Mapping from a string name to a known prompt formatter function. @@ -97,7 +117,7 @@ def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper, b return prompt_fn -def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]]: +def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]: if name not in PROMPT_FORMAT_FNS: raise ValueError( f"Unknown prompt format function name: {name} " f"(must be one of: {list(PROMPT_FORMAT_FNS.keys())}" @@ -106,9 +126,7 @@ def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper, bool] @registered_prompt_format_fn -def canary( - cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False -) -> tuple[list[torch.Tensor], list[torch.Tensor]]: +def canary(cuts: CutSet, tokenizer: TokenizerWrapper) -> tuple[list[torch.Tensor], list[torch.Tensor]]: """ Prepend and append control tokens to the token sequence as per Canary format. diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 5ad91e75a867..a5ce0d64e467 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -27,6 +27,7 @@ from nemo.collections.asr.data.audio_to_text_lhotse_prompted import ( PromptedAudioToTextLhotseDataset, + PromptedAudioToTextMiniBatch, get_prompt_format_fn, ) from nemo.collections.asr.metrics import BLEU, WER @@ -498,7 +499,7 @@ def transcribe( return super().transcribe(audio=audio, override_config=trcfg) - def _setup_dataloader_from_config(self, config: Optional[Dict], inference: bool = False): + def _setup_dataloader_from_config(self, config: Optional[Dict]): assert config.get("use_lhotse", False), ( "Multi-task model only supports dataloading with Lhotse. " "Please set config.{train,validation,test}_ds.use_lhotse=True" @@ -510,7 +511,6 @@ def _setup_dataloader_from_config(self, config: Optional[Dict], inference: bool dataset=PromptedAudioToTextLhotseDataset( tokenizer=self.tokenizer, prompt_format_fn=get_prompt_format_fn(self.prompt_format), - inference=inference, ), ) @@ -554,7 +554,7 @@ def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict # preserve config self._update_dataset_config(dataset_name='validation', config=val_data_config) - self._validation_dl = self._setup_dataloader_from_config(config=val_data_config, inference=True) + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): """ @@ -570,7 +570,7 @@ def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): # preserve config self._update_dataset_config(dataset_name='test', config=test_data_config) - self._test_dl = self._setup_dataloader_from_config(config=test_data_config, inference=True) + self._test_dl = self._setup_dataloader_from_config(config=test_data_config) @property def input_types(self) -> Optional[Dict[str, NeuralType]]: @@ -664,20 +664,18 @@ def forward( return transf_log_probs, encoded_len, enc_states, enc_mask # PTL-specific methods - def training_step(self, batch, batch_nb): + def training_step(self, batch: PromptedAudioToTextMiniBatch, batch_nb): if batch is None: return torch.tensor([0.0]) - # During training prompt and prompt_len are null, ignore. - signal, signal_len, transcript, transcript_len, prompt, prompt_len = batch - input_ids, labels = transcript[:, :-1], transcript[:, 1:] + input_ids, labels = batch.prompted_transcript[:, :-1], batch.prompted_transcript[:, 1:] transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( - input_signal=signal, - input_signal_length=signal_len, + input_signal=batch.audio, + input_signal_length=batch.audio_lens, transcript=input_ids, - transcript_length=transcript_len, + transcript_length=batch.prompted_transcript_lens, ) audio_loss = self.loss(log_probs=transf_log_probs, labels=labels) @@ -689,16 +687,16 @@ def training_step(self, batch, batch_nb): return {'loss': audio_loss, 'log': tensorboard_logs} - def validation_pass(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"): + def validation_pass(self, batch: PromptedAudioToTextMiniBatch, batch_idx, dataloader_idx=0, eval_mode="val"): # During inference, dataloader passes pure prompt without transcript text. signal, signal_len, transcript, transcript_len, prompt, prompt_len = batch - input_ids, labels = transcript[:, :-1], transcript[:, 1:] + input_ids, labels = batch.prompted_transcript[:, :-1], batch.prompted_transcript[:, 1:] transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( - input_signal=signal, - input_signal_length=signal_len, + input_signal=batch.audio, + input_signal_length=batch.audio_lens, transcript=input_ids, - transcript_length=transcript_len, + transcript_length=batch.prompted_transcript_lens, ) transf_loss = self.loss(log_probs=transf_log_probs, labels=labels) @@ -710,10 +708,10 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"): self.wer.update( predictions=enc_states, predictions_lengths=encoded_len, - targets=transcript, - targets_lengths=transcript_len, + targets=batch.transcript, + targets_lengths=batch.transcript_lens, predictions_mask=enc_mask, - input_ids=prompt, + input_ids=batch.prompt, ) wer, wer_num, wer_denom = self.wer.compute() output_dict.update({"val_wer": wer, "val_wer_num": wer_num, "val_wer_denom": wer_denom}) @@ -722,10 +720,10 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"): self.bleu.update( predictions=enc_states, predictions_lengths=encoded_len, - targets=transcript, - targets_lengths=transcript_len, + targets=batch.transcript, + targets_lengths=batch.transcript_lens, predictions_mask=enc_mask, - input_ids=prompt, + input_ids=batch.prompt, ) bleu_metrics = self.bleu.compute(prefix=f"{eval_mode}_") output_dict.update(bleu_metrics) @@ -823,7 +821,9 @@ def _transcribe_input_manifest_processing( return super()._transcribe_input_manifest_processing(audio_files, temp_dir, trcfg) - def _transcribe_forward(self, batch: Any, trcfg: MultiTaskTranscriptionConfig): + def _transcribe_forward( + self, batch: PromptedAudioToTextMiniBatch | tuple[torch.Tensor, ...], trcfg: MultiTaskTranscriptionConfig + ) -> dict: """ Internal function to perform the model's custom forward pass to return outputs that are processed by `_transcribe_output_processing()`. @@ -836,13 +836,25 @@ def _transcribe_forward(self, batch: Any, trcfg: MultiTaskTranscriptionConfig): Returns: The model's outputs that are processed by `_transcribe_output_processing()`. """ - log_probs, encoded_len, enc_states, enc_mask = self.forward( - input_signal=batch[0], input_signal_length=batch[1] - ) - if len(batch) == 6: - # Prompt provided by the dataloader. - decoder_input_ids = batch[4] + if isinstance(batch, PromptedAudioToTextMiniBatch): + # Handling regular Canary DataLoader + audio = batch.audio + audio_lens = batch.audio_lens + decoder_input_ids = batch.prompted_transcript else: + # Handling TensorDataset / external DataLoader + audio, audio_lens = batch[0], batch[1] + if len(batch) == 6: + # Prompt provided by the user. + decoder_input_ids = batch[4] + else: + # Prompt to be built dynamically. + decoder_input_ids = None + batch_size = audio.shape[0] + + log_probs, encoded_len, enc_states, enc_mask = self.forward(input_signal=audio, input_signal_length=audio_lens) + + if decoder_input_ids is None: # The dataloader provided only audio + audio_lens, so we # are constructing the prompt dynamically using TranscribeConfig. @@ -877,17 +889,17 @@ def _transcribe_forward(self, batch: Any, trcfg: MultiTaskTranscriptionConfig): decoder_input_ids = ( self.prompt.encode_dialog(turns=turns)["context_ids"] .unsqueeze(0) - .repeat(batch[0].shape[0], 1) + .repeat(batch_size, 1) .to(trcfg._internal.device) ) - output = dict( + + return dict( log_probs=log_probs, encoded_lengths=encoded_len, encoder_states=enc_states, encoder_mask=enc_mask, decoder_input_ids=decoder_input_ids, ) - return output def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionConfig) -> GenericTranscriptionType: """ @@ -954,7 +966,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo 'channel_selector': config.get('channel_selector', None), } - temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), inference=True) + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) return temporary_datalayer def _transcribe_on_end(self, trcfg: MultiTaskTranscriptionConfig): @@ -1017,34 +1029,36 @@ def get_transcribe_config(cls) -> MultiTaskTranscriptionConfig: """ return MultiTaskTranscriptionConfig() - def predict_step(self, batch, batch_idx=0, dataloader_idx=0, has_processed_signal=False): - signal, signal_len, _, _, prompt, prompt_len = batch - - processed_signal = None - processed_signal_length = None + def predict_step( + self, batch: PromptedAudioToTextMiniBatch, batch_idx=0, dataloader_idx=0, has_processed_signal=False + ): if has_processed_signal: - processed_signal = signal - processed_signal_length = signal_len + processed_signal = batch.audio + processed_signal_length = batch.audio_lens signal = None signal_len = None + else: + processed_signal = None + processed_signal_length = None + signal = batch.audio + signal_len = batch.audio_lens transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( input_signal=signal, input_signal_length=signal_len, processed_signal=processed_signal, processed_signal_length=processed_signal_length, - transcript=prompt, - transcript_length=prompt_len, + transcript=batch.prompt, + transcript_length=batch.prompt_lens, ) text = self.decoding.decode_predictions_tensor( encoder_hidden_states=enc_states, encoder_input_mask=enc_mask, - decoder_input_ids=prompt, + decoder_input_ids=batch.prompt, return_hypotheses=False, )[0] - text = [self.decoding.strip_special_tokens(t) for t in text] return text @property diff --git a/nemo/collections/asr/modules/transformer/transformer_generators.py b/nemo/collections/asr/modules/transformer/transformer_generators.py index 1a38e7fa4b6c..6d4ef223a75b 100644 --- a/nemo/collections/asr/modules/transformer/transformer_generators.py +++ b/nemo/collections/asr/modules/transformer/transformer_generators.py @@ -15,7 +15,9 @@ from contextlib import contextmanager import torch +from torch.distributions import Categorical +from nemo.collections.asr.parts.submodules.token_classifier import TokenClassifier from nemo.collections.common.parts import NEG_INF, mask_padded_tokens __all__ = [ @@ -30,12 +32,13 @@ class GreedySequenceGenerator: """ Greedy sequence generator based on the decoder followed by log_softmax. + Optionally supports temperature sampling with ``n_samples`` and ``temperature`` options. Args: embedding: nn.Module, transforms input_ids into vector embeddings decoder: nn.Module, takes embeddings and produces hidden_states - log_softmax: nn.Module, takes hidden_states and produces log_probs - which correspond to probability distribution of tokens (ids) + classifier: nn.Module, takes hidden_states and produces + logits or log-probability distribution of tokens (ids) pad: index of padding token in the vocabulary bos: index of beginning of sequence token in the vocabulary eos: index of end of sequence token in the vocabulary @@ -45,28 +48,43 @@ class GreedySequenceGenerator: source sequences plus max_delta_length batch_size: size of the batch of generated sequences if neither source nor target starting sequences are provided + n_samples: number of sequences to generate (requires ``temperature`` to be set) + temperature: temperature for temperature sampling. Even with ``n_samples`` set to 1, + enabling temperature will sample hypotheses instead of returning the best ones. """ def __init__( self, embedding, decoder, - log_softmax, + classifier: TokenClassifier, pad=0, bos=1, eos=2, max_sequence_length=512, max_delta_length=20, batch_size=1, + n_samples=1, + temperature=None, ): super().__init__() self.embedding = embedding self.decoder = decoder - self.log_softmax = log_softmax + if hasattr(classifier, "set_log_softmax_enabled"): + classifier = classifier.set_log_softmax_enabled(False) + else: + assert temperature is None, ( + "The module passed as 'classifier' does not support disabling log-softmax, but we require it " + "for temperature sampling since 'temperature' was set. " + "Your model architecture may not support temperature sampling: we suggest disabling temperature." + ) + self.classifier = classifier self.pad, self.bos, self.eos = pad, bos, eos self.max_seq_length = max_sequence_length self.max_delta_len = max_delta_length self.batch_size = batch_size + self.n_samples = n_samples + self.temperature = temperature def _one_step_forward( self, @@ -107,8 +125,8 @@ def _one_step_forward( decoder_mems_list = self.decoder.forward( decoder_hidden_states, decoder_input_mask, decoder_mems_list, return_mems=True ) - log_probs = self.log_softmax.forward(hidden_states=decoder_mems_list[-1][:, -1:]) - return log_probs, decoder_mems_list + logits = self.classifier.forward(hidden_states=decoder_mems_list[-1][:, -1:], temperature=self.temperature) + return logits, decoder_mems_list def _prepare_for_search(self, decoder_input_ids=None, encoder_hidden_states=None): """ @@ -145,30 +163,52 @@ def _forward( self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False ): assert not return_beam_scores + is_sampling = self.temperature is not None and self.n_samples > 1 + tgt, batch_size, max_generation_length = self._prepare_for_search(decoder_input_ids, encoder_hidden_states) + if is_sampling: + tgt = torch.repeat_interleave(tgt, self.n_samples, dim=0) + encoder_hidden_states = torch.repeat_interleave(encoder_hidden_states, self.n_samples, dim=0) + encoder_input_mask = torch.repeat_interleave(encoder_input_mask, self.n_samples, dim=0) + orig_batch_size = batch_size + batch_size = batch_size * self.n_samples # pad profile tracks sequences ending with token to replace # everything after with token decoder_parameter = next(self.decoder.parameters()) - pad_profile = torch.zeros(batch_size, 1).long().to(decoder_parameter.device) + pad_profile = torch.zeros(batch_size).long().to(decoder_parameter.device) decoder_mems_list = None for i in range(max_generation_length): - log_probs, decoder_mems_list = self._one_step_forward( - tgt[:, -1:], encoder_hidden_states, encoder_input_mask, decoder_mems_list, i + if i == 0: + input_ids = tgt + else: + input_ids = tgt[:, -1:] + + logits, decoder_mems_list = self._one_step_forward( + input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, i ) - next_tokens = torch.argmax(log_probs[:, -1], dim=-1, keepdim=True) + if self.temperature is None: # Greedy decoding + next_tokens = torch.argmax(logits[:, -1], dim=-1) + else: # Temperature sampling + next_tokens = Categorical(logits=logits[:, -1] / self.temperature).sample() + next_tokens = self.pad * pad_profile + next_tokens * (1 - pad_profile) pad_profile = torch.max(pad_profile, (next_tokens == self.eos).long()) - tgt = torch.cat((tgt, next_tokens), dim=-1) + tgt = torch.cat((tgt, next_tokens.unsqueeze(1)), dim=-1) # abort generation if all sequences end with if pad_profile.sum() == batch_size: break - return tgt + samples = None + if is_sampling: + samples = list(tgt.view(orig_batch_size, self.n_samples, -1)) + tgt = tgt[:: self.n_samples] + + return tgt, samples def __call__( self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False @@ -195,9 +235,9 @@ def freeze(self) -> None: for param in self.decoder.parameters(): param.requires_grad = False self.decoder.eval() - for param in self.log_softmax.parameters(): + for param in self.classifier.parameters(): param.requires_grad = False - self.log_softmax.eval() + self.classifier.eval() def unfreeze(self) -> None: """Unfreeze weights of embedding, decoder, and classification layers.""" @@ -207,14 +247,14 @@ def unfreeze(self) -> None: for param in self.decoder.parameters(): param.requires_grad = True self.decoder.train() - for param in self.log_softmax.parameters(): + for param in self.classifier.parameters(): param.requires_grad = True - self.log_softmax.train() + self.classifier.train() @contextmanager def as_frozen(self): """ - Context manager which temporarily freezes embedding, decoder, and log_softmax modules, + Context manager which temporarily freezes embedding, decoder, and classifier modules, yields control and finally unfreezes the modules. """ self.freeze() diff --git a/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py b/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py index ab3938eebe35..de2d63cd99de 100644 --- a/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py @@ -231,9 +231,12 @@ def format_hypotheses(self, packed_result: List[Hypothesis], decoder_input_ids: hyp.y_sequence = hyp.y_sequence[prefix.shape[0] :] for hyp in packed_result: ids = hyp.y_sequence + ids_len = ids.shape[0] pos = -1 while ids[pos] == self.pad or ids[pos] == self.eos: pos -= 1 + if ids_len + pos == -1: + break # empty sequence if pos < -1: hyp.y_sequence = ids[: pos + 1] diff --git a/nemo/collections/asr/parts/submodules/multitask_decoding.py b/nemo/collections/asr/parts/submodules/multitask_decoding.py index e2ed2ca5c4bf..46a62604323f 100644 --- a/nemo/collections/asr/parts/submodules/multitask_decoding.py +++ b/nemo/collections/asr/parts/submodules/multitask_decoding.py @@ -25,6 +25,10 @@ AEDBeamInferConfig, TransformerAEDBeamInfer, ) +from nemo.collections.asr.parts.submodules.multitask_greedy_decoding import ( + AEDGreedyInferConfig, + TransformerAEDGreedyInfer, +) from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec @@ -60,11 +64,9 @@ class AbstractMultiTaskDecoding(ABC): The config may further contain the following sub-dictionaries: "greedy": - max_symbols: int, describing the maximum number of target tokens to decode per - timestep during greedy decoding. Setting to larger values allows longer sentences - to be decoded, at the cost of increased execution time. - preserve_frame_confidence: Same as above, overrides above value. - confidence_method_cfg: Same as above, overrides confidence_cfg.method_cfg. + temperature: None (disabled) or float, specifying this enables temperature sampling instead of greedy decoding. + max_generation_delta: int = -1 # -1 means up to the max length of the decoder + preserve_alignments: bool = False (unsupported) "beam": beam_size: int, defining the beam size for beam search. Must be >= 1. @@ -103,30 +105,43 @@ def __init__( self.preserve_alignments = self.cfg.get('preserve_alignments', None) self.compute_langs = self.cfg.get('compute_langs', False) self.compute_hypothesis_token_set = self.cfg.get('compute_hypothesis_token_set', False) + self.transformer_decoder = transformer_decoder + self.log_softmax_module = log_softmax_module + self.tokenizer = tokenizer + + self.change_strategy(self.cfg.strategy) + def change_strategy(self, strategy: str) -> "AbstractMultiTaskDecoding": possible_strategies = ['greedy', 'greedy_batch', 'beam'] - if self.cfg.strategy not in possible_strategies: - raise ValueError(f"Decoding strategy must be one of {possible_strategies}") + if strategy not in possible_strategies: + raise ValueError(f"Decoding strategy must be one of {possible_strategies}" f"but was provided {strategy}") # Update preserve alignments if self.preserve_alignments is None: - if self.cfg.strategy in ['greedy', 'greedy_batch']: + if strategy in ['greedy', 'greedy_batch']: self.preserve_alignments = self.cfg.greedy.get('preserve_alignments', False) - elif self.cfg.strategy in ['beam']: + elif strategy in ['beam']: self.preserve_alignments = self.cfg.beam.get('preserve_alignments', False) - if self.cfg.strategy == 'greedy' or self.cfg.strategy == 'greedy_batch': + if strategy in ['greedy', 'greedy_batch']: - # self.decoding = None - raise NotImplementedError("Greedy decoding is not implemented yet.") + self.decoding = TransformerAEDGreedyInfer( + transformer_decoder=self.transformer_decoder, + log_softmax_module=self.log_softmax_module, + tokenizer=self.tokenizer, + max_generation_delta=self.cfg.greedy.get('max_generation_delta', 50), + preserve_alignments=self.preserve_alignments, + temperature=self.cfg.greedy.temperature, + n_samples=self.cfg.greedy.n_samples, + ) - elif self.cfg.strategy == 'beam': + elif strategy == 'beam': self.decoding = TransformerAEDBeamInfer( - transformer_decoder=transformer_decoder, - log_softmax_module=log_softmax_module, - tokenizer=tokenizer, + transformer_decoder=self.transformer_decoder, + log_softmax_module=self.log_softmax_module, + tokenizer=self.tokenizer, search_type=self.cfg.beam.get('search_type', 'default'), beam_size=self.cfg.beam.beam_size, length_penalty=self.cfg.beam.get('length_penalty', 0.0), @@ -139,7 +154,7 @@ def __init__( raise ValueError( f"Incorrect decoding strategy provided. Must be one of {possible_strategies}\n" - f"but was provided {self.cfg.strategy}" + f"but was provided {strategy}" ) def decode_predictions_tensor( @@ -465,9 +480,7 @@ class MultiTaskDecodingConfig: compute_langs: bool = False # greedy decoding config - # greedy: rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig = field( - # default_factory=rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig - # ) + greedy: AEDGreedyInferConfig = field(default_factory=AEDGreedyInferConfig) # beam decoding config beam: AEDBeamInferConfig = field(default_factory=lambda: AEDBeamInferConfig(beam_size=1)) diff --git a/nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py b/nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py new file mode 100644 index 000000000000..891d003bd001 --- /dev/null +++ b/nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional + +import torch + +from nemo.collections.asr.modules.transformer import GreedySequenceGenerator +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.core import Typing, typecheck +from nemo.core.neural_types import ChannelType, HypothesisType, LabelsType, MaskType, NeuralType +from nemo.utils import logging + + +def pack_hypotheses( + hypotheses: List[Hypothesis], beam_hypotheses: torch.Tensor, scores: List[Optional[float]] +) -> List[Hypothesis]: + + for idx, hyp in enumerate(hypotheses): # type: Hypothesis + if scores[idx] is not None: + hyp.score = scores[idx] + + hypi = beam_hypotheses[idx] + if torch.is_tensor(hypi): + hyp.y_sequence = hypi.long() + else: + hyp.y_sequence = torch.tensor(hypi, dtype=torch.long) + + if hyp.dec_state is not None: + hyp.dec_state = _states_to_device(hyp.dec_state) + + return hypotheses + + +def _states_to_device(dec_state, device='cpu'): + if torch.is_tensor(dec_state): + dec_state = dec_state.to(device) + + elif isinstance(dec_state, (list, tuple)): + dec_state = tuple(_states_to_device(dec_i, device) for dec_i in dec_state) + + return dec_state + + +class AEDGreedyInfer(ABC): + def __init__( + self, + transformer_decoder: torch.nn.Module, + log_softmax_module: torch.nn.Module, + tokenizer: TokenizerSpec, + search_type: str = 'default', + preserve_alignments: bool = False, + ): + super().__init__() + + self.transformer_decoder = transformer_decoder + self.log_softmax_module = log_softmax_module + self.tokenizer = tokenizer + self.search_type = search_type + + self.preserve_alignments = preserve_alignments + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @abstractmethod + def forward( + self, + encoder_hidden_states: torch.Tensor, + encoder_input_mask: torch.Tensor, + decoder_input_ids: Optional[torch.Tensor] = None, + partial_hypotheses: Optional[List[Hypothesis]] = None, + ): + raise NotImplementedError() + + def set_decoding_type(self, decoding_type: str): + self.decoding_type = decoding_type + + +class TransformerAEDGreedyInfer(AEDGreedyInfer, Typing): + """ + A greedy decoder engine for AED Transformer models with support for temperature sampling. + """ + + @property + def input_types(self): + """Returns definitions of module input ports.""" + # Input can be of dimention - + # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] + + return { + "encoder_hidden_states": NeuralType(tuple(('B', 'T', 'D')), ChannelType()), + "encoder_input_mask": NeuralType(tuple(('B', 'T')), MaskType()), + "decoder_input_ids": NeuralType(('B', 'T'), LabelsType()), + "partial_hypotheses": NeuralType(optional=True), + } + + @property + def output_types(self): + """Returns definitions of module output ports.""" + return {"predictions": [NeuralType(elements_type=HypothesisType())]} + + def __init__( + self, + transformer_decoder: torch.nn.Module, + log_softmax_module: torch.nn.Module, + tokenizer: TokenizerSpec, + temperature: float | None = None, + max_generation_delta: int = 50, + preserve_alignments: bool = False, + n_samples: int = 1, + ): + super().__init__( + transformer_decoder=transformer_decoder, + log_softmax_module=log_softmax_module, + tokenizer=tokenizer, + preserve_alignments=preserve_alignments, + ) + self.temperature = temperature + self.n_samples = n_samples + self.bos = tokenizer.bos + self.pad = tokenizer.pad + self.eos = tokenizer.eos + self.greedy_search = GreedySequenceGenerator( + embedding=transformer_decoder.embedding, + decoder=transformer_decoder.decoder, + classifier=log_softmax_module, + max_sequence_length=transformer_decoder.max_sequence_length, + bos=self.bos, + pad=self.pad, + eos=self.eos, + max_delta_length=max_generation_delta, + temperature=self.temperature, + n_samples=n_samples, + ) + + self.preserve_alignments = preserve_alignments + if self.preserve_alignments: + logging.info( + "Preservation of alignments was requested but {} does not implement it.".format( + self.__class__.__name__ + ) + ) + + @typecheck() + def forward( + self, + encoder_hidden_states: torch.Tensor, + encoder_input_mask: torch.Tensor, + decoder_input_ids: Optional[torch.Tensor] = None, + partial_hypotheses: Optional[List[Hypothesis]] = None, + ): + """Returns a list of hypotheses given an input batch of the encoder hidden embedding. + Output token is generated auto-repressively. + + Args: + decoder_output: A tensor of size (batch, timesteps, features) or (batch, timesteps) (each timestep is a label). + decoder_lengths: list of int representing the length of each sequence + output sequence. + + Returns: + packed list containing batch number of sentences (Hypotheses). + """ + with torch.inference_mode(): + best_hypo, topk_hypotheses = self.greedy_search( + encoder_hidden_states=encoder_hidden_states, + encoder_input_mask=encoder_input_mask, + decoder_input_ids=decoder_input_ids, + ) + + if topk_hypotheses is not None: + topk_hypotheses = [x.detach().cpu() for x in topk_hypotheses] # each item is [beam, seq_len] + beam_scores = [[None] * self.n_samples for _ in topk_hypotheses] # each item is [beam,] + packed_result = [] + for i in range(len(topk_hypotheses)): + # Pack results into Hypotheses + hypotheses = [Hypothesis(score=0.0, y_sequence=[], timestep=[]) for _ in range(self.n_samples)] + self.format_hypotheses(hypotheses, decoder_input_ids) + packed_result.append( + NBestHypotheses(pack_hypotheses(hypotheses, topk_hypotheses[i], beam_scores[i])) + ) + else: + beam_scores = [None for _ in range(len(best_hypo))] + best_hypo = best_hypo.cpu() + hypotheses = [ + Hypothesis(score=0.0, y_sequence=[], timestep=[]) for _ in range(encoder_hidden_states.shape[0]) + ] + # Pack results into Hypotheses + packed_result = pack_hypotheses(hypotheses, best_hypo, beam_scores) + self.format_hypotheses(packed_result, decoder_input_ids) + + return (packed_result,) + + def format_hypotheses(self, packed_result: List[Hypothesis], decoder_input_ids: torch.Tensor | None) -> None: + """ + For each hypothesis in the mini-batch: + * Remove the decoder input ids (prompt) from the predictions + * Remove BOS, EOS, and PAD ids from the predictions. + Modifies results in-place. + """ + if decoder_input_ids is not None: + assert ( + len(packed_result) == decoder_input_ids.shape[0] + ), f"Mismatching number of examples {len(packed_result)=} {decoder_input_ids.shape[0]=}" + decoder_input_ids = decoder_input_ids.detach().cpu() + for hyp, prefix in zip(packed_result, decoder_input_ids): + assert ( + hyp.y_sequence[: prefix.shape[0]] == prefix + ).all(), f"The decoder input IDs were not found at the beginning of prediction: {hyp.y_sequence=} {prefix=})" + hyp.y_sequence = hyp.y_sequence[prefix.shape[0] :] + for hyp in packed_result: + ids = hyp.y_sequence + ids_len = ids.shape[0] + pos = -1 + while ids[pos] == self.pad or ids[pos] == self.eos: + pos -= 1 + if ids_len + pos == -1: + break # empty sequence + if pos < -1: + hyp.y_sequence = ids[: pos + 1] + + +@dataclass +class AEDGreedyInferConfig: + temperature: float | None = None + max_generation_delta: int = -1 # -1 means up to the max length of the decoder + preserve_alignments: bool = False + n_samples: int = 1 diff --git a/nemo/collections/asr/parts/submodules/token_classifier.py b/nemo/collections/asr/parts/submodules/token_classifier.py index 4061d19d9015..5f50308f0a14 100644 --- a/nemo/collections/asr/parts/submodules/token_classifier.py +++ b/nemo/collections/asr/parts/submodules/token_classifier.py @@ -15,12 +15,13 @@ from dataclasses import dataclass from typing import Dict, Optional +import torch from torch import nn as nn from nemo.collections.asr.parts.submodules.classifier import Classifier from nemo.collections.common.parts import MultiLayerPerceptron from nemo.core.classes import typecheck -from nemo.core.neural_types import LogitsType, LogprobsType, NeuralType +from nemo.core.neural_types import ChannelType, FloatType, LogitsType, LogprobsType, NeuralType __all__ = ['BertPretrainingTokenClassifier', 'TokenClassifier'] @@ -42,7 +43,14 @@ class TokenClassifier(Classifier): """ @property - def output_types(self) -> Optional[Dict[str, NeuralType]]: + def input_types(self) -> Dict[str, NeuralType]: + return { + "hidden_states": NeuralType(('B', 'T', 'D'), ChannelType()), + "temperature": NeuralType(None, FloatType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: """ Returns definitions of module output ports. """ @@ -61,7 +69,6 @@ def __init__( dropout: float = 0.0, use_transformer_init: bool = True, ) -> None: - """ Initializes the Token Classifier module. @@ -81,8 +88,12 @@ def __init__( ) self.post_init(use_transformer_init=use_transformer_init) + def set_log_softmax_enabled(self, value: bool) -> "TokenClassifier": + self.log_softmax = value + return self + @typecheck() - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor, temperature: float | None = None) -> torch.Tensor: """ Performs the forward step of the module. Args: @@ -91,7 +102,7 @@ def forward(self, hidden_states): Returns: logits value for each class [BATCH_SIZE x SEQ_LENGTH x NUM_CLASSES] """ hidden_states = self.dropout(hidden_states) - logits = self.mlp(hidden_states) + logits = self.mlp(hidden_states, temperature=temperature) return logits @@ -100,6 +111,13 @@ class BertPretrainingTokenClassifier(Classifier): A module to perform token level classification tasks for Bert pretraining. """ + @property + def input_types(self) -> Dict[str, NeuralType]: + return { + "hidden_states": NeuralType(('B', 'T', 'D'), ChannelType()), + "temperature": NeuralType(None, FloatType(), optional=True), + } + @property def output_types(self) -> Optional[Dict[str, NeuralType]]: """ @@ -120,7 +138,6 @@ def __init__( dropout: float = 0.0, use_transformer_init: bool = True, ) -> None: - """ Initializes the Token Classifier module. @@ -147,8 +164,12 @@ def __init__( ) self.post_init(use_transformer_init=use_transformer_init) + def set_log_softmax_enabled(self, value: bool) -> "BertPretrainingTokenClassifier": + self.log_softmax = value + return self + @typecheck() - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor, temperature: float | None = None) -> torch.Tensor: """ Performs the forward step of the module. Args: @@ -160,5 +181,5 @@ def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.act(hidden_states) transform = self.norm(hidden_states) - logits = self.mlp(transform) + logits = self.mlp(transform, temperature=temperature) return logits diff --git a/nemo/collections/common/parts/multi_layer_perceptron.py b/nemo/collections/common/parts/multi_layer_perceptron.py index 76c06bf23ea6..5110406fedfd 100644 --- a/nemo/collections/common/parts/multi_layer_perceptron.py +++ b/nemo/collections/common/parts/multi_layer_perceptron.py @@ -51,11 +51,15 @@ def __init__( def last_linear_layer(self): return getattr(self, f'layer{self.layers - 1}') - def forward(self, hidden_states): + def forward(self, hidden_states, temperature: float | None = None): output_states = hidden_states[:] for i in range(self.layers): output_states = getattr(self, f'layer{i}')(output_states) + if temperature is not None: + output_states = output_states / temperature + if self.log_softmax: output_states = torch.log_softmax(output_states, dim=-1) + return output_states diff --git a/tests/collections/asr/decoding/test_multi_task_decoding.py b/tests/collections/asr/decoding/test_multi_task_decoding.py new file mode 100644 index 000000000000..ee76d42ca411 --- /dev/null +++ b/tests/collections/asr/decoding/test_multi_task_decoding.py @@ -0,0 +1,182 @@ +from unittest.mock import Mock + +import pytest +import torch + +from nemo.collections.asr.modules.transformer.transformer import TransformerDecoderNM +from nemo.collections.asr.modules.transformer.transformer_generators import ( + BeamSearchSequenceGenerator, + GreedySequenceGenerator, +) +from nemo.collections.asr.parts.submodules.multitask_beam_decoding import TransformerAEDBeamInfer +from nemo.collections.asr.parts.submodules.multitask_greedy_decoding import TransformerAEDGreedyInfer +from nemo.collections.asr.parts.submodules.token_classifier import TokenClassifier + + +@pytest.fixture() +def deterministic_rng(): + state = torch.get_rng_state() + torch.manual_seed(0) + yield + torch.set_rng_state(state) + + +@pytest.fixture() +def decoder_nm(deterministic_rng): + return TransformerDecoderNM( + vocab_size=8, + hidden_size=2, + num_layers=1, + inner_size=4, + num_attention_heads=1, + max_sequence_length=32, + ).eval() + + +@pytest.fixture() +def nnet(decoder_nm): + ans = ( + decoder_nm.embedding, + decoder_nm.decoder, + TokenClassifier(hidden_size=2, num_classes=8), + ) + ans = tuple(m.eval() for m in ans) + return ans + + +@pytest.fixture() +def inputs(): + B, T, C = 1, 5, 2 + return ( + torch.tensor([[1]], dtype=torch.long), # decoder_input_ids + torch.ones(B, T, C, dtype=torch.float), # encoder_hidden_states + torch.ones(B, T, dtype=torch.float), # encoder_input_mask + ) + + +@pytest.fixture() +def tokenizer(): + tok = Mock() + tok.pad = 0 + tok.bos = 1 + tok.eos = 2 + return tok + + +def test_greedy_decoding(inputs, nnet): + gen = GreedySequenceGenerator(*nnet) + output = gen(*inputs) + + assert len(output) == 2 + best_path, hypotheses = output + + assert best_path is not None + assert torch.is_tensor(best_path) + assert best_path.shape == (1, 25) + + assert hypotheses is None + + +def test_temperature_sampling_decoding(inputs, nnet): + gen = GreedySequenceGenerator(*nnet, temperature=10.0, n_samples=2) + output = gen(*inputs) + + assert len(output) == 2 + best_path, hypotheses = output + + assert best_path is not None + assert torch.is_tensor(best_path) + assert best_path.shape == (1, 25) + + assert isinstance(hypotheses, list) + assert len(hypotheses) == 1 + (seq0,) = hypotheses + assert seq0.shape == (2, 25) + + +def test_beam_decoding_beam_scores_false(inputs, nnet): + gen = BeamSearchSequenceGenerator(*nnet, beam_size=2) + output = gen(*inputs, return_beam_scores=False) + + assert len(output) == 1 + (best_path,) = output + + assert best_path is not None + assert torch.is_tensor(best_path) + assert best_path.shape == (26,) + + +def test_beam_decoding_beam_scores_true(inputs, nnet): + gen = BeamSearchSequenceGenerator(*nnet, beam_size=2) + output = gen(*inputs, return_beam_scores=True) + + assert len(output) == 3 + beam_paths, scores, best_path = output + + assert beam_paths is not None + assert isinstance(beam_paths, list) + assert len(beam_paths) == 1 + (beam_paths_seq0,) = beam_paths + assert torch.is_tensor(beam_paths_seq0) + assert beam_paths_seq0.shape == (2, 26) + + assert scores is not None + assert isinstance(scores, list) + assert len(scores) == 1 + (scores_seq0,) = scores + assert torch.is_tensor(scores_seq0) + assert scores_seq0.shape == (2,) + + assert best_path is not None + assert torch.is_tensor(best_path) + assert best_path.shape == (1, 26) + + +def test_transformer_aed_beam_infer_strips_prompt(inputs, decoder_nm, nnet, tokenizer): + decoder_input_ids, encoder_hidden_states, encoder_input_mask = inputs + *_, classifier = nnet + + # Run the actual top-level module used by MultiTask AED model for decoding. + # This module is expected to trim the prompt from the beginning, and eos and pad from the end. + gen = TransformerAEDBeamInfer(decoder_nm, classifier, tokenizer) + ans = gen( + encoder_hidden_states=encoder_hidden_states, + encoder_input_mask=encoder_input_mask, + decoder_input_ids=decoder_input_ids, + ) + best_path = ans[0][0].y_sequence + assert best_path is not None + assert torch.is_tensor(best_path) + + # Now run the underlying beam search generator that doesn't trim anything. + *_, (untrimmed,) = gen.beam_search(*inputs, return_beam_scores=True) + assert untrimmed is not None + assert torch.is_tensor(untrimmed) + + # Check that the expected trimming has indeed been done. + assert (untrimmed[decoder_input_ids.shape[1] :] == best_path).all() # stripped the prompt [1,] from the beggining + + +def test_transformer_aed_greedy_infer_strips_prompt(inputs, decoder_nm, nnet, tokenizer): + decoder_input_ids, encoder_hidden_states, encoder_input_mask = inputs + *_, classifier = nnet + + # Run the actual top-level module used by MultiTask AED model for decoding. + # This module is expected to trim the prompt from the beginning, and eos and pad from the end. + gen = TransformerAEDGreedyInfer(decoder_nm, classifier, tokenizer) + ans = gen( + encoder_hidden_states=encoder_hidden_states, + encoder_input_mask=encoder_input_mask, + decoder_input_ids=decoder_input_ids, + ) + best_path = ans[0][0].y_sequence + assert best_path is not None + assert torch.is_tensor(best_path) + + # Now run the underlying beam search generator that doesn't trim anything. + (untrimmed,), _ = gen.greedy_search(*inputs) + assert untrimmed is not None + assert torch.is_tensor(untrimmed) + + # Check that the expected trimming has indeed been done. + assert (untrimmed[decoder_input_ids.shape[1] :] == best_path).all() # stripped the prompt [1,] from the beggining From 5b0d0b61bb6b0173966af0ae57a6b652a9c5fcde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 31 Jul 2024 13:11:08 -0400 Subject: [PATCH 2/6] Fix transcripts for Canary containing EOS MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../asr/data/audio_to_text_lhotse_prompted.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index 0e8fdf22034f..71a73d2b0936 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -58,7 +58,7 @@ class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset): def __init__( self, tokenizer: TokenizerSpec, - prompt_format_fn: Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]], + prompt_format_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]], ): super().__init__() self.tokenizer = TokenizerWrapper(tokenizer) @@ -69,10 +69,9 @@ def __init__( def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch: audio, audio_lens, cuts = self.load_audio(cuts) - prompts_with_answers, prompts = self.prompt_format_fn(cuts, self.tokenizer) + prompts_with_answers, prompts, answers = self.prompt_format_fn(cuts, self.tokenizer) - transcript = [pa[len(p) :] for pa, p in zip(prompts_with_answers, prompts)] - transcript, transcript_lens = self._collate_tokens(transcript) + transcript, transcript_lens = self._collate_tokens(answers) prompts_with_answers, prompts_with_answers_lens = self._collate_tokens(prompts_with_answers) prompts, prompt_lens = self._collate_tokens(prompts) @@ -99,7 +98,7 @@ def _collate_tokens(self, tokens: list[list[int]]) -> tuple[torch.Tensor, torch. PROMPT_FORMAT_FNS = {} -def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]]): +def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]): """ Decorator for registering prompt functions under a name. @@ -126,7 +125,9 @@ def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper], Sequ @registered_prompt_format_fn -def canary(cuts: CutSet, tokenizer: TokenizerWrapper) -> tuple[list[torch.Tensor], list[torch.Tensor]]: +def canary( + cuts: CutSet, tokenizer: TokenizerWrapper +) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: """ Prepend and append control tokens to the token sequence as per Canary format. @@ -155,7 +156,7 @@ def canary(cuts: CutSet, tokenizer: TokenizerWrapper) -> tuple[list[torch.Tensor ), "To use 'canary' prompt format, you must use the CanaryTokenizer." formatter = CanaryPromptFormatter(tokenizer._tokenizer) - prompts_with_answers, prompts = [], [] + prompts_with_answers, prompts, answers = [], [], [] for cut in cuts: if isinstance(cut, MixedCut): cut = cut._first_non_padding_cut @@ -198,8 +199,9 @@ def canary(cuts: CutSet, tokenizer: TokenizerWrapper) -> tuple[list[torch.Tensor ) prompts_with_answers.append(encoded["input_ids"]) prompts.append(encoded["context_ids"]) + answers.append(encoded["answer_ids"][:-1]) # Strip Canary's EOS - return prompts_with_answers, prompts + return prompts_with_answers, prompts, answers class ProbablyIncorrectLanguageKeyError(RuntimeError): From 176603ede77a6ff2ee6fd48967bf2ccb07d9b4ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 31 Jul 2024 13:12:27 -0400 Subject: [PATCH 3/6] Fix validation_pass MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/asr/models/aed_multitask_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index a5ce0d64e467..9888fb72fb3b 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -689,7 +689,6 @@ def training_step(self, batch: PromptedAudioToTextMiniBatch, batch_nb): def validation_pass(self, batch: PromptedAudioToTextMiniBatch, batch_idx, dataloader_idx=0, eval_mode="val"): # During inference, dataloader passes pure prompt without transcript text. - signal, signal_len, transcript, transcript_len, prompt, prompt_len = batch input_ids, labels = batch.prompted_transcript[:, :-1], batch.prompted_transcript[:, 1:] transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( From 3008378b6783581e8a498619e5cbc9d026ea0cee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 31 Jul 2024 15:13:34 -0400 Subject: [PATCH 4/6] Review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../asr/data/audio_to_text_lhotse_prompted.py | 12 ++++- .../asr/models/aed_multitask_models.py | 5 +-- .../transformer/transformer_generators.py | 27 ++++++----- .../parts/submodules/multitask_decoding.py | 4 +- .../asr/parts/submodules/token_classifier.py | 45 +++++++++++-------- .../common/parts/multi_layer_perceptron.py | 6 +-- .../asr/decoding/test_multi_task_decoding.py | 31 +++++++++---- 7 files changed, 82 insertions(+), 48 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index 71a73d2b0936..53ffa1c00e65 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -37,6 +37,14 @@ class PromptedAudioToTextMiniBatch: prompted_transcript: torch.Tensor prompted_transcript_lens: torch.Tensor + def get_decoder_inputs_outputs(self) -> tuple[torch.Tensor, torch.Tensor]: + """ + Returns the inputs and outputs of transformer decoder for training. + The input is ``prompted_transcript`` (minus last token), + and the output is ``prompted_transcript`` (minus first token). + """ + return self.prompted_transcript[:, :-1], self.prompted_transcript[:, 1:] + class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset): """ @@ -75,7 +83,6 @@ def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch: prompts_with_answers, prompts_with_answers_lens = self._collate_tokens(prompts_with_answers) prompts, prompt_lens = self._collate_tokens(prompts) - # return audio, audio_lens, prompts_with_answers, prompts_with_answers_lens, prompts, prompts_lens return PromptedAudioToTextMiniBatch( audio=audio, audio_lens=audio_lens, @@ -199,6 +206,9 @@ def canary( ) prompts_with_answers.append(encoded["input_ids"]) prompts.append(encoded["context_ids"]) + assert ( + encoded["answer_ids"][-1].item() == formatter.tokenizer.eos + ), f"Expected the last token in answer_ids to be EOS, but we got {encoded['answer_ids']=}" answers.append(encoded["answer_ids"][:-1]) # Strip Canary's EOS return prompts_with_answers, prompts, answers diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 9888fb72fb3b..cbfbb0ba3a83 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -669,7 +669,7 @@ def training_step(self, batch: PromptedAudioToTextMiniBatch, batch_nb): if batch is None: return torch.tensor([0.0]) - input_ids, labels = batch.prompted_transcript[:, :-1], batch.prompted_transcript[:, 1:] + input_ids, labels = batch.get_decoder_inputs_outputs() transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( input_signal=batch.audio, @@ -688,8 +688,7 @@ def training_step(self, batch: PromptedAudioToTextMiniBatch, batch_nb): return {'loss': audio_loss, 'log': tensorboard_logs} def validation_pass(self, batch: PromptedAudioToTextMiniBatch, batch_idx, dataloader_idx=0, eval_mode="val"): - # During inference, dataloader passes pure prompt without transcript text. - input_ids, labels = batch.prompted_transcript[:, :-1], batch.prompted_transcript[:, 1:] + input_ids, labels = batch.get_decoder_inputs_outputs() transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( input_signal=batch.audio, diff --git a/nemo/collections/asr/modules/transformer/transformer_generators.py b/nemo/collections/asr/modules/transformer/transformer_generators.py index 6d4ef223a75b..fea1a346b637 100644 --- a/nemo/collections/asr/modules/transformer/transformer_generators.py +++ b/nemo/collections/asr/modules/transformer/transformer_generators.py @@ -70,14 +70,6 @@ def __init__( super().__init__() self.embedding = embedding self.decoder = decoder - if hasattr(classifier, "set_log_softmax_enabled"): - classifier = classifier.set_log_softmax_enabled(False) - else: - assert temperature is None, ( - "The module passed as 'classifier' does not support disabling log-softmax, but we require it " - "for temperature sampling since 'temperature' was set. " - "Your model architecture may not support temperature sampling: we suggest disabling temperature." - ) self.classifier = classifier self.pad, self.bos, self.eos = pad, bos, eos self.max_seq_length = max_sequence_length @@ -93,6 +85,7 @@ def _one_step_forward( encoder_input_mask=None, decoder_mems_list=None, pos=0, + need_scores: bool = True, ): """ One step of autoregressive output generation. @@ -125,7 +118,8 @@ def _one_step_forward( decoder_mems_list = self.decoder.forward( decoder_hidden_states, decoder_input_mask, decoder_mems_list, return_mems=True ) - logits = self.classifier.forward(hidden_states=decoder_mems_list[-1][:, -1:], temperature=self.temperature) + with self.classifier.with_log_softmax_enabled(need_scores) as clf: + logits = clf.forward(hidden_states=decoder_mems_list[-1][:, -1:]) return logits, decoder_mems_list def _prepare_for_search(self, decoder_input_ids=None, encoder_hidden_states=None): @@ -187,7 +181,12 @@ def _forward( input_ids = tgt[:, -1:] logits, decoder_mems_list = self._one_step_forward( - input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, i + input_ids, + encoder_hidden_states, + encoder_input_mask, + decoder_mems_list, + i, + need_scores=return_beam_scores, ) if self.temperature is None: # Greedy decoding @@ -292,9 +291,15 @@ def _one_step_forward( encoder_input_mask=None, decoder_mems_list=None, pos=0, + need_scores: bool = True, ): log_probs, decoder_mems_list = super()._one_step_forward( - decoder_input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, pos + decoder_input_ids, + encoder_hidden_states, + encoder_input_mask, + decoder_mems_list, + pos, + need_scores=need_scores, ) batch_size, seq_len, vocab_size = log_probs.size() diff --git a/nemo/collections/asr/parts/submodules/multitask_decoding.py b/nemo/collections/asr/parts/submodules/multitask_decoding.py index 46a62604323f..715ee7168037 100644 --- a/nemo/collections/asr/parts/submodules/multitask_decoding.py +++ b/nemo/collections/asr/parts/submodules/multitask_decoding.py @@ -130,7 +130,7 @@ def change_strategy(self, strategy: str) -> "AbstractMultiTaskDecoding": transformer_decoder=self.transformer_decoder, log_softmax_module=self.log_softmax_module, tokenizer=self.tokenizer, - max_generation_delta=self.cfg.greedy.get('max_generation_delta', 50), + max_generation_delta=self.cfg.greedy.get('max_generation_delta', -1), preserve_alignments=self.preserve_alignments, temperature=self.cfg.greedy.temperature, n_samples=self.cfg.greedy.n_samples, @@ -145,7 +145,7 @@ def change_strategy(self, strategy: str) -> "AbstractMultiTaskDecoding": search_type=self.cfg.beam.get('search_type', 'default'), beam_size=self.cfg.beam.beam_size, length_penalty=self.cfg.beam.get('length_penalty', 0.0), - max_generation_delta=self.cfg.beam.get('max_generation_delta', 50), + max_generation_delta=self.cfg.beam.get('max_generation_delta', -1), return_best_hypothesis=self.cfg.beam.get('return_best_hypothesis', True), preserve_alignments=self.preserve_alignments, ) diff --git a/nemo/collections/asr/parts/submodules/token_classifier.py b/nemo/collections/asr/parts/submodules/token_classifier.py index 5f50308f0a14..cc435308fcae 100644 --- a/nemo/collections/asr/parts/submodules/token_classifier.py +++ b/nemo/collections/asr/parts/submodules/token_classifier.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from contextlib import contextmanager from dataclasses import dataclass from typing import Dict, Optional @@ -46,7 +46,6 @@ class TokenClassifier(Classifier): def input_types(self) -> Dict[str, NeuralType]: return { "hidden_states": NeuralType(('B', 'T', 'D'), ChannelType()), - "temperature": NeuralType(None, FloatType(), optional=True), } @property @@ -54,7 +53,7 @@ def output_types(self) -> Dict[str, NeuralType]: """ Returns definitions of module output ports. """ - if not self.log_softmax: + if not self.mlp.log_softmax: return {"logits": NeuralType(('B', 'T', 'C'), LogitsType())} else: return {"log_probs": NeuralType(('B', 'T', 'C'), LogprobsType())} @@ -82,18 +81,24 @@ def __init__( use_transformer_init: whether to initialize the weights of the classifier head with the same approach used in Transformer """ super().__init__(hidden_size=hidden_size, dropout=dropout) - self.log_softmax = log_softmax self.mlp = MultiLayerPerceptron( hidden_size, num_classes, num_layers=num_layers, activation=activation, log_softmax=log_softmax ) self.post_init(use_transformer_init=use_transformer_init) - def set_log_softmax_enabled(self, value: bool) -> "TokenClassifier": - self.log_softmax = value - return self + @property + def log_softmax(self) -> bool: + return self.mlp.log_softmax + + @contextmanager + def with_log_softmax_enabled(self, value: bool) -> "TokenClassifier": + prev = self.mlp.log_softmax + self.mlp.log_softmax = value + yield self + self.mlp.log_softmax = prev @typecheck() - def forward(self, hidden_states: torch.Tensor, temperature: float | None = None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Performs the forward step of the module. Args: @@ -102,7 +107,7 @@ def forward(self, hidden_states: torch.Tensor, temperature: float | None = None) Returns: logits value for each class [BATCH_SIZE x SEQ_LENGTH x NUM_CLASSES] """ hidden_states = self.dropout(hidden_states) - logits = self.mlp(hidden_states, temperature=temperature) + logits = self.mlp(hidden_states) return logits @@ -115,7 +120,6 @@ class BertPretrainingTokenClassifier(Classifier): def input_types(self) -> Dict[str, NeuralType]: return { "hidden_states": NeuralType(('B', 'T', 'D'), ChannelType()), - "temperature": NeuralType(None, FloatType(), optional=True), } @property @@ -123,7 +127,7 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: """ Returns definitions of module output ports. """ - if not self.log_softmax: + if not self.mlp.log_softmax: return {"logits": NeuralType(('B', 'T', 'C'), LogitsType())} else: return {"log_probs": NeuralType(('B', 'T', 'C'), LogprobsType())} @@ -152,8 +156,6 @@ def __init__( """ super().__init__(hidden_size=hidden_size, dropout=dropout) - self.log_softmax = log_softmax - if activation not in ACT2FN: raise ValueError(f'activation "{activation}" not found') self.dense = nn.Linear(hidden_size, hidden_size) @@ -164,12 +166,19 @@ def __init__( ) self.post_init(use_transformer_init=use_transformer_init) - def set_log_softmax_enabled(self, value: bool) -> "BertPretrainingTokenClassifier": - self.log_softmax = value - return self + @property + def log_softmax(self) -> bool: + return self.mlp.log_softmax + + @contextmanager + def with_log_softmax_enabled(self, value: bool) -> "TokenClassifier": + prev = self.mlp.log_softmax + self.mlp.log_softmax = value + yield self + self.mlp.log_softmax = prev @typecheck() - def forward(self, hidden_states: torch.Tensor, temperature: float | None = None) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Performs the forward step of the module. Args: @@ -181,5 +190,5 @@ def forward(self, hidden_states: torch.Tensor, temperature: float | None = None) hidden_states = self.dense(hidden_states) hidden_states = self.act(hidden_states) transform = self.norm(hidden_states) - logits = self.mlp(transform, temperature=temperature) + logits = self.mlp(transform) return logits diff --git a/nemo/collections/common/parts/multi_layer_perceptron.py b/nemo/collections/common/parts/multi_layer_perceptron.py index 5110406fedfd..76c06bf23ea6 100644 --- a/nemo/collections/common/parts/multi_layer_perceptron.py +++ b/nemo/collections/common/parts/multi_layer_perceptron.py @@ -51,15 +51,11 @@ def __init__( def last_linear_layer(self): return getattr(self, f'layer{self.layers - 1}') - def forward(self, hidden_states, temperature: float | None = None): + def forward(self, hidden_states): output_states = hidden_states[:] for i in range(self.layers): output_states = getattr(self, f'layer{i}')(output_states) - if temperature is not None: - output_states = output_states / temperature - if self.log_softmax: output_states = torch.log_softmax(output_states, dim=-1) - return output_states diff --git a/tests/collections/asr/decoding/test_multi_task_decoding.py b/tests/collections/asr/decoding/test_multi_task_decoding.py index ee76d42ca411..72ffeda9097f 100644 --- a/tests/collections/asr/decoding/test_multi_task_decoding.py +++ b/tests/collections/asr/decoding/test_multi_task_decoding.py @@ -132,8 +132,18 @@ def test_beam_decoding_beam_scores_true(inputs, nnet): assert best_path.shape == (1, 26) -def test_transformer_aed_beam_infer_strips_prompt(inputs, decoder_nm, nnet, tokenizer): - decoder_input_ids, encoder_hidden_states, encoder_input_mask = inputs +@pytest.fixture() +def prompted_inputs(): + B, T, C = 1, 5, 2 + return ( + torch.tensor([[1, 0, 2, 3, 4]], dtype=torch.long), # prompt + torch.ones(B, T, C, dtype=torch.float), # encoder_hidden_states + torch.ones(B, T, dtype=torch.float), # encoder_input_mask + ) + + +def test_transformer_aed_beam_infer_strips_prompt(prompted_inputs, decoder_nm, nnet, tokenizer): + decoder_input_ids, encoder_hidden_states, encoder_input_mask = prompted_inputs *_, classifier = nnet # Run the actual top-level module used by MultiTask AED model for decoding. @@ -149,16 +159,19 @@ def test_transformer_aed_beam_infer_strips_prompt(inputs, decoder_nm, nnet, toke assert torch.is_tensor(best_path) # Now run the underlying beam search generator that doesn't trim anything. - *_, (untrimmed,) = gen.beam_search(*inputs, return_beam_scores=True) + *_, (untrimmed,) = gen.beam_search(*prompted_inputs, return_beam_scores=True) assert untrimmed is not None assert torch.is_tensor(untrimmed) # Check that the expected trimming has indeed been done. - assert (untrimmed[decoder_input_ids.shape[1] :] == best_path).all() # stripped the prompt [1,] from the beggining + torch.testing.assert_close( + untrimmed[decoder_input_ids.shape[1] :], best_path + ) # stripped the prompt from the beggining -def test_transformer_aed_greedy_infer_strips_prompt(inputs, decoder_nm, nnet, tokenizer): - decoder_input_ids, encoder_hidden_states, encoder_input_mask = inputs +def test_transformer_aed_greedy_infer_strips_prompt(prompted_inputs, decoder_nm, nnet, tokenizer): + decoder_input_ids, encoder_hidden_states, encoder_input_mask = prompted_inputs + decoder_input_ids = torch.tensor([[1, 0, 2, 3, 4]], dtype=torch.long) # prompt *_, classifier = nnet # Run the actual top-level module used by MultiTask AED model for decoding. @@ -174,9 +187,11 @@ def test_transformer_aed_greedy_infer_strips_prompt(inputs, decoder_nm, nnet, to assert torch.is_tensor(best_path) # Now run the underlying beam search generator that doesn't trim anything. - (untrimmed,), _ = gen.greedy_search(*inputs) + (untrimmed,), _ = gen.greedy_search(*prompted_inputs) assert untrimmed is not None assert torch.is_tensor(untrimmed) # Check that the expected trimming has indeed been done. - assert (untrimmed[decoder_input_ids.shape[1] :] == best_path).all() # stripped the prompt [1,] from the beggining + torch.testing.assert_close( + untrimmed[decoder_input_ids.shape[1] :], best_path + ) # stripped the prompt from the beggining From e4a7fd053ad9273a4bb837bb46e25992242e5764 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 31 Jul 2024 16:02:52 -0400 Subject: [PATCH 5/6] Revie 2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../asr/modules/transformer/transformer_generators.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo/collections/asr/modules/transformer/transformer_generators.py b/nemo/collections/asr/modules/transformer/transformer_generators.py index fea1a346b637..e6775a48f635 100644 --- a/nemo/collections/asr/modules/transformer/transformer_generators.py +++ b/nemo/collections/asr/modules/transformer/transformer_generators.py @@ -85,7 +85,7 @@ def _one_step_forward( encoder_input_mask=None, decoder_mems_list=None, pos=0, - need_scores: bool = True, + return_scores: bool = True, ): """ One step of autoregressive output generation. @@ -118,7 +118,7 @@ def _one_step_forward( decoder_mems_list = self.decoder.forward( decoder_hidden_states, decoder_input_mask, decoder_mems_list, return_mems=True ) - with self.classifier.with_log_softmax_enabled(need_scores) as clf: + with self.classifier.with_log_softmax_enabled(return_scores) as clf: logits = clf.forward(hidden_states=decoder_mems_list[-1][:, -1:]) return logits, decoder_mems_list @@ -186,7 +186,7 @@ def _forward( encoder_input_mask, decoder_mems_list, i, - need_scores=return_beam_scores, + return_scores=return_beam_scores, ) if self.temperature is None: # Greedy decoding @@ -291,7 +291,7 @@ def _one_step_forward( encoder_input_mask=None, decoder_mems_list=None, pos=0, - need_scores: bool = True, + return_scores: bool = True, ): log_probs, decoder_mems_list = super()._one_step_forward( decoder_input_ids, @@ -299,7 +299,7 @@ def _one_step_forward( encoder_input_mask, decoder_mems_list, pos, - need_scores=need_scores, + return_scores=return_scores, ) batch_size, seq_len, vocab_size = log_probs.size() From 51a40b01aef97050ec47e4682f379cf86e571201 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 31 Jul 2024 16:23:41 -0400 Subject: [PATCH 6/6] Fix non-deterministic unit test assertions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- tests/collections/asr/decoding/test_multi_task_decoding.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/collections/asr/decoding/test_multi_task_decoding.py b/tests/collections/asr/decoding/test_multi_task_decoding.py index 72ffeda9097f..906caccad396 100644 --- a/tests/collections/asr/decoding/test_multi_task_decoding.py +++ b/tests/collections/asr/decoding/test_multi_task_decoding.py @@ -63,7 +63,7 @@ def tokenizer(): return tok -def test_greedy_decoding(inputs, nnet): +def test_greedy_decoding(inputs, nnet, deterministic_rng): gen = GreedySequenceGenerator(*nnet) output = gen(*inputs) @@ -86,12 +86,13 @@ def test_temperature_sampling_decoding(inputs, nnet): assert best_path is not None assert torch.is_tensor(best_path) - assert best_path.shape == (1, 25) + assert best_path.shape[0] == 1 assert isinstance(hypotheses, list) assert len(hypotheses) == 1 (seq0,) = hypotheses - assert seq0.shape == (2, 25) + assert seq0.shape[0] == 2 + assert (seq0[0] != seq0[1]).any() def test_beam_decoding_beam_scores_false(inputs, nnet):