diff --git a/nemo/backends/pytorch/nm.py b/nemo/backends/pytorch/nm.py index 734ab9d60b50..07576fe3e6bd 100644 --- a/nemo/backends/pytorch/nm.py +++ b/nemo/backends/pytorch/nm.py @@ -6,7 +6,7 @@ import torch as t import torch.nn as nn -from nemo.core import DeviceType, NeuralModule, WeightShareTransform +from nemo.core import DeviceType, ModuleType, NeuralModule, WeightShareTransform from nemo.utils.helpers import get_cuda_device, rgetattr, rsetattr @@ -38,6 +38,9 @@ def __init__(self, pretrained_model_name=None, name=None): NeuralModule.__init__(self, name) # For NeuralModule API nn.Module.__init__(self) # For PyTorch API + # Set module type. + self._type = ModuleType.trainable + self._device = get_cuda_device(self.placement) # Store pretrained model name (to be removed/changed) @@ -132,6 +135,8 @@ class NonTrainableNM(NeuralModule): def __init__(self, name=None): NeuralModule.__init__(self, name) # For NeuralModule API self._device = get_cuda_device(self.placement) + # Set module type. + self._type = ModuleType.nontrainable def __call__(self, force_pt=False, *input, **kwargs): pt_call = len(input) > 0 or force_pt @@ -191,6 +196,10 @@ class DataLayerNM(NeuralModule): def __init__(self, name=None): NeuralModule.__init__(self, name) # For NeuralModule API + + # Set module type. + self._type = ModuleType.datalayer + self._device = get_cuda_device(self.placement) # if 'batch_size' not in kwargs: @@ -326,6 +335,10 @@ class LossNM(NeuralModule): def __init__(self, name=None): NeuralModule.__init__(self, name) # For NeuralModule API + + # Set module type. + self._type = ModuleType.loss + self._device = get_cuda_device(self.placement) def get_weights(self): diff --git a/nemo/core/neural_graph.py b/nemo/core/neural_graph.py index 850f2879eefa..7df04db6c34d 100644 --- a/nemo/core/neural_graph.py +++ b/nemo/core/neural_graph.py @@ -28,7 +28,7 @@ from nemo.core import OperationMode from nemo.core.neural_interface import NeuralInterface -from nemo.core.neural_modules import NeuralModule +from nemo.core.neural_modules import ModuleType, NeuralModule from nemo.core.neural_types import NeuralPortNameMismatchError, NeuralType, NmTensor from nemo.package_info import __version__ as nemo_version from nemo.utils import logging @@ -920,3 +920,45 @@ def summary(self) -> str: # Return the result. return desc + + def freeze(self, module_names: Optional[List[str]] = None): + """ + A method that freezes the weights of the trainable modules in a graph. + Args: + module_names: List of modules to be frozen (Optional). If passed, all modules will be unfrozen. + Raises: + KeyError: If name of the module won't be recognized. + """ + if module_names is None: + # Work on all modules. + module_names = self._modules.keys() + # Iterate through modules one by one. + for name in module_names: + if name not in self._modules.keys(): + raise KeyError("Module `{}` not present in the `{}` graph".format(name, self.name)) + # Check module type. + module = self._modules[name] + if module.type == ModuleType.trainable: + # Freeze weights of the module + module.freeze() + + def unfreeze(self, module_names: Optional[List[str]] = None): + """ + Unfreezes weights of the trainable modules in a graph. + Args: + module_names: List of modules to be unfrozen (Optional). If not passed, all modules will be unfrozen. + Raises: + KeyError: If name of the module won't be recognized. + """ + if module_names is None: + # Work on all modules. + module_names = self._modules.keys() + # Iterate through modules one by one. + for name in module_names: + if name not in self._modules.keys(): + raise KeyError("Module `{}` not present in the `{}` graph".format(name, self.name)) + # Check module type. + module = self._modules[name] + if module.type == ModuleType.trainable: + # Unfreeze weights of the module. + module.unfreeze() diff --git a/nemo/core/neural_modules.py b/nemo/core/neural_modules.py index 24abfb46264d..23bad2e35a31 100644 --- a/nemo/core/neural_modules.py +++ b/nemo/core/neural_modules.py @@ -15,8 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This file contains NeuralModule and NmTensor classes.""" -__all__ = ['WeightShareTransform', 'NeuralModule'] +__all__ = ['WeightShareTransform', 'NeuralModule', 'ModuleType'] import uuid from abc import abstractmethod @@ -39,6 +38,16 @@ YAML = YAML(typ='safe') +class ModuleType(Enum): + """ Back-end independent module types """ + + module = 0 + datalayer = 1 + trainable = 2 + loss = 3 + nontrainable = 4 + + class WeightShareTransform(Enum): """When sharing parameters, what kind of transform to apply.""" @@ -69,6 +78,9 @@ def __init__(self, name=None): # Register module and store the generated name. self._name = self._app_state.register_module(self, name) + # Set "module" type as default. + self._type = ModuleType.module + # Set "both" as default operation mode. self._operation_mode = OperationMode.both @@ -478,15 +490,6 @@ def _deserialize_configuration(cls, serialized_init_params: Dict[str, Any]): # In this case configuration = init parameters. return serialized_init_params - @deprecated(version=0.11) - @staticmethod - def create_ports(**kwargs): - """ Deprecated method, to be remoted in the next release.""" - raise Exception( - 'Deprecated method. Please implement ``inputs`` and ``outputs`` \ - properties to define module ports instead' - ) - @property @abstractmethod def input_ports(self) -> Dict[str, NeuralType]: @@ -534,6 +537,11 @@ def operation_mode(self): """ Returns the operation mode. """ return self._operation_mode + @property + def type(self): + """ Returns the type of module. """ + return self._type + @operation_mode.setter def operation_mode(self, operation_mode: OperationMode): """ Sets the operation mode. """