Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: sam1373 <[email protected]>
  • Loading branch information
sam1373 committed Jan 24, 2022
1 parent da1b142 commit 6d38032
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 53 deletions.
162 changes: 110 additions & 52 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,36 @@ def get_test_dataloader_prefix(self, dataloader_idx: int = 0) -> str:
"""
return self._test_names[dataloader_idx]

def load_part_of_state_dict(self, state_dict, include, exclude, load_from_string):

excluded_param_names = []
# create dict
dict_to_load = {}
for k, v in state_dict.items():
should_add = False
# if any string in include is present, should add
for p in include:
if p in k:
should_add = True
break
# except for if any string from exclude is present
for e in exclude:
if e in k:
excluded_param_names.append(k)
should_add = False
break
if should_add:
dict_to_load[k] = v

# Restore checkpoint part into current model
self.load_state_dict(dict_to_load, strict=False)
logging.info(f'Model checkpoint partially restored from `{load_from_string}``')
if len(excluded_param_names) > 0:
logging.info(
f'The following parameters were excluded from loading from `{load_from_string}` : `{excluded_param_names}` `'
)
logging.info(f'Make sure that this is what you wanted!')

@rank_zero_only
def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: str = 'cpu'):
"""
Expand All @@ -869,18 +899,20 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st
path: Str path to .nemo model
parts_to_load: List of strings, at least one of which needs to be contained in parameter name
include: List of strings, at least one of which needs to be contained in parameter name
to be loaded from this .nemo file
excluded: Optional list of strings, which can be used to exclude any parameter containing one of
exclude: Optional list of strings, which can be used to exclude any parameter containing one of
these strings from being loaded from this .nemo file
init_from_pretrained_model: Str name of a pretrained model checkpoint (obtained via cloud).
The model will be downloaded (or a cached copy will be used), instantiated and then
its state dict will be extracted.
its state dict will be extracted. If loading from multiple files, you can pass in a dict
with the same format as for init_from_nemo_model, except with "name" instead of "path"
init_from_ptl_ckpt: Str name of a Pytorch Lightning checkpoint file. It will be loaded and
the state dict will extracted.
the state dict will extracted. If loading from multiple files, you can pass in a dict
with the same format as for init_from_nemo_model.
Args:
cfg: The config used to instantiate the model. It need only contain one of the above keys.
Expand Down Expand Up @@ -926,68 +958,94 @@ def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: st
model_path, map_location=map_location, strict=cfg.get("init_strict", True)
)

parts = model_load_cfg.pop('parts_to_load', [])
excluded = model_load_cfg.pop('excluded', [])

# create dict
dict_to_load = {}
for k, v in restored_model.state_dict().items():
should_add = True
for p in parts:
if not (p in k):
should_add = False
break
for e in excluded:
if e in k:
should_add = False
break
if should_add:
dict_to_load[k] = v

# Restore checkpoint part into current model
self.load_state_dict(dict_to_load, strict=False)
logging.info(f'Model checkpoint partially restored from nemo file with path : `{model_path}``')
include = model_load_cfg.pop('include', [])
exclude = model_load_cfg.pop('exclude', [])

self.load_part_of_state_dict(
restored_model.state_dict(), include, exclude, f'nemo file with path `{model_path}`'
)

del restored_model

if 'init_from_pretrained_model' in cfg and cfg.init_from_pretrained_model is not None:
with open_dict(cfg):
# Restore model
model_name = cfg.pop('init_from_pretrained_model')

# Check if model is being resumed or not - only works if `Trainer` is attached to model
if hasattr(self, 'trainer') and self.trainer is not None:
trainer = self.trainer
if (
hasattr(trainer, 'resume_from_checkpoint')
and trainer.checkpoint_connector.resume_checkpoint_path is not None
):
logging.info(
"Model training is being resumed via Pytorch Lightning.\n"
"Initialization from pretrained model (via cloud) will be skipped."

if isinstance(cfg.init_from_pretrained_model, str):
model_name = cfg.pop('init_from_pretrained_model')

# Check if model is being resumed or not - only works if `Trainer` is attached to model
if hasattr(self, 'trainer') and self.trainer is not None:
trainer = self.trainer
if (
hasattr(trainer, 'resume_from_checkpoint')
and trainer.checkpoint_connector.resume_checkpoint_path is not None
):
logging.info(
"Model training is being resumed via Pytorch Lightning.\n"
"Initialization from pretrained model (via cloud) will be skipped."
)
return

restored_model = self.from_pretrained(
model_name, map_location=map_location, strict=cfg.get("init_strict", True)
)

# Restore checkpoint into current model
self.load_state_dict(restored_model.state_dict(), strict=False)
logging.info(f'Model checkpoint restored from pretrained chackpoint with name : `{model_name}`')

del restored_model
else:
model_load_dict = cfg.init_from_pretrained_model
for model_load_cfg in model_load_dict.values():
model_name = model_load_cfg.name
# Restore model
restored_model = self.from_pretrained(
model_name, map_location=map_location, strict=cfg.get("init_strict", True)
)
return

restored_model = self.from_pretrained(
model_name, map_location=map_location, strict=cfg.get("init_strict", True)
)
include = model_load_cfg.pop('include', [])
exclude = model_load_cfg.pop('exclude', [])

# Restore checkpoint into current model
self.load_state_dict(restored_model.state_dict(), strict=False)
logging.info(f'Model checkpoint restored from pretrained chackpoint with name : `{model_name}`')
self.load_part_of_state_dict(
restored_model.state_dict(),
include,
exclude,
f'pretrained chackpoint with name `{model_name}`',
)

del restored_model
del restored_model

if 'init_from_ptl_ckpt' in cfg and cfg.init_from_ptl_ckpt is not None:
with open_dict(cfg):
# Restore checkpoint
ckpt_path = cfg.pop('init_from_ptl_ckpt')
ckpt = torch.load(ckpt_path, map_location=map_location)
if isinstance(cfg.init_from_ptl_ckpt, str):
# Restore checkpoint
ckpt_path = cfg.pop('init_from_ptl_ckpt')
ckpt = torch.load(ckpt_path, map_location=map_location)

# Restore checkpoint into current model
self.load_state_dict(ckpt['state_dict'], strict=False)
logging.info(
f'Model checkpoint restored from pytorch lightning chackpoint with path : `{ckpt_path}`'
)

# Restore checkpoint into current model
self.load_state_dict(ckpt['state_dict'], strict=False)
logging.info(f'Model checkpoint restored from pytorch lightning chackpoint with path : `{ckpt_path}`')
del ckpt
else:
model_load_dict = cfg.init_from_ptl_ckpt
for model_load_cfg in model_load_dict.values():
ckpt_path = model_load_cfg.path
# Restore model
ckpt = torch.load(ckpt_path, map_location=map_location)

include = model_load_cfg.pop('include', [])
exclude = model_load_cfg.pop('exclude', [])

self.load_part_of_state_dict(
ckpt['state_dict'], include, exclude, f'nemo file with path `{model_path}`'
)

del ckpt
del ckpt

def teardown(self, stage: str):
"""
Expand Down
2 changes: 1 addition & 1 deletion tutorials/asr/Self_Supervised_Pre_Training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@
{
"cell_type": "markdown",
"source": [
"Note that for this loss the outputs must match the inputs, so since we are using Citrinet architecture with 8x stride, we would need to either set \"cfg.model.loss.combine_time_steps\" to 8, or put additional stride layers in the decoder. By default for Citrinet with 8x stride we use \"cfg.model.loss.combine_time_steps=4\" and \"cfg.model.decoder_stride_layers=1\" to match the 8x stride."
"Note that for this loss the outputs must match the inputs, so since we are using Citrinet architecture with 8x stride, we would need to either set \"cfg.model.loss.combine_time_steps\" to 8, or put additional stride layers in the decoder. By default for Citrinet with 8x stride we use \"cfg.model.loss.combine_time_steps=4\" and \"cfg.model.decoder.stride_layers=1\" to match the 8x stride."
],
"metadata": {
"id": "4JnepitBZ3ta"
Expand Down

0 comments on commit 6d38032

Please sign in to comment.