From bd3f8833e38154952f391343980e5b060c5d0319 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Thu, 1 Apr 2021 18:43:16 -0700 Subject: [PATCH] Prepare RNNT to switch to Numba loss for compatibility (#1995) * Prepare RNNT to switch to Numba loss for compatibility Signed-off-by: smajumdar * Update tests for RNNT Signed-off-by: smajumdar * Address comments Signed-off-by: smajumdar * Address comments Signed-off-by: smajumdar * Fix wrong resolution of gradient calculation Signed-off-by: smajumdar * Fix wrong resolution of gradient calculation Signed-off-by: smajumdar * Drop WarpRNNT requirement Signed-off-by: smajumdar Signed-off-by: Virginia Adams --- Dockerfile | 2 +- Jenkinsfile | 39 +++++++++++ nemo/collections/asr/losses/rnnt.py | 10 ++- .../collections/asr/models/rnnt_bpe_models.py | 17 ----- nemo/collections/asr/models/rnnt_models.py | 70 +++++++++++++------ nemo/collections/asr/parts/numba/__init__.py | 6 ++ .../asr/parts/numba/rnnt_loss/rnnt.py | 1 + .../asr/parts/numba/rnnt_loss/rnnt_pytorch.py | 2 +- .../asr/test_asr_rnnt_encdec_model.py | 25 ++----- .../asr/test_asr_rnnt_encoder_model_bpe.py | 29 ++------ 10 files changed, 120 insertions(+), 81 deletions(-) diff --git a/Dockerfile b/Dockerfile index 8dc2750496317..370c9015c25a4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -49,7 +49,7 @@ RUN git clone --branch v0.8.1 https://github.com/pytorch/text.git && \ cd text && \ git submodule update --init --recursive && \ python setup.py clean install && \ - cd .. && rm -r text + cd .. && rm -r text # build RNN-T loss WORKDIR /workspace/deps/rnnt diff --git a/Jenkinsfile b/Jenkinsfile index f7f3bd855d9cb..24863415ea19d 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -291,6 +291,45 @@ pipeline { } } +// TODO: UNCOMMENT TESTS AFTER 21.04 release (numba 0.53 min requirement) +// stage('L2: ASR RNNT dev run') { +// when { +// anyOf { +// branch 'main' +// changeRequest target: 'main' +// } +// } +// failFast true +// parallel { +// stage('Speech to Text - RNNT') { +// steps { +// sh 'python examples/asr/speech_to_text_rnnt.py \ +// model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \ +// model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \ +// model.train_ds.batch_size=8 \ +// trainer.gpus=[0] \ +// +trainer.fast_dev_run=True \ +// exp_manager.exp_dir=examples/asr/speech_to_text_rnnt_results' +// sh 'rm -rf examples/asr/speech_to_text_rnnt_results' +// } +// } +// stage('L2: Speech to Text RNNT WPE') { +// steps { +// sh 'python examples/asr/speech_to_text_rnnt_bpe.py \ +// --config-path="experimental/contextnet_rnnt/" --config-name="config_rnnt_bpe.yaml" \ +// model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \ +// model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \ +// model.tokenizer.dir="/home/TestData/asr_tokenizers/an4_wpe_128/" \ +// model.tokenizer.type="wpe" \ +// trainer.gpus=[0] \ +// +trainer.fast_dev_run=True \ +// exp_manager.exp_dir=examples/asr/speech_to_text_rnnt_wpe_results' +// sh 'rm -rf examples/asr/speech_to_text_rnnt_wpe_results' +// } +// } +// } +// } + stage('L2: ASR Multi-dataloader dev run') { when { anyOf { diff --git a/nemo/collections/asr/losses/rnnt.py b/nemo/collections/asr/losses/rnnt.py index 499119a54bcab..7e5a7b7c2b6f7 100644 --- a/nemo/collections/asr/losses/rnnt.py +++ b/nemo/collections/asr/losses/rnnt.py @@ -96,7 +96,7 @@ class RNNTLossConfig: ), } -RNNT_LOSS_RESOLVER['default'] = RNNT_LOSS_RESOLVER['warprnnt'] +RNNT_LOSS_RESOLVER['default'] = RNNT_LOSS_RESOLVER['warprnnt_numba'] def _warn_unused_additional_kwargs(loss_name, kwargs): @@ -108,6 +108,10 @@ def _warn_unused_additional_kwargs(loss_name, kwargs): ) +def resolve_rnnt_default_loss_name() -> str: + return RNNT_LOSS_RESOLVER['default'].loss_name + + def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) -> torch.nn.Module: loss_function_names = list(RNNT_LOSS_RESOLVER.keys()) @@ -152,6 +156,9 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) if loss_name == 'default': loss_name = loss_config.loss_name + """ + Resolve RNNT loss functions + """ if loss_name == 'warprnnt': loss_func = warprnnt.RNNTLoss(blank=blank_idx, reduction='none') _warn_unused_additional_kwargs(loss_name, loss_kwargs) @@ -244,6 +251,7 @@ def forward(self, log_probs, targets, input_lengths, target_lengths): max_targets_len = target_lengths.max() # Force cast joint to float32 + # TODO: Remove once Numba supports FP16 if log_probs.dtype != torch.float32: logits_orig = log_probs log_probs = log_probs.float() diff --git a/nemo/collections/asr/models/rnnt_bpe_models.py b/nemo/collections/asr/models/rnnt_bpe_models.py index fc2c34e2a39c5..27a2e62757e2c 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models.py +++ b/nemo/collections/asr/models/rnnt_bpe_models.py @@ -29,13 +29,6 @@ from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging, model_utils -try: - import warprnnt_pytorch as warprnnt - - WARP_RNNT_AVAILABLE = True -except (ImportError, ModuleNotFoundError): - WARP_RNNT_AVAILABLE = False - class EncDecRNNTBPEModel(EncDecRNNTModel, ASRBPEMixin): """Base class for encoder decoder RNNT-based models with subword tokenization.""" @@ -52,16 +45,6 @@ def list_available_models(cls) -> Optional[PretrainedModelInfo]: return result def __init__(self, cfg: DictConfig, trainer: Trainer = None): - # Required loss function - if not WARP_RNNT_AVAILABLE: - raise ImportError( - "Could not import `warprnnt_pytorch`.\n" - "Please visit https://github.com/HawkAaron/warp-transducer " - "and follow the steps in the readme to build and install the " - "pytorch bindings for RNNT Loss, or use the provided docker " - "container that supports RNN-T loss." - ) - # Convert to Hydra 1.0 compatible DictConfig cfg = model_utils.convert_model_config_to_dict_config(cfg) cfg = model_utils.maybe_update_config_version(cfg) diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index 240bba71ea47f..c9e665db0ea13 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -26,7 +26,7 @@ from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs -from nemo.collections.asr.losses.rnnt import RNNTLoss +from nemo.collections.asr.losses.rnnt import RNNTLoss, resolve_rnnt_default_loss_name from nemo.collections.asr.metrics.rnnt_wer import RNNTWER, RNNTDecoding from nemo.collections.asr.models.asr_model import ASRModel from nemo.collections.asr.parts.perturb import process_augmentations @@ -34,13 +34,6 @@ from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType from nemo.utils import logging -try: - import warprnnt_pytorch as warprnnt - - WARP_RNNT_AVAILABLE = True -except (ImportError, ModuleNotFoundError): - WARP_RNNT_AVAILABLE = False - class EncDecRNNTModel(ASRModel): """Base class for encoder decoder RNNT-based models.""" @@ -57,16 +50,6 @@ def list_available_models(cls) -> Optional[PretrainedModelInfo]: return result def __init__(self, cfg: DictConfig, trainer: Trainer = None): - # Required loss function - if not WARP_RNNT_AVAILABLE: - raise ImportError( - "Could not import `warprnnt_pytorch`.\n" - "Please visit https://github.com/HawkAaron/warp-transducer " - "and follow the steps in the readme to build and install the " - "pytorch bindings for RNNT Loss, or use the provided docker " - "container that supports RNN-T loss." - ) - # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 self.world_size = 1 @@ -91,7 +74,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.decoder = EncDecRNNTModel.from_config_dict(self.cfg.decoder) self.joint = EncDecRNNTModel.from_config_dict(self.cfg.joint) - self.loss = RNNTLoss(num_classes=self.joint.num_classes_with_blank - 1) + + # Setup RNNT Loss + loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get("loss", None)) + + self.loss = RNNTLoss( + num_classes=self.joint.num_classes_with_blank - 1, loss_name=loss_name, loss_kwargs=loss_kwargs + ) if hasattr(self.cfg, 'spec_augment') and self._cfg.spec_augment is not None: self.spec_augmentation = EncDecRNNTModel.from_config_dict(self.cfg.spec_augment) @@ -130,6 +119,44 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self._optim_variational_noise_std = 0 self._optim_variational_noise_start = 0 + def extract_rnnt_loss_cfg(self, cfg: Optional[DictConfig]): + """ + Helper method to extract the rnnt loss name, and potentially its kwargs + to be passed. + + Args: + cfg: Should contain `loss_name` as a string which is resolved to a RNNT loss name. + If the default should be used, then `default` can be used. + Optionally, one can pass additional kwargs to the loss function. The subdict + should have a keyname as follows : `{loss_name}_kwargs`. + + Note that whichever loss_name is selected, that corresponding kwargs will be + selected. For the "default" case, the "{resolved_default}_kwargs" will be used. + + Examples: + .. code-block:: yaml + loss_name: "default" + + warprnnt_numba_kwargs: + kwargs2: some_other_val + + Returns: + A tuple, the resolved loss name as well as its kwargs (if found). + """ + if cfg is None: + cfg = DictConfig({}) + + loss_name = cfg.get("loss_name", "default") + + if loss_name == "default": + loss_name = resolve_rnnt_default_loss_name() + + loss_kwargs = cfg.get(f"{loss_name}_kwargs", None) + + logging.info(f"Using RNNT Loss : {loss_name}\n" f"Loss {loss_name}_kwargs: {loss_kwargs}") + + return loss_name, loss_kwargs + @torch.no_grad() def transcribe( self, paths2audio_files: List[str], batch_size: int = 4, return_hypotheses: bool = False @@ -231,7 +258,10 @@ def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[Di self.decoder = EncDecRNNTModel.from_config_dict(new_decoder_config) del self.loss - self.loss = RNNTLoss(num_classes=self.joint.num_classes_with_blank - 1) + loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get('loss', None)) + self.loss = RNNTLoss( + num_classes=self.joint.num_classes_with_blank - 1, loss_name=loss_name, loss_kwargs=loss_kwargs + ) if decoding_cfg is None: # Assume same decoding config as before diff --git a/nemo/collections/asr/parts/numba/__init__.py b/nemo/collections/asr/parts/numba/__init__.py index e77d166fabbf7..e72a10034355d 100644 --- a/nemo/collections/asr/parts/numba/__init__.py +++ b/nemo/collections/asr/parts/numba/__init__.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from nemo.collections.asr.parts.numba.numba_utils import numba_cuda_is_supported from nemo.collections.asr.parts.numba.rnnt_loss.rnnt_pytorch import RNNTLossNumba +# Prevent Numba CUDA logs from showing at info level +cuda_logger = logging.getLogger('numba.cuda.cudadrv.driver') +cuda_logger.setLevel(logging.ERROR) # only show error + __NUMBA_MINIMUM_VERSION__ = "0.53.0" diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py index e05d5e8df0a13..038713f30a911 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py @@ -172,6 +172,7 @@ def rnnt_loss_gpu( if status != global_constants.RNNTStatus.RNNT_STATUS_SUCCESS: raise RuntimeError("Invalid parameter passed when calculating working space memory") + # Select GPU index cuda.select_device(acts.device.index) gpu_workspace = torch.zeros(gpu_size, device=acts.device, dtype=acts.dtype, requires_grad=False) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index f25987b621fea..016f2dce9d8a4 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -50,7 +50,7 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction): certify_inputs(acts, labels, act_lens, label_lens) loss_func = rnnt.rnnt_loss_gpu if is_cuda else rnnt.rnnt_loss_cpu - grads = torch.zeros_like(acts) if acts.requires_grad else torch.zeros(0).to(acts) + grads = torch.zeros_like(acts) if acts.requires_grad else None minibatch_size = acts.size(0) costs = torch.zeros(minibatch_size, device=acts.device, dtype=acts.dtype) diff --git a/tests/collections/asr/test_asr_rnnt_encdec_model.py b/tests/collections/asr/test_asr_rnnt_encdec_model.py index 81f51e3544025..6a2e4eb76d88c 100644 --- a/tests/collections/asr/test_asr_rnnt_encdec_model.py +++ b/tests/collections/asr/test_asr_rnnt_encdec_model.py @@ -21,15 +21,10 @@ from nemo.collections.asr.models import EncDecRNNTModel from nemo.collections.asr.parts import rnnt_beam_decoding as beam_decode from nemo.collections.asr.parts import rnnt_greedy_decoding as greedy_decode +from nemo.collections.asr.parts.numba import __NUMBA_MINIMUM_VERSION__, numba_utils from nemo.utils.config_utils import assert_dataclass_signature_match -try: - from warprnnt_pytorch import RNNTLoss - - WARP_RNNT_AVAILABLE = True - -except (ImportError, ModuleNotFoundError): - WARP_RNNT_AVAILABLE = False +NUMBA_RNNT_LOSS_AVAILABLE = numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__) @pytest.fixture() @@ -99,9 +94,7 @@ def asr_model(): class TestEncDecRNNTModel: @pytest.mark.skipif( - not WARP_RNNT_AVAILABLE, - reason='RNNTLoss has not been compiled. Please compile and install ' - 'RNNT Loss first before running this test', + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_constructor(self, asr_model): @@ -113,9 +106,7 @@ def test_constructor(self, asr_model): assert isinstance(instance2, EncDecRNNTModel) @pytest.mark.skipif( - not WARP_RNNT_AVAILABLE, - reason='RNNTLoss has not been compiled. Please compile and install ' - 'RNNT Loss first before running this test', + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_forward(self, asr_model): @@ -149,9 +140,7 @@ def test_forward(self, asr_model): assert diff <= 1e-6 @pytest.mark.skipif( - not WARP_RNNT_AVAILABLE, - reason='RNNTLoss has not been compiled. Please compile and install ' - 'RNNT Loss first before running this test', + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_vocab_change(self, asr_model): @@ -172,9 +161,7 @@ def test_vocab_change(self, asr_model): assert asr_model.num_weights == (nw1 + (pred_embedding + joint_joint)) @pytest.mark.skipif( - not WARP_RNNT_AVAILABLE, - reason='RNNTLoss has not been compiled. Please compile and install ' - 'RNNT Loss first before running this test', + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_decoding_change(self, asr_model): diff --git a/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py b/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py index f5813eb69dad6..e0a5004ba7dec 100644 --- a/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py @@ -24,14 +24,9 @@ from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.parts import rnnt_beam_decoding as beam_decode from nemo.collections.asr.parts import rnnt_greedy_decoding as greedy_decode +from nemo.collections.asr.parts.numba import __NUMBA_MINIMUM_VERSION__, numba_utils -try: - from warprnnt_pytorch import RNNTLoss - - WARP_RNNT_AVAILABLE = True - -except (ImportError, ModuleNotFoundError): - WARP_RNNT_AVAILABLE = False +NUMBA_RNNT_LOSS_AVAILABLE = numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__) @pytest.fixture() @@ -95,9 +90,7 @@ def asr_model(test_data_dir): class TestEncDecRNNTBPEModel: @pytest.mark.skipif( - not WARP_RNNT_AVAILABLE, - reason='RNNTLoss has not been compiled. Please compile and install ' - 'RNNT Loss first before running this test', + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_constructor(self, asr_model): @@ -109,9 +102,7 @@ def test_constructor(self, asr_model): assert isinstance(instance2, EncDecRNNTBPEModel) @pytest.mark.skipif( - not WARP_RNNT_AVAILABLE, - reason='RNNTLoss has not been compiled. Please compile and install ' - 'RNNT Loss first before running this test', + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_forward(self, asr_model): @@ -145,9 +136,7 @@ def test_forward(self, asr_model): assert diff <= 1e-6 @pytest.mark.skipif( - not WARP_RNNT_AVAILABLE, - reason='RNNTLoss has not been compiled. Please compile and install ' - 'RNNT Loss first before running this test', + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_save_restore_artifact(self, asr_model): @@ -164,9 +153,7 @@ def test_save_restore_artifact(self, asr_model): assert len(new_model.tokenizer.tokenizer.get_vocab()) == 128 @pytest.mark.skipif( - not WARP_RNNT_AVAILABLE, - reason='RNNTLoss has not been compiled. Please compile and install ' - 'RNNT Loss first before running this test', + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_vocab_change(self, test_data_dir, asr_model): @@ -195,9 +182,7 @@ def test_vocab_change(self, test_data_dir, asr_model): assert asr_model.num_weights == (nw1 + (pred_embedding + joint_joint)) @pytest.mark.skipif( - not WARP_RNNT_AVAILABLE, - reason='RNNTLoss has not been compiled. Please compile and install ' - 'RNNT Loss first before running this test', + not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_decoding_change(self, asr_model):