Skip to content

Commit

Permalink
nemo-ux-state: handle None in state_dict.keys; disable auto-grad when…
Browse files Browse the repository at this point in the history
… transforming ckpt

Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
akoumpa committed Jul 26, 2024
1 parent 917f715 commit 209a241
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions nemo/lightning/io/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
from torch import nn
import torch

SourceModuleT = TypeVar("SourceModuleT", bound=nn.Module)
TargetModuleT = TypeVar("TargetModuleT", bound=nn.Module)
Expand All @@ -18,12 +19,12 @@ class TransformCTX:
target: nn.Module
target_state: dict


@torch.no_grad
def apply_transforms(
source: nn.Module,
target: TargetModuleT,
mapping: Dict[str, str],
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = None,
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = [],
) -> TargetModuleT:
"""
Applies a series of transformations to adapt the state dictionary of a source module to
Expand Down Expand Up @@ -101,9 +102,8 @@ def scale_weights(ctx):
for key, val in mapping.items():
ctx = StateDictTransform(key, val)(ctx)

if transforms:
for transform in transforms:
ctx = transform(ctx)
for transform in transforms:
ctx = transform(ctx)

_params: Dict[str, nn.Parameter] = {}
for name, param in _target.named_parameters():
Expand Down Expand Up @@ -144,9 +144,9 @@ def scale_weights(ctx):

_module.register_buffer(_key, val)

keys = [name for name in list(target_state.keys()) if not name.endswith("_extra_state")]
keys = list(filter(lambda x: x is not None and not x.endswith("_extra_state"), target_state.keys()))
if len(keys) != 0:
raise RuntimeError(f"Additional keys: {target_state.keys()} in checkpoint but not in model.")
raise RuntimeError(f"Additional keys: {keys} in checkpoint but not in model.")

# TODO: Is this correct?
# for key in target.state_dict():
Expand All @@ -165,7 +165,7 @@ def scale_weights(ctx):


def _default_transform(inp):
return inp.float()
return inp


class StateDictTransform(Generic[F]):
Expand Down Expand Up @@ -325,6 +325,8 @@ def _match_keys(keys: List[str], pattern: str) -> np.ndarray:
wildcard_matches = [[] for _ in range(pattern.count("*"))]

for key in keys:
if key is None:
continue
match = regex_pattern.match(key)
if match:
for i, group in enumerate(match.groups()):
Expand All @@ -343,6 +345,8 @@ def _match_keys(keys: List[str], pattern: str) -> np.ndarray:

# Populate the array with the keys, now that we have the correct shape and ordering
for key in keys:
if key is None:
continue
match = regex_pattern.match(key)
if match:
# Convert match groups to indices based on their position in wildcard_matches
Expand Down

0 comments on commit 209a241

Please sign in to comment.