Skip to content

Commit

Permalink
Prepare RNNT to switch to Numba loss for compatibility (NVIDIA#1995)
Browse files Browse the repository at this point in the history
* Prepare RNNT to switch to Numba loss for compatibility

Signed-off-by: smajumdar <[email protected]>

* Update tests for RNNT

Signed-off-by: smajumdar <[email protected]>

* Address comments

Signed-off-by: smajumdar <[email protected]>

* Address comments

Signed-off-by: smajumdar <[email protected]>

* Fix wrong resolution of gradient calculation

Signed-off-by: smajumdar <[email protected]>

* Fix wrong resolution of gradient calculation

Signed-off-by: smajumdar <[email protected]>

* Drop WarpRNNT requirement

Signed-off-by: smajumdar <[email protected]>
Signed-off-by: Virginia Adams <[email protected]>
  • Loading branch information
titu1994 authored and vadam5 committed Apr 9, 2021
1 parent fd85443 commit bd3f883
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 81 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ RUN git clone --branch v0.8.1 https://github.com/pytorch/text.git && \
cd text && \
git submodule update --init --recursive && \
python setup.py clean install && \
cd .. && rm -r text
cd .. && rm -r text

# build RNN-T loss
WORKDIR /workspace/deps/rnnt
Expand Down
39 changes: 39 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,45 @@ pipeline {
}
}

// TODO: UNCOMMENT TESTS AFTER 21.04 release (numba 0.53 min requirement)
// stage('L2: ASR RNNT dev run') {
// when {
// anyOf {
// branch 'main'
// changeRequest target: 'main'
// }
// }
// failFast true
// parallel {
// stage('Speech to Text - RNNT') {
// steps {
// sh 'python examples/asr/speech_to_text_rnnt.py \
// model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \
// model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \
// model.train_ds.batch_size=8 \
// trainer.gpus=[0] \
// +trainer.fast_dev_run=True \
// exp_manager.exp_dir=examples/asr/speech_to_text_rnnt_results'
// sh 'rm -rf examples/asr/speech_to_text_rnnt_results'
// }
// }
// stage('L2: Speech to Text RNNT WPE') {
// steps {
// sh 'python examples/asr/speech_to_text_rnnt_bpe.py \
// --config-path="experimental/contextnet_rnnt/" --config-name="config_rnnt_bpe.yaml" \
// model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \
// model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \
// model.tokenizer.dir="/home/TestData/asr_tokenizers/an4_wpe_128/" \
// model.tokenizer.type="wpe" \
// trainer.gpus=[0] \
// +trainer.fast_dev_run=True \
// exp_manager.exp_dir=examples/asr/speech_to_text_rnnt_wpe_results'
// sh 'rm -rf examples/asr/speech_to_text_rnnt_wpe_results'
// }
// }
// }
// }

stage('L2: ASR Multi-dataloader dev run') {
when {
anyOf {
Expand Down
10 changes: 9 additions & 1 deletion nemo/collections/asr/losses/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class RNNTLossConfig:
),
}

RNNT_LOSS_RESOLVER['default'] = RNNT_LOSS_RESOLVER['warprnnt']
RNNT_LOSS_RESOLVER['default'] = RNNT_LOSS_RESOLVER['warprnnt_numba']


def _warn_unused_additional_kwargs(loss_name, kwargs):
Expand All @@ -108,6 +108,10 @@ def _warn_unused_additional_kwargs(loss_name, kwargs):
)


def resolve_rnnt_default_loss_name() -> str:
return RNNT_LOSS_RESOLVER['default'].loss_name


def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) -> torch.nn.Module:
loss_function_names = list(RNNT_LOSS_RESOLVER.keys())

Expand Down Expand Up @@ -152,6 +156,9 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None)
if loss_name == 'default':
loss_name = loss_config.loss_name

"""
Resolve RNNT loss functions
"""
if loss_name == 'warprnnt':
loss_func = warprnnt.RNNTLoss(blank=blank_idx, reduction='none')
_warn_unused_additional_kwargs(loss_name, loss_kwargs)
Expand Down Expand Up @@ -244,6 +251,7 @@ def forward(self, log_probs, targets, input_lengths, target_lengths):
max_targets_len = target_lengths.max()

# Force cast joint to float32
# TODO: Remove once Numba supports FP16
if log_probs.dtype != torch.float32:
logits_orig = log_probs
log_probs = log_probs.float()
Expand Down
17 changes: 0 additions & 17 deletions nemo/collections/asr/models/rnnt_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,6 @@
from nemo.core.classes.common import PretrainedModelInfo
from nemo.utils import logging, model_utils

try:
import warprnnt_pytorch as warprnnt

WARP_RNNT_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
WARP_RNNT_AVAILABLE = False


class EncDecRNNTBPEModel(EncDecRNNTModel, ASRBPEMixin):
"""Base class for encoder decoder RNNT-based models with subword tokenization."""
Expand All @@ -52,16 +45,6 @@ def list_available_models(cls) -> Optional[PretrainedModelInfo]:
return result

def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# Required loss function
if not WARP_RNNT_AVAILABLE:
raise ImportError(
"Could not import `warprnnt_pytorch`.\n"
"Please visit https://github.com/HawkAaron/warp-transducer "
"and follow the steps in the readme to build and install the "
"pytorch bindings for RNNT Loss, or use the provided docker "
"container that supports RNN-T loss."
)

# Convert to Hydra 1.0 compatible DictConfig
cfg = model_utils.convert_model_config_to_dict_config(cfg)
cfg = model_utils.maybe_update_config_version(cfg)
Expand Down
70 changes: 50 additions & 20 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,14 @@

from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs
from nemo.collections.asr.losses.rnnt import RNNTLoss
from nemo.collections.asr.losses.rnnt import RNNTLoss, resolve_rnnt_default_loss_name
from nemo.collections.asr.metrics.rnnt_wer import RNNTWER, RNNTDecoding
from nemo.collections.asr.models.asr_model import ASRModel
from nemo.collections.asr.parts.perturb import process_augmentations
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType
from nemo.utils import logging

try:
import warprnnt_pytorch as warprnnt

WARP_RNNT_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
WARP_RNNT_AVAILABLE = False


class EncDecRNNTModel(ASRModel):
"""Base class for encoder decoder RNNT-based models."""
Expand All @@ -57,16 +50,6 @@ def list_available_models(cls) -> Optional[PretrainedModelInfo]:
return result

def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# Required loss function
if not WARP_RNNT_AVAILABLE:
raise ImportError(
"Could not import `warprnnt_pytorch`.\n"
"Please visit https://github.com/HawkAaron/warp-transducer "
"and follow the steps in the readme to build and install the "
"pytorch bindings for RNNT Loss, or use the provided docker "
"container that supports RNN-T loss."
)

# Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
# Global_rank and local_rank is set by LightningModule in Lightning 1.2.0
self.world_size = 1
Expand All @@ -91,7 +74,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

self.decoder = EncDecRNNTModel.from_config_dict(self.cfg.decoder)
self.joint = EncDecRNNTModel.from_config_dict(self.cfg.joint)
self.loss = RNNTLoss(num_classes=self.joint.num_classes_with_blank - 1)

# Setup RNNT Loss
loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get("loss", None))

self.loss = RNNTLoss(
num_classes=self.joint.num_classes_with_blank - 1, loss_name=loss_name, loss_kwargs=loss_kwargs
)

if hasattr(self.cfg, 'spec_augment') and self._cfg.spec_augment is not None:
self.spec_augmentation = EncDecRNNTModel.from_config_dict(self.cfg.spec_augment)
Expand Down Expand Up @@ -130,6 +119,44 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self._optim_variational_noise_std = 0
self._optim_variational_noise_start = 0

def extract_rnnt_loss_cfg(self, cfg: Optional[DictConfig]):
"""
Helper method to extract the rnnt loss name, and potentially its kwargs
to be passed.
Args:
cfg: Should contain `loss_name` as a string which is resolved to a RNNT loss name.
If the default should be used, then `default` can be used.
Optionally, one can pass additional kwargs to the loss function. The subdict
should have a keyname as follows : `{loss_name}_kwargs`.
Note that whichever loss_name is selected, that corresponding kwargs will be
selected. For the "default" case, the "{resolved_default}_kwargs" will be used.
Examples:
.. code-block:: yaml
loss_name: "default"
warprnnt_numba_kwargs:
kwargs2: some_other_val
Returns:
A tuple, the resolved loss name as well as its kwargs (if found).
"""
if cfg is None:
cfg = DictConfig({})

loss_name = cfg.get("loss_name", "default")

if loss_name == "default":
loss_name = resolve_rnnt_default_loss_name()

loss_kwargs = cfg.get(f"{loss_name}_kwargs", None)

logging.info(f"Using RNNT Loss : {loss_name}\n" f"Loss {loss_name}_kwargs: {loss_kwargs}")

return loss_name, loss_kwargs

@torch.no_grad()
def transcribe(
self, paths2audio_files: List[str], batch_size: int = 4, return_hypotheses: bool = False
Expand Down Expand Up @@ -231,7 +258,10 @@ def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[Di
self.decoder = EncDecRNNTModel.from_config_dict(new_decoder_config)

del self.loss
self.loss = RNNTLoss(num_classes=self.joint.num_classes_with_blank - 1)
loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get('loss', None))
self.loss = RNNTLoss(
num_classes=self.joint.num_classes_with_blank - 1, loss_name=loss_name, loss_kwargs=loss_kwargs
)

if decoding_cfg is None:
# Assume same decoding config as before
Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/asr/parts/numba/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

from nemo.collections.asr.parts.numba.numba_utils import numba_cuda_is_supported
from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_pytorch import RNNTLossNumba

# Prevent Numba CUDA logs from showing at info level
cuda_logger = logging.getLogger('numba.cuda.cudadrv.driver')
cuda_logger.setLevel(logging.ERROR) # only show error

__NUMBA_MINIMUM_VERSION__ = "0.53.0"
1 change: 1 addition & 0 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def rnnt_loss_gpu(
if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS:
raise RuntimeError("Invalid parameter passed when calculating working space memory")

# Select GPU index
cuda.select_device(acts.device.index)
gpu_workspace = torch.zeros(gpu_size, device=acts.device, dtype=acts.dtype, requires_grad=False)

Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction):
certify_inputs(acts, labels, act_lens, label_lens)

loss_func = rnnt.rnnt_loss_gpu if is_cuda else rnnt.rnnt_loss_cpu
grads = torch.zeros_like(acts) if acts.requires_grad else torch.zeros(0).to(acts)
grads = torch.zeros_like(acts) if acts.requires_grad else None
minibatch_size = acts.size(0)
costs = torch.zeros(minibatch_size, device=acts.device, dtype=acts.dtype)

Expand Down
25 changes: 6 additions & 19 deletions tests/collections/asr/test_asr_rnnt_encdec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,10 @@
from nemo.collections.asr.models import EncDecRNNTModel
from nemo.collections.asr.parts import rnnt_beam_decoding as beam_decode
from nemo.collections.asr.parts import rnnt_greedy_decoding as greedy_decode
from nemo.collections.asr.parts.numba import __NUMBA_MINIMUM_VERSION__, numba_utils
from nemo.utils.config_utils import assert_dataclass_signature_match

try:
from warprnnt_pytorch import RNNTLoss

WARP_RNNT_AVAILABLE = True

except (ImportError, ModuleNotFoundError):
WARP_RNNT_AVAILABLE = False
NUMBA_RNNT_LOSS_AVAILABLE = numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__)


@pytest.fixture()
Expand Down Expand Up @@ -99,9 +94,7 @@ def asr_model():

class TestEncDecRNNTModel:
@pytest.mark.skipif(
not WARP_RNNT_AVAILABLE,
reason='RNNTLoss has not been compiled. Please compile and install '
'RNNT Loss first before running this test',
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_constructor(self, asr_model):
Expand All @@ -113,9 +106,7 @@ def test_constructor(self, asr_model):
assert isinstance(instance2, EncDecRNNTModel)

@pytest.mark.skipif(
not WARP_RNNT_AVAILABLE,
reason='RNNTLoss has not been compiled. Please compile and install '
'RNNT Loss first before running this test',
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_forward(self, asr_model):
Expand Down Expand Up @@ -149,9 +140,7 @@ def test_forward(self, asr_model):
assert diff <= 1e-6

@pytest.mark.skipif(
not WARP_RNNT_AVAILABLE,
reason='RNNTLoss has not been compiled. Please compile and install '
'RNNT Loss first before running this test',
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_vocab_change(self, asr_model):
Expand All @@ -172,9 +161,7 @@ def test_vocab_change(self, asr_model):
assert asr_model.num_weights == (nw1 + (pred_embedding + joint_joint))

@pytest.mark.skipif(
not WARP_RNNT_AVAILABLE,
reason='RNNTLoss has not been compiled. Please compile and install '
'RNNT Loss first before running this test',
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_decoding_change(self, asr_model):
Expand Down
Loading

0 comments on commit bd3f883

Please sign in to comment.