From 672fd491327ffed659c118adfa651a52bec95f88 Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 6 Feb 2020 14:56:57 -0800 Subject: [PATCH 01/15] stage 1 of switch DDP Signed-off-by: Jason --- nemo/backends/pytorch/actions.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/nemo/backends/pytorch/actions.py b/nemo/backends/pytorch/actions.py index 4c9db09fd67b..5bb775d351a5 100644 --- a/nemo/backends/pytorch/actions.py +++ b/nemo/backends/pytorch/actions.py @@ -5,11 +5,12 @@ import os from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional +from typing import List, Optional import torch import torch.distributed as dist import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP import torch.optim as optim from nemo import logging @@ -26,7 +27,6 @@ amp = None convert_syncbn = None create_syncbn_process_group = None -DDP = None LARC = None FusedLAMB = None FusedAdam = None @@ -60,7 +60,6 @@ def __init__( if local_rank is not None: global convert_syncbn global create_syncbn_process_group - global DDP global LARC global FusedLAMB global FusedAdam @@ -69,7 +68,6 @@ def __init__( apex_optimizer = importlib.import_module('apex.optimizers') convert_syncbn = parallel.convert_syncbn_model create_syncbn_process_group = parallel.create_syncbn_process_group - DDP = parallel.DistributedDataParallel LARC = parallel.LARC FusedLAMB = apex_optimizer.FusedLAMB FusedAdam = apex_optimizer.FusedAdam From d535720ba325e5833f4521794d07ef546f2386d7 Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 6 Feb 2020 18:10:59 -0800 Subject: [PATCH 02/15] switch from apex to torch Signed-off-by: Jason --- nemo/backends/pytorch/actions.py | 130 ++++++++++++++++++++++--------- 1 file changed, 95 insertions(+), 35 deletions(-) diff --git a/nemo/backends/pytorch/actions.py b/nemo/backends/pytorch/actions.py index 3955648f924d..b364eea442ba 100644 --- a/nemo/backends/pytorch/actions.py +++ b/nemo/backends/pytorch/actions.py @@ -5,6 +5,7 @@ import json import os from collections import defaultdict +from contextlib import contextmanager from pathlib import Path from typing import List, Optional @@ -26,8 +27,8 @@ # these imports will happen on as-needed basis amp = None -convert_syncbn = None -create_syncbn_process_group = None +# convert_syncbn = None +# create_syncbn_process_group = None LARC = None FusedLAMB = None FusedAdam = None @@ -59,16 +60,16 @@ def __init__( global amp amp = importlib.import_module('apex.amp') if local_rank is not None: - global convert_syncbn - global create_syncbn_process_group + # global convert_syncbn + # global create_syncbn_process_group global LARC global FusedLAMB global FusedAdam global FusedNovoGrad parallel = importlib.import_module('apex.parallel') apex_optimizer = importlib.import_module('apex.optimizers') - convert_syncbn = parallel.convert_syncbn_model - create_syncbn_process_group = parallel.create_syncbn_process_group + # convert_syncbn = parallel.convert_syncbn_model + # create_syncbn_process_group = parallel.create_syncbn_process_group LARC = parallel.LARC FusedLAMB = apex_optimizer.FusedLAMB FusedAdam = apex_optimizer.FusedAdam @@ -377,7 +378,7 @@ def __initialize_amp( return optimizer def __nm_graph_forward_pass( - self, call_chain, registered_tensors, mode=ModelMode.train, disable_allreduce=False, use_cache=False, + self, call_chain, registered_tensors, mode=ModelMode.train, use_cache=False, ): for ind in range(1, len(call_chain)): if use_cache: @@ -397,12 +398,12 @@ def __nm_graph_forward_pass( m_id = call_chain[ind][0].unique_instance_id pmodule = self.module_reference_table[m_id][1] - if self._local_rank is not None: - if isinstance(pmodule, DDP): - if disable_allreduce: - pmodule.disable_allreduce() - else: - pmodule.enable_allreduce() + # if self._local_rank is not None: + # if isinstance(pmodule, DDP): + # if disable_allreduce: + # pmodule.disable_allreduce() + # else: + # pmodule.enable_allreduce() if mode == ModelMode.train: # if module.is_trainable(): @@ -1064,6 +1065,11 @@ def train( gradient_predivide=False, amp_max_loss_scale=2.0 ** 24, ): + if gradient_predivide: + logging.error( + "gradient_predivide is currently disabled, and is under consideration for removal in future versions. " + "If this functionality is needed, please raise a github issue." + ) if not optimization_params: optimization_params = {} num_epochs = optimization_params.get("num_epochs", None) @@ -1205,23 +1211,42 @@ def train( key = call_chain[i][0].unique_instance_id pmodule = self.module_reference_table[key][1] if not isinstance(pmodule, DDP) and isinstance(pmodule, torch.nn.Module): - gpf = 1 - if gradient_predivide: - gpf = dist.get_world_size() - pmodule = DDP(pmodule, gradient_predivide_factor=gpf) - - # Convert batchnorm modules to synced if applicable - if synced_batchnorm and isinstance(pmodule, torch.nn.Module): - world_size = dist.get_world_size() - if synced_batchnorm_groupsize > 0 and world_size % synced_batchnorm_groupsize != 0: - raise ValueError( - f"Synchronized batch norm group size" - f" ({synced_batchnorm_groupsize}) must be 0" - f" or divide total number of GPUs" - f" ({world_size})." + # gpf = 1 + # if gradient_predivide: + # gpf = dist.get_world_size() + # pmodule = DDP(pmodule, gradient_predivide_factor=gpf) # Old Apex Method + + # Per pytorch docs, convert sync bn prior to DDP + if synced_batchnorm: + world_size = dist.get_world_size() + sync_batchnorm_group = None + if synced_batchnorm_groupsize > 0: + if world_size % synced_batchnorm_groupsize != 0: + raise ValueError( + f"Synchronized batch norm group size ({synced_batchnorm_groupsize}) must be 0" + f" or divide total number of GPUs ({world_size})." + ) + sync_batchnorm_group = torch.distributed.new_group(synced_batchnorm_groupsize) + pmodule = nn.SyncBatchNorm.convert_sync_batchnorm( + pmodule, process_group=sync_batchnorm_group ) - process_group = create_syncbn_process_group(synced_batchnorm_groupsize) - pmodule = convert_syncbn(pmodule, process_group=process_group) + + # By default, disable broadcast_buffers. This disables batch norm synchronization on forward + # pass + pmodule = DDP(pmodule, device_ids=[self.local_rank], broadcast_buffers=False) + + # # Convert batchnorm modules to synced if applicable + # if synced_batchnorm and isinstance(pmodule, torch.nn.Module): + # world_size = dist.get_world_size() + # if synced_batchnorm_groupsize > 0 and world_size % synced_batchnorm_groupsize != 0: + # raise ValueError( + # f"Synchronized batch norm group size" + # f" ({synced_batchnorm_groupsize}) must be 0" + # f" or divide total number of GPUs" + # f" ({world_size})." + # ) + # process_group = create_syncbn_process_group(synced_batchnorm_groupsize) + # pmodule = convert_syncbn(pmodule, process_group=process_group) self.module_reference_table[key] = ( self.module_reference_table[key][0], @@ -1300,9 +1325,7 @@ def train( } disable_allreduce = batch_counter < (batches_per_step - 1) self.__nm_graph_forward_pass( - call_chain=curr_call_chain, - registered_tensors=registered_tensors, - disable_allreduce=disable_allreduce, + call_chain=curr_call_chain, registered_tensors=registered_tensors, ) curr_tensors_to_optimize = training_loop[self.step % len(training_loop)][1] @@ -1323,19 +1346,27 @@ def train( if nan: continue if self._optim_level in AmpOptimizations and self._optim_level != Optimization.mxprO0: - with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce,) as scaled_loss: + with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce) as scaled_loss: if torch.isnan(scaled_loss).any() or torch.isinf(scaled_loss).any(): if stop_on_nan_loss: raise ValueError('Loss is NaN or inf -' ' exiting') logging.warning('WARNING: Loss is NaN or inf') curr_optimizer.zero_grad() continue - scaled_loss.backward(bps_scale.to(scaled_loss.get_device())) + if disable_allreduce: + with self.no_sync(curr_call_chain): + scaled_loss.backward(bps_scale.to(scaled_loss.get_device())) + else: + scaled_loss.backward(bps_scale.to(scaled_loss.get_device())) # no AMP optimizations needed else: # multi-GPU, float32 if self._local_rank is not None: - final_loss.backward(bps_scale.to(final_loss.get_device())) + if disable_allreduce: + with self.no_sync(curr_call_chain): + final_loss.backward(bps_scale.to(final_loss.get_device())) + else: + final_loss.backward(bps_scale.to(final_loss.get_device())) # single device (CPU or GPU) else: # Fix (workaround?) enabling to backpropagate gradiens on CPUs. @@ -1430,3 +1461,32 @@ def infer( use_cache=use_cache, offload_to_cpu=offload_to_cpu, ) + + @contextmanager + def no_sync(self, call_chain): + """ + Wrapper contextmanager around pytorch DDP's @no_sync since pytorch requires ALL wrapper DDP models to be + inside the @no_sync context manager for graduation accumulation + """ + modules = [] + for ind in range(1, len(call_chain)): + m_id = call_chain[ind][0].unique_instance_id + module = self.module_reference_table[m_id][1] + if isinstance(module, DDP): + modules.append(module) + + @contextmanager + def recursive_yield(list_of_modules): + mod = list_of_modules.pop() + with mod.no_sync() as ctx: + try: + recursive_yield(list_of_modules) + yield [ctx] + finally: + pass + + with recursive_yield(modules): + try: + yield + finally: + pass From a9be01ad3ba038957904c9d71c6cd40e1a6d86e3 Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 6 Feb 2020 18:27:40 -0800 Subject: [PATCH 03/15] simplify contextmangaer Signed-off-by: Jason --- nemo/backends/pytorch/actions.py | 35 ++++++++++++++------------------ 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/nemo/backends/pytorch/actions.py b/nemo/backends/pytorch/actions.py index b364eea442ba..d70b9f52f795 100644 --- a/nemo/backends/pytorch/actions.py +++ b/nemo/backends/pytorch/actions.py @@ -1354,7 +1354,7 @@ def train( curr_optimizer.zero_grad() continue if disable_allreduce: - with self.no_sync(curr_call_chain): + with self.no_sync(self.get_DDP_modules(curr_call_chain)): scaled_loss.backward(bps_scale.to(scaled_loss.get_device())) else: scaled_loss.backward(bps_scale.to(scaled_loss.get_device())) @@ -1363,7 +1363,7 @@ def train( # multi-GPU, float32 if self._local_rank is not None: if disable_allreduce: - with self.no_sync(curr_call_chain): + with self.no_sync(self.get_DDP_modules(curr_call_chain)): final_loss.backward(bps_scale.to(final_loss.get_device())) else: final_loss.backward(bps_scale.to(final_loss.get_device())) @@ -1462,12 +1462,7 @@ def infer( offload_to_cpu=offload_to_cpu, ) - @contextmanager - def no_sync(self, call_chain): - """ - Wrapper contextmanager around pytorch DDP's @no_sync since pytorch requires ALL wrapper DDP models to be - inside the @no_sync context manager for graduation accumulation - """ + def get_DDP_modules(self, callchain): modules = [] for ind in range(1, len(call_chain)): m_id = call_chain[ind][0].unique_instance_id @@ -1475,18 +1470,18 @@ def no_sync(self, call_chain): if isinstance(module, DDP): modules.append(module) - @contextmanager - def recursive_yield(list_of_modules): - mod = list_of_modules.pop() - with mod.no_sync() as ctx: - try: - recursive_yield(list_of_modules) - yield [ctx] - finally: - pass - - with recursive_yield(modules): + return modules + + @contextmanager + def no_sync(self, modules): + """ + Wrapper contextmanager around pytorch DDP's @no_sync since pytorch requires ALL wrapper DDP models to be + inside the @no_sync context manager for graduation accumulation + """ + mod = modules.pop() + with mod.no_sync() as ctx: try: - yield + self.no_sync(modules) + yield [ctx] finally: pass From 71d4bff1303c35574fd35594931e2253b59a7ca6 Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 6 Feb 2020 18:54:17 -0800 Subject: [PATCH 04/15] cleaner version with exitstack as opposed to nested with statements Signed-off-by: Jason --- nemo/backends/pytorch/actions.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/nemo/backends/pytorch/actions.py b/nemo/backends/pytorch/actions.py index d70b9f52f795..26a3d52f6098 100644 --- a/nemo/backends/pytorch/actions.py +++ b/nemo/backends/pytorch/actions.py @@ -5,7 +5,7 @@ import json import os from collections import defaultdict -from contextlib import contextmanager +from contextlib import contextmanager, ExitStack from pathlib import Path from typing import List, Optional @@ -1354,7 +1354,9 @@ def train( curr_optimizer.zero_grad() continue if disable_allreduce: - with self.no_sync(self.get_DDP_modules(curr_call_chain)): + with ExitStack() as stack: + for mod in self.get_DDP_modules(curr_call_chain): + stack.enter_context(mod.no_sync()) scaled_loss.backward(bps_scale.to(scaled_loss.get_device())) else: scaled_loss.backward(bps_scale.to(scaled_loss.get_device())) @@ -1363,7 +1365,9 @@ def train( # multi-GPU, float32 if self._local_rank is not None: if disable_allreduce: - with self.no_sync(self.get_DDP_modules(curr_call_chain)): + with ExitStack() as stack: + for mod in self.get_DDP_modules(curr_call_chain): + stack.enter_context(mod.no_sync()) final_loss.backward(bps_scale.to(final_loss.get_device())) else: final_loss.backward(bps_scale.to(final_loss.get_device())) @@ -1462,7 +1466,7 @@ def infer( offload_to_cpu=offload_to_cpu, ) - def get_DDP_modules(self, callchain): + def get_DDP_modules(self, call_chain): modules = [] for ind in range(1, len(call_chain)): m_id = call_chain[ind][0].unique_instance_id @@ -1471,17 +1475,3 @@ def get_DDP_modules(self, callchain): modules.append(module) return modules - - @contextmanager - def no_sync(self, modules): - """ - Wrapper contextmanager around pytorch DDP's @no_sync since pytorch requires ALL wrapper DDP models to be - inside the @no_sync context manager for graduation accumulation - """ - mod = modules.pop() - with mod.no_sync() as ctx: - try: - self.no_sync(modules) - yield [ctx] - finally: - pass From 723531777915da6aea1ed9b65e63d0c42032cf13 Mon Sep 17 00:00:00 2001 From: Jason Date: Fri, 7 Feb 2020 16:33:45 -0800 Subject: [PATCH 05/15] isort; update changelog Signed-off-by: Jason --- CHANGELOG.md | 4 ++++ nemo/backends/pytorch/actions.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fbf6d6ac532..e7fba3a8748f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -90,6 +90,8 @@ To release a new version, please update the changelog as followed: - Updated licenses - Updated nemo's use of the logging library. from nemo import logging is now the reccomended way of using the nemo logger. neural_factory.logger and all other instances of logger are now deprecated and planned for removal in the next version. Please see PR 267 for complete change information. ([PR #267](https://github.com/NVIDIA/NeMo/pull/267), [PR #283](https://github.com/NVIDIA/NeMo/pull/283), [PR #305](https://github.com/NVIDIA/NeMo/pull/305), [PR #311](https://github.com/NVIDIA/NeMo/pull/311)) - @blisc +- Changed Distributed Data Parallel from Apex to Torch +([PR #336](https://github.com/NVIDIA/NeMo/pull/336)) - @blisc - Added TRADE (dialogue state tracking model) on MultiWOZ dataset ([PR #322](https://github.com/NVIDIA/NeMo/pull/322)) - @chiphuyen, @VahidooX @@ -104,6 +106,8 @@ To release a new version, please update the changelog as followed: ([PR #308](https://github.com/NVIDIA/NeMo/pull/309)) - @tkornuta-nvidia ### Removed +- gradient_predivide_factor arg of train() now has no effect +([PR #336](https://github.com/NVIDIA/NeMo/pull/336)) - @blisc ### Security diff --git a/nemo/backends/pytorch/actions.py b/nemo/backends/pytorch/actions.py index 26a3d52f6098..00158b0c7c50 100644 --- a/nemo/backends/pytorch/actions.py +++ b/nemo/backends/pytorch/actions.py @@ -5,15 +5,15 @@ import json import os from collections import defaultdict -from contextlib import contextmanager, ExitStack +from contextlib import ExitStack, contextmanager from pathlib import Path from typing import List, Optional import torch import torch.distributed as dist import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel as DDP import torch.optim as optim +from torch.nn.parallel import DistributedDataParallel as DDP from nemo import logging from nemo.backends.pytorch.module_wrapper import TrainableNeuralModuleWrapper From f627f4416ebfedd8ff7521b39094c00d19066498 Mon Sep 17 00:00:00 2001 From: Jason Date: Fri, 7 Feb 2020 16:46:27 -0800 Subject: [PATCH 06/15] update ZerosLikeNM to be NonTrainable Signed-off-by: Jason --- nemo/backends/pytorch/actions.py | 2 +- nemo/backends/pytorch/common/other.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/backends/pytorch/actions.py b/nemo/backends/pytorch/actions.py index 00158b0c7c50..864795662dda 100644 --- a/nemo/backends/pytorch/actions.py +++ b/nemo/backends/pytorch/actions.py @@ -5,7 +5,7 @@ import json import os from collections import defaultdict -from contextlib import ExitStack, contextmanager +from contextlib import ExitStack from pathlib import Path from typing import List, Optional diff --git a/nemo/backends/pytorch/common/other.py b/nemo/backends/pytorch/common/other.py index 80d43dadae15..45e04f94c1aa 100644 --- a/nemo/backends/pytorch/common/other.py +++ b/nemo/backends/pytorch/common/other.py @@ -15,7 +15,7 @@ import torch import torch.nn as nn -from nemo.backends.pytorch.nm import TrainableNM +from nemo.backends.pytorch.nm import TrainableNM, NonTrainableNM from nemo.core import NeuralModule from nemo.core.neural_types import * @@ -328,7 +328,7 @@ def forward(self, input_seq): return p -class ZerosLikeNM(TrainableNM): +class ZerosLikeNM(NonTrainableNM): @property def input_ports(self): """Returns definitions of module input ports. From b85cb40792c5c5c259e091c27348212d4867b3fd Mon Sep 17 00:00:00 2001 From: Jason Date: Fri, 7 Feb 2020 17:28:55 -0800 Subject: [PATCH 07/15] style Signed-off-by: Jason --- nemo/backends/pytorch/common/other.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/backends/pytorch/common/other.py b/nemo/backends/pytorch/common/other.py index 45e04f94c1aa..e7f6fd5d9cf0 100644 --- a/nemo/backends/pytorch/common/other.py +++ b/nemo/backends/pytorch/common/other.py @@ -15,7 +15,7 @@ import torch import torch.nn as nn -from nemo.backends.pytorch.nm import TrainableNM, NonTrainableNM +from nemo.backends.pytorch.nm import NonTrainableNM, TrainableNM from nemo.core import NeuralModule from nemo.core.neural_types import * From c61ea90341c183a2d9cddf60a9fec26b439554bf Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 13 Feb 2020 12:57:18 -0800 Subject: [PATCH 08/15] weight sharing fix attempt Signed-off-by: Jason --- .../asr_postprocessor/asr_postprocessor.py | 31 +++++++++++++++++-- .../nlp/language_modeling/bert_pretraining.py | 4 +++ .../language_modeling_transformer.py | 3 ++ .../machine_translation_tutorial.py | 25 +++++++++++++-- nemo/backends/pytorch/common/rnn.py | 2 +- nemo/backends/pytorch/nm.py | 6 ++-- 6 files changed, 62 insertions(+), 9 deletions(-) diff --git a/examples/nlp/asr_postprocessor/asr_postprocessor.py b/examples/nlp/asr_postprocessor/asr_postprocessor.py index 204e9db5664f..c4be5d5f033c 100644 --- a/examples/nlp/asr_postprocessor/asr_postprocessor.py +++ b/examples/nlp/asr_postprocessor/asr_postprocessor.py @@ -27,6 +27,7 @@ eval_iter_callback, ) from nemo.core.callbacks import CheckpointCallback +from nemo.core import WeightShareTransform from nemo.utils.lr_policies import SquareAnnealing parser = nemo.utils.NemoArgParser(description='ASR postprocessor') @@ -126,9 +127,33 @@ ) # tie all embeddings weights -t_log_softmax.mlp.layer0.weight = encoder.bert.embeddings.word_embeddings.weight -decoder.embedding_layer.token_embedding.weight = encoder.bert.embeddings.word_embeddings.weight -decoder.embedding_layer.position_embedding.weight = encoder.bert.embeddings.position_embeddings.weight +# t_log_softmax.mlp.layer0.weight = encoder.bert.embeddings.word_embeddings.weight +t_log_softmax.tie_weights_with( + encoder, + weight_names=["mlp.layer0.weight"], + name2name_and_transform={ + "mlp.layer0.weight": ("bert.embeddings.word_embeddings.weight", WeightShareTransform.SAME) + }, +) +# decoder.embedding_layer.token_embedding.weight = encoder.bert.embeddings.word_embeddings.weight +decoder.tie_weights_with( + encoder, + weight_names=["embedding_layer.token_embedding.weight"], + name2name_and_transform={ + "embedding_layer.token_embedding.weight": ("bert.embeddings.word_embeddings.weight", WeightShareTransform.SAME) + }, +) +# decoder.embedding_layer.position_embedding.weight = encoder.bert.embeddings.position_embeddings.weight +decoder.tie_weights_with( + encoder, + weight_names=["embedding_layer.position_embedding.weight"], + name2name_and_transform={ + "embedding_layer.position_embedding.weight": ( + "bert.embeddings.position_embeddings.weight", + WeightShareTransform.SAME, + ) + }, +) def create_pipeline(dataset, tokens_in_batch, clean=False, training=True): diff --git a/examples/nlp/language_modeling/bert_pretraining.py b/examples/nlp/language_modeling/bert_pretraining.py index eaf40dc454ed..3ba9184abe98 100644 --- a/examples/nlp/language_modeling/bert_pretraining.py +++ b/examples/nlp/language_modeling/bert_pretraining.py @@ -14,6 +14,10 @@ # limitations under the License. # ============================================================================= + +# Todo: fix weight tying + + """ To pretrain BERT on raw text dataset run diff --git a/examples/nlp/language_modeling/language_modeling_transformer.py b/examples/nlp/language_modeling/language_modeling_transformer.py index d49040949538..285487f0dfdb 100644 --- a/examples/nlp/language_modeling/language_modeling_transformer.py +++ b/examples/nlp/language_modeling/language_modeling_transformer.py @@ -14,6 +14,9 @@ # limitations under the License. # ============================================================================= +# Todo: fix weight tying + + import math import nemo diff --git a/examples/nlp/neural_machine_translation/machine_translation_tutorial.py b/examples/nlp/neural_machine_translation/machine_translation_tutorial.py index 8cda90810521..06a1af674365 100644 --- a/examples/nlp/neural_machine_translation/machine_translation_tutorial.py +++ b/examples/nlp/neural_machine_translation/machine_translation_tutorial.py @@ -19,12 +19,16 @@ https://nvidia.github.io/NeMo/nlp/ neural-machine-translation.html#translation-with-pretrained-model """ + +# Todo: fix weight tying + import torch import nemo import nemo.collections.nlp as nemo_nlp from nemo.collections.nlp.callbacks.machine_translation_callback import eval_epochs_done_callback, eval_iter_callback from nemo.utils.lr_policies import get_lr_policy +from nemo.core import WeightShareTransform parser = nemo.utils.NemoArgParser(description='Transformer for Neural Machine Translation') parser.set_defaults( @@ -165,8 +169,25 @@ ) if tie_weight: - log_softmax.mlp.last_linear_layer.weight = encoder.embedding_layer.token_embedding.weight - decoder.embedding_layer.token_embedding.weight = encoder.embedding_layer.token_embedding.weight + # log_softmax.mlp.last_linear_layer.weight = encoder.embedding_layer.token_embedding.weight + log_softmax.tie_weights_with( + encoder, + weight_names=["mlp.last_linear_layer.weight"], + name2name_and_transform={ + "mlp.last_linear_layer.weight": ("embedding_layer.token_embedding.weight", WeightShareTransform.SAME) + }, + ) + # decoder.embedding_layer.token_embedding.weight = encoder.embedding_layer.token_embedding.weight + decoder.tie_weights_with( + encoder, + weight_names=["embedding_layer.token_embedding.weight"], + name2name_and_transform={ + "embedding_layer.token_embedding.weight": ( + "embedding_layer.token_embedding.weight", + WeightShareTransform.SAME, + ) + }, + ) def create_pipeline(dataset_src, dataset_tgt, tokens_in_batch, clean=False, training=True): diff --git a/nemo/backends/pytorch/common/rnn.py b/nemo/backends/pytorch/common/rnn.py index fbf7dbb7eb97..ca67154786d0 100644 --- a/nemo/backends/pytorch/common/rnn.py +++ b/nemo/backends/pytorch/common/rnn.py @@ -124,7 +124,7 @@ def __init__( ) self.out = nn.Linear(hidden_size, voc_size) if tie_emb_out_weights: - self.out.weight = self.embedding.weight # Weight tying + self.out.weight = nn.Parameter(self.embedding.weight) # Weight tying self.attention = Attention(hidden_size, attention_method, dropout=attn_dropout) # self.apply(init_weights) diff --git a/nemo/backends/pytorch/nm.py b/nemo/backends/pytorch/nm.py index e759035f6a9d..5da1a861e903 100644 --- a/nemo/backends/pytorch/nm.py +++ b/nemo/backends/pytorch/nm.py @@ -75,18 +75,18 @@ def tie_weights_with(self, module, weight_names, name2name_and_transform=None): if name2name_and_transform is None: for name in weight_names: - rsetattr(self, name, rgetattr(module, name)) + rsetattr(self, name, nn.Parameter(rgetattr(module, name))) else: for self_w_name in weight_names: if self_w_name in name2name_and_transform: if name2name_and_transform[self_w_name][1] == WeightShareTransform.SAME: rsetattr( - self, self_w_name, rgetattr(module, name2name_and_transform[self_w_name][0]), + self, self_w_name, nn.Parameter(rgetattr(module, name2name_and_transform[self_w_name][0])), ) elif name2name_and_transform[self_w_name][1] == WeightShareTransform.TRANSPOSE: raise NotImplementedError("Sorry, currently this is not implemented.") else: - rsetattr(self, self_w_name, rgetattr(module, self_w_name)) + rsetattr(self, self_w_name, nn.Parameter(rgetattr(module, self_w_name))) @t.jit.ignore def save_to(self, path): From f1a57bbb9367d3acf78d5f703d82f8b4bb08097a Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 13 Feb 2020 12:58:52 -0800 Subject: [PATCH 09/15] isorT Signed-off-by: Jason --- examples/nlp/asr_postprocessor/asr_postprocessor.py | 2 +- .../neural_machine_translation/machine_translation_tutorial.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/nlp/asr_postprocessor/asr_postprocessor.py b/examples/nlp/asr_postprocessor/asr_postprocessor.py index c4be5d5f033c..f91638b02d74 100644 --- a/examples/nlp/asr_postprocessor/asr_postprocessor.py +++ b/examples/nlp/asr_postprocessor/asr_postprocessor.py @@ -26,8 +26,8 @@ eval_epochs_done_callback_wer, eval_iter_callback, ) -from nemo.core.callbacks import CheckpointCallback from nemo.core import WeightShareTransform +from nemo.core.callbacks import CheckpointCallback from nemo.utils.lr_policies import SquareAnnealing parser = nemo.utils.NemoArgParser(description='ASR postprocessor') diff --git a/examples/nlp/neural_machine_translation/machine_translation_tutorial.py b/examples/nlp/neural_machine_translation/machine_translation_tutorial.py index 06a1af674365..997aa79ccd34 100644 --- a/examples/nlp/neural_machine_translation/machine_translation_tutorial.py +++ b/examples/nlp/neural_machine_translation/machine_translation_tutorial.py @@ -27,8 +27,8 @@ import nemo import nemo.collections.nlp as nemo_nlp from nemo.collections.nlp.callbacks.machine_translation_callback import eval_epochs_done_callback, eval_iter_callback -from nemo.utils.lr_policies import get_lr_policy from nemo.core import WeightShareTransform +from nemo.utils.lr_policies import get_lr_policy parser = nemo.utils.NemoArgParser(description='Transformer for Neural Machine Translation') parser.set_defaults( From dabaea2cc940f306c3019a31f89945df1448210d Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 13 Feb 2020 13:10:14 -0800 Subject: [PATCH 10/15] update DDP call Signed-off-by: Jason --- nemo/backends/pytorch/actions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/backends/pytorch/actions.py b/nemo/backends/pytorch/actions.py index 7115eebf41af..faad7a28436f 100644 --- a/nemo/backends/pytorch/actions.py +++ b/nemo/backends/pytorch/actions.py @@ -1239,7 +1239,9 @@ def train( # By default, disable broadcast_buffers. This disables batch norm synchronization on forward # pass - pmodule = DDP(pmodule, device_ids=[self.local_rank], broadcast_buffers=False) + pmodule = DDP( + pmodule, device_ids=[self.local_rank], broadcast_buffers=False, find_unused_parameters=True + ) # # Convert batchnorm modules to synced if applicable # if synced_batchnorm and isinstance(pmodule, torch.nn.Module): From 4eb90b5f4960461e58921dc322ea8b88da18ce08 Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 13 Feb 2020 13:44:59 -0800 Subject: [PATCH 11/15] update unittests Signed-off-by: Jason --- nemo/backends/pytorch/actions.py | 5 +- tests/asr/test_asr.py | 14 +- tests/asr/test_zeroDS.py | 14 +- tests/core/test_deploy_export.py | 7 +- tests/core/test_weight_share.py | 446 ++++++++++++++++--------------- 5 files changed, 247 insertions(+), 239 deletions(-) diff --git a/nemo/backends/pytorch/actions.py b/nemo/backends/pytorch/actions.py index faad7a28436f..95ed41e553cf 100644 --- a/nemo/backends/pytorch/actions.py +++ b/nemo/backends/pytorch/actions.py @@ -934,9 +934,8 @@ def __extract_dynamic_axes(port_name: str, ntype: NeuralType, dynamic_axes: defa outputs_to_drop = set() if type(module).__name__ == "JasperEncoder": logging.info( - f"Module is JasperEncoder. We are removing" - f"input and output length ports since they " - f"are not needed for deployment" + "Module is JasperEncoder. We are removing input and output length ports since they are not needed for " + "deployment" ) inputs_to_drop.add("length") outputs_to_drop.add("encoded_lengths") diff --git a/tests/asr/test_asr.py b/tests/asr/test_asr.py index 86bec0de6d63..38bd05826ee8 100644 --- a/tests/asr/test_asr.py +++ b/tests/asr/test_asr.py @@ -93,13 +93,13 @@ def setUpClass(cls) -> None: else: logging.info("ASR data found in: {0}".format(os.path.join(data_folder, "asr"))) - @classmethod - def tearDownClass(cls) -> None: - super().tearDownClass() - data_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/")) - logging.info("Looking up for test ASR data") - if os.path.exists(os.path.join(data_folder, "asr")): - shutil.rmtree(os.path.join(data_folder, "asr")) + # @classmethod + # def tearDownClass(cls) -> None: + # super().tearDownClass() + # data_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/")) + # logging.info("Looking up for test ASR data") + # if os.path.exists(os.path.join(data_folder, "asr")): + # shutil.rmtree(os.path.join(data_folder, "asr")) def test_transcript_normalizers(self): # Create test json diff --git a/tests/asr/test_zeroDS.py b/tests/asr/test_zeroDS.py index a413e1f2e514..2a7b05e14b55 100644 --- a/tests/asr/test_zeroDS.py +++ b/tests/asr/test_zeroDS.py @@ -78,13 +78,13 @@ def setUpClass(cls) -> None: else: logging.info("ASR data found in: {0}".format(os.path.join(data_folder, "asr"))) - @classmethod - def tearDownClass(cls) -> None: - super().tearDownClass() - data_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/")) - logging.info("Looking up for test ASR data") - if os.path.exists(os.path.join(data_folder, "asr")): - shutil.rmtree(os.path.join(data_folder, "asr")) + # @classmethod + # def tearDownClass(cls) -> None: + # super().tearDownClass() + # data_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/")) + # logging.info("Looking up for test ASR data") + # if os.path.exists(os.path.join(data_folder, "asr")): + # shutil.rmtree(os.path.join(data_folder, "asr")) def test_asr_with_zero_ds(self): logging.info("Testing ASR NMs with ZeroDS and without pre-processing") diff --git a/tests/core/test_deploy_export.py b/tests/core/test_deploy_export.py index 7c7959e004c8..303209084802 100644 --- a/tests/core/test_deploy_export.py +++ b/tests/core/test_deploy_export.py @@ -33,10 +33,11 @@ import nemo import nemo.collections.asr as nemo_asr -import nemo.collections.nlp as nemo_nlp import nemo.collections.nlp.nm.trainables.common.token_classification_nm from tests.common_setup import NeMoUnitTest +logging = nemo.logging + class TestDeployExport(NeMoUnitTest): def setUp(self): @@ -82,7 +83,7 @@ def __test_export_route(self, module, out_name, mode, input_example=None): inputs[input_name] = ( input_example[i].cpu().numpy() if isinstance(input_example, tuple) else input_example.cpu().numpy() ) - print('Execution Providers: ', ort_session.get_providers()) + logging.info('Execution Providers: ', ort_session.get_providers()) outputs_scr = ort_session.run(None, inputs) outputs_scr = torch.from_numpy(outputs_scr[0]).cuda() elif mode == nemo.core.DeploymentFormat.TORCHSCRIPT: @@ -106,6 +107,8 @@ def __test_export_route(self, module, out_name, mode, input_example=None): if out.exists(): os.remove(out) + if mode == nemo.core.DeploymentFormat.PYTORCH and out.with_suffix(out.suffix + ".json").exists(): + os.remove(out.with_suffix(out.suffix + ".json")) def __test_export_route_all(self, module, out_name, input_example=None): if input_example is not None: diff --git a/tests/core/test_weight_share.py b/tests/core/test_weight_share.py index 92f82ce18061..f4468bf596fb 100644 --- a/tests/core/test_weight_share.py +++ b/tests/core/test_weight_share.py @@ -1,220 +1,226 @@ -# # ! /usr/bin/python -# # -*- coding: utf-8 -*- -# -# # Copyright 2019 NVIDIA. All Rights Reserved. -# # -# # Licensed under the Apache License, Version 2.0 (the "License"); -# # you may not use this file except in compliance with the License. -# # You may obtain a copy of the License at -# # -# # http://www.apache.org/licenses/LICENSE-2.0 -# # -# # Unless required by applicable law or agreed to in writing, software -# # distributed under the License is distributed on an "AS IS" BASIS, -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# # See the License for the specific language governing permissions and -# # limitations under the License. -# # ============================================================================= -# -# import os -# import shutil -# import tarfile -# import unittest -# from typing import Dict -# -# import numpy as np -# import torch -# from ruamel.yaml import YAML -# -# import nemo -# import nemo.collections.asr as nemo_asr -# from nemo.core import WeightShareTransform -# from nemo.core.neural_types import * -# from tests.common_setup import NeMoUnitTest -# -# logging = nemo.logging -# -# -# class TestWeightSharing(NeMoUnitTest): -# labels = [ -# "'", -# "a", -# "b", -# "c", -# "d", -# "e", -# "f", -# "g", -# "h", -# "i", -# "j", -# "k", -# "l", -# "m", -# "n", -# "o", -# "p", -# "q", -# "r", -# "s", -# "t", -# "u", -# "v", -# "w", -# "x", -# "y", -# "z", -# " ", -# ] -# manifest_filepath = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/asr/an4_train.json")) -# featurizer_config = { -# 'window': 'hann', -# 'dither': 1e-05, -# 'normalize': 'per_feature', -# 'frame_splicing': 1, -# 'int_values': False, -# 'window_stride': 0.01, -# 'sample_rate': 16000, -# 'features': 64, -# 'n_fft': 512, -# 'window_size': 0.02, -# } -# yaml = YAML(typ="safe") -# -# @classmethod -# def setUpClass(cls) -> None: -# super().setUpClass() -# data_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/")) -# logging.info("Looking up for test ASR data") -# if not os.path.exists(os.path.join(data_folder, "asr")): -# logging.info("Extracting ASR data to: {0}".format(os.path.join(data_folder, "asr"))) -# tar = tarfile.open(os.path.join(data_folder, "asr.tar.gz"), "r:gz") -# tar.extractall(path=data_folder) -# tar.close() -# else: -# logging.info("ASR data found in: {0}".format(os.path.join(data_folder, "asr"))) -# -# @classmethod -# def tearDownClass(cls) -> None: -# super().tearDownClass() -# data_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/")) -# logging.info("Looking up for test ASR data") -# if os.path.exists(os.path.join(data_folder, "asr")): -# shutil.rmtree(os.path.join(data_folder, "asr")) -# -# def __check_if_weights_are_equal(self, w1: Dict, w2: Dict): -# all_same = set(w1.keys()) == set(w2.keys()) -# if not all_same: -# return False -# else: -# for key in w1.keys(): -# all_same = all_same and np.array_equal( -# w1[key][0].cpu().detach().numpy(), w2[key][0].cpu().detach().numpy(), -# ) -# return all_same -# -# def test_TaylorNet_get_weights(self): -# tn1 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) -# tn2 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) -# # because of randomness, actual weights should be different -# self.assertFalse(self.__check_if_weights_are_equal(tn1.get_weights(), tn2.get_weights())) -# tn3 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) -# tn3.set_weights(tn1.get_weights()) -# # check than weights are the same -# self.assertTrue(self.__check_if_weights_are_equal(tn1.get_weights(), tn3.get_weights())) -# # change weights on one module - another module should not change -# tn1.fc1.bias.data = torch.tensor([0.1]) -# self.assertFalse(self.__check_if_weights_are_equal(tn1.get_weights(), tn3.get_weights())) -# -# def test_TaylorNet_tie_weights(self): -# tn1 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) -# tn2 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) -# # because of randomness, actual weights should be different -# self.assertFalse(self.__check_if_weights_are_equal(tn1.get_weights(), tn2.get_weights())) -# tn2.tie_weights_with(tn1, list(tn1.get_weights().keys())) -# # change weights on one module - another module should change too -# tn1.fc1.bias.data = torch.tensor([0.1]) -# self.assertTrue(self.__check_if_weights_are_equal(tn1.get_weights(), tn2.get_weights())) -# -# def test_tie_weights2(self): -# voc_size = 3 -# dim = 2 -# embd = nemo.backends.pytorch.common.SequenceEmbedding(voc_size=voc_size, hidden_size=dim) -# proj = nemo.backends.pytorch.common.SequenceProjection(from_dim=dim, to_dim=voc_size) -# embd.tie_weights_with( -# proj, -# weight_names=["embedding.weight"], -# name2name_and_transform={"embedding.weight": ("projection.weight", WeightShareTransform.SAME,)}, -# ) -# self.assertTrue( -# np.array_equal(embd.embedding.weight.detach().numpy(), proj.projection.weight.detach().numpy(),) -# ) -# was = embd.embedding.weight.detach().numpy() -# embd.embedding.weight.data = torch.tensor(np.random.randint(0, 10, (3, 2)) * 1.0) -# after = embd.embedding.weight.detach().numpy() -# self.assertTrue( -# np.array_equal(embd.embedding.weight.detach().numpy(), proj.projection.weight.detach().numpy(),) -# ) -# self.assertFalse(np.array_equal(was, after)) -# -# def test_set_weights(self): -# voc_size = 3 -# dim = 2 -# embd = nemo.backends.pytorch.common.SequenceEmbedding(voc_size=voc_size, hidden_size=dim) -# weights = torch.tensor(np.random.randint(0, 10, (3, 2)) * 1.0) -# name2weights = {"embedding.weight": (weights, True)} -# embd.set_weights(name2weight=name2weights) -# self.assertTrue(np.array_equal(embd.embedding.weight.detach().numpy(), weights.detach().numpy(),)) -# weights = torch.tensor(np.random.randint(0, 10, (3, 2)) * 1.0) -# self.assertFalse(np.array_equal(embd.embedding.weight.detach().numpy(), weights.detach().numpy(),)) -# -# def test_freeze_unfreeze_TrainableNM(self): -# path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/jasper_smaller.yaml")) -# with open(path) as file: -# jasper_model_definition = self.yaml.load(file) -# dl = nemo_asr.AudioToTextDataLayer( -# # featurizer_config=self.featurizer_config, -# manifest_filepath=self.manifest_filepath, -# labels=self.labels, -# batch_size=4, -# ) -# pre_process_params = { -# #'int_values': False, -# 'frame_splicing': 1, -# 'features': 64, -# 'window_size': 0.02, -# 'n_fft': 512, -# 'dither': 1e-05, -# 'window': 'hann', -# 'sample_rate': 16000, -# 'normalize': 'per_feature', -# 'window_stride': 0.01, -# } -# preprocessing = nemo_asr.AudioToMelSpectrogramPreprocessor(**pre_process_params) -# jasper_encoder = nemo_asr.JasperEncoder( -# feat_in=jasper_model_definition['AudioToMelSpectrogramPreprocessor']['features'], -# **jasper_model_definition['JasperEncoder'], -# ) -# jasper_decoder = nemo_asr.JasperDecoderForCTC(feat_in=1024, num_classes=len(self.labels)) -# ctc_loss = nemo_asr.CTCLossNM(num_classes=len(self.labels)) -# jasper_encoder.freeze() -# jasper_encoder.unfreeze(set(['encoder.4.conv.1.weight'])) -# jasper_decoder.unfreeze() -# # DAG -# audio_signal, a_sig_length, transcript, transcript_len = dl() -# processed_signal, p_length = preprocessing(input_signal=audio_signal, length=a_sig_length) -# -# encoded, encoded_len = jasper_encoder(audio_signal=processed_signal, length=p_length) -# # logging.info(jasper_encoder) -# log_probs = jasper_decoder(encoder_output=encoded) -# loss = ctc_loss( -# log_probs=log_probs, targets=transcript, input_length=encoded_len, target_length=transcript_len, -# ) -# -# callback = nemo.core.SimpleLossLoggerCallback( -# tensors=[loss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), -# ) -# optimizer = self.nf.get_trainer() -# optimizer.train( -# [loss], callbacks=[callback], optimizer="sgd", optimization_params={"num_epochs": 2, "lr": 0.0003}, -# ) +# ! /usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright 2019 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import os +import shutil +import tarfile +from typing import Dict + +import numpy as np +import torch +from ruamel.yaml import YAML + +import nemo +import nemo.collections.asr as nemo_asr +from nemo.core import WeightShareTransform +from nemo.core.neural_types import * +from tests.common_setup import NeMoUnitTest + +logging = nemo.logging + + +class TestWeightSharing(NeMoUnitTest): + labels = [ + "'", + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z", + " ", + ] + manifest_filepath = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/asr/an4_train.json")) + featurizer_config = { + 'window': 'hann', + 'dither': 1e-05, + 'normalize': 'per_feature', + 'frame_splicing': 1, + 'int_values': False, + 'window_stride': 0.01, + 'sample_rate': 16000, + 'features': 64, + 'n_fft': 512, + 'window_size': 0.02, + } + yaml = YAML(typ="safe") + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + data_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/")) + logging.info("Looking up for test ASR data") + if not os.path.exists(os.path.join(data_folder, "asr")): + logging.info("Extracting ASR data to: {0}".format(os.path.join(data_folder, "asr"))) + tar = tarfile.open(os.path.join(data_folder, "asr.tar.gz"), "r:gz") + tar.extractall(path=data_folder) + tar.close() + else: + logging.info("ASR data found in: {0}".format(os.path.join(data_folder, "asr"))) + + # @classmethod + # def tearDownClass(cls) -> None: + # super().tearDownClass() + # data_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/")) + # logging.info("Looking up for test ASR data") + # if os.path.exists(os.path.join(data_folder, "asr")): + # logging.info("Removing test ASR data") + # shutil.rmtree(os.path.join(data_folder, "asr")) + + def __check_if_weights_are_equal(self, w1: Dict, w2: Dict): + all_same = set(w1.keys()) == set(w2.keys()) + if not all_same: + return False + else: + for key in w1.keys(): + all_same = all_same and np.array_equal( + w1[key][0].cpu().detach().numpy(), w2[key][0].cpu().detach().numpy(), + ) + return all_same + + def test_TaylorNet_get_weights(self): + tn1 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) + tn2 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) + # because of randomness, actual weights should be different + self.assertFalse(self.__check_if_weights_are_equal(tn1.get_weights(), tn2.get_weights())) + tn3 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) + tn3.set_weights(tn1.get_weights()) + # check than weights are the same + self.assertTrue(self.__check_if_weights_are_equal(tn1.get_weights(), tn3.get_weights())) + # change weights on one module - another module should not change + tn1.fc1.bias.data = torch.tensor([0.1]) + self.assertFalse(self.__check_if_weights_are_equal(tn1.get_weights(), tn3.get_weights())) + + # def test_TaylorNet_tie_weights(self): + # tn1 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) + # tn2 = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) + # # because of randomness, actual weights should be different + # self.assertFalse(self.__check_if_weights_are_equal(tn1.get_weights(), tn2.get_weights())) + # tn2.tie_weights_with(tn1, list(tn1.get_weights().keys())) + # # change weights on one module - another module should change too + # tn2.fc1.bias.data = torch.tensor([0.1]) + # self.assertTrue(self.__check_if_weights_are_equal(tn1.get_weights(), tn2.get_weights())) + + # def test_tie_weights2(self): + # voc_size = 3 + # dim = 2 + # embd = nemo.backends.pytorch.common.other.SequenceEmbedding(voc_size=voc_size, hidden_size=dim) + # proj = nemo.backends.pytorch.common.other.SequenceProjection(from_dim=dim, to_dim=voc_size) + # embd.tie_weights_with( + # proj, + # weight_names=["embedding.weight"], + # name2name_and_transform={"embedding.weight": ("projection.weight", WeightShareTransform.SAME,)}, + # ) + # self.assertTrue( + # np.array_equal(embd.embedding.weight.detach().numpy(), proj.projection.weight.detach().numpy(),) + # ) + # was = embd.embedding.weight.detach().numpy() + # embd.embedding.weight.data = torch.tensor(np.random.randint(0, 10, (3, 2)) * 1.0) + # after = embd.embedding.weight.detach().numpy() + # self.assertTrue( + # np.array_equal(embd.embedding.weight.detach().numpy(), proj.projection.weight.detach().numpy(),) + # ) + # self.assertFalse(np.array_equal(was, after)) + + def test_set_weights(self): + voc_size = 3 + dim = 2 + embd = nemo.backends.pytorch.common.SequenceEmbedding(voc_size=voc_size, hidden_size=dim) + weights = torch.tensor(np.random.randint(0, 10, (3, 2)) * 1.0) + name2weights = {"embedding.weight": (weights, True)} + embd.set_weights(name2weight=name2weights) + self.assertTrue(np.array_equal(embd.embedding.weight.detach().numpy(), weights.detach().numpy(),)) + weights = torch.tensor(np.random.randint(0, 10, (3, 2)) * 1.0) + self.assertFalse(np.array_equal(embd.embedding.weight.detach().numpy(), weights.detach().numpy(),)) + + def test_freeze_unfreeze_TrainableNM(self): + path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/jasper_smaller.yaml")) + with open(path) as file: + jasper_model_definition = self.yaml.load(file) + dl = nemo_asr.AudioToTextDataLayer( + # featurizer_config=self.featurizer_config, + manifest_filepath=self.manifest_filepath, + labels=self.labels, + batch_size=4, + ) + pre_process_params = { + #'int_values': False, + 'frame_splicing': 1, + 'features': 64, + 'window_size': 0.02, + 'n_fft': 512, + 'dither': 1e-05, + 'window': 'hann', + 'sample_rate': 16000, + 'normalize': 'per_feature', + 'window_stride': 0.01, + } + preprocessing = nemo_asr.AudioToMelSpectrogramPreprocessor(**pre_process_params) + jasper_encoder = nemo_asr.JasperEncoder( + feat_in=jasper_model_definition['AudioToMelSpectrogramPreprocessor']['features'], + **jasper_model_definition['JasperEncoder'], + ) + jasper_decoder = nemo_asr.JasperDecoderForCTC(feat_in=1024, num_classes=len(self.labels)) + ctc_loss = nemo_asr.CTCLossNM(num_classes=len(self.labels)) + jasper_encoder.freeze() + jasper_encoder.unfreeze(set(['encoder.4.mconv.0.conv.weight'])) + frozen_weight = jasper_encoder.encoder[1].mconv[0].conv.weight.detach().cpu().numpy() + unfrozen_weight = jasper_encoder.encoder[4].mconv[0].conv.weight.detach().cpu().numpy() + # jasper_decoder.unfreeze() + # DAG + audio_signal, a_sig_length, transcript, transcript_len = dl() + processed_signal, p_length = preprocessing(input_signal=audio_signal, length=a_sig_length) + + encoded, encoded_len = jasper_encoder(audio_signal=processed_signal, length=p_length) + # logging.info(jasper_encoder) + log_probs = jasper_decoder(encoder_output=encoded) + loss = ctc_loss( + log_probs=log_probs, targets=transcript, input_length=encoded_len, target_length=transcript_len, + ) + + callback = nemo.core.SimpleLossLoggerCallback( + tensors=[loss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), + ) + optimizer = self.nf.get_trainer() + optimizer.train( + [loss], callbacks=[callback], optimizer="sgd", optimization_params={"max_steps": 5, "lr": 0.0003}, + ) + new_frozen_weight = jasper_encoder.encoder[1].mconv[0].conv.weight.data + new_unfrozen_weight = jasper_encoder.encoder[4].mconv[0].conv.weight.data + self.assertTrue(np.array_equal(frozen_weight, new_frozen_weight.detach().cpu().numpy())) + self.assertFalse(np.array_equal(unfrozen_weight, new_unfrozen_weight.detach().cpu().numpy())) From a71da025ce9a33bfb2c07d5e6ae99b9a1dc1fdee Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 13 Feb 2020 14:39:13 -0800 Subject: [PATCH 12/15] make VoidType always return SAME; update tie weights test Signed-off-by: Jason --- nemo/backends/pytorch/common/other.py | 5 +- nemo/core/neural_types/neural_type.py | 12 +++- tests/core/test_weight_share.py | 86 ++++++++++++++++++++------- 3 files changed, 79 insertions(+), 24 deletions(-) diff --git a/nemo/backends/pytorch/common/other.py b/nemo/backends/pytorch/common/other.py index 42fe4ec5e0d0..e86957d22f7d 100644 --- a/nemo/backends/pytorch/common/other.py +++ b/nemo/backends/pytorch/common/other.py @@ -11,7 +11,6 @@ import torch.nn as nn from nemo.backends.pytorch.nm import NonTrainableNM, TrainableNM -from nemo.core import NeuralModule from nemo.core.neural_types import * @@ -21,14 +20,14 @@ def input_ports(self): """Returns definitions of module input ports. """ # return {"input_seq": NeuralType({0: AxisType(TimeTag), 1: AxisType(BatchTag)})} - return {"input_seq": NeuralModule(ChannelType(), ('T', 'B'))} + return {"input_seq": NeuralType(('B', 'T'))} @property def output_ports(self): """Returns definitions of module output ports. """ # return {"outputs": NeuralType({0: AxisType(TimeTag), 1: AxisType(BatchTag), 2: AxisType(ChannelTag),})} - return {"outputs": NeuralType(('T', 'B', 'D'), ChannelType())} + return {"outputs": NeuralType(('B', 'T', 'C'))} def __init__(self, voc_size, hidden_size, dropout=0.0): super().__init__() diff --git a/nemo/core/neural_types/neural_type.py b/nemo/core/neural_types/neural_type.py index a2070c354b3c..a7ad02577adc 100644 --- a/nemo/core/neural_types/neural_type.py +++ b/nemo/core/neural_types/neural_type.py @@ -21,7 +21,6 @@ 'NeuralTypeError', 'NeuralPortNameMismatchError', 'NeuralPortNmTensorMismatchError', - 'NeuralPortNmTensorMismatchError', 'CanNotInferResultNeuralType', ] import uuid @@ -46,6 +45,15 @@ class NeuralType(object): type can be optional. """ + def __str__(self): + return ( + f"axes: {[(c.kind, c.size, c.is_list) for c in self.axes]}\n" + f"elements_type: {self.elements_type.__class__.__name__}" + ) + # return f"axes: {self.axes}" # " elements_type: {self.elements_type}" + # return f" elements_type: {self.elements_type.__class__.__name__}" + # return "help" + def __init__(self, axes: Optional[Tuple] = None, elements_type: ElementType = VoidType(), optional=False): if not isinstance(elements_type, ElementType): raise ValueError( @@ -87,6 +95,8 @@ def compare(self, second) -> NeuralTypeComparisonResult: dimensions_pass = NeuralType.__compare_axes(axes_a, axes_b) element_comparison_result = self.elements_type.compare(second.elements_type) + if isinstance(second.elements_type, VoidType): + element_comparison_result = NeuralTypeComparisonResult.SAME # SAME DIMS if dimensions_pass == 0: diff --git a/tests/core/test_weight_share.py b/tests/core/test_weight_share.py index f4468bf596fb..3fd9fcd1b9d1 100644 --- a/tests/core/test_weight_share.py +++ b/tests/core/test_weight_share.py @@ -27,9 +27,12 @@ import nemo import nemo.collections.asr as nemo_asr +from nemo.collections.nlp.nm.trainables.common import TokenClassifier +from nemo.collections.nlp.nm.losses import PaddedSmoothedCrossEntropyLossNM from nemo.core import WeightShareTransform from nemo.core.neural_types import * from tests.common_setup import NeMoUnitTest +from nemo.backends.pytorch.nm import DataLayerNM logging = nemo.logging @@ -136,26 +139,69 @@ def test_TaylorNet_get_weights(self): # tn2.fc1.bias.data = torch.tensor([0.1]) # self.assertTrue(self.__check_if_weights_are_equal(tn1.get_weights(), tn2.get_weights())) - # def test_tie_weights2(self): - # voc_size = 3 - # dim = 2 - # embd = nemo.backends.pytorch.common.other.SequenceEmbedding(voc_size=voc_size, hidden_size=dim) - # proj = nemo.backends.pytorch.common.other.SequenceProjection(from_dim=dim, to_dim=voc_size) - # embd.tie_weights_with( - # proj, - # weight_names=["embedding.weight"], - # name2name_and_transform={"embedding.weight": ("projection.weight", WeightShareTransform.SAME,)}, - # ) - # self.assertTrue( - # np.array_equal(embd.embedding.weight.detach().numpy(), proj.projection.weight.detach().numpy(),) - # ) - # was = embd.embedding.weight.detach().numpy() - # embd.embedding.weight.data = torch.tensor(np.random.randint(0, 10, (3, 2)) * 1.0) - # after = embd.embedding.weight.detach().numpy() - # self.assertTrue( - # np.array_equal(embd.embedding.weight.detach().numpy(), proj.projection.weight.detach().numpy(),) - # ) - # self.assertFalse(np.array_equal(was, after)) + def test_tie_weights(self): + class DummyDataLayer(DataLayerNM): + def __init__(self, vocab_size): + super().__init__() + self.vocab_size = vocab_size + + class DummyDS(torch.utils.data.Dataset): + def __init__(self, vocab_size): + super().__init__() + + def __getitem__(self, index): + model_inputs = torch.randint(high=vocab_size, size=[10]) + model_outputs = torch.randint(high=vocab_size, size=[10]) + return (model_inputs, model_outputs) + + def __len__(self): + return 10 + + self._dataset = DummyDS(vocab_size) + + @property + def output_ports(self): + return { + "model_inputs": NeuralType(('B', 'T')), + "model_outputs": NeuralType(('B', 'T')), + } + + def __len__(self): + return len(self._dataset) + + @property + def dataset(self): + return self._dataset + + def data_iterator(self): + pass + + voc_size = 10 + dim = 10 + embd = nemo.backends.pytorch.common.other.SequenceEmbedding(voc_size=voc_size, hidden_size=dim) + proj = TokenClassifier(hidden_size=dim, num_classes=voc_size) + data = DummyDataLayer(voc_size) + loss = PaddedSmoothedCrossEntropyLossNM(0) + embd.tie_weights_with( + proj, + weight_names=["embedding.weight"], + name2name_and_transform={"embedding.weight": ("mlp.layer2.weight", WeightShareTransform.SAME)}, + ) + self.assertTrue( + np.array_equal(embd.embedding.weight.detach().cpu().numpy(), proj.mlp.layer2.weight.detach().cpu().numpy()) + ) + _in, _out = data() + pred = embd(input_seq=_in) + pred = proj(hidden_states=pred) + loss_t = loss(target_ids=_in, logits=pred) + + self.nf.train( + [loss_t], optimizer="sgd", optimization_params={"max_steps": 5, "lr": 0.0003}, + ) + + self.assertTrue( + np.array_equal(embd.embedding.weight.detach().cpu().numpy(), proj.mlp.layer2.weight.detach().cpu().numpy()) + ) def test_set_weights(self): voc_size = 3 From 74ad50c585bc962356168c5401c9c14888f36ad3 Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 13 Feb 2020 14:52:13 -0800 Subject: [PATCH 13/15] fix nlp examples; do right typing Signed-off-by: Jason --- .../nlp/language_modeling/bert_pretraining.py | 14 ++++--- .../language_modeling_transformer.py | 14 ++++--- .../machine_translation_tutorial.py | 3 -- nemo/backends/pytorch/common/other.py | 38 +------------------ nemo/core/neural_types/neural_type.py | 5 --- tests/core/test_weight_share.py | 4 +- 6 files changed, 20 insertions(+), 58 deletions(-) diff --git a/examples/nlp/language_modeling/bert_pretraining.py b/examples/nlp/language_modeling/bert_pretraining.py index d836080d841f..857556b89b07 100644 --- a/examples/nlp/language_modeling/bert_pretraining.py +++ b/examples/nlp/language_modeling/bert_pretraining.py @@ -13,11 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= - - -# Todo: fix weight tying - - """ To pretrain BERT on raw text dataset run @@ -228,7 +223,14 @@ # tie weights of MLM softmax layer and embedding layer of the encoder if mlm_classifier.mlp.last_linear_layer.weight.shape != bert_model.bert.embeddings.word_embeddings.weight.shape: raise ValueError("Final classification layer does not match embedding " "layer.") -mlm_classifier.mlp.last_linear_layer.weight = bert_model.bert.embeddings.word_embeddings.weight +# mlm_classifier.mlp.last_linear_layer.weight = bert_model.bert.embeddings.word_embeddings.weight +mlm_classifier.tie_weights_with( + bert_model, + weight_names=["mlp.last_linear_layer.weight"], + name2name_and_transform={ + "mlp.last_linear_layer.weight": ("bert.embeddings.word_embeddings.weight", nemo_core.WeightShareTransform.SAME) + }, +) def create_pipeline(data_file, batch_size, preprocessed_data=False, batches_per_step=1, **kwargs): diff --git a/examples/nlp/language_modeling/language_modeling_transformer.py b/examples/nlp/language_modeling/language_modeling_transformer.py index 285487f0dfdb..185fac3cb8b5 100644 --- a/examples/nlp/language_modeling/language_modeling_transformer.py +++ b/examples/nlp/language_modeling/language_modeling_transformer.py @@ -13,13 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= - -# Todo: fix weight tying - - import math import nemo +from nemo.core import WeightShareTransform import nemo.collections.nlp as nemo_nlp import nemo.collections.nlp.nm.data_layers.lm_transformer_datalayer import nemo.collections.nlp.nm.trainables.common.token_classification_nm @@ -117,7 +114,14 @@ ) # tie weight of embedding and log_softmax layers -log_softmax.mlp.last_linear_layer.weight = encoder.embedding_layer.token_embedding.weight +# log_softmax.mlp.last_linear_layer.weight = encoder.embedding_layer.token_embedding.weight +log_softmax.tie_weights_with( + encoder, + weight_names=["mlp.layer0.weight"], + name2name_and_transform={ + "mlp.layer0.weight": ("embedding_layer.token_embedding.weight", WeightShareTransform.SAME) + }, +) def create_pipeline( diff --git a/examples/nlp/neural_machine_translation/machine_translation_tutorial.py b/examples/nlp/neural_machine_translation/machine_translation_tutorial.py index 997aa79ccd34..ae05afa88e32 100644 --- a/examples/nlp/neural_machine_translation/machine_translation_tutorial.py +++ b/examples/nlp/neural_machine_translation/machine_translation_tutorial.py @@ -19,9 +19,6 @@ https://nvidia.github.io/NeMo/nlp/ neural-machine-translation.html#translation-with-pretrained-model """ - -# Todo: fix weight tying - import torch import nemo diff --git a/nemo/backends/pytorch/common/other.py b/nemo/backends/pytorch/common/other.py index e86957d22f7d..c15f27b79c92 100644 --- a/nemo/backends/pytorch/common/other.py +++ b/nemo/backends/pytorch/common/other.py @@ -27,7 +27,7 @@ def output_ports(self): """Returns definitions of module output ports. """ # return {"outputs": NeuralType({0: AxisType(TimeTag), 1: AxisType(BatchTag), 2: AxisType(ChannelTag),})} - return {"outputs": NeuralType(('B', 'T', 'C'))} + return {"outputs": NeuralType(('B', 'T', 'C'), ChannelType())} def __init__(self, voc_size, hidden_size, dropout=0.0): super().__init__() @@ -46,42 +46,6 @@ def forward(self, input_seq): return embedded -class SequenceProjection(TrainableNM): - @property - def input_ports(self): - """Returns definitions of module input ports. - - input_seq: - Empty Type?!? - """ - return {"input_seq": NeuralType({})} - - @property - def output_ports(self): - """Returns definitions of module output ports. - - outputs: - None - """ - return {"outputs": None} - - def __init__(self, from_dim, to_dim, dropout=0.0): - super().__init__() - - self.from_dim = from_dim - self.to_dim = to_dim - self.dropout = dropout - self.projection = nn.Linear(self.from_dim, self.to_dim, bias=False) - if self.dropout != 0.0: - self.embedding_dropout = nn.Dropout(self.dropout) - - def forward(self, input_seq): - p = self.projection(input_seq) - if self.dropout != 0.0: - p = self.dropout(p) - return p - - class ZerosLikeNM(NonTrainableNM): @property def input_ports(self): diff --git a/nemo/core/neural_types/neural_type.py b/nemo/core/neural_types/neural_type.py index a7ad02577adc..80bda4aa01d9 100644 --- a/nemo/core/neural_types/neural_type.py +++ b/nemo/core/neural_types/neural_type.py @@ -50,9 +50,6 @@ def __str__(self): f"axes: {[(c.kind, c.size, c.is_list) for c in self.axes]}\n" f"elements_type: {self.elements_type.__class__.__name__}" ) - # return f"axes: {self.axes}" # " elements_type: {self.elements_type}" - # return f" elements_type: {self.elements_type.__class__.__name__}" - # return "help" def __init__(self, axes: Optional[Tuple] = None, elements_type: ElementType = VoidType(), optional=False): if not isinstance(elements_type, ElementType): @@ -95,8 +92,6 @@ def compare(self, second) -> NeuralTypeComparisonResult: dimensions_pass = NeuralType.__compare_axes(axes_a, axes_b) element_comparison_result = self.elements_type.compare(second.elements_type) - if isinstance(second.elements_type, VoidType): - element_comparison_result = NeuralTypeComparisonResult.SAME # SAME DIMS if dimensions_pass == 0: diff --git a/tests/core/test_weight_share.py b/tests/core/test_weight_share.py index 3fd9fcd1b9d1..6ade311ffb77 100644 --- a/tests/core/test_weight_share.py +++ b/tests/core/test_weight_share.py @@ -163,7 +163,7 @@ def __len__(self): def output_ports(self): return { "model_inputs": NeuralType(('B', 'T')), - "model_outputs": NeuralType(('B', 'T')), + "model_outputs": NeuralType(('B', 'T'), LabelsType()), } def __len__(self): @@ -193,7 +193,7 @@ def data_iterator(self): _in, _out = data() pred = embd(input_seq=_in) pred = proj(hidden_states=pred) - loss_t = loss(target_ids=_in, logits=pred) + loss_t = loss(target_ids=_out, logits=pred) self.nf.train( [loss_t], optimizer="sgd", optimization_params={"max_steps": 5, "lr": 0.0003}, From 390c4101a2dec8280756e3c32a0921799e6ab6d1 Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 13 Feb 2020 14:52:32 -0800 Subject: [PATCH 14/15] isort Signed-off-by: Jason --- .../nlp/language_modeling/language_modeling_transformer.py | 2 +- tests/core/test_weight_share.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/nlp/language_modeling/language_modeling_transformer.py b/examples/nlp/language_modeling/language_modeling_transformer.py index 185fac3cb8b5..2572b90af785 100644 --- a/examples/nlp/language_modeling/language_modeling_transformer.py +++ b/examples/nlp/language_modeling/language_modeling_transformer.py @@ -16,12 +16,12 @@ import math import nemo -from nemo.core import WeightShareTransform import nemo.collections.nlp as nemo_nlp import nemo.collections.nlp.nm.data_layers.lm_transformer_datalayer import nemo.collections.nlp.nm.trainables.common.token_classification_nm from nemo.collections.nlp.callbacks.lm_transformer_callback import eval_epochs_done_callback, eval_iter_callback from nemo.collections.nlp.data.datasets.lm_transformer_dataset import LanguageModelDataDesc +from nemo.core import WeightShareTransform from nemo.utils.lr_policies import CosineAnnealing parser = nemo.utils.NemoArgParser(description='LM Transformer') diff --git a/tests/core/test_weight_share.py b/tests/core/test_weight_share.py index 6ade311ffb77..51ef90683735 100644 --- a/tests/core/test_weight_share.py +++ b/tests/core/test_weight_share.py @@ -27,12 +27,12 @@ import nemo import nemo.collections.asr as nemo_asr -from nemo.collections.nlp.nm.trainables.common import TokenClassifier +from nemo.backends.pytorch.nm import DataLayerNM from nemo.collections.nlp.nm.losses import PaddedSmoothedCrossEntropyLossNM +from nemo.collections.nlp.nm.trainables.common import TokenClassifier from nemo.core import WeightShareTransform from nemo.core.neural_types import * from tests.common_setup import NeMoUnitTest -from nemo.backends.pytorch.nm import DataLayerNM logging = nemo.logging From 8c26247f2d68a2861ff3a9edf7848974a443d233 Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 13 Feb 2020 15:47:23 -0800 Subject: [PATCH 15/15] small update Signed-off-by: Jason --- .../asr_postprocessor/asr_postprocessor.py | 4 +- nemo/backends/pytorch/common/other.py | 1 + nemo/backends/pytorch/module_wrapper.py | 2 +- tests/core/test_weight_share.py | 68 ++++++++++++++++++- 4 files changed, 70 insertions(+), 5 deletions(-) diff --git a/examples/nlp/asr_postprocessor/asr_postprocessor.py b/examples/nlp/asr_postprocessor/asr_postprocessor.py index f91638b02d74..187529ddd2e4 100644 --- a/examples/nlp/asr_postprocessor/asr_postprocessor.py +++ b/examples/nlp/asr_postprocessor/asr_postprocessor.py @@ -128,6 +128,8 @@ # tie all embeddings weights # t_log_softmax.mlp.layer0.weight = encoder.bert.embeddings.word_embeddings.weight +# decoder.embedding_layer.token_embedding.weight = encoder.bert.embeddings.word_embeddings.weight +# decoder.embedding_layer.position_embedding.weight = encoder.bert.embeddings.position_embeddings.weight t_log_softmax.tie_weights_with( encoder, weight_names=["mlp.layer0.weight"], @@ -135,7 +137,6 @@ "mlp.layer0.weight": ("bert.embeddings.word_embeddings.weight", WeightShareTransform.SAME) }, ) -# decoder.embedding_layer.token_embedding.weight = encoder.bert.embeddings.word_embeddings.weight decoder.tie_weights_with( encoder, weight_names=["embedding_layer.token_embedding.weight"], @@ -143,7 +144,6 @@ "embedding_layer.token_embedding.weight": ("bert.embeddings.word_embeddings.weight", WeightShareTransform.SAME) }, ) -# decoder.embedding_layer.position_embedding.weight = encoder.bert.embeddings.position_embeddings.weight decoder.tie_weights_with( encoder, weight_names=["embedding_layer.position_embedding.weight"], diff --git a/nemo/backends/pytorch/common/other.py b/nemo/backends/pytorch/common/other.py index c15f27b79c92..86cff8ff6221 100644 --- a/nemo/backends/pytorch/common/other.py +++ b/nemo/backends/pytorch/common/other.py @@ -38,6 +38,7 @@ def __init__(self, voc_size, hidden_size, dropout=0.0): self.embedding = nn.Embedding(self.voc_size, self.hidden_size) if self.dropout != 0.0: self.embedding_dropout = nn.Dropout(self.dropout) + self.to(self._device) def forward(self, input_seq): embedded = self.embedding(input_seq) diff --git a/nemo/backends/pytorch/module_wrapper.py b/nemo/backends/pytorch/module_wrapper.py index f439a847411d..1a5de7595d6a 100644 --- a/nemo/backends/pytorch/module_wrapper.py +++ b/nemo/backends/pytorch/module_wrapper.py @@ -86,7 +86,7 @@ def set_weights(self, name2weight, name2name_and_transform=None): def tie_weights_with(self, module, weight_names): for name in weight_names: - rsetattr(self._pt_module, name, rgetattr(module, name)) + rsetattr(self._pt_module, name, nn.Parameter(rgetattr(module, name))) @property def num_weights(self): diff --git a/tests/core/test_weight_share.py b/tests/core/test_weight_share.py index 51ef90683735..6317052ae77d 100644 --- a/tests/core/test_weight_share.py +++ b/tests/core/test_weight_share.py @@ -203,6 +203,70 @@ def data_iterator(self): np.array_equal(embd.embedding.weight.detach().cpu().numpy(), proj.mlp.layer2.weight.detach().cpu().numpy()) ) + def test_untied_weights(self): + class DummyDataLayer(DataLayerNM): + def __init__(self, vocab_size): + super().__init__() + self.vocab_size = vocab_size + + class DummyDS(torch.utils.data.Dataset): + def __init__(self, vocab_size): + super().__init__() + + def __getitem__(self, index): + model_inputs = torch.randint(high=vocab_size, size=[10]) + model_outputs = torch.randint(high=vocab_size, size=[10]) + return (model_inputs, model_outputs) + + def __len__(self): + return 10 + + self._dataset = DummyDS(vocab_size) + + @property + def output_ports(self): + return { + "model_inputs": NeuralType(('B', 'T')), + "model_outputs": NeuralType(('B', 'T'), LabelsType()), + } + + def __len__(self): + return len(self._dataset) + + @property + def dataset(self): + return self._dataset + + def data_iterator(self): + pass + + voc_size = 10 + dim = 10 + embd = nemo.backends.pytorch.common.other.SequenceEmbedding(voc_size=voc_size, hidden_size=dim) + proj = TokenClassifier(hidden_size=dim, num_classes=voc_size) + data = DummyDataLayer(voc_size) + loss = PaddedSmoothedCrossEntropyLossNM(0) + # embd.tie_weights_with( + # proj, + # weight_names=["embedding.weight"], + # name2name_and_transform={"embedding.weight": ("mlp.layer2.weight", WeightShareTransform.SAME)}, + # ) + self.assertFalse( + np.array_equal(embd.embedding.weight.detach().cpu().numpy(), proj.mlp.layer2.weight.detach().cpu().numpy()) + ) + _in, _out = data() + pred = embd(input_seq=_in) + pred = proj(hidden_states=pred) + loss_t = loss(target_ids=_out, logits=pred) + + self.nf.train( + [loss_t], optimizer="sgd", optimization_params={"max_steps": 5, "lr": 0.0003}, + ) + + self.assertFalse( + np.array_equal(embd.embedding.weight.detach().cpu().numpy(), proj.mlp.layer2.weight.detach().cpu().numpy()) + ) + def test_set_weights(self): voc_size = 3 dim = 2 @@ -210,9 +274,9 @@ def test_set_weights(self): weights = torch.tensor(np.random.randint(0, 10, (3, 2)) * 1.0) name2weights = {"embedding.weight": (weights, True)} embd.set_weights(name2weight=name2weights) - self.assertTrue(np.array_equal(embd.embedding.weight.detach().numpy(), weights.detach().numpy(),)) + self.assertTrue(np.array_equal(embd.embedding.weight.detach().cpu().numpy(), weights.detach().cpu().numpy())) weights = torch.tensor(np.random.randint(0, 10, (3, 2)) * 1.0) - self.assertFalse(np.array_equal(embd.embedding.weight.detach().numpy(), weights.detach().numpy(),)) + self.assertFalse(np.array_equal(embd.embedding.weight.detach().cpu().numpy(), weights.detach().cpu().numpy())) def test_freeze_unfreeze_TrainableNM(self): path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/jasper_smaller.yaml"))