Skip to content

Commit

Permalink
Support Alignment Extraction for all RNNT Beam decoding methods (NVID…
Browse files Browse the repository at this point in the history
…IA#5925)

* Partial impl of ALSD alignment extraction

Signed-off-by: smajumdar <[email protected]>

* Partial impl of ALSD alignment extraction

Signed-off-by: smajumdar <[email protected]>

* Remove everything else

Signed-off-by: smajumdar <[email protected]>

* Support dataclass in AbstractRNNTDecoding

Signed-off-by: smajumdar <[email protected]>

* Add first draft unittest

Signed-off-by: smajumdar <[email protected]>

* Correct the logic to more to the next timestep in the alignment

Signed-off-by: smajumdar <[email protected]>

* Finalize ALSD alignment generation

Signed-off-by: smajumdar <[email protected]>

* Add support for TSD greedy alignment extraction

Signed-off-by: smajumdar <[email protected]>

* Add support for mAES greedy alignment extraction

Signed-off-by: smajumdar <[email protected]>

* Finalize extraction of alignments from all beam algorithms for RNNT

Signed-off-by: smajumdar <[email protected]>

* Style fixes

Signed-off-by: smajumdar <[email protected]>

* Add copyright

Signed-off-by: smajumdar <[email protected]>

* Address comments

Signed-off-by: smajumdar <[email protected]>

---------

Signed-off-by: smajumdar <[email protected]>
Signed-off-by: Jason <[email protected]>
  • Loading branch information
titu1994 authored and blisc committed Feb 10, 2023
1 parent 368f57e commit 8a879c4
Show file tree
Hide file tree
Showing 3 changed files with 436 additions and 28 deletions.
8 changes: 7 additions & 1 deletion nemo/collections/asr/metrics/rnnt_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

import copy
from abc import abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, is_dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union

import editdistance
import numpy as np
import torch
from omegaconf import OmegaConf
from torchmetrics import Metric

from nemo.collections.asr.metrics.wer import move_dimension_to_the_front
Expand Down Expand Up @@ -193,6 +194,11 @@ class AbstractRNNTDecoding(ConfidenceMixin):

def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
super(AbstractRNNTDecoding, self).__init__()

# Convert dataclass to config object
if is_dataclass(decoding_cfg):
decoding_cfg = OmegaConf.structured(decoding_cfg)

self.cfg = decoding_cfg
self.blank_id = blank_id
self.num_extra_outputs = joint.num_extra_outputs
Expand Down
Loading

0 comments on commit 8a879c4

Please sign in to comment.