Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch Apex with Pytorch #336

Merged
merged 19 commits into from
Feb 14, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -94,6 +94,8 @@ To release a new version, please update the changelog as followed:
- Updated licenses
- Updated nemo's use of the logging library. from nemo import logging is now the reccomended way of using the nemo logger. neural_factory.logger and all other instances of logger are now deprecated and planned for removal in the next version. Please see PR 267 for complete change information.
([PR #267](https://github.com/NVIDIA/NeMo/pull/267), [PR #283](https://github.com/NVIDIA/NeMo/pull/283), [PR #305](https://github.com/NVIDIA/NeMo/pull/305), [PR #311](https://github.com/NVIDIA/NeMo/pull/311)) - @blisc
- Changed Distributed Data Parallel from Apex to Torch
([PR #336](https://github.com/NVIDIA/NeMo/pull/336)) - @blisc

- Added TRADE (dialogue state tracking model) on MultiWOZ dataset
([PR #322](https://github.com/NVIDIA/NeMo/pull/322)) - @chiphuyen, @VahidooX
@@ -108,6 +110,8 @@ To release a new version, please update the changelog as followed:
([PR #308](https://github.com/NVIDIA/NeMo/pull/309)) - @tkornuta-nvidia

### Removed
- gradient_predivide_factor arg of train() now has no effect
([PR #336](https://github.com/NVIDIA/NeMo/pull/336)) - @blisc
- Dropped support of the following ASR configs: jasper10x4.yaml, quartznet10x5.yaml, quartznet15x5_in.yaml, quartznet5x3.yaml, quartznet5x5.yaml, quartznet_an4.yaml. They are moved to experimental/configs and can still be used with v0.9 for use in replicating paper results
([PR #354](https://github.com/NVIDIA/NeMo/pull/354)) - @blisc

31 changes: 28 additions & 3 deletions examples/nlp/asr_postprocessor/asr_postprocessor.py
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
eval_epochs_done_callback_wer,
eval_iter_callback,
)
from nemo.core import WeightShareTransform
from nemo.core.callbacks import CheckpointCallback
from nemo.utils.lr_policies import SquareAnnealing

@@ -126,9 +127,33 @@
)

# tie all embeddings weights
t_log_softmax.mlp.layer0.weight = encoder.bert.embeddings.word_embeddings.weight
decoder.embedding_layer.token_embedding.weight = encoder.bert.embeddings.word_embeddings.weight
decoder.embedding_layer.position_embedding.weight = encoder.bert.embeddings.position_embeddings.weight
# t_log_softmax.mlp.layer0.weight = encoder.bert.embeddings.word_embeddings.weight
# decoder.embedding_layer.token_embedding.weight = encoder.bert.embeddings.word_embeddings.weight
# decoder.embedding_layer.position_embedding.weight = encoder.bert.embeddings.position_embeddings.weight
t_log_softmax.tie_weights_with(
encoder,
weight_names=["mlp.layer0.weight"],
name2name_and_transform={
"mlp.layer0.weight": ("bert.embeddings.word_embeddings.weight", WeightShareTransform.SAME)
},
)
decoder.tie_weights_with(
encoder,
weight_names=["embedding_layer.token_embedding.weight"],
name2name_and_transform={
"embedding_layer.token_embedding.weight": ("bert.embeddings.word_embeddings.weight", WeightShareTransform.SAME)
},
)
decoder.tie_weights_with(
encoder,
weight_names=["embedding_layer.position_embedding.weight"],
name2name_and_transform={
"embedding_layer.position_embedding.weight": (
"bert.embeddings.position_embeddings.weight",
WeightShareTransform.SAME,
)
},
)


def create_pipeline(dataset, tokens_in_batch, clean=False, training=True):
10 changes: 8 additions & 2 deletions examples/nlp/language_modeling/bert_pretraining.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

"""

To pretrain BERT on raw text dataset run
@@ -224,7 +223,14 @@
# tie weights of MLM softmax layer and embedding layer of the encoder
if mlm_classifier.mlp.last_linear_layer.weight.shape != bert_model.bert.embeddings.word_embeddings.weight.shape:
raise ValueError("Final classification layer does not match embedding " "layer.")
mlm_classifier.mlp.last_linear_layer.weight = bert_model.bert.embeddings.word_embeddings.weight
# mlm_classifier.mlp.last_linear_layer.weight = bert_model.bert.embeddings.word_embeddings.weight
mlm_classifier.tie_weights_with(
bert_model,
weight_names=["mlp.last_linear_layer.weight"],
name2name_and_transform={
"mlp.last_linear_layer.weight": ("bert.embeddings.word_embeddings.weight", nemo_core.WeightShareTransform.SAME)
},
)


def create_pipeline(data_file, batch_size, preprocessed_data=False, batches_per_step=1, **kwargs):
11 changes: 9 additions & 2 deletions examples/nlp/language_modeling/language_modeling_transformer.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import math

import nemo
@@ -22,6 +21,7 @@
import nemo.collections.nlp.nm.trainables.common.token_classification_nm
from nemo.collections.nlp.callbacks.lm_transformer_callback import eval_epochs_done_callback, eval_iter_callback
from nemo.collections.nlp.data.datasets.lm_transformer_dataset import LanguageModelDataDesc
from nemo.core import WeightShareTransform
from nemo.utils.lr_policies import CosineAnnealing

parser = nemo.utils.NemoArgParser(description='LM Transformer')
@@ -114,7 +114,14 @@
)

# tie weight of embedding and log_softmax layers
log_softmax.mlp.last_linear_layer.weight = encoder.embedding_layer.token_embedding.weight
# log_softmax.mlp.last_linear_layer.weight = encoder.embedding_layer.token_embedding.weight
log_softmax.tie_weights_with(
encoder,
weight_names=["mlp.layer0.weight"],
name2name_and_transform={
"mlp.layer0.weight": ("embedding_layer.token_embedding.weight", WeightShareTransform.SAME)
},
)


def create_pipeline(
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@
import nemo
import nemo.collections.nlp as nemo_nlp
from nemo.collections.nlp.callbacks.machine_translation_callback import eval_epochs_done_callback, eval_iter_callback
from nemo.core import WeightShareTransform
from nemo.utils.lr_policies import get_lr_policy

parser = nemo.utils.NemoArgParser(description='Transformer for Neural Machine Translation')
@@ -165,8 +166,25 @@
)

if tie_weight:
log_softmax.mlp.last_linear_layer.weight = encoder.embedding_layer.token_embedding.weight
decoder.embedding_layer.token_embedding.weight = encoder.embedding_layer.token_embedding.weight
# log_softmax.mlp.last_linear_layer.weight = encoder.embedding_layer.token_embedding.weight
log_softmax.tie_weights_with(
encoder,
weight_names=["mlp.last_linear_layer.weight"],
name2name_and_transform={
"mlp.last_linear_layer.weight": ("embedding_layer.token_embedding.weight", WeightShareTransform.SAME)
},
)
# decoder.embedding_layer.token_embedding.weight = encoder.embedding_layer.token_embedding.weight
decoder.tie_weights_with(
encoder,
weight_names=["embedding_layer.token_embedding.weight"],
name2name_and_transform={
"embedding_layer.token_embedding.weight": (
"embedding_layer.token_embedding.weight",
WeightShareTransform.SAME,
)
},
)


def create_pipeline(dataset_src, dataset_tgt, tokens_in_batch, clean=False, training=True):
128 changes: 86 additions & 42 deletions nemo/backends/pytorch/actions.py
Original file line number Diff line number Diff line change
@@ -5,13 +5,15 @@
import json
import os
from collections import defaultdict
from contextlib import ExitStack
from pathlib import Path
from typing import Dict, List, Optional
from typing import List, Optional

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP

from nemo import logging
from nemo.backends.pytorch.module_wrapper import TrainableNeuralModuleWrapper
@@ -25,9 +27,8 @@

# these imports will happen on as-needed basis
amp = None
convert_syncbn = None
create_syncbn_process_group = None
DDP = None
# convert_syncbn = None
blisc marked this conversation as resolved.
Show resolved Hide resolved
# create_syncbn_process_group = None
LARC = None
FusedLAMB = None
FusedAdam = None
@@ -59,18 +60,16 @@ def __init__(
global amp
amp = importlib.import_module('apex.amp')
if local_rank is not None:
global convert_syncbn
global create_syncbn_process_group
global DDP
# global convert_syncbn
# global create_syncbn_process_group
global LARC
global FusedLAMB
global FusedAdam
global FusedNovoGrad
parallel = importlib.import_module('apex.parallel')
apex_optimizer = importlib.import_module('apex.optimizers')
convert_syncbn = parallel.convert_syncbn_model
create_syncbn_process_group = parallel.create_syncbn_process_group
DDP = parallel.DistributedDataParallel
# convert_syncbn = parallel.convert_syncbn_model
# create_syncbn_process_group = parallel.create_syncbn_process_group
LARC = parallel.LARC
FusedLAMB = apex_optimizer.FusedLAMB
FusedAdam = apex_optimizer.FusedAdam
@@ -379,7 +378,7 @@ def __initialize_amp(
return optimizer

def __nm_graph_forward_pass(
self, call_chain, registered_tensors, mode=ModelMode.train, disable_allreduce=False, use_cache=False,
self, call_chain, registered_tensors, mode=ModelMode.train, use_cache=False,
):
for ind in range(1, len(call_chain)):
if use_cache:
@@ -399,12 +398,12 @@ def __nm_graph_forward_pass(
m_id = call_chain[ind][0].unique_instance_id
pmodule = self.module_reference_table[m_id][1]

if self._local_rank is not None:
if isinstance(pmodule, DDP):
if disable_allreduce:
pmodule.disable_allreduce()
else:
pmodule.enable_allreduce()
# if self._local_rank is not None:
# if isinstance(pmodule, DDP):
# if disable_allreduce:
# pmodule.disable_allreduce()
# else:
# pmodule.enable_allreduce()

if mode == ModelMode.train:
# if module.is_trainable():
@@ -935,9 +934,8 @@ def __extract_dynamic_axes(port_name: str, ntype: NeuralType, dynamic_axes: defa
outputs_to_drop = set()
if type(module).__name__ == "JasperEncoder":
logging.info(
f"Module is JasperEncoder. We are removing"
f"input and output length ports since they "
f"are not needed for deployment"
"Module is JasperEncoder. We are removing input and output length ports since they are not needed for "
"deployment"
)
inputs_to_drop.add("length")
outputs_to_drop.add("encoded_lengths")
@@ -1072,6 +1070,11 @@ def train(
gradient_predivide=False,
amp_max_loss_scale=2.0 ** 24,
):
if gradient_predivide:
logging.error(
"gradient_predivide is currently disabled, and is under consideration for removal in future versions. "
"If this functionality is needed, please raise a github issue."
)
if not optimization_params:
optimization_params = {}
num_epochs = optimization_params.get("num_epochs", None)
@@ -1213,23 +1216,44 @@ def train(
key = call_chain[i][0].unique_instance_id
pmodule = self.module_reference_table[key][1]
if not isinstance(pmodule, DDP) and isinstance(pmodule, torch.nn.Module):
gpf = 1
if gradient_predivide:
gpf = dist.get_world_size()
pmodule = DDP(pmodule, gradient_predivide_factor=gpf)

# Convert batchnorm modules to synced if applicable
if synced_batchnorm and isinstance(pmodule, torch.nn.Module):
world_size = dist.get_world_size()
if synced_batchnorm_groupsize > 0 and world_size % synced_batchnorm_groupsize != 0:
raise ValueError(
f"Synchronized batch norm group size"
f" ({synced_batchnorm_groupsize}) must be 0"
f" or divide total number of GPUs"
f" ({world_size})."
# gpf = 1
# if gradient_predivide:
# gpf = dist.get_world_size()
# pmodule = DDP(pmodule, gradient_predivide_factor=gpf) # Old Apex Method

# Per pytorch docs, convert sync bn prior to DDP
if synced_batchnorm:
world_size = dist.get_world_size()
sync_batchnorm_group = None
if synced_batchnorm_groupsize > 0:
if world_size % synced_batchnorm_groupsize != 0:
raise ValueError(
f"Synchronized batch norm group size ({synced_batchnorm_groupsize}) must be 0"
f" or divide total number of GPUs ({world_size})."
)
sync_batchnorm_group = torch.distributed.new_group(synced_batchnorm_groupsize)
pmodule = nn.SyncBatchNorm.convert_sync_batchnorm(
pmodule, process_group=sync_batchnorm_group
)
process_group = create_syncbn_process_group(synced_batchnorm_groupsize)
pmodule = convert_syncbn(pmodule, process_group=process_group)

# By default, disable broadcast_buffers. This disables batch norm synchronization on forward
# pass
pmodule = DDP(
pmodule, device_ids=[self.local_rank], broadcast_buffers=False, find_unused_parameters=True
)

# # Convert batchnorm modules to synced if applicable
# if synced_batchnorm and isinstance(pmodule, torch.nn.Module):
# world_size = dist.get_world_size()
# if synced_batchnorm_groupsize > 0 and world_size % synced_batchnorm_groupsize != 0:
# raise ValueError(
# f"Synchronized batch norm group size"
# f" ({synced_batchnorm_groupsize}) must be 0"
# f" or divide total number of GPUs"
# f" ({world_size})."
# )
# process_group = create_syncbn_process_group(synced_batchnorm_groupsize)
# pmodule = convert_syncbn(pmodule, process_group=process_group)

self.module_reference_table[key] = (
self.module_reference_table[key][0],
@@ -1308,9 +1332,7 @@ def train(
}
disable_allreduce = batch_counter < (batches_per_step - 1)
self.__nm_graph_forward_pass(
call_chain=curr_call_chain,
registered_tensors=registered_tensors,
disable_allreduce=disable_allreduce,
call_chain=curr_call_chain, registered_tensors=registered_tensors,
)

curr_tensors_to_optimize = training_loop[self.step % len(training_loop)][1]
@@ -1331,19 +1353,31 @@ def train(
if nan:
continue
if self._optim_level in AmpOptimizations and self._optim_level != Optimization.mxprO0:
with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce,) as scaled_loss:
with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce) as scaled_loss:
if torch.isnan(scaled_loss).any() or torch.isinf(scaled_loss).any():
if stop_on_nan_loss:
raise ValueError('Loss is NaN or inf -' ' exiting')
logging.warning('WARNING: Loss is NaN or inf')
curr_optimizer.zero_grad()
continue
scaled_loss.backward(bps_scale.to(scaled_loss.get_device()))
if disable_allreduce:
with ExitStack() as stack:
for mod in self.get_DDP_modules(curr_call_chain):
stack.enter_context(mod.no_sync())
scaled_loss.backward(bps_scale.to(scaled_loss.get_device()))
else:
scaled_loss.backward(bps_scale.to(scaled_loss.get_device()))
# no AMP optimizations needed
else:
# multi-GPU, float32
if self._local_rank is not None:
final_loss.backward(bps_scale.to(final_loss.get_device()))
if disable_allreduce:
with ExitStack() as stack:
for mod in self.get_DDP_modules(curr_call_chain):
stack.enter_context(mod.no_sync())
final_loss.backward(bps_scale.to(final_loss.get_device()))
else:
final_loss.backward(bps_scale.to(final_loss.get_device()))
# single device (CPU or GPU)
else:
# Fix (workaround?) enabling to backpropagate gradiens on CPUs.
@@ -1438,3 +1472,13 @@ def infer(
use_cache=use_cache,
offload_to_cpu=offload_to_cpu,
)

def get_DDP_modules(self, call_chain):
modules = []
for ind in range(1, len(call_chain)):
m_id = call_chain[ind][0].unique_instance_id
module = self.module_reference_table[m_id][1]
if isinstance(module, DDP):
modules.append(module)

return modules
Loading