Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic freezing in Nemo #5879

Merged
merged 17 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/source/core/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -756,3 +756,30 @@ It also means that ``.forward(...)`` and ``__call__(...)`` methods each produce

.. note:: To temporarily disable typechecking, you can enclose your code in ```with typecheck.disable_checks():``` statement.


Dynamic Layer Freezing
----------------------

You can selectively freeze any modules inside a Nemo model by specifying a freezing schedule in the config yaml. Freezing stops any gradient updates
to that module, so that its weights are not changed for that step. This can be useful for combatting catastrophic forgetting, for example
when finetuning a large pretrained model on a small dataset.

The default approach is to freeze a module for the first N training steps, but you can also enable freezing for a specific range of steps,
for example, from step 20 - 100, or even activate freezing from some N until the end of training. You can also freeze a module for the entire training run.
Dynamic freezing is specified in training steps, not epochs.

To enable freezing, add the following to your config:

.. code-block:: yaml

model:
...
freeze_updates:
enabled: true # set to false if you want to disable freezing

modules: # list all of the modules you want to have freezing logic for
encoder: 200 # module will be frozen for the first 200 training steps
decoder: [50, -1] # module will be frozen at step 50 and will remain frozen until training ends
joint: [10, 100] # module will be frozen between step 10 and step 100 (step >= 10 and step <= 100)
transcoder: -1 # module will be frozen for the entire training run

54 changes: 53 additions & 1 deletion nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,12 +1622,32 @@ def _setup_nsys_profiling(self):
else:
raise ValueError(f'Nsys end_step must be greater than or equal to nsys start_step')

def on_train_start(self):
""" PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-start
We use it here to copy the relevant config for dynamic freezing.
"""

# dynamic freezing
# should fire only once, on the very first batch of training and never again
if not hasattr(self, '_freeze_cfg'):
if (
hasattr(self.cfg, 'freeze_updates')
and self.cfg.freeze_updates is not None
and self.cfg.freeze_updates.get('enabled', False)
):
setattr(self, '_freeze_cfg', OmegaConf.to_container(self.cfg.freeze_updates))
self._freeze_cfg['is_frozen'] = {k: False for k in self._freeze_cfg['modules'].keys()}
else:
setattr(self, '_freeze_cfg', None)

def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> Optional[int]:
""" PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-batch-start
We use it here to enable nsys profiling.
We use it here to enable nsys profiling and dynamic freezing.
"""

# nsys profiling
if self.device.type == 'cuda':
if hasattr(self, '_nsys_profile_enabled'):
if self._nsys_profile_enabled:
Expand All @@ -1637,6 +1657,28 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> O
if self._nsys_profile_gen_shape:
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()

# dynamic freezing
if hasattr(self, '_freeze_cfg') and self._freeze_cfg is not None:
if self.training and hasattr(self, "trainer") and self.trainer is not None:
num_updates = self.trainer.global_step + 1

for ml, m_steps in self._freeze_cfg['modules'].items():
# we could do hasattr check here, but it's too expensive for each step
# consequently you'll throw an error if the module name doesn't exist
# or was spelled wrong in the config.yaml
if isinstance(m_steps, list):
assert len(m_steps) == 2, "freeze_updates modules list cannot have more than two elements"
should_freeze = (num_updates >= m_steps[0]) and (num_updates <= m_steps[1] or m_steps[1] == -1)
else:
should_freeze = num_updates <= m_steps or m_steps == -1
if should_freeze and not self._freeze_cfg['is_frozen'][ml]:
getattr(self, ml).freeze()
getattr(self, ml).train()
self._freeze_cfg['is_frozen'][ml] = True
elif not should_freeze and self._freeze_cfg['is_frozen'][ml]:
getattr(self, ml).unfreeze()
self._freeze_cfg['is_frozen'][ml] = False

def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, unused: int = 0) -> None:
""" PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-batch-end
Expand All @@ -1650,6 +1692,16 @@ def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, unused: int =
logging.info("====== End nsys profiling ======")
torch.cuda.cudart().cudaProfilerStop()

def on_train_end(self):
""" PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end
We use it here to cleanup the dynamic freezing config.
"""

# dynamic freezing cleanup
if hasattr(self, '_freeze_cfg'):
delattr(self, '_freeze_cfg')

# TODO: Remove in PTL 1.7.2
def cuda(self, device=None):
""" PTL is overriding this method and changing the pytorch behavior of a module.
Expand Down