Skip to content

Commit

Permalink
clean code for comments
Browse files Browse the repository at this point in the history
Signed-off-by: Jason <[email protected]>
  • Loading branch information
blisc committed May 28, 2020
1 parent 31fc556 commit b9e4441
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 69 deletions.
1 change: 0 additions & 1 deletion examples/asr/jasper_an4.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def main():
# Delete old graph and make a new one
del g0
nf.reset_trainer()
# [print(p) for p in nemo.utils.app_state.AppState().modules]
loss, eval_tensors, callbacks, total_steps, _, _, new_g = create_dags(args.model_config, vocab, args, nf)

nf.train(
Expand Down
66 changes: 4 additions & 62 deletions nemo/backends/pytorch/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@

# these imports will happen on as-needed basis
amp = None
# convert_syncbn = None
# create_syncbn_process_group = None
LARC = None
FusedLAMB = None
FusedAdam = None
Expand Down Expand Up @@ -63,16 +61,12 @@ def __init__(
global amp
amp = importlib.import_module('apex.amp')
if local_rank is not None:
# global convert_syncbn
# global create_syncbn_process_group
global LARC
global FusedLAMB
global FusedAdam
global FusedNovoGrad
parallel = importlib.import_module('apex.parallel')
apex_optimizer = importlib.import_module('apex.optimizers')
# convert_syncbn = parallel.convert_syncbn_model
# create_syncbn_process_group = parallel.create_syncbn_process_group
LARC = parallel.LARC
FusedLAMB = apex_optimizer.FusedLAMB
FusedAdam = apex_optimizer.FusedAdam
Expand Down Expand Up @@ -150,12 +144,6 @@ def __get_top_sorted_modules_and_dataloader(self, hook: List[NmTensor]):
"distributed mode. Please instantiate NeuralModuleFactory first and pass its instance as "
"`factory` parameter to all your Neural Module objects.".format(str(m[0]))
)
# key = m[0].unique_instance_id
# if key not in self.module_reference_table:
# if isinstance(m[0], TrainableNeuralModuleWrapper):
# self.module_reference_table[key] = (m[0], m[0]._pt_module)
# else:
# self.module_reference_table[key] = (m[0], m[0])

return top_sorted_modules, tdataset

Expand Down Expand Up @@ -349,18 +337,9 @@ def __nm_graph_forward_pass(
if in_cache:
continue
call_args = call_chain[ind][1]
# module = call_chain[ind][0]
# pmodule = self.module_reference_table[m_id][1]
m_id = call_chain[ind][0].unique_instance_id
pmodule = self.ddp_module_dict[m_id] if self.ddp_initialized else call_chain[ind][0]

# if self._local_rank is not None:
# if isinstance(pmodule, DDP):
# if disable_allreduce:
# pmodule.disable_allreduce()
# else:
# pmodule.enable_allreduce()

if mode == OperationMode.training:
# if module.is_trainable():
if isinstance(pmodule, nn.Module):
Expand All @@ -374,14 +353,8 @@ def __nm_graph_forward_pass(
# prepare call signature for `module`
call_set = {}
for tensor_name, nmtensor in call_args.items():
# _add_uuid_2_name(nmtensor.name, nmtensor.producer._uuid)
key = nmtensor.unique_name
call_set[tensor_name] = registered_tensors[key]
# actual PyTorch module call with signature
# if isinstance(self.module_reference_table[m_id][0], TrainableNeuralModuleWrapper,):
# new_tensors = pmodule(**call_set)
# else:
# new_tensors = pmodule(force_pt=True, **call_set)
new_tensors = pmodule(force_pt=True, **call_set)

if not isinstance(new_tensors, List):
Expand Down Expand Up @@ -462,11 +435,6 @@ def _eval(self, tensors_2_evaluate, callback, step, verbose=False):
assert dist.is_initialized()
is_distributed = True
world_size = torch.distributed.get_world_size()
# logging.info(
# "Doing distributed evaluation. Rank {0} of {1}".format(
# self.local_rank, world_size
# )
# )

if dl_nm.dataset is not None:
sampler = None
Expand Down Expand Up @@ -638,11 +606,6 @@ def _infer(
assert dist.is_initialized()
is_distributed = True
world_size = torch.distributed.get_world_size()
# logging.info(
# "Doing distributed evaluation. Rank {0} of {1}".format(
# self.local_rank, world_size
# )
# )
if dl_nm.dataset is not None:
sampler = None
if not isinstance(dl_nm.dataset, torch.utils.data.IterableDataset):
Expand Down Expand Up @@ -729,12 +692,6 @@ def _infer(
use_cache=use_cache,
)

# if offload_to_cpu:
# # Take all cuda tensors and save them to value_dict as
# # cpu tensors to save GPU memory
# for name, tensor in registered_e_tensors.items():
# if isinstance(tensor, torch.Tensor):
# registered_e_tensors[name] = tensor.cpu()
if cache:
self.append_to_cache(registered_e_tensors, offload_to_cpu)

Expand Down Expand Up @@ -913,10 +870,10 @@ def __extract_dynamic_axes(port_name: str, ntype: NeuralType, dynamic_axes: defa

module.eval()
try:
# # Remove NeMo-related things from the module
# # We need to change __call__ method. Note that this will change the
# # whole class, not just this object! Which is why we need to repair it
# # in the finally block
# Remove NeMo-related things from the module
# We need to change __call__ method. Note that this will change the
# whole class, not just this object! Which is why we need to repair it
# in the finally block
__orig_call__ = type(module).__call__
type(module).__call__ = torch.nn.Module.__call__

Expand Down Expand Up @@ -1313,10 +1270,6 @@ def save_state_to(self, path):
dataNM = training_loop[0][2][0][0]
placement_gpu = dataNM.placement == DeviceType.AllGpu
if placement_gpu:
# if len(training_loop) > 1:
# raise NotImplementedError(
# "Distributed training does nor work with multiple "
# "optimizers")
logging.info("Doing distributed training")
if t_dataset is not None:
train_sampler = None
Expand All @@ -1341,12 +1294,6 @@ def save_state_to(self, path):
else:
train_sampler = None

# for train_iter in training_loop:
# call_chain = train_iter[2]
# for i in range(1, len(call_chain) - 1):
# key = call_chain[i][0].unique_instance_id
# pmodule = self.module_reference_table[key][1]
# num_trainable_weights = self.module_reference_table[key][1].num_weights
self.ddp_initialized = True
module_list = [mod.name for mod in AppState().modules]
module_list = sorted(module_list)
Expand All @@ -1356,11 +1303,6 @@ def save_state_to(self, path):
num_trainable_weights = module.num_weights
self.ddp_module_dict[key] = module
if not isinstance(module, DDP) and isinstance(module, torch.nn.Module) and num_trainable_weights > 0:
# gpf = 1
# if gradient_predivide:
# gpf = dist.get_world_size()
# pmodule = DDP(pmodule, gradient_predivide_factor=gpf) # Old Apex Method

# Per pytorch docs, convert sync bn prior to DDP
if synced_batchnorm:
world_size = dist.get_world_size()
Expand Down
2 changes: 0 additions & 2 deletions nemo/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,6 @@ def __restore_from(self, path, state):
try:
trainer_checkpoints = get_checkpoint_from_dir(["trainer"], path)
state.restore_state_from(trainer_checkpoints[0])
# for tr, checkpoint in zip([self.action], trainer_checkpoints):
except (ValueError) as e:
logging.warning(e)
logging.warning(
Expand Down Expand Up @@ -891,7 +890,6 @@ def on_iteration_start(self):
setattr(self.module, self.arg_name, value)
if self.tb_writer is not None:
class_name = self.module.__class__.__name__
# name = f'param/{class_name}.{self.arg_name}'
name = f"param/{class_name}.{self.arg_name}"
self.tb_writer.add_scalar(name, value, self.step)
else:
Expand Down
4 changes: 0 additions & 4 deletions nemo/utils/nemo_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,3 @@ def critical(self, msg, *args, mode=LogMode.EACH, **kwargs):
and not self._logged_once(msg, mode)
):
self._logger._log(Logger.CRITICAL, msg, args, **kwargs)


# # Necessary to catch the correct caller
# _logging._srcfile = os.path.normcase(inspect.getfile(Logger.__class__))

0 comments on commit b9e4441

Please sign in to comment.