Skip to content

Commit

Permalink
Merge pull request #13 from LukasHedegaard/develop
Browse files Browse the repository at this point in the history
Flattened state_dict export and import
  • Loading branch information
LukasHedegaard authored Aug 23, 2021
2 parents dba1504 + 34fcc0e commit 25b624e
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 10 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.6.0]
## Added
- Flattened state dict export and loading via a `flatten` argument. This feature improves interoperability complex modules, that were not originally constructed with the `co.Sequential` and `co.Parallel` building blocks.
- Context manager for triggering flattened state_dict export and loading.


## [0.5.0]
## Added
Expand Down
8 changes: 1 addition & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,6 @@ In addition, we support interoperability with a wide range of modules from `torc
</div>

```python3
import continual as co
from torch import nn

mb_conv = co.Residual(
co.Sequential(
co.Conv3d(32, 64, kernel_size=(1, 1, 1)),
Expand All @@ -288,9 +285,6 @@ mb_conv = co.Residual(
</div>

```python3
import continual as co
from torch import nn

def norm_relu(module, channels):
return co.Sequential(
module,
Expand All @@ -317,7 +311,7 @@ inception_module = co.Parallel(
```


## Continual 3D [Squeeze-and-Excitation module](https://arxiv.org/pdf/1709.01507.pdf)
### Continual 3D [Squeeze-and-Excitation module](https://arxiv.org/pdf/1709.01507.pdf)

<div align="center">
<img src="https://raw.githubusercontent.com/LukasHedegaard/continual-inference/main/figures/examples/se_block.png" width="230">
Expand Down
1 change: 1 addition & 0 deletions continual/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
MaxPool3d,
)
from .ptflops import _register_ptflops # noqa: F401
from .utils import flat_state_dict # noqa: F401
46 changes: 44 additions & 2 deletions continual/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,49 @@ def int_from(tuple_or_int: Union[int, Tuple[int, ...]], dim=0) -> int:
return tuple_or_int[dim]


class Sequential(nn.Sequential, Padded, CoModule):
class FlattenableStateDict:
"""Mixes in the ability to flatten state dicts.
It is assumed that classes that inherit this modlue also inherit from nn.Module
"""

def state_dict(
self, destination=None, prefix="", keep_vars=False, flatten=False
) -> "OrderedDict[str, Tensor]":
d = nn.Module.state_dict(self, destination, prefix, keep_vars)
from continual.utils import flat_state_dict

if flatten or flat_state_dict.flatten:
flat_keys = [
".".join(part for part in name.split(".") if not part.isdigit())
for name in list(d.keys())
]
if len(set(flat_keys)) == len(d.keys()):
d = OrderedDict(list(zip(flat_keys, d.values())))

return d

def load_state_dict(
self,
state_dict: "OrderedDict[str, Tensor]",
strict: bool = True,
flatten=False,
):
from continual.utils import flat_state_dict

if flatten or flat_state_dict.flatten:
long_keys = nn.Module.state_dict(self, keep_vars=True).keys()
short2long = {
".".join(part for part in key.split(".") if not part.isdigit()): key
for key in list(long_keys)
}
state_dict = OrderedDict(
[(short2long[key], val) for key, val in state_dict.items()]
)

nn.Module.load_state_dict(self, state_dict, strict)


class Sequential(FlattenableStateDict, nn.Sequential, Padded, CoModule):
"""Continual Sequential module
This module is a drop-in replacement for `torch.nn.Sequential`
Expand Down Expand Up @@ -129,7 +171,7 @@ def parallel_mul(inputs: Sequence[Tensor]) -> Tensor:
AggregationFunc = Union[Aggregation, Callable[[Sequence[Tensor]], Tensor]]


class Parallel(nn.Sequential, Padded, CoModule):
class Parallel(FlattenableStateDict, nn.Sequential, Padded, CoModule):
"""Continual parallel container.
Args:
Expand Down
26 changes: 26 additions & 0 deletions continual/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,29 @@ def _getattr(obj, attr):
return getattr(obj, attr, *args)

return reduce(_getattr, [obj] + attr.split("."))


class _FlatStateDict(object):
"""Context-manager state holder."""

def __init__(self):
self.flatten = False

def __enter__(self):
self.flatten = True

def __exit__(self, *args, **kwargs):
self.flatten = False


flat_state_dict = _FlatStateDict()
"""Context-manager that flattens the state dict of containers.
If a container module was not explicitely named by means of an OrderedDict,
it will attempt to flatten the keys during both the `state_dict` and `load_state_dict` operations.
Example:
>>> with co.flat_state_dict:
>>> sd = module.state_dict() # All unnamed nested keys are flattened, e.g. "0.weight" -> "weight"
>>> module.load_state_dict(sd) # Automatically unflattened during loading "weight" -> "0.weight"
"""
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def from_file(file_name: str = "requirements.txt", comment_char: str = "#"):

setup(
name="continual-inference",
version="0.5.0",
version="0.6.0",
description="Building blocks for Continual Inference Networks in PyTorch",
long_description=long_description(),
long_description_content_type="text/markdown",
Expand Down
62 changes: 62 additions & 0 deletions tests/continual/test_container.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import OrderedDict

import torch
from torch import nn

Expand Down Expand Up @@ -191,3 +193,63 @@ def test_parallel():
par.clean_state()
out_steps = par.forward_steps(input, pad_end=True)
assert torch.allclose(out_steps, out_all)


def test_flat_state_dict():
# >> Part 1: Save both flat and original state dicts

# If modules are not named, it can be flattened
seq_to_flatten = co.Sequential(nn.Conv1d(1, 1, 3))

sd = seq_to_flatten.state_dict()
assert set(sd) == {"0.weight", "0.bias"}

sd_flat = seq_to_flatten.state_dict(flatten=True)
assert set(sd_flat) == {"weight", "bias"}

seq_not_to_flatten = co.Sequential(OrderedDict([("c1", nn.Conv1d(1, 1, 3))]))
sd_no_flat = seq_not_to_flatten.state_dict(flatten=True)
assert set(sd_no_flat) == {"c1.weight", "c1.bias"}

# A nested example:
nested = co.Parallel(seq_to_flatten, seq_not_to_flatten)
sd = nested.state_dict()
assert set(sd) == {"0.0.weight", "0.0.bias", "1.c1.weight", "1.c1.bias"}

sd_flat = nested.state_dict(flatten=True)
assert set(sd_flat) == {"weight", "bias", "c1.weight", "c1.bias"}

# >> Part 2: Load flat state dict
nested_new = co.Parallel(
co.Sequential(nn.Conv1d(1, 1, 3)),
co.Sequential(OrderedDict([("c1", nn.Conv1d(1, 1, 3))])),
)

assert not torch.equal(nested[0][0].weight, nested_new[0][0].weight)
assert not torch.equal(nested[0][0].bias, nested_new[0][0].bias)
assert not torch.equal(nested[1].c1.weight, nested_new[1].c1.weight)
assert not torch.equal(nested[1].c1.bias, nested_new[1].c1.bias)

nested_new.load_state_dict(sd_flat, flatten=True)

assert torch.equal(nested[0][0].weight, nested_new[0][0].weight)
assert torch.equal(nested[0][0].bias, nested_new[0][0].bias)
assert torch.equal(nested[1].c1.weight, nested_new[1].c1.weight)
assert torch.equal(nested[1].c1.bias, nested_new[1].c1.bias)

# >> Part 3: Test context manager
with co.utils.flat_state_dict:
# Export works as above despite `flatten=False`
sd_flat2 = nested.state_dict(flatten=False)
assert sd_flat.keys() == sd_flat2.keys()
assert all(torch.equal(sd_flat[key], sd_flat2[key]) for key in sd_flat.keys())

# Loading works as above despite `flatten=False`
nested_new.load_state_dict(sd_flat, flatten=False)

assert torch.equal(nested[0][0].weight, nested_new[0][0].weight)
assert torch.equal(nested[0][0].bias, nested_new[0][0].bias)
assert torch.equal(nested[1].c1.weight, nested_new[1].c1.weight)
assert torch.equal(nested[1].c1.bias, nested_new[1].c1.bias)

assert True # Need to step down here to trigger context manager __exit__

0 comments on commit 25b624e

Please sign in to comment.