Skip to content

Commit

Permalink
Add Canary support for decoding with return_hypotheses=True (NVIDIA#8338
Browse files Browse the repository at this point in the history
)

* change default decoding to beam=1 and length up to decoder max len

Signed-off-by: stevehuang52 <[email protected]>

* add Canary support for return_hypotheses=True

Signed-off-by: stevehuang52 <[email protected]>

* change len_pen default to 0

Signed-off-by: stevehuang52 <[email protected]>

* fix for nbest hypotheses

Signed-off-by: stevehuang52 <[email protected]>

* fix for return best hypo

Signed-off-by: stevehuang52 <[email protected]>

---------

Signed-off-by: stevehuang52 <[email protected]>
Signed-off-by: Zeeshan Patel <[email protected]>
  • Loading branch information
stevehuang52 authored and zpx01 committed Mar 8, 2024
1 parent 4995981 commit cb2b5e7
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 33 deletions.
2 changes: 1 addition & 1 deletion examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def autocast(dtype=None):
logging.info(f"Finished transcribing {len(filepaths)} files !")
logging.info(f"Writing transcriptions into file: {cfg.output_filename}")

# if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
# if transcriptions form a tuple of (best_hypotheses, all_hypotheses), extract just best hypothesis
if type(transcriptions) == tuple and len(transcriptions) == 2:
transcriptions = transcriptions[0]

Expand Down
21 changes: 14 additions & 7 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,23 +745,30 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo

del log_probs, encoded_len

beam_hypotheses = self.decoding.decode_predictions_tensor(
best_hypotheses, all_hypotheses = self.decoding.decode_predictions_tensor(
encoder_hidden_states=enc_states,
encoder_input_mask=enc_mask,
decoder_input_ids=decoder_input_ids if self.context_len_for_AR_decoding > 0 else None,
return_hypotheses=trcfg.return_hypotheses,
)[
0
] # type: List[str] | List[Hypothesis]
)

if trcfg.return_hypotheses:
for hyp in beam_hypotheses:
for hyp in best_hypotheses:
hyp.text = self.decoding.strip_special_tokens(hyp.text)
if all_hypotheses is not None:
for i in range(len(all_hypotheses)):
for j in range(len(all_hypotheses[i])):
all_hypotheses[i][j].text = self.decoding.strip_special_tokens(all_hypotheses[i][j].text)
else:
beam_hypotheses = [self.decoding.strip_special_tokens(text) for text in beam_hypotheses]
best_hypotheses = [self.decoding.strip_special_tokens(text) for text in best_hypotheses]
if all_hypotheses is not None:
for i in range(len(all_hypotheses)):
all_hypotheses[i] = [self.decoding.strip_special_tokens(text) for text in all_hypotheses[i]]

del enc_states, enc_mask, decoder_input_ids
return beam_hypotheses
if all_hypotheses is None:
return best_hypotheses
return best_hypotheses, all_hypotheses

def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader':
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,18 @@ def __call__(
self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False
):
with self.as_frozen():
return self._forward(
results = self._forward(
decoder_input_ids, encoder_hidden_states, encoder_input_mask, return_beam_scores=return_beam_scores
)
if not return_beam_scores:
return results
else:
prefixes, scores, tgt = results
prefixes = prefixes.view(-1, self.beam_size, tgt.size(1)).split(1, dim=0)
scores = scores.view(-1, self.beam_size).split(1, dim=0)
prefixes = [x.squeeze(0) for x in prefixes] # each item is [beam, seq_len]
scores = [x.squeeze(0) for x in scores] # each item is [beam,]
return prefixes, scores, tgt

def freeze(self) -> None:
"""Freeze weights of embedding, decoder, and classification layers to prevent memory leak.
Expand Down
39 changes: 21 additions & 18 deletions nemo/collections/asr/parts/submodules/multitask_beam_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch

from nemo.collections.asr.modules.transformer import BeamSearchSequenceGenerator
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.core import Typing, typecheck
from nemo.core.neural_types import ChannelType, HypothesisType, LabelsType, MaskType, NeuralType
Expand Down Expand Up @@ -85,7 +85,6 @@ def forward(
encoder_hidden_states: torch.Tensor,
encoder_input_mask: torch.Tensor,
decoder_input_ids: Optional[torch.Tensor] = None,
return_scores: bool = False,
partial_hypotheses: Optional[List[Hypothesis]] = None,
):
raise NotImplementedError()
Expand All @@ -112,7 +111,6 @@ def input_types(self):
"encoder_hidden_states": NeuralType(tuple(('B', 'T', 'D')), ChannelType()),
"encoder_input_mask": NeuralType(tuple(('B', 'T')), MaskType()),
"decoder_input_ids": NeuralType(('B', 'T'), LabelsType()),
"return_scores": NeuralType(optional=True),
"partial_hypotheses": NeuralType(optional=True),
}

Expand Down Expand Up @@ -142,7 +140,7 @@ def __init__(
return_best_hypothesis=return_best_hypothesis,
preserve_alignments=preserve_alignments,
)

self.beam_size = beam_size
self.beam_search = BeamSearchSequenceGenerator(
embedding=transformer_decoder.embedding,
decoder=transformer_decoder.decoder,
Expand Down Expand Up @@ -170,7 +168,6 @@ def forward(
encoder_hidden_states: torch.Tensor,
encoder_input_mask: torch.Tensor,
decoder_input_ids: Optional[torch.Tensor] = None,
return_scores: bool = False,
partial_hypotheses: Optional[List[Hypothesis]] = None,
):
"""Returns a list of hypotheses given an input batch of the encoder hidden embedding.
Expand All @@ -185,25 +182,31 @@ def forward(
packed list containing batch number of sentences (Hypotheses).
"""
with torch.inference_mode():
hypotheses = [
Hypothesis(score=0.0, y_sequence=[], timestep=[]) for _ in range(encoder_hidden_states.shape[0])
]
beam_hypotheses = self.beam_search(
topk_hypotheses, beam_scores, best_hypo = self.beam_search(
encoder_hidden_states=encoder_hidden_states,
encoder_input_mask=encoder_input_mask,
decoder_input_ids=decoder_input_ids,
return_beam_scores=return_scores,
return_beam_scores=True,
)

if return_scores:
_, beam_scores, beam_hypotheses = beam_hypotheses
beam_scores = beam_scores.detach().cpu()
if not self.return_best_hypothesis:
topk_hypotheses = [x.detach().cpu() for x in topk_hypotheses] # each item is [beam, seq_len]
beam_scores = [x.detach().cpu() for x in beam_scores] # each item is [beam,]
packed_result = []
for i in range(len(topk_hypotheses)):
hypotheses = [Hypothesis(score=0.0, y_sequence=[], timestep=[]) for _ in range(self.beam_size)]
# Pack results into Hypotheses
packed_result.append(
NBestHypotheses(pack_hypotheses(hypotheses, topk_hypotheses[i], beam_scores[i]))
)
else:
beam_scores = [None for _ in range(len(beam_hypotheses))]
beam_hypotheses = beam_hypotheses.detach().cpu()

# Pack results into Hypotheses
packed_result = pack_hypotheses(hypotheses, beam_hypotheses, beam_scores)
beam_scores = [None for _ in range(len(best_hypo))]
best_hypo = best_hypo.detach().cpu()
hypotheses = [
Hypothesis(score=0.0, y_sequence=[], timestep=[]) for _ in range(encoder_hidden_states.shape[0])
]
# Pack results into Hypotheses
packed_result = pack_hypotheses(hypotheses, best_hypo, beam_scores)

return (packed_result,)

Expand Down
6 changes: 0 additions & 6 deletions nemo/collections/asr/parts/submodules/multitask_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ class AbstractMultiTaskDecoding(ABC):
- greedy, greedy_batch (for greedy decoding).
- beam, tsd, alsd (for beam search decoding).
return_scores: bool flag which determines whether to return the scores of the hypotheses.
compute_langs: a bool flag, which allows to compute language id (LID) information per token,
word, and the entire sample (most likely language id). The LIDS will be available
in the returned Hypothesis object as a dictionary
Expand Down Expand Up @@ -104,7 +102,6 @@ def __init__(

self.preserve_alignments = self.cfg.get('preserve_alignments', None)
self.compute_langs = self.cfg.get('compute_langs', False)
self.return_scores = self.cfg.get('return_scores', False)
self.compute_hypothesis_token_set = self.cfg.get('compute_hypothesis_token_set', False)

possible_strategies = ['greedy', 'greedy_batch', 'beam']
Expand Down Expand Up @@ -181,7 +178,6 @@ def decode_predictions_tensor(
encoder_hidden_states=encoder_hidden_states,
encoder_input_mask=encoder_input_mask,
decoder_input_ids=decoder_input_ids,
return_scores=self.return_scores,
partial_hypotheses=partial_hypotheses,
) # type: [List[Hypothesis]]

Expand Down Expand Up @@ -322,8 +318,6 @@ class MultiTaskDecoding(AbstractMultiTaskDecoding):
- greedy, greedy_batch (for greedy decoding).
- beam, tsd, alsd (for beam search decoding).
return_scores: bool flag which determines whether to return the scores of the hypotheses.
compute_langs: a bool flag, which allows to compute language id (LID) information per token,
word, and the entire sample (most likely language id). The LIDS will be available
in the returned Hypothesis object as a dictionary
Expand Down

0 comments on commit cb2b5e7

Please sign in to comment.