Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Canary Adapters tutorial #9670

Merged
merged 6 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def change_prompt(
prompt_cls = PromptFormatter.resolve(self.prompt_format)
self.prompt = prompt_cls(
tokenizer=self.tokenizer,
defaults=OmegaConf.to_container(pd) if (pd := self.cfg.prompt_defaults) is not None else None,
defaults=OmegaConf.to_container(pd) if (pd := self.cfg.get('prompt_defaults')) is not None else None,
)

# Update config
Expand Down Expand Up @@ -979,7 +979,7 @@ def _transcribe_on_end(self, trcfg: MultiTaskTranscriptionConfig):
"""
super()._transcribe_on_end(trcfg)

self.transf_decoder.unfreeze()
self.transf_decoder.unfreeze(partial=True)

def _may_be_make_dict_and_fix_paths(self, json_items, manifest_path, trcfg: MultiTaskTranscriptionConfig):
"""
Expand Down
14 changes: 0 additions & 14 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,20 +665,6 @@ def test_dataloader(self):

""" Transcription related methods """

def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig):
super()._transcribe_on_begin(audio, trcfg)

# Freeze the encoder and decoder modules
self.encoder.freeze()
self.decoder.freeze()

def _transcribe_on_end(self, trcfg: TranscribeConfig):
super()._transcribe_on_end(trcfg)

# Unfreeze the encoder and decoder modules
self.encoder.unfreeze()
self.decoder.unfreeze()

def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig):
logits, logits_len, greedy_predictions = self.forward(input_signal=batch[0], input_signal_length=batch[1])
output = dict(logits=logits, logits_len=logits_len)
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig):
super()._transcribe_on_end(trcfg)

if hasattr(self, 'ctc_decoder'):
self.ctc_decoder.unfreeze()
self.ctc_decoder.unfreeze(partial=True)

def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig):
if self.cur_decoder == "rnnt":
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/transformer_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,4 +633,4 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig):
super()._transcribe_on_end(trcfg)

# Unfreeze the encoder and decoder modules
self.transf_decoder.unfreeze()
self.transf_decoder.unfreeze(partial=True)
6 changes: 3 additions & 3 deletions nemo/collections/asr/parts/mixins/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,13 +770,13 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig):

# Unfreeze the encoder and decoder modules
if hasattr(self, 'encoder'):
self.encoder.unfreeze()
self.encoder.unfreeze(partial=True)

if hasattr(self, 'decoder'):
self.decoder.unfreeze()
self.decoder.unfreeze(partial=True)

if hasattr(self, 'joint'):
self.joint.unfreeze()
self.joint.unfreeze(partial=True)

@classmethod
def get_transcribe_config(cls) -> TranscribeConfig:
Expand Down
59 changes: 47 additions & 12 deletions nemo/core/classes/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.nn import Module

from nemo.core.classes.common import FileIO, Serialization, Typing
from nemo.utils import logging

__all__ = ['NeuralModule']

Expand Down Expand Up @@ -54,18 +55,59 @@ def input_example(self, max_batch=None, max_dim=None):
def freeze(self) -> None:
r"""
Freeze all params for inference.

This method sets `requires_grad` to False for all parameters of the module.
It also stores the original `requires_grad` state of each parameter in a dictionary,
so that `unfreeze()` can restore the original state if `partial=True` is set in `unfreeze()`.
"""
for param in self.parameters():
grad_map = {}

for pname, param in self.named_parameters():
# Store the original grad state
grad_map[pname] = param.requires_grad
# Freeze the parameter
param.requires_grad = False

# Store the frozen grad map
if not hasattr(self, '_frozen_grad_map'):
self._frozen_grad_map = grad_map
else:
self._frozen_grad_map.update(grad_map)

self.eval()

def unfreeze(self) -> None:
def unfreeze(self, partial: bool = False) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the name partial: my understanding is that this is intended as only_adapters=False but you don't want to name it as such to maintain separation of concerns between adapter and freeze/unfreeze APIs. I think the current name is a bit confusing when read out of the context of this PR; maybe we could at least extend the doc to explain why/when you might want to use this option?

"""
Unfreeze all parameters for training.

Args:
partial: If True, only unfreeze parameters that were previously frozen. If the parameter was already frozen
when calling `freeze()`, it will remain frozen after calling `unfreeze(partial=True)`.
"""
for param in self.parameters():
param.requires_grad = True
if partial and not hasattr(self, '_frozen_grad_map'):
raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`")

for pname, param in self.named_parameters():
if not partial:
# Unfreeze all parameters
param.requires_grad = True
else:
# Unfreeze only parameters that were previously frozen

# Check if the parameter was frozen
if pname in self._frozen_grad_map:
param.requires_grad = self._frozen_grad_map[pname]
else:
# Log a warning if the parameter was not found in the frozen grad map
logging.warning(
f"Parameter {pname} not found in list of previously frozen parameters. "
f"Unfreezing this parameter."
)
param.requires_grad = True

# Clean up the frozen grad map
if hasattr(self, '_frozen_grad_map'):
delattr(self, '_frozen_grad_map')

self.train()

Expand All @@ -75,18 +117,11 @@ def as_frozen(self):
Context manager which temporarily freezes a module, yields control and finally unfreezes the module.
"""
training_mode = self.training
grad_map = {}
for pname, param in self.named_parameters():
grad_map[pname] = param.requires_grad

self.freeze()
try:
yield
finally:
self.unfreeze()

for pname, param in self.named_parameters():
param.requires_grad = grad_map[pname]
self.unfreeze(partial=True)

if training_mode:
self.train()
Expand Down
89 changes: 89 additions & 0 deletions tests/core/test_neural_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'os' is not used.
import tempfile

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'tempfile' is not used.

import pytest
import torch

from nemo.core.classes.module import NeuralModule


class TempModule(NeuralModule):

def __init__(self):
super().__init__()

self.layer1 = torch.nn.Linear(10, 10, bias=False)
self.layer2 = torch.nn.Linear(10, 10, bias=False)


class TestNeuralModule:

@pytest.mark.unit
def test_num_weights(self):
module = TempModule()
assert module.num_weights == 200

@pytest.mark.unit
def test_freeze(self):
module = TempModule()
module.freeze()
for p in module.parameters():
assert not p.requires_grad

@pytest.mark.unit
def test_unfreeze(self):
module = TempModule()
module.freeze()
module.unfreeze()
for p in module.parameters():
assert p.requires_grad

@pytest.mark.unit
def test_as_frozen(self):
module = TempModule()

for p in module.parameters():
assert p.requires_grad

with module.as_frozen():
for p in module.parameters():
assert not p.requires_grad

for p in module.parameters():
assert p.requires_grad

@pytest.mark.unit
def test_partial_unfreeze(self):
module = TempModule()

for param in module.layer1.parameters():
param.requires_grad = False

module.freeze()

for param in module.layer1.parameters():
assert not param.requires_grad

assert module._frozen_grad_map is not None
assert len(module._frozen_grad_map) == 2
assert module._frozen_grad_map['layer1.weight'] is False

module.unfreeze(partial=True)

# layer1 should still be frozen due to partial unfreeze
assert module.layer1.weight.requires_grad is False
assert not hasattr(module, '_frozen_grad_map')
Loading