Skip to content

Commit

Permalink
adding checkpoint callback
Browse files Browse the repository at this point in the history
Signed-off-by: Jason <[email protected]>
  • Loading branch information
blisc committed May 16, 2020
1 parent 4f6e1f7 commit 3c7b89e
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 37 deletions.
7 changes: 4 additions & 3 deletions examples/asr/jasper_an4_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def create_dags(model_config_file, vocab, args, nf):
predictions.rename("test")
train_callback = nemo.core.SimpleLossLogger(tensors_to_log=["loss", "test"])

# checkpointer_callback = nemo.core.CheckpointCallback(folder=nf.checkpoint_dir, step_freq=args.checkpoint_save_freq)
checkpointer_callback = nemo.core.CheckpointCallback(folder=nf.checkpoint_dir, step_freq=args.checkpoint_save_freq)

# eval_tensors = [loss_e, predictions_e, transcript_e, transcript_len_e]
# eval_callback = nemo.core.EvaluatorCallback(
Expand All @@ -105,11 +105,12 @@ def create_dags(model_config_file, vocab, args, nf):
# eval_at_start=not args.do_not_eval_at_start,
# )
# callbacks = [train_callback, checkpointer_callback, eval_callback]
callbacks = [train_callback]
callbacks = [train_callback, checkpointer_callback]

@nemo.core.callbacks.on_step_start
def my_own_func(state):
print(state)
if state["step"] % 100 == 0:
print(state)

callbacks.append(my_own_func)

Expand Down
72 changes: 52 additions & 20 deletions nemo/backends/pytorch/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from nemo.backends.pytorch.optimizers import AdamW, Novograd, master_params
from nemo.core import DeploymentFormat, DeviceType, NeuralModule, NmTensor
from nemo.core.callbacks import ActionCallback, EvaluatorCallback, NeMoCallback, SimpleLossLoggerCallback
from nemo.core.neural_factory import Actions, OperationMode, Optimization, TrainingState, topological_sort_from_leaves
from nemo.core.neural_factory import Actions, OperationMode, Optimization, topological_sort_from_leaves
from nemo.core.neural_types import *
from nemo.utils.app_state import AppState
from nemo.utils.decorators import deprecated
Expand Down Expand Up @@ -1145,8 +1145,40 @@ def _update_callbacks(callbacks=None, registered_tensors=None, final_loss=None):
else: # For now, we can use the old callback function. In the future we should improve this
registered_tensors["loss"] = final_loss

def get_state(self):
return {"step": self.step, "tensors": self._training_state, "epoch_num":self.epoch_num, "optimizer": self.optimizers}
def get_state(action):
class StateWrapper(dict):
def restore_state_from(self, path):
if os.path.isfile(path):
# map_location could be cuda:<device_id> but cpu seems to be more
# general since we are also saving step and epoch_num
# load_state_dict should move the variables to the relevant device
checkpoint = torch.load(path, map_location="cpu")
self.step = checkpoint["step"]
self.epoch_num = checkpoint["epoch_num"]
if checkpoint["optimizer_state"]:
for opt, opt_chkpt in zip(self["optimizers"], checkpoint["optimizer_state"]):
opt.load_state_dict(opt_chkpt)
else:
raise FileNotFoundError("Could not find checkpoint file: {0}".format(path))

def save_state_to(self, path):
state = {
"step": self["step"],
"epoch_num": self["epoch"],
"optimizer_state": [opt.state_dict() for opt in self["optimizers"]],
}
torch.save(state, path)

return StateWrapper(
{
"step": action.step,
"tensors": action._training_state,
"epoch": action.epoch_num,
"local_rank": action.local_rank,
"global_rank": action.global_rank,
"optimizers": action.optimizers,
}
)

self._training_state = TrainingState(self)
# Analyse the arguments passed to train.
Expand Down Expand Up @@ -1181,9 +1213,9 @@ def get_state(self):

if tensors_to_optimize is None:
# This is Evaluation Mode
self._init_callbacks(callbacks)
_init_callbacks(callbacks, self)
# Do action start callbacks
self._perform_on_action_end(callbacks=callbacks)
_perform_on_action_end(callbacks, get_state(self))
return
# Check if tensors_to_optimize is just a list of NmTensors
elif tensors_to_optimize is not None and (
Expand Down Expand Up @@ -1385,9 +1417,9 @@ def get_state(self):
train_dataloader = dataNM.data_iterator
train_sampler = None

self._init_callbacks(callbacks)
_init_callbacks(callbacks, self)
# Do action start callbacks
self._perform_on_action_start(callbacks=callbacks)
_perform_on_action_start(callbacks, get_state(self))

nan_or_inf = False

Expand All @@ -1400,7 +1432,7 @@ def get_state(self):
break

# Register epochs start with callbacks
self._perform_on_epoch_start(callbacks=callbacks)
_perform_on_epoch_start(callbacks, get_state(self))

# iteration over batches in epoch
batch_counter = 0
Expand All @@ -1413,7 +1445,7 @@ def get_state(self):
curr_optimizer = training_loop[self.step % len(training_loop)][0]
curr_optimizer.zero_grad()
# Register iteration start with callbacks
self._perform_on_step_start(callbacks=callbacks)
_perform_on_step_start(callbacks, get_state(self))

# set learning rate policy
if lr_policy is not None:
Expand Down Expand Up @@ -1445,18 +1477,18 @@ def get_state(self):

for t, d in zip(curr_call_chain[0][2].values(), tensors):
if t is not None:
self.training_state.set_tensor(t, d)
self._training_state.set_tensor(t, d)
disable_allreduce = batch_counter < (batches_per_step - 1)
self.__nm_graph_forward_pass(
call_chain=curr_call_chain, registered_tensors=self.training_state.tensor_dict,
call_chain=curr_call_chain, registered_tensors=self._training_state.tensor_dict,
)

curr_tensors_to_optimize = training_loop[self.step % len(training_loop)][1]
final_loss = 0
for tensor in curr_tensors_to_optimize:
if (
torch.isnan(self.training_state.tensor_dict[tensor.unique_name]).any()
or torch.isinf(self.training_state.tensor_dict[tensor.unique_name]).any()
torch.isnan(self._training_state.tensor_dict[tensor.unique_name]).any()
or torch.isinf(self._training_state.tensor_dict[tensor.unique_name]).any()
):
if (
(stop_on_nan_loss)
Expand All @@ -1472,7 +1504,7 @@ def get_state(self):
)
else:
logging.warning('Loss is NaN or inf, continuing training')
final_loss += self.training_state.tensor_dict[tensor.unique_name]
final_loss += self._training_state.tensor_dict[tensor.unique_name]

if self._optim_level in AmpOptimizations and self._optim_level != Optimization.mxprO0:
with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce) as scaled_loss:
Expand Down Expand Up @@ -1514,21 +1546,21 @@ def get_state(self):
curr_optimizer.step()
batch_counter = 0
# Register iteration end with callbacks
self._update_callbacks(
callbacks=callbacks, registered_tensors=self.training_state.tensor_dict, final_loss=final_loss
_update_callbacks(
callbacks, registered_tensors=self._training_state.tensor_dict, final_loss=final_loss
)
self._perform_on_step_end(callbacks=callbacks)
_perform_on_step_end(callbacks, get_state(self))
self.step += 1
self.training_state.clear_dict()
self._training_state.clear_dict()
# End of epoch for loop
# Register epochs end with callbacks
self._perform_on_epoch_end(callbacks=callbacks)
_perform_on_epoch_end(callbacks, get_state(self))
self.epoch_num += 1

# Check again if we should stop on NaN/inf
self._check_nan_or_inf(placement_gpu, nan_or_inf)

self._perform_on_action_end(callbacks=callbacks)
_perform_on_action_end(callbacks, get_state(self))

def infer(
self,
Expand Down
30 changes: 16 additions & 14 deletions nemo/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def on_step_end(self, state):
# raise KeyError(f"{self} was passed {tensor_key} but the tensor was not found in the state_dict. "
# f"Current state tensors include {state['tensors'].tensor_list()}")


class WandBLogger(NeMoCallback):
def __init__(self, step_freq=100, tensors_to_log=["loss"]):
# Step_freq: how often logs are printed
Expand All @@ -104,6 +105,7 @@ def on_step_end(self, state):
# raise KeyError(f"{self} was passed {tensor_key} but the tensor was not found in the state_dict. "
# f"Current state tensors include {state['tensors'].tensor_list()}")


class SimpleLossLogger(NeMoCallback):
def __init__(self, step_freq=100, tensors_to_log=["loss"]):
# Step_freq: how often logs are printed
Expand Down Expand Up @@ -442,7 +444,7 @@ def __init__(
self._force_load = force_load

def __save_to(self, path, state):
if state.global_rank is not None and state.global_rank != 0:
if state["global_rank"] is not None and state["global_rank"] != 0:
return
if not os.path.isdir(path):
logging.info(f"Creating {path} folder")
Expand All @@ -457,19 +459,19 @@ def __save_to(self, path, state):
)
unique_mod_names.add(str(module))
if self._step_freq > -1:
filename = f"{module}-STEP-{state.step}.pt"
filename = f"{module}-STEP-{state['step']}.pt"
else:
filename = f"{module}-EPOCH-{state.epoch_num}.pt"
filename = f"{module}-EPOCH-{state['epoch']}.pt"
module.save_to(os.path.join(path, filename))

if self._step_freq > -1:
filename = f"trainer-STEP-{state.step}.pt"
state.save_state_to(f'{path}/{filename}')
self._saved_ckpts.append(f'-{state.step}.pt')
filename = f"trainer-STEP-{state['step']}.pt"
state.save_state_to(f"{path}/{filename}")
self._saved_ckpts.append(f"-{state['step']}.pt")
else:
filename = f"trainer-EPOCH-{state.epoch_num}.pt"
state.save_state_to(f'{path}/{filename}')
self._saved_ckpts.append(f'-{state.epoch_num}.pt')
filename = f"trainer-EPOCH-{state['epoch']}.pt"
state.save_state_to(f"{path}/{filename}")
self._saved_ckpts.append(f"-{state['epoch']}.pt")

if len(self._saved_ckpts) > self._ckpt2keep:
for end in self._saved_ckpts[: -self._ckpt2keep]:
Expand All @@ -495,7 +497,7 @@ def __restore_from(self, path, state):
module_checkpoints = get_checkpoint_from_dir(modules_to_restore_name, path)

for mod, checkpoint in zip(modules_to_restore, module_checkpoints):
mod.restore_from(checkpoint, state.local_rank)
mod.restore_from(checkpoint, state["local_rank"])
except (BaseException, ValueError) as e:
if self._force_load:
raise ValueError(
Expand Down Expand Up @@ -536,21 +538,21 @@ def on_train_start(self, state):
for name in unique_mod_names:
logging.info(f"{name}")
logging.info(f"Total model parameters: {num_parameters}")
self.__restore_from(path=self._load_from_folder)
self.__restore_from(self._load_from_folder, state)

def on_step_end(self, state):
step = state["step"]
if self._step_freq > 0 and step % self._step_freq == 0 and step > 0:
self.__save_to(path=self._folder)
self.__save_to(self._folder, state)

def on_train_end(self, state):
if self._step_freq > 0 or self._epoch_freq > 0:
self.__save_to(path=self._folder)
self.__save_to(self._folder, state)

def on_epoch_end(self, state):
epoch = state["epoch"]
if self._epoch_freq > 0 and epoch % self._epoch_freq == 0 and epoch > 0:
self.__save_to(path=self._folder)
self.__save_to(self._folder, state)


class EvaluatorCallback(ActionCallback):
Expand Down

0 comments on commit 3c7b89e

Please sign in to comment.