diff --git a/CHANGELOG.md b/CHANGELOG.md index ca6f3c3..f142207 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.6.0] +## Added +- Flattened state dict export and loading via a `flatten` argument. This feature improves interoperability complex modules, that were not originally constructed with the `co.Sequential` and `co.Parallel` building blocks. +- Context manager for triggering flattened state_dict export and loading. + ## [0.5.0] ## Added diff --git a/README.md b/README.md index 457c687..3ad2b2a 100644 --- a/README.md +++ b/README.md @@ -263,9 +263,6 @@ In addition, we support interoperability with a wide range of modules from `torc ```python3 -import continual as co -from torch import nn - mb_conv = co.Residual( co.Sequential( co.Conv3d(32, 64, kernel_size=(1, 1, 1)), @@ -288,9 +285,6 @@ mb_conv = co.Residual( ```python3 -import continual as co -from torch import nn - def norm_relu(module, channels): return co.Sequential( module, @@ -317,7 +311,7 @@ inception_module = co.Parallel( ``` -## Continual 3D [Squeeze-and-Excitation module](https://arxiv.org/pdf/1709.01507.pdf) +### Continual 3D [Squeeze-and-Excitation module](https://arxiv.org/pdf/1709.01507.pdf)
diff --git a/continual/__init__.py b/continual/__init__.py index 018d4a1..0db20d6 100644 --- a/continual/__init__.py +++ b/continual/__init__.py @@ -16,3 +16,4 @@ MaxPool3d, ) from .ptflops import _register_ptflops # noqa: F401 +from .utils import flat_state_dict # noqa: F401 diff --git a/continual/container.py b/continual/container.py index 2c981b2..d9d8ce8 100644 --- a/continual/container.py +++ b/continual/container.py @@ -19,7 +19,49 @@ def int_from(tuple_or_int: Union[int, Tuple[int, ...]], dim=0) -> int: return tuple_or_int[dim] -class Sequential(nn.Sequential, Padded, CoModule): +class FlattenableStateDict: + """Mixes in the ability to flatten state dicts. + It is assumed that classes that inherit this modlue also inherit from nn.Module + """ + + def state_dict( + self, destination=None, prefix="", keep_vars=False, flatten=False + ) -> "OrderedDict[str, Tensor]": + d = nn.Module.state_dict(self, destination, prefix, keep_vars) + from continual.utils import flat_state_dict + + if flatten or flat_state_dict.flatten: + flat_keys = [ + ".".join(part for part in name.split(".") if not part.isdigit()) + for name in list(d.keys()) + ] + if len(set(flat_keys)) == len(d.keys()): + d = OrderedDict(list(zip(flat_keys, d.values()))) + + return d + + def load_state_dict( + self, + state_dict: "OrderedDict[str, Tensor]", + strict: bool = True, + flatten=False, + ): + from continual.utils import flat_state_dict + + if flatten or flat_state_dict.flatten: + long_keys = nn.Module.state_dict(self, keep_vars=True).keys() + short2long = { + ".".join(part for part in key.split(".") if not part.isdigit()): key + for key in list(long_keys) + } + state_dict = OrderedDict( + [(short2long[key], val) for key, val in state_dict.items()] + ) + + nn.Module.load_state_dict(self, state_dict, strict) + + +class Sequential(FlattenableStateDict, nn.Sequential, Padded, CoModule): """Continual Sequential module This module is a drop-in replacement for `torch.nn.Sequential` @@ -129,7 +171,7 @@ def parallel_mul(inputs: Sequence[Tensor]) -> Tensor: AggregationFunc = Union[Aggregation, Callable[[Sequence[Tensor]], Tensor]] -class Parallel(nn.Sequential, Padded, CoModule): +class Parallel(FlattenableStateDict, nn.Sequential, Padded, CoModule): """Continual parallel container. Args: diff --git a/continual/utils.py b/continual/utils.py index 1145a03..88df873 100644 --- a/continual/utils.py +++ b/continual/utils.py @@ -20,3 +20,29 @@ def _getattr(obj, attr): return getattr(obj, attr, *args) return reduce(_getattr, [obj] + attr.split(".")) + + +class _FlatStateDict(object): + """Context-manager state holder.""" + + def __init__(self): + self.flatten = False + + def __enter__(self): + self.flatten = True + + def __exit__(self, *args, **kwargs): + self.flatten = False + + +flat_state_dict = _FlatStateDict() +"""Context-manager that flattens the state dict of containers. + +If a container module was not explicitely named by means of an OrderedDict, +it will attempt to flatten the keys during both the `state_dict` and `load_state_dict` operations. + +Example: + >>> with co.flat_state_dict: + >>> sd = module.state_dict() # All unnamed nested keys are flattened, e.g. "0.weight" -> "weight" + >>> module.load_state_dict(sd) # Automatically unflattened during loading "weight" -> "0.weight" +""" diff --git a/setup.py b/setup.py index cf01a95..14b6ff3 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ def from_file(file_name: str = "requirements.txt", comment_char: str = "#"): setup( name="continual-inference", - version="0.5.0", + version="0.6.0", description="Building blocks for Continual Inference Networks in PyTorch", long_description=long_description(), long_description_content_type="text/markdown", diff --git a/tests/continual/test_container.py b/tests/continual/test_container.py index c313140..0e7649b 100644 --- a/tests/continual/test_container.py +++ b/tests/continual/test_container.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + import torch from torch import nn @@ -191,3 +193,63 @@ def test_parallel(): par.clean_state() out_steps = par.forward_steps(input, pad_end=True) assert torch.allclose(out_steps, out_all) + + +def test_flat_state_dict(): + # >> Part 1: Save both flat and original state dicts + + # If modules are not named, it can be flattened + seq_to_flatten = co.Sequential(nn.Conv1d(1, 1, 3)) + + sd = seq_to_flatten.state_dict() + assert set(sd) == {"0.weight", "0.bias"} + + sd_flat = seq_to_flatten.state_dict(flatten=True) + assert set(sd_flat) == {"weight", "bias"} + + seq_not_to_flatten = co.Sequential(OrderedDict([("c1", nn.Conv1d(1, 1, 3))])) + sd_no_flat = seq_not_to_flatten.state_dict(flatten=True) + assert set(sd_no_flat) == {"c1.weight", "c1.bias"} + + # A nested example: + nested = co.Parallel(seq_to_flatten, seq_not_to_flatten) + sd = nested.state_dict() + assert set(sd) == {"0.0.weight", "0.0.bias", "1.c1.weight", "1.c1.bias"} + + sd_flat = nested.state_dict(flatten=True) + assert set(sd_flat) == {"weight", "bias", "c1.weight", "c1.bias"} + + # >> Part 2: Load flat state dict + nested_new = co.Parallel( + co.Sequential(nn.Conv1d(1, 1, 3)), + co.Sequential(OrderedDict([("c1", nn.Conv1d(1, 1, 3))])), + ) + + assert not torch.equal(nested[0][0].weight, nested_new[0][0].weight) + assert not torch.equal(nested[0][0].bias, nested_new[0][0].bias) + assert not torch.equal(nested[1].c1.weight, nested_new[1].c1.weight) + assert not torch.equal(nested[1].c1.bias, nested_new[1].c1.bias) + + nested_new.load_state_dict(sd_flat, flatten=True) + + assert torch.equal(nested[0][0].weight, nested_new[0][0].weight) + assert torch.equal(nested[0][0].bias, nested_new[0][0].bias) + assert torch.equal(nested[1].c1.weight, nested_new[1].c1.weight) + assert torch.equal(nested[1].c1.bias, nested_new[1].c1.bias) + + # >> Part 3: Test context manager + with co.utils.flat_state_dict: + # Export works as above despite `flatten=False` + sd_flat2 = nested.state_dict(flatten=False) + assert sd_flat.keys() == sd_flat2.keys() + assert all(torch.equal(sd_flat[key], sd_flat2[key]) for key in sd_flat.keys()) + + # Loading works as above despite `flatten=False` + nested_new.load_state_dict(sd_flat, flatten=False) + + assert torch.equal(nested[0][0].weight, nested_new[0][0].weight) + assert torch.equal(nested[0][0].bias, nested_new[0][0].bias) + assert torch.equal(nested[1].c1.weight, nested_new[1].c1.weight) + assert torch.equal(nested[1].c1.bias, nested_new[1].c1.bias) + + assert True # Need to step down here to trigger context manager __exit__