From d904b60535c3652cd31293c0bb5f6f95190209c4 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 7 Jul 2023 14:52:07 -0700 Subject: [PATCH] Hybrid conformer export (#6983) (#6995) * Implemented generic kv-pair setting of export_config from args * Hybrid conformer export * Hybrid decoder export * Cleanup * Changed from **kwargs * Docstring * Docs added * Stringify args * Added docs for ASR export configs * lowercase ctc --------- Signed-off-by: Boris Fomitchev Co-authored-by: Boris Fomitchev Signed-off-by: Gerald Shen --- docs/source/asr/models.rst | 10 ++++++ docs/source/core/export.rst | 31 +++++++++++++++++++ nemo/collections/asr/models/asr_model.py | 8 +++++ .../asr/models/hybrid_rnnt_ctc_models.py | 14 +++++++++ nemo/collections/asr/models/rnnt_models.py | 12 +++++-- nemo/core/classes/exportable.py | 14 +++++++++ scripts/export.py | 19 +++++++++--- 7 files changed, 102 insertions(+), 6 deletions(-) diff --git a/docs/source/asr/models.rst b/docs/source/asr/models.rst index 80a0fd90f0fbc..697a898271455 100644 --- a/docs/source/asr/models.rst +++ b/docs/source/asr/models.rst @@ -215,6 +215,11 @@ It is recommended to train a model in streaming model with limited context for t You may find FastConformer variants of cache-aware streaming models under ``/examples/asr/conf/fastconformer/``. +Note cache-aware streaming models are being exported without caching support by default. +To include caching support, `model.set_export_config({'cache_support' : 'True'})` should be called before export. +Or, if ``/scripts/export.py`` is being used: +`python export.py cache_aware_conformer.nemo cache_aware_conformer.onnx --config cache_support=True` + .. _LSTM-Transducer_model: LSTM-Transducer @@ -291,6 +296,11 @@ Similar example configs for FastConformer variants of Hybrid models can be found ``/examples/asr/conf/fastconformer/hybrid_transducer_ctc/`` ``/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/`` +Note Hybrid models are being exported as RNNT (encoder and decoder+joint parts) by default. +To export as CTC (single encoder+decoder graph), `model.set_export_config({'decoder_type' : 'ctc'})` should be called before export. +Or, if ``/scripts/export.py`` is being used: +`python export.py hybrid_transducer.nemo hybrid_transducer.onnx --config decoder_type=ctc` + .. _Conformer-HAT_model: Conformer-HAT (Hybrid Autoregressive Transducer) diff --git a/docs/source/core/export.rst b/docs/source/core/export.rst index 0e598e215dbfd..f54daffe9c9c5 100644 --- a/docs/source/core/export.rst +++ b/docs/source/core/export.rst @@ -177,6 +177,37 @@ Another common requirement for models that are being exported is to run certain # call base method for common set of modifications Exportable._prepare_for_export(self, **kwargs) +Some models that require control flow, need to be exported in multiple parts. Typical examples are RNNT nets. +To facilitate that, the hooks below are provided. To export, for example, 'encoder' and 'decoder' subnets of the model, overload list_export_subnets to return ['encoder', 'decoder']. + +.. code-block:: Python + + def get_export_subnet(self, subnet=None): + """ + Returns Exportable subnet model/module to export + """ + + + def list_export_subnets(self): + """ + Returns default set of subnet names exported for this model + First goes the one receiving input (input_example) + """ + +Some nertworks may be exported differently according to user-settable options (like ragged batch support for TTS or cache support for ASR). To facilitate that - `set_export_config()` method is provided by Exportable to set key/value pairs to predefined model.export_config dictionary, to be used during the export: + +.. code-block:: Python + def set_export_config(self, args): + """ + Sets/updates export_config dictionary + """ +Also, if an action hook on setting config is desired, this method may be overloaded by `Exportable` descendants to include one. +An example can be found in ``/nemo/collections/asr/models/rnnt_models.py``. + +Here is example on now `set_export_config()` call is being tied to command line arguments in ``/scripts/export.py`` : + +.. code-block:: Python + python scripts/export.py hybrid_conformer.nemo hybrid_conformer.onnx --config decoder_type=ctc Exportable Model Code ~~~~~~~~~~~~~~~~~~~~~ diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index 6ac3633201e2c..7e03d587139f1 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -215,3 +215,11 @@ def disabled_deployment_input_names(self): @property def disabled_deployment_output_names(self): return self.encoder.disabled_deployment_output_names + + def set_export_config(self, args): + if 'cache_support' in args: + enable = bool(args['cache_support']) + self.encoder.export_cache_support = enable + logging.info(f"Caching support enabled: {enable}") + self.encoder.setup_streaming_params() + super().set_export_config(args) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 5ca6124ecfd7c..11c616b1257f6 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -645,6 +645,20 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): self.finalize_interctc_metrics(metrics, outputs, prefix="test_") return metrics + # EncDecRNNTModel is exported in 2 parts + def list_export_subnets(self): + if self.cur_decoder == 'rnnt': + return ['encoder', 'decoder_joint'] + else: + return ['self'] + + @property + def output_module(self): + if self.cur_decoder == 'rnnt': + return self.decoder + else: + return self.ctc_decoder + @classmethod def list_available_models(cls) -> Optional[PretrainedModelInfo]: """ diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index 92bb04fd2a3ed..0c1da97c5012d 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -28,7 +28,7 @@ from nemo.collections.asr.data.audio_to_text_dali import AudioToCharDALIDataset, DALIOutputs from nemo.collections.asr.losses.rnnt import RNNTLoss, resolve_rnnt_default_loss_name from nemo.collections.asr.metrics.rnnt_wer import RNNTWER, RNNTDecoding, RNNTDecodingConfig -from nemo.collections.asr.models.asr_model import ASRModel +from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel from nemo.collections.asr.modules.rnnt import RNNTDecoderJoint from nemo.collections.asr.parts.mixins import ASRModuleMixin from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType @@ -39,7 +39,7 @@ from nemo.utils import logging -class EncDecRNNTModel(ASRModel, ASRModuleMixin, Exportable): +class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecModel): """Base class for encoder decoder RNNT-based models.""" def __init__(self, cfg: DictConfig, trainer: Trainer = None): @@ -960,6 +960,14 @@ def list_export_subnets(self): def decoder_joint(self): return RNNTDecoderJoint(self.decoder, self.joint) + def set_export_config(self, args): + if 'decoder_type' in args: + if hasattr(self, 'change_decoding_strategy'): + self.change_decoding_strategy(decoder_type=args['decoder_type']) + else: + raise Exception("Model does not have decoder type option") + super().set_export_config(args) + @classmethod def list_available_models(cls) -> List[PretrainedModelInfo]: """ diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 3d2682f2304e8..8469e80219d60 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -302,3 +302,17 @@ def list_export_subnets(self): First goes the one receiving input (input_example) """ return ['self'] + + def get_export_config(self): + """ + Returns export_config dictionary + """ + return getattr(self, 'export_config', {}) + + def set_export_config(self, args): + """ + Sets/updates export_config dictionary + """ + ex_config = self.get_export_config() + ex_config.update(args) + self.export_config = ex_config diff --git a/scripts/export.py b/scripts/export.py index fe3b79ebdf280..4b21bc4ffd734 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -62,6 +62,15 @@ def get_args(argv): ) parser.add_argument("--device", default="cuda", help="Device to export for") parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification") + parser.add_argument( + "--config", + metavar="KEY=VALUE", + nargs='+', + help="Set a number of key-value pairs to model.export_config dictionary " + "(do not put spaces before or after the = sign). " + "Note that values are always treated as strings.", + ) + args = parser.parse_args(argv) return args @@ -130,10 +139,12 @@ def nemo_export(argv): in_args["max_dim"] = args.max_dim max_dim = args.max_dim - if args.cache_support and hasattr(model, "encoder") and hasattr(model.encoder, "export_cache_support"): - model.encoder.export_cache_support = True - logging.info("Caching support is enabled.") - model.encoder.setup_streaming_params() + if args.cache_support: + model.set_export_config({"cache_support": "True"}) + + if args.config: + kv = dict(map(lambda s: s.split('='), args.config)) + model.set_export_config(kv) autocast = nullcontext if args.autocast: