diff --git a/CHANGELOG.md b/CHANGELOG.md
index ca6f3c3..f142207 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
+## [0.6.0]
+## Added
+- Flattened state dict export and loading via a `flatten` argument. This feature improves interoperability complex modules, that were not originally constructed with the `co.Sequential` and `co.Parallel` building blocks.
+- Context manager for triggering flattened state_dict export and loading.
+
## [0.5.0]
## Added
diff --git a/README.md b/README.md
index 457c687..3ad2b2a 100644
--- a/README.md
+++ b/README.md
@@ -263,9 +263,6 @@ In addition, we support interoperability with a wide range of modules from `torc
```python3
-import continual as co
-from torch import nn
-
mb_conv = co.Residual(
co.Sequential(
co.Conv3d(32, 64, kernel_size=(1, 1, 1)),
@@ -288,9 +285,6 @@ mb_conv = co.Residual(
```python3
-import continual as co
-from torch import nn
-
def norm_relu(module, channels):
return co.Sequential(
module,
@@ -317,7 +311,7 @@ inception_module = co.Parallel(
```
-## Continual 3D [Squeeze-and-Excitation module](https://arxiv.org/pdf/1709.01507.pdf)
+### Continual 3D [Squeeze-and-Excitation module](https://arxiv.org/pdf/1709.01507.pdf)
diff --git a/continual/__init__.py b/continual/__init__.py
index 018d4a1..0db20d6 100644
--- a/continual/__init__.py
+++ b/continual/__init__.py
@@ -16,3 +16,4 @@
MaxPool3d,
)
from .ptflops import _register_ptflops # noqa: F401
+from .utils import flat_state_dict # noqa: F401
diff --git a/continual/container.py b/continual/container.py
index 2c981b2..d9d8ce8 100644
--- a/continual/container.py
+++ b/continual/container.py
@@ -19,7 +19,49 @@ def int_from(tuple_or_int: Union[int, Tuple[int, ...]], dim=0) -> int:
return tuple_or_int[dim]
-class Sequential(nn.Sequential, Padded, CoModule):
+class FlattenableStateDict:
+ """Mixes in the ability to flatten state dicts.
+ It is assumed that classes that inherit this modlue also inherit from nn.Module
+ """
+
+ def state_dict(
+ self, destination=None, prefix="", keep_vars=False, flatten=False
+ ) -> "OrderedDict[str, Tensor]":
+ d = nn.Module.state_dict(self, destination, prefix, keep_vars)
+ from continual.utils import flat_state_dict
+
+ if flatten or flat_state_dict.flatten:
+ flat_keys = [
+ ".".join(part for part in name.split(".") if not part.isdigit())
+ for name in list(d.keys())
+ ]
+ if len(set(flat_keys)) == len(d.keys()):
+ d = OrderedDict(list(zip(flat_keys, d.values())))
+
+ return d
+
+ def load_state_dict(
+ self,
+ state_dict: "OrderedDict[str, Tensor]",
+ strict: bool = True,
+ flatten=False,
+ ):
+ from continual.utils import flat_state_dict
+
+ if flatten or flat_state_dict.flatten:
+ long_keys = nn.Module.state_dict(self, keep_vars=True).keys()
+ short2long = {
+ ".".join(part for part in key.split(".") if not part.isdigit()): key
+ for key in list(long_keys)
+ }
+ state_dict = OrderedDict(
+ [(short2long[key], val) for key, val in state_dict.items()]
+ )
+
+ nn.Module.load_state_dict(self, state_dict, strict)
+
+
+class Sequential(FlattenableStateDict, nn.Sequential, Padded, CoModule):
"""Continual Sequential module
This module is a drop-in replacement for `torch.nn.Sequential`
@@ -129,7 +171,7 @@ def parallel_mul(inputs: Sequence[Tensor]) -> Tensor:
AggregationFunc = Union[Aggregation, Callable[[Sequence[Tensor]], Tensor]]
-class Parallel(nn.Sequential, Padded, CoModule):
+class Parallel(FlattenableStateDict, nn.Sequential, Padded, CoModule):
"""Continual parallel container.
Args:
diff --git a/continual/utils.py b/continual/utils.py
index 1145a03..88df873 100644
--- a/continual/utils.py
+++ b/continual/utils.py
@@ -20,3 +20,29 @@ def _getattr(obj, attr):
return getattr(obj, attr, *args)
return reduce(_getattr, [obj] + attr.split("."))
+
+
+class _FlatStateDict(object):
+ """Context-manager state holder."""
+
+ def __init__(self):
+ self.flatten = False
+
+ def __enter__(self):
+ self.flatten = True
+
+ def __exit__(self, *args, **kwargs):
+ self.flatten = False
+
+
+flat_state_dict = _FlatStateDict()
+"""Context-manager that flattens the state dict of containers.
+
+If a container module was not explicitely named by means of an OrderedDict,
+it will attempt to flatten the keys during both the `state_dict` and `load_state_dict` operations.
+
+Example:
+ >>> with co.flat_state_dict:
+ >>> sd = module.state_dict() # All unnamed nested keys are flattened, e.g. "0.weight" -> "weight"
+ >>> module.load_state_dict(sd) # Automatically unflattened during loading "weight" -> "0.weight"
+"""
diff --git a/setup.py b/setup.py
index cf01a95..14b6ff3 100644
--- a/setup.py
+++ b/setup.py
@@ -29,7 +29,7 @@ def from_file(file_name: str = "requirements.txt", comment_char: str = "#"):
setup(
name="continual-inference",
- version="0.5.0",
+ version="0.6.0",
description="Building blocks for Continual Inference Networks in PyTorch",
long_description=long_description(),
long_description_content_type="text/markdown",
diff --git a/tests/continual/test_container.py b/tests/continual/test_container.py
index c313140..0e7649b 100644
--- a/tests/continual/test_container.py
+++ b/tests/continual/test_container.py
@@ -1,3 +1,5 @@
+from collections import OrderedDict
+
import torch
from torch import nn
@@ -191,3 +193,63 @@ def test_parallel():
par.clean_state()
out_steps = par.forward_steps(input, pad_end=True)
assert torch.allclose(out_steps, out_all)
+
+
+def test_flat_state_dict():
+ # >> Part 1: Save both flat and original state dicts
+
+ # If modules are not named, it can be flattened
+ seq_to_flatten = co.Sequential(nn.Conv1d(1, 1, 3))
+
+ sd = seq_to_flatten.state_dict()
+ assert set(sd) == {"0.weight", "0.bias"}
+
+ sd_flat = seq_to_flatten.state_dict(flatten=True)
+ assert set(sd_flat) == {"weight", "bias"}
+
+ seq_not_to_flatten = co.Sequential(OrderedDict([("c1", nn.Conv1d(1, 1, 3))]))
+ sd_no_flat = seq_not_to_flatten.state_dict(flatten=True)
+ assert set(sd_no_flat) == {"c1.weight", "c1.bias"}
+
+ # A nested example:
+ nested = co.Parallel(seq_to_flatten, seq_not_to_flatten)
+ sd = nested.state_dict()
+ assert set(sd) == {"0.0.weight", "0.0.bias", "1.c1.weight", "1.c1.bias"}
+
+ sd_flat = nested.state_dict(flatten=True)
+ assert set(sd_flat) == {"weight", "bias", "c1.weight", "c1.bias"}
+
+ # >> Part 2: Load flat state dict
+ nested_new = co.Parallel(
+ co.Sequential(nn.Conv1d(1, 1, 3)),
+ co.Sequential(OrderedDict([("c1", nn.Conv1d(1, 1, 3))])),
+ )
+
+ assert not torch.equal(nested[0][0].weight, nested_new[0][0].weight)
+ assert not torch.equal(nested[0][0].bias, nested_new[0][0].bias)
+ assert not torch.equal(nested[1].c1.weight, nested_new[1].c1.weight)
+ assert not torch.equal(nested[1].c1.bias, nested_new[1].c1.bias)
+
+ nested_new.load_state_dict(sd_flat, flatten=True)
+
+ assert torch.equal(nested[0][0].weight, nested_new[0][0].weight)
+ assert torch.equal(nested[0][0].bias, nested_new[0][0].bias)
+ assert torch.equal(nested[1].c1.weight, nested_new[1].c1.weight)
+ assert torch.equal(nested[1].c1.bias, nested_new[1].c1.bias)
+
+ # >> Part 3: Test context manager
+ with co.utils.flat_state_dict:
+ # Export works as above despite `flatten=False`
+ sd_flat2 = nested.state_dict(flatten=False)
+ assert sd_flat.keys() == sd_flat2.keys()
+ assert all(torch.equal(sd_flat[key], sd_flat2[key]) for key in sd_flat.keys())
+
+ # Loading works as above despite `flatten=False`
+ nested_new.load_state_dict(sd_flat, flatten=False)
+
+ assert torch.equal(nested[0][0].weight, nested_new[0][0].weight)
+ assert torch.equal(nested[0][0].bias, nested_new[0][0].bias)
+ assert torch.equal(nested[1].c1.weight, nested_new[1].c1.weight)
+ assert torch.equal(nested[1].c1.bias, nested_new[1].c1.bias)
+
+ assert True # Need to step down here to trigger context manager __exit__