diff --git a/CHANGELOG.md b/CHANGELOG.md index 2eac5e58c3af..d56bfdf4c471 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -94,6 +94,8 @@ To release a new version, please update the changelog as followed: - Updated licenses - Updated nemo's use of the logging library. from nemo import logging is now the reccomended way of using the nemo logger. neural_factory.logger and all other instances of logger are now deprecated and planned for removal in the next version. Please see PR 267 for complete change information. ([PR #267](https://github.com/NVIDIA/NeMo/pull/267), [PR #283](https://github.com/NVIDIA/NeMo/pull/283), [PR #305](https://github.com/NVIDIA/NeMo/pull/305), [PR #311](https://github.com/NVIDIA/NeMo/pull/311)) - @blisc +- Changed Distributed Data Parallel from Apex to Torch +([PR #336](https://github.com/NVIDIA/NeMo/pull/336)) - @blisc - Added TRADE (dialogue state tracking model) on MultiWOZ dataset ([PR #322](https://github.com/NVIDIA/NeMo/pull/322)) - @chiphuyen, @VahidooX @@ -108,6 +110,8 @@ To release a new version, please update the changelog as followed: ([PR #308](https://github.com/NVIDIA/NeMo/pull/309)) - @tkornuta-nvidia ### Removed +- gradient_predivide_factor arg of train() now has no effect +([PR #336](https://github.com/NVIDIA/NeMo/pull/336)) - @blisc - Dropped support of the following ASR configs: jasper10x4.yaml, quartznet10x5.yaml, quartznet15x5_in.yaml, quartznet5x3.yaml, quartznet5x5.yaml, quartznet_an4.yaml. They are moved to experimental/configs and can still be used with v0.9 for use in replicating paper results ([PR #354](https://github.com/NVIDIA/NeMo/pull/354)) - @blisc diff --git a/examples/nlp/asr_postprocessor/asr_postprocessor.py b/examples/nlp/asr_postprocessor/asr_postprocessor.py index 204e9db5664f..187529ddd2e4 100644 --- a/examples/nlp/asr_postprocessor/asr_postprocessor.py +++ b/examples/nlp/asr_postprocessor/asr_postprocessor.py @@ -26,6 +26,7 @@ eval_epochs_done_callback_wer, eval_iter_callback, ) +from nemo.core import WeightShareTransform from nemo.core.callbacks import CheckpointCallback from nemo.utils.lr_policies import SquareAnnealing @@ -126,9 +127,33 @@ ) # tie all embeddings weights -t_log_softmax.mlp.layer0.weight = encoder.bert.embeddings.word_embeddings.weight -decoder.embedding_layer.token_embedding.weight = encoder.bert.embeddings.word_embeddings.weight -decoder.embedding_layer.position_embedding.weight = encoder.bert.embeddings.position_embeddings.weight +# t_log_softmax.mlp.layer0.weight = encoder.bert.embeddings.word_embeddings.weight +# decoder.embedding_layer.token_embedding.weight = encoder.bert.embeddings.word_embeddings.weight +# decoder.embedding_layer.position_embedding.weight = encoder.bert.embeddings.position_embeddings.weight +t_log_softmax.tie_weights_with( + encoder, + weight_names=["mlp.layer0.weight"], + name2name_and_transform={ + "mlp.layer0.weight": ("bert.embeddings.word_embeddings.weight", WeightShareTransform.SAME) + }, +) +decoder.tie_weights_with( + encoder, + weight_names=["embedding_layer.token_embedding.weight"], + name2name_and_transform={ + "embedding_layer.token_embedding.weight": ("bert.embeddings.word_embeddings.weight", WeightShareTransform.SAME) + }, +) +decoder.tie_weights_with( + encoder, + weight_names=["embedding_layer.position_embedding.weight"], + name2name_and_transform={ + "embedding_layer.position_embedding.weight": ( + "bert.embeddings.position_embeddings.weight", + WeightShareTransform.SAME, + ) + }, +) def create_pipeline(dataset, tokens_in_batch, clean=False, training=True): diff --git a/examples/nlp/language_modeling/bert_pretraining.py b/examples/nlp/language_modeling/bert_pretraining.py index 7ca3871b7d32..857556b89b07 100644 --- a/examples/nlp/language_modeling/bert_pretraining.py +++ b/examples/nlp/language_modeling/bert_pretraining.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= - """ To pretrain BERT on raw text dataset run @@ -224,7 +223,14 @@ # tie weights of MLM softmax layer and embedding layer of the encoder if mlm_classifier.mlp.last_linear_layer.weight.shape != bert_model.bert.embeddings.word_embeddings.weight.shape: raise ValueError("Final classification layer does not match embedding " "layer.") -mlm_classifier.mlp.last_linear_layer.weight = bert_model.bert.embeddings.word_embeddings.weight +# mlm_classifier.mlp.last_linear_layer.weight = bert_model.bert.embeddings.word_embeddings.weight +mlm_classifier.tie_weights_with( + bert_model, + weight_names=["mlp.last_linear_layer.weight"], + name2name_and_transform={ + "mlp.last_linear_layer.weight": ("bert.embeddings.word_embeddings.weight", nemo_core.WeightShareTransform.SAME) + }, +) def create_pipeline(data_file, batch_size, preprocessed_data=False, batches_per_step=1, **kwargs): diff --git a/examples/nlp/language_modeling/language_modeling_transformer.py b/examples/nlp/language_modeling/language_modeling_transformer.py index d49040949538..2572b90af785 100644 --- a/examples/nlp/language_modeling/language_modeling_transformer.py +++ b/examples/nlp/language_modeling/language_modeling_transformer.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= - import math import nemo @@ -22,6 +21,7 @@ import nemo.collections.nlp.nm.trainables.common.token_classification_nm from nemo.collections.nlp.callbacks.lm_transformer_callback import eval_epochs_done_callback, eval_iter_callback from nemo.collections.nlp.data.datasets.lm_transformer_dataset import LanguageModelDataDesc +from nemo.core import WeightShareTransform from nemo.utils.lr_policies import CosineAnnealing parser = nemo.utils.NemoArgParser(description='LM Transformer') @@ -114,7 +114,14 @@ ) # tie weight of embedding and log_softmax layers -log_softmax.mlp.last_linear_layer.weight = encoder.embedding_layer.token_embedding.weight +# log_softmax.mlp.last_linear_layer.weight = encoder.embedding_layer.token_embedding.weight +log_softmax.tie_weights_with( + encoder, + weight_names=["mlp.layer0.weight"], + name2name_and_transform={ + "mlp.layer0.weight": ("embedding_layer.token_embedding.weight", WeightShareTransform.SAME) + }, +) def create_pipeline( diff --git a/examples/nlp/neural_machine_translation/machine_translation_tutorial.py b/examples/nlp/neural_machine_translation/machine_translation_tutorial.py index 8cda90810521..ae05afa88e32 100644 --- a/examples/nlp/neural_machine_translation/machine_translation_tutorial.py +++ b/examples/nlp/neural_machine_translation/machine_translation_tutorial.py @@ -24,6 +24,7 @@ import nemo import nemo.collections.nlp as nemo_nlp from nemo.collections.nlp.callbacks.machine_translation_callback import eval_epochs_done_callback, eval_iter_callback +from nemo.core import WeightShareTransform from nemo.utils.lr_policies import get_lr_policy parser = nemo.utils.NemoArgParser(description='Transformer for Neural Machine Translation') @@ -165,8 +166,25 @@ ) if tie_weight: - log_softmax.mlp.last_linear_layer.weight = encoder.embedding_layer.token_embedding.weight - decoder.embedding_layer.token_embedding.weight = encoder.embedding_layer.token_embedding.weight + # log_softmax.mlp.last_linear_layer.weight = encoder.embedding_layer.token_embedding.weight + log_softmax.tie_weights_with( + encoder, + weight_names=["mlp.last_linear_layer.weight"], + name2name_and_transform={ + "mlp.last_linear_layer.weight": ("embedding_layer.token_embedding.weight", WeightShareTransform.SAME) + }, + ) + # decoder.embedding_layer.token_embedding.weight = encoder.embedding_layer.token_embedding.weight + decoder.tie_weights_with( + encoder, + weight_names=["embedding_layer.token_embedding.weight"], + name2name_and_transform={ + "embedding_layer.token_embedding.weight": ( + "embedding_layer.token_embedding.weight", + WeightShareTransform.SAME, + ) + }, + ) def create_pipeline(dataset_src, dataset_tgt, tokens_in_batch, clean=False, training=True): diff --git a/nemo/backends/pytorch/actions.py b/nemo/backends/pytorch/actions.py index 0516d08d68ce..95ed41e553cf 100644 --- a/nemo/backends/pytorch/actions.py +++ b/nemo/backends/pytorch/actions.py @@ -5,13 +5,15 @@ import json import os from collections import defaultdict +from contextlib import ExitStack from pathlib import Path -from typing import Dict, List, Optional +from typing import List, Optional import torch import torch.distributed as dist import torch.nn as nn import torch.optim as optim +from torch.nn.parallel import DistributedDataParallel as DDP from nemo import logging from nemo.backends.pytorch.module_wrapper import TrainableNeuralModuleWrapper @@ -25,9 +27,8 @@ # these imports will happen on as-needed basis amp = None -convert_syncbn = None -create_syncbn_process_group = None -DDP = None +# convert_syncbn = None +# create_syncbn_process_group = None LARC = None FusedLAMB = None FusedAdam = None @@ -59,18 +60,16 @@ def __init__( global amp amp = importlib.import_module('apex.amp') if local_rank is not None: - global convert_syncbn - global create_syncbn_process_group - global DDP + # global convert_syncbn + # global create_syncbn_process_group global LARC global FusedLAMB global FusedAdam global FusedNovoGrad parallel = importlib.import_module('apex.parallel') apex_optimizer = importlib.import_module('apex.optimizers') - convert_syncbn = parallel.convert_syncbn_model - create_syncbn_process_group = parallel.create_syncbn_process_group - DDP = parallel.DistributedDataParallel + # convert_syncbn = parallel.convert_syncbn_model + # create_syncbn_process_group = parallel.create_syncbn_process_group LARC = parallel.LARC FusedLAMB = apex_optimizer.FusedLAMB FusedAdam = apex_optimizer.FusedAdam @@ -379,7 +378,7 @@ def __initialize_amp( return optimizer def __nm_graph_forward_pass( - self, call_chain, registered_tensors, mode=ModelMode.train, disable_allreduce=False, use_cache=False, + self, call_chain, registered_tensors, mode=ModelMode.train, use_cache=False, ): for ind in range(1, len(call_chain)): if use_cache: @@ -399,12 +398,12 @@ def __nm_graph_forward_pass( m_id = call_chain[ind][0].unique_instance_id pmodule = self.module_reference_table[m_id][1] - if self._local_rank is not None: - if isinstance(pmodule, DDP): - if disable_allreduce: - pmodule.disable_allreduce() - else: - pmodule.enable_allreduce() + # if self._local_rank is not None: + # if isinstance(pmodule, DDP): + # if disable_allreduce: + # pmodule.disable_allreduce() + # else: + # pmodule.enable_allreduce() if mode == ModelMode.train: # if module.is_trainable(): @@ -935,9 +934,8 @@ def __extract_dynamic_axes(port_name: str, ntype: NeuralType, dynamic_axes: defa outputs_to_drop = set() if type(module).__name__ == "JasperEncoder": logging.info( - f"Module is JasperEncoder. We are removing" - f"input and output length ports since they " - f"are not needed for deployment" + "Module is JasperEncoder. We are removing input and output length ports since they are not needed for " + "deployment" ) inputs_to_drop.add("length") outputs_to_drop.add("encoded_lengths") @@ -1072,6 +1070,11 @@ def train( gradient_predivide=False, amp_max_loss_scale=2.0 ** 24, ): + if gradient_predivide: + logging.error( + "gradient_predivide is currently disabled, and is under consideration for removal in future versions. " + "If this functionality is needed, please raise a github issue." + ) if not optimization_params: optimization_params = {} num_epochs = optimization_params.get("num_epochs", None) @@ -1213,23 +1216,44 @@ def train( key = call_chain[i][0].unique_instance_id pmodule = self.module_reference_table[key][1] if not isinstance(pmodule, DDP) and isinstance(pmodule, torch.nn.Module): - gpf = 1 - if gradient_predivide: - gpf = dist.get_world_size() - pmodule = DDP(pmodule, gradient_predivide_factor=gpf) - - # Convert batchnorm modules to synced if applicable - if synced_batchnorm and isinstance(pmodule, torch.nn.Module): - world_size = dist.get_world_size() - if synced_batchnorm_groupsize > 0 and world_size % synced_batchnorm_groupsize != 0: - raise ValueError( - f"Synchronized batch norm group size" - f" ({synced_batchnorm_groupsize}) must be 0" - f" or divide total number of GPUs" - f" ({world_size})." + # gpf = 1 + # if gradient_predivide: + # gpf = dist.get_world_size() + # pmodule = DDP(pmodule, gradient_predivide_factor=gpf) # Old Apex Method + + # Per pytorch docs, convert sync bn prior to DDP + if synced_batchnorm: + world_size = dist.get_world_size() + sync_batchnorm_group = None + if synced_batchnorm_groupsize > 0: + if world_size % synced_batchnorm_groupsize != 0: + raise ValueError( + f"Synchronized batch norm group size ({synced_batchnorm_groupsize}) must be 0" + f" or divide total number of GPUs ({world_size})." + ) + sync_batchnorm_group = torch.distributed.new_group(synced_batchnorm_groupsize) + pmodule = nn.SyncBatchNorm.convert_sync_batchnorm( + pmodule, process_group=sync_batchnorm_group ) - process_group = create_syncbn_process_group(synced_batchnorm_groupsize) - pmodule = convert_syncbn(pmodule, process_group=process_group) + + # By default, disable broadcast_buffers. This disables batch norm synchronization on forward + # pass + pmodule = DDP( + pmodule, device_ids=[self.local_rank], broadcast_buffers=False, find_unused_parameters=True + ) + + # # Convert batchnorm modules to synced if applicable + # if synced_batchnorm and isinstance(pmodule, torch.nn.Module): + # world_size = dist.get_world_size() + # if synced_batchnorm_groupsize > 0 and world_size % synced_batchnorm_groupsize != 0: + # raise ValueError( + # f"Synchronized batch norm group size" + # f" ({synced_batchnorm_groupsize}) must be 0" + # f" or divide total number of GPUs" + # f" ({world_size})." + # ) + # process_group = create_syncbn_process_group(synced_batchnorm_groupsize) + # pmodule = convert_syncbn(pmodule, process_group=process_group) self.module_reference_table[key] = ( self.module_reference_table[key][0], @@ -1308,9 +1332,7 @@ def train( } disable_allreduce = batch_counter < (batches_per_step - 1) self.__nm_graph_forward_pass( - call_chain=curr_call_chain, - registered_tensors=registered_tensors, - disable_allreduce=disable_allreduce, + call_chain=curr_call_chain, registered_tensors=registered_tensors, ) curr_tensors_to_optimize = training_loop[self.step % len(training_loop)][1] @@ -1331,19 +1353,31 @@ def train( if nan: continue if self._optim_level in AmpOptimizations and self._optim_level != Optimization.mxprO0: - with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce,) as scaled_loss: + with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce) as scaled_loss: if torch.isnan(scaled_loss).any() or torch.isinf(scaled_loss).any(): if stop_on_nan_loss: raise ValueError('Loss is NaN or inf -' ' exiting') logging.warning('WARNING: Loss is NaN or inf') curr_optimizer.zero_grad() continue - scaled_loss.backward(bps_scale.to(scaled_loss.get_device())) + if disable_allreduce: + with ExitStack() as stack: + for mod in self.get_DDP_modules(curr_call_chain): + stack.enter_context(mod.no_sync()) + scaled_loss.backward(bps_scale.to(scaled_loss.get_device())) + else: + scaled_loss.backward(bps_scale.to(scaled_loss.get_device())) # no AMP optimizations needed else: # multi-GPU, float32 if self._local_rank is not None: - final_loss.backward(bps_scale.to(final_loss.get_device())) + if disable_allreduce: + with ExitStack() as stack: + for mod in self.get_DDP_modules(curr_call_chain): + stack.enter_context(mod.no_sync()) + final_loss.backward(bps_scale.to(final_loss.get_device())) + else: + final_loss.backward(bps_scale.to(final_loss.get_device())) # single device (CPU or GPU) else: # Fix (workaround?) enabling to backpropagate gradiens on CPUs. @@ -1438,3 +1472,13 @@ def infer( use_cache=use_cache, offload_to_cpu=offload_to_cpu, ) + + def get_DDP_modules(self, call_chain): + modules = [] + for ind in range(1, len(call_chain)): + m_id = call_chain[ind][0].unique_instance_id + module = self.module_reference_table[m_id][1] + if isinstance(module, DDP): + modules.append(module) + + return modules diff --git a/nemo/backends/pytorch/common/other.py b/nemo/backends/pytorch/common/other.py index c9b9040dd32c..86cff8ff6221 100644 --- a/nemo/backends/pytorch/common/other.py +++ b/nemo/backends/pytorch/common/other.py @@ -10,8 +10,7 @@ import torch import torch.nn as nn -from nemo.backends.pytorch.nm import TrainableNM -from nemo.core import NeuralModule +from nemo.backends.pytorch.nm import NonTrainableNM, TrainableNM from nemo.core.neural_types import * @@ -21,14 +20,14 @@ def input_ports(self): """Returns definitions of module input ports. """ # return {"input_seq": NeuralType({0: AxisType(TimeTag), 1: AxisType(BatchTag)})} - return {"input_seq": NeuralModule(ChannelType(), ('T', 'B'))} + return {"input_seq": NeuralType(('B', 'T'))} @property def output_ports(self): """Returns definitions of module output ports. """ # return {"outputs": NeuralType({0: AxisType(TimeTag), 1: AxisType(BatchTag), 2: AxisType(ChannelTag),})} - return {"outputs": NeuralType(('T', 'B', 'D'), ChannelType())} + return {"outputs": NeuralType(('B', 'T', 'C'), ChannelType())} def __init__(self, voc_size, hidden_size, dropout=0.0): super().__init__() @@ -39,6 +38,7 @@ def __init__(self, voc_size, hidden_size, dropout=0.0): self.embedding = nn.Embedding(self.voc_size, self.hidden_size) if self.dropout != 0.0: self.embedding_dropout = nn.Dropout(self.dropout) + self.to(self._device) def forward(self, input_seq): embedded = self.embedding(input_seq) @@ -47,7 +47,7 @@ def forward(self, input_seq): return embedded -class ZerosLikeNM(TrainableNM): +class ZerosLikeNM(NonTrainableNM): @property def input_ports(self): """Returns definitions of module input ports. diff --git a/nemo/backends/pytorch/common/rnn.py b/nemo/backends/pytorch/common/rnn.py index fbf7dbb7eb97..ca67154786d0 100644 --- a/nemo/backends/pytorch/common/rnn.py +++ b/nemo/backends/pytorch/common/rnn.py @@ -124,7 +124,7 @@ def __init__( ) self.out = nn.Linear(hidden_size, voc_size) if tie_emb_out_weights: - self.out.weight = self.embedding.weight # Weight tying + self.out.weight = nn.Parameter(self.embedding.weight) # Weight tying self.attention = Attention(hidden_size, attention_method, dropout=attn_dropout) # self.apply(init_weights) diff --git a/nemo/backends/pytorch/module_wrapper.py b/nemo/backends/pytorch/module_wrapper.py index f439a847411d..1a5de7595d6a 100644 --- a/nemo/backends/pytorch/module_wrapper.py +++ b/nemo/backends/pytorch/module_wrapper.py @@ -86,7 +86,7 @@ def set_weights(self, name2weight, name2name_and_transform=None): def tie_weights_with(self, module, weight_names): for name in weight_names: - rsetattr(self._pt_module, name, rgetattr(module, name)) + rsetattr(self._pt_module, name, nn.Parameter(rgetattr(module, name))) @property def num_weights(self): diff --git a/nemo/backends/pytorch/nm.py b/nemo/backends/pytorch/nm.py index e759035f6a9d..5da1a861e903 100644 --- a/nemo/backends/pytorch/nm.py +++ b/nemo/backends/pytorch/nm.py @@ -75,18 +75,18 @@ def tie_weights_with(self, module, weight_names, name2name_and_transform=None): if name2name_and_transform is None: for name in weight_names: - rsetattr(self, name, rgetattr(module, name)) + rsetattr(self, name, nn.Parameter(rgetattr(module, name))) else: for self_w_name in weight_names: if self_w_name in name2name_and_transform: if name2name_and_transform[self_w_name][1] == WeightShareTransform.SAME: rsetattr( - self, self_w_name, rgetattr(module, name2name_and_transform[self_w_name][0]), + self, self_w_name, nn.Parameter(rgetattr(module, name2name_and_transform[self_w_name][0])), ) elif name2name_and_transform[self_w_name][1] == WeightShareTransform.TRANSPOSE: raise NotImplementedError("Sorry, currently this is not implemented.") else: - rsetattr(self, self_w_name, rgetattr(module, self_w_name)) + rsetattr(self, self_w_name, nn.Parameter(rgetattr(module, self_w_name))) @t.jit.ignore def save_to(self, path): diff --git a/nemo/core/neural_types/neural_type.py b/nemo/core/neural_types/neural_type.py index a2070c354b3c..80bda4aa01d9 100644 --- a/nemo/core/neural_types/neural_type.py +++ b/nemo/core/neural_types/neural_type.py @@ -21,7 +21,6 @@ 'NeuralTypeError', 'NeuralPortNameMismatchError', 'NeuralPortNmTensorMismatchError', - 'NeuralPortNmTensorMismatchError', 'CanNotInferResultNeuralType', ] import uuid @@ -46,6 +45,12 @@ class NeuralType(object): type can be optional. """ + def __str__(self): + return ( + f"axes: {[(c.kind, c.size, c.is_list) for c in self.axes]}\n" + f"elements_type: {self.elements_type.__class__.__name__}" + ) + def __init__(self, axes: Optional[Tuple] = None, elements_type: ElementType = VoidType(), optional=False): if not isinstance(elements_type, ElementType): raise ValueError( diff --git a/tests/asr/test_asr.py b/tests/asr/test_asr.py index 86bec0de6d63..38bd05826ee8 100644 --- a/tests/asr/test_asr.py +++ b/tests/asr/test_asr.py @@ -93,13 +93,13 @@ def setUpClass(cls) -> None: else: logging.info("ASR data found in: {0}".format(os.path.join(data_folder, "asr"))) - @classmethod - def tearDownClass(cls) -> None: - super().tearDownClass() - data_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/")) - logging.info("Looking up for test ASR data") - if os.path.exists(os.path.join(data_folder, "asr")): - shutil.rmtree(os.path.join(data_folder, "asr")) + # @classmethod + # def tearDownClass(cls) -> None: + # super().tearDownClass() + # data_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/")) + # logging.info("Looking up for test ASR data") + # if os.path.exists(os.path.join(data_folder, "asr")): + # shutil.rmtree(os.path.join(data_folder, "asr")) def test_transcript_normalizers(self): # Create test json diff --git a/tests/asr/test_zeroDS.py b/tests/asr/test_zeroDS.py index a413e1f2e514..2a7b05e14b55 100644 --- a/tests/asr/test_zeroDS.py +++ b/tests/asr/test_zeroDS.py @@ -78,13 +78,13 @@ def setUpClass(cls) -> None: else: logging.info("ASR data found in: {0}".format(os.path.join(data_folder, "asr"))) - @classmethod - def tearDownClass(cls) -> None: - super().tearDownClass() - data_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/")) - logging.info("Looking up for test ASR data") - if os.path.exists(os.path.join(data_folder, "asr")): - shutil.rmtree(os.path.join(data_folder, "asr")) + # @classmethod + # def tearDownClass(cls) -> None: + # super().tearDownClass() + # data_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/")) + # logging.info("Looking up for test ASR data") + # if os.path.exists(os.path.join(data_folder, "asr")): + # shutil.rmtree(os.path.join(data_folder, "asr")) def test_asr_with_zero_ds(self): logging.info("Testing ASR NMs with ZeroDS and without pre-processing") diff --git a/tests/core/test_deploy_export.py b/tests/core/test_deploy_export.py index 7c7959e004c8..303209084802 100644 --- a/tests/core/test_deploy_export.py +++ b/tests/core/test_deploy_export.py @@ -33,10 +33,11 @@ import nemo import nemo.collections.asr as nemo_asr -import nemo.collections.nlp as nemo_nlp import nemo.collections.nlp.nm.trainables.common.token_classification_nm from tests.common_setup import NeMoUnitTest +logging = nemo.logging + class TestDeployExport(NeMoUnitTest): def setUp(self): @@ -82,7 +83,7 @@ def __test_export_route(self, module, out_name, mode, input_example=None): inputs[input_name] = ( input_example[i].cpu().numpy() if isinstance(input_example, tuple) else input_example.cpu().numpy() ) - print('Execution Providers: ', ort_session.get_providers()) + logging.info('Execution Providers: ', ort_session.get_providers()) outputs_scr = ort_session.run(None, inputs) outputs_scr = torch.from_numpy(outputs_scr[0]).cuda() elif mode == nemo.core.DeploymentFormat.TORCHSCRIPT: @@ -106,6 +107,8 @@ def __test_export_route(self, module, out_name, mode, input_example=None): if out.exists(): os.remove(out) + if mode == nemo.core.DeploymentFormat.PYTORCH and out.with_suffix(out.suffix + ".json").exists(): + os.remove(out.with_suffix(out.suffix + ".json")) def __test_export_route_all(self, module, out_name, input_example=None): if input_example is not None: diff --git a/tests/core/test_weight_share.py b/tests/core/test_weight_share.py index 92f82ce18061..6317052ae77d 100644 --- a/tests/core/test_weight_share.py +++ b/tests/core/test_weight_share.py @@ -1,220 +1,336 @@ -# # ! /usr/bin/python -# # -*- coding: utf-8 -*- +# ! /usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright 2019 NVIDIA. All Rights Reserved. # -# # Copyright 2019 NVIDIA. All Rights Reserved. -# # -# # Licensed under the Apache License, Version 2.0 (the "License"); -# # you may not use this file except in compliance with the License. -# # You may obtain a copy of the License at -# # -# # http://www.apache.org/licenses/LICENSE-2.0 -# # -# # Unless required by applicable law or agreed to in writing, software -# # distributed under the License is distributed on an "AS IS" BASIS, -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# # See the License for the specific language governing permissions and -# # limitations under the License. -# # ============================================================================= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# import os -# import shutil -# import tarfile -# import unittest -# from typing import Dict +# http://www.apache.org/licenses/LICENSE-2.0 # -# import numpy as np -# import torch -# from ruamel.yaml import YAML -# -# import nemo -# import nemo.collections.asr as nemo_asr -# from nemo.core import WeightShareTransform -# from nemo.core.neural_types import * -# from tests.common_setup import NeMoUnitTest -# -# logging = nemo.logging -# -# -# class TestWeightSharing(NeMoUnitTest): -# labels = [ -# "'", -# "a", -# "b", -# "c", -# "d", -# "e", -# "f", -# "g", -# "h", -# "i", -# "j", -# "k", -# "l", -# "m", -# "n", -# "o", -# "p", -# "q", -# "r", -# "s", -# "t", -# "u", -# "v", -# "w", -# "x", -# "y", -# "z", -# " ", -# ] -# manifest_filepath = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/asr/an4_train.json")) -# featurizer_config = { -# 'window': 'hann', -# 'dither': 1e-05, -# 'normalize': 'per_feature', -# 'frame_splicing': 1, -# 'int_values': False, -# 'window_stride': 0.01, -# 'sample_rate': 16000, -# 'features': 64, -# 'n_fft': 512, -# 'window_size': 0.02, -# } -# yaml = YAML(typ="safe") -# -# @classmethod -# def setUpClass(cls) -> None: -# super().setUpClass() -# data_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/")) -# logging.info("Looking up for test ASR data") -# if not os.path.exists(os.path.join(data_folder, "asr")): -# logging.info("Extracting ASR data to: {0}".format(os.path.join(data_folder, "asr"))) -# tar = tarfile.open(os.path.join(data_folder, "asr.tar.gz"), "r:gz") -# tar.extractall(path=data_folder) -# tar.close() -# else: -# logging.info("ASR data found in: {0}".format(os.path.join(data_folder, "asr"))) -# -# @classmethod -# def tearDownClass(cls) -> None: -# super().tearDownClass() -# data_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/")) -# logging.info("Looking up for test ASR data") -# if os.path.exists(os.path.join(data_folder, "asr")): -# shutil.rmtree(os.path.join(data_folder, "asr")) -# -# def __check_if_weights_are_equal(self, w1: Dict, w2: Dict): -# all_same = set(w1.keys()) == set(w2.keys()) -# if not all_same: -# return False -# else: -# for key in w1.keys(): -# all_same = all_same and np.array_equal( -# w1[key][0].cpu().detach().numpy(), w2[key][0].cpu().detach().numpy(), -# ) -# return all_same -# -# def test_TaylorNet_get_weights(self): -# tn1 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) -# tn2 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) -# # because of randomness, actual weights should be different -# self.assertFalse(self.__check_if_weights_are_equal(tn1.get_weights(), tn2.get_weights())) -# tn3 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) -# tn3.set_weights(tn1.get_weights()) -# # check than weights are the same -# self.assertTrue(self.__check_if_weights_are_equal(tn1.get_weights(), tn3.get_weights())) -# # change weights on one module - another module should not change -# tn1.fc1.bias.data = torch.tensor([0.1]) -# self.assertFalse(self.__check_if_weights_are_equal(tn1.get_weights(), tn3.get_weights())) -# -# def test_TaylorNet_tie_weights(self): -# tn1 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) -# tn2 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) -# # because of randomness, actual weights should be different -# self.assertFalse(self.__check_if_weights_are_equal(tn1.get_weights(), tn2.get_weights())) -# tn2.tie_weights_with(tn1, list(tn1.get_weights().keys())) -# # change weights on one module - another module should change too -# tn1.fc1.bias.data = torch.tensor([0.1]) -# self.assertTrue(self.__check_if_weights_are_equal(tn1.get_weights(), tn2.get_weights())) -# -# def test_tie_weights2(self): -# voc_size = 3 -# dim = 2 -# embd = nemo.backends.pytorch.common.SequenceEmbedding(voc_size=voc_size, hidden_size=dim) -# proj = nemo.backends.pytorch.common.SequenceProjection(from_dim=dim, to_dim=voc_size) -# embd.tie_weights_with( -# proj, -# weight_names=["embedding.weight"], -# name2name_and_transform={"embedding.weight": ("projection.weight", WeightShareTransform.SAME,)}, -# ) -# self.assertTrue( -# np.array_equal(embd.embedding.weight.detach().numpy(), proj.projection.weight.detach().numpy(),) -# ) -# was = embd.embedding.weight.detach().numpy() -# embd.embedding.weight.data = torch.tensor(np.random.randint(0, 10, (3, 2)) * 1.0) -# after = embd.embedding.weight.detach().numpy() -# self.assertTrue( -# np.array_equal(embd.embedding.weight.detach().numpy(), proj.projection.weight.detach().numpy(),) -# ) -# self.assertFalse(np.array_equal(was, after)) -# -# def test_set_weights(self): -# voc_size = 3 -# dim = 2 -# embd = nemo.backends.pytorch.common.SequenceEmbedding(voc_size=voc_size, hidden_size=dim) -# weights = torch.tensor(np.random.randint(0, 10, (3, 2)) * 1.0) -# name2weights = {"embedding.weight": (weights, True)} -# embd.set_weights(name2weight=name2weights) -# self.assertTrue(np.array_equal(embd.embedding.weight.detach().numpy(), weights.detach().numpy(),)) -# weights = torch.tensor(np.random.randint(0, 10, (3, 2)) * 1.0) -# self.assertFalse(np.array_equal(embd.embedding.weight.detach().numpy(), weights.detach().numpy(),)) -# -# def test_freeze_unfreeze_TrainableNM(self): -# path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/jasper_smaller.yaml")) -# with open(path) as file: -# jasper_model_definition = self.yaml.load(file) -# dl = nemo_asr.AudioToTextDataLayer( -# # featurizer_config=self.featurizer_config, -# manifest_filepath=self.manifest_filepath, -# labels=self.labels, -# batch_size=4, -# ) -# pre_process_params = { -# #'int_values': False, -# 'frame_splicing': 1, -# 'features': 64, -# 'window_size': 0.02, -# 'n_fft': 512, -# 'dither': 1e-05, -# 'window': 'hann', -# 'sample_rate': 16000, -# 'normalize': 'per_feature', -# 'window_stride': 0.01, -# } -# preprocessing = nemo_asr.AudioToMelSpectrogramPreprocessor(**pre_process_params) -# jasper_encoder = nemo_asr.JasperEncoder( -# feat_in=jasper_model_definition['AudioToMelSpectrogramPreprocessor']['features'], -# **jasper_model_definition['JasperEncoder'], -# ) -# jasper_decoder = nemo_asr.JasperDecoderForCTC(feat_in=1024, num_classes=len(self.labels)) -# ctc_loss = nemo_asr.CTCLossNM(num_classes=len(self.labels)) -# jasper_encoder.freeze() -# jasper_encoder.unfreeze(set(['encoder.4.conv.1.weight'])) -# jasper_decoder.unfreeze() -# # DAG -# audio_signal, a_sig_length, transcript, transcript_len = dl() -# processed_signal, p_length = preprocessing(input_signal=audio_signal, length=a_sig_length) -# -# encoded, encoded_len = jasper_encoder(audio_signal=processed_signal, length=p_length) -# # logging.info(jasper_encoder) -# log_probs = jasper_decoder(encoder_output=encoded) -# loss = ctc_loss( -# log_probs=log_probs, targets=transcript, input_length=encoded_len, target_length=transcript_len, -# ) -# -# callback = nemo.core.SimpleLossLoggerCallback( -# tensors=[loss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), -# ) -# optimizer = self.nf.get_trainer() -# optimizer.train( -# [loss], callbacks=[callback], optimizer="sgd", optimization_params={"num_epochs": 2, "lr": 0.0003}, -# ) +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import os +import shutil +import tarfile +from typing import Dict + +import numpy as np +import torch +from ruamel.yaml import YAML + +import nemo +import nemo.collections.asr as nemo_asr +from nemo.backends.pytorch.nm import DataLayerNM +from nemo.collections.nlp.nm.losses import PaddedSmoothedCrossEntropyLossNM +from nemo.collections.nlp.nm.trainables.common import TokenClassifier +from nemo.core import WeightShareTransform +from nemo.core.neural_types import * +from tests.common_setup import NeMoUnitTest + +logging = nemo.logging + + +class TestWeightSharing(NeMoUnitTest): + labels = [ + "'", + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z", + " ", + ] + manifest_filepath = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/asr/an4_train.json")) + featurizer_config = { + 'window': 'hann', + 'dither': 1e-05, + 'normalize': 'per_feature', + 'frame_splicing': 1, + 'int_values': False, + 'window_stride': 0.01, + 'sample_rate': 16000, + 'features': 64, + 'n_fft': 512, + 'window_size': 0.02, + } + yaml = YAML(typ="safe") + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + data_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/")) + logging.info("Looking up for test ASR data") + if not os.path.exists(os.path.join(data_folder, "asr")): + logging.info("Extracting ASR data to: {0}".format(os.path.join(data_folder, "asr"))) + tar = tarfile.open(os.path.join(data_folder, "asr.tar.gz"), "r:gz") + tar.extractall(path=data_folder) + tar.close() + else: + logging.info("ASR data found in: {0}".format(os.path.join(data_folder, "asr"))) + + # @classmethod + # def tearDownClass(cls) -> None: + # super().tearDownClass() + # data_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/")) + # logging.info("Looking up for test ASR data") + # if os.path.exists(os.path.join(data_folder, "asr")): + # logging.info("Removing test ASR data") + # shutil.rmtree(os.path.join(data_folder, "asr")) + + def __check_if_weights_are_equal(self, w1: Dict, w2: Dict): + all_same = set(w1.keys()) == set(w2.keys()) + if not all_same: + return False + else: + for key in w1.keys(): + all_same = all_same and np.array_equal( + w1[key][0].cpu().detach().numpy(), w2[key][0].cpu().detach().numpy(), + ) + return all_same + + def test_TaylorNet_get_weights(self): + tn1 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) + tn2 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) + # because of randomness, actual weights should be different + self.assertFalse(self.__check_if_weights_are_equal(tn1.get_weights(), tn2.get_weights())) + tn3 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) + tn3.set_weights(tn1.get_weights()) + # check than weights are the same + self.assertTrue(self.__check_if_weights_are_equal(tn1.get_weights(), tn3.get_weights())) + # change weights on one module - another module should not change + tn1.fc1.bias.data = torch.tensor([0.1]) + self.assertFalse(self.__check_if_weights_are_equal(tn1.get_weights(), tn3.get_weights())) + + # def test_TaylorNet_tie_weights(self): + # tn1 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) + # tn2 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) + # # because of randomness, actual weights should be different + # self.assertFalse(self.__check_if_weights_are_equal(tn1.get_weights(), tn2.get_weights())) + # tn2.tie_weights_with(tn1, list(tn1.get_weights().keys())) + # # change weights on one module - another module should change too + # tn2.fc1.bias.data = torch.tensor([0.1]) + # self.assertTrue(self.__check_if_weights_are_equal(tn1.get_weights(), tn2.get_weights())) + + def test_tie_weights(self): + class DummyDataLayer(DataLayerNM): + def __init__(self, vocab_size): + super().__init__() + self.vocab_size = vocab_size + + class DummyDS(torch.utils.data.Dataset): + def __init__(self, vocab_size): + super().__init__() + + def __getitem__(self, index): + model_inputs = torch.randint(high=vocab_size, size=[10]) + model_outputs = torch.randint(high=vocab_size, size=[10]) + return (model_inputs, model_outputs) + + def __len__(self): + return 10 + + self._dataset = DummyDS(vocab_size) + + @property + def output_ports(self): + return { + "model_inputs": NeuralType(('B', 'T')), + "model_outputs": NeuralType(('B', 'T'), LabelsType()), + } + + def __len__(self): + return len(self._dataset) + + @property + def dataset(self): + return self._dataset + + def data_iterator(self): + pass + + voc_size = 10 + dim = 10 + embd = nemo.backends.pytorch.common.other.SequenceEmbedding(voc_size=voc_size, hidden_size=dim) + proj = TokenClassifier(hidden_size=dim, num_classes=voc_size) + data = DummyDataLayer(voc_size) + loss = PaddedSmoothedCrossEntropyLossNM(0) + embd.tie_weights_with( + proj, + weight_names=["embedding.weight"], + name2name_and_transform={"embedding.weight": ("mlp.layer2.weight", WeightShareTransform.SAME)}, + ) + self.assertTrue( + np.array_equal(embd.embedding.weight.detach().cpu().numpy(), proj.mlp.layer2.weight.detach().cpu().numpy()) + ) + _in, _out = data() + pred = embd(input_seq=_in) + pred = proj(hidden_states=pred) + loss_t = loss(target_ids=_out, logits=pred) + + self.nf.train( + [loss_t], optimizer="sgd", optimization_params={"max_steps": 5, "lr": 0.0003}, + ) + + self.assertTrue( + np.array_equal(embd.embedding.weight.detach().cpu().numpy(), proj.mlp.layer2.weight.detach().cpu().numpy()) + ) + + def test_untied_weights(self): + class DummyDataLayer(DataLayerNM): + def __init__(self, vocab_size): + super().__init__() + self.vocab_size = vocab_size + + class DummyDS(torch.utils.data.Dataset): + def __init__(self, vocab_size): + super().__init__() + + def __getitem__(self, index): + model_inputs = torch.randint(high=vocab_size, size=[10]) + model_outputs = torch.randint(high=vocab_size, size=[10]) + return (model_inputs, model_outputs) + + def __len__(self): + return 10 + + self._dataset = DummyDS(vocab_size) + + @property + def output_ports(self): + return { + "model_inputs": NeuralType(('B', 'T')), + "model_outputs": NeuralType(('B', 'T'), LabelsType()), + } + + def __len__(self): + return len(self._dataset) + + @property + def dataset(self): + return self._dataset + + def data_iterator(self): + pass + + voc_size = 10 + dim = 10 + embd = nemo.backends.pytorch.common.other.SequenceEmbedding(voc_size=voc_size, hidden_size=dim) + proj = TokenClassifier(hidden_size=dim, num_classes=voc_size) + data = DummyDataLayer(voc_size) + loss = PaddedSmoothedCrossEntropyLossNM(0) + # embd.tie_weights_with( + # proj, + # weight_names=["embedding.weight"], + # name2name_and_transform={"embedding.weight": ("mlp.layer2.weight", WeightShareTransform.SAME)}, + # ) + self.assertFalse( + np.array_equal(embd.embedding.weight.detach().cpu().numpy(), proj.mlp.layer2.weight.detach().cpu().numpy()) + ) + _in, _out = data() + pred = embd(input_seq=_in) + pred = proj(hidden_states=pred) + loss_t = loss(target_ids=_out, logits=pred) + + self.nf.train( + [loss_t], optimizer="sgd", optimization_params={"max_steps": 5, "lr": 0.0003}, + ) + + self.assertFalse( + np.array_equal(embd.embedding.weight.detach().cpu().numpy(), proj.mlp.layer2.weight.detach().cpu().numpy()) + ) + + def test_set_weights(self): + voc_size = 3 + dim = 2 + embd = nemo.backends.pytorch.common.SequenceEmbedding(voc_size=voc_size, hidden_size=dim) + weights = torch.tensor(np.random.randint(0, 10, (3, 2)) * 1.0) + name2weights = {"embedding.weight": (weights, True)} + embd.set_weights(name2weight=name2weights) + self.assertTrue(np.array_equal(embd.embedding.weight.detach().cpu().numpy(), weights.detach().cpu().numpy())) + weights = torch.tensor(np.random.randint(0, 10, (3, 2)) * 1.0) + self.assertFalse(np.array_equal(embd.embedding.weight.detach().cpu().numpy(), weights.detach().cpu().numpy())) + + def test_freeze_unfreeze_TrainableNM(self): + path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/jasper_smaller.yaml")) + with open(path) as file: + jasper_model_definition = self.yaml.load(file) + dl = nemo_asr.AudioToTextDataLayer( + # featurizer_config=self.featurizer_config, + manifest_filepath=self.manifest_filepath, + labels=self.labels, + batch_size=4, + ) + pre_process_params = { + #'int_values': False, + 'frame_splicing': 1, + 'features': 64, + 'window_size': 0.02, + 'n_fft': 512, + 'dither': 1e-05, + 'window': 'hann', + 'sample_rate': 16000, + 'normalize': 'per_feature', + 'window_stride': 0.01, + } + preprocessing = nemo_asr.AudioToMelSpectrogramPreprocessor(**pre_process_params) + jasper_encoder = nemo_asr.JasperEncoder( + feat_in=jasper_model_definition['AudioToMelSpectrogramPreprocessor']['features'], + **jasper_model_definition['JasperEncoder'], + ) + jasper_decoder = nemo_asr.JasperDecoderForCTC(feat_in=1024, num_classes=len(self.labels)) + ctc_loss = nemo_asr.CTCLossNM(num_classes=len(self.labels)) + jasper_encoder.freeze() + jasper_encoder.unfreeze(set(['encoder.4.mconv.0.conv.weight'])) + frozen_weight = jasper_encoder.encoder[1].mconv[0].conv.weight.detach().cpu().numpy() + unfrozen_weight = jasper_encoder.encoder[4].mconv[0].conv.weight.detach().cpu().numpy() + # jasper_decoder.unfreeze() + # DAG + audio_signal, a_sig_length, transcript, transcript_len = dl() + processed_signal, p_length = preprocessing(input_signal=audio_signal, length=a_sig_length) + + encoded, encoded_len = jasper_encoder(audio_signal=processed_signal, length=p_length) + # logging.info(jasper_encoder) + log_probs = jasper_decoder(encoder_output=encoded) + loss = ctc_loss( + log_probs=log_probs, targets=transcript, input_length=encoded_len, target_length=transcript_len, + ) + + callback = nemo.core.SimpleLossLoggerCallback( + tensors=[loss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), + ) + optimizer = self.nf.get_trainer() + optimizer.train( + [loss], callbacks=[callback], optimizer="sgd", optimization_params={"max_steps": 5, "lr": 0.0003}, + ) + new_frozen_weight = jasper_encoder.encoder[1].mconv[0].conv.weight.data + new_unfrozen_weight = jasper_encoder.encoder[4].mconv[0].conv.weight.data + self.assertTrue(np.array_equal(frozen_weight, new_frozen_weight.detach().cpu().numpy())) + self.assertFalse(np.array_equal(unfrozen_weight, new_unfrozen_weight.detach().cpu().numpy()))