From 76376200a1416e6e5595a53d60ed010eefe20496 Mon Sep 17 00:00:00 2001 From: LukasHedegaard Date: Thu, 26 Aug 2021 11:25:41 +0200 Subject: [PATCH 1/6] Add Lambda, Addm and Multiply --- CHANGELOG.md | 4 ++ README.md | 5 +++ continual/__init__.py | 1 + continual/closure.py | 51 ++++++++++++++++++++++++++ continual/conditional.py | 0 tests/continual/test_closure.py | 65 +++++++++++++++++++++++++++++++++ 6 files changed, 126 insertions(+) create mode 100644 continual/closure.py create mode 100644 continual/conditional.py create mode 100644 tests/continual/test_closure.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 044a773..1c06da7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ From v1.0.0 and on, the project will adherence strictly to Semantic Versioning. ## [Unreleased] +### Added +- `co.Lambda` module. +- `co.Add` module. +- `co.Multiply` module. ## [0.8.1] diff --git a/README.md b/README.md index 05dc03f..bf5ab18 100644 --- a/README.md +++ b/README.md @@ -203,6 +203,11 @@ Below is a list of the included modules and utilities included in the library: - `co.Residual` - Residual wrapper for modules. - `co.Delay` - Pure delay module (e.g. needed in residuals). +- Functions + - `co.Lambda` - Lambda module which wraps any function. + - `co.Add` - Adds a constant value. + - `co.Multiply` - Multiplies with a constant factor. + - Converters - `co.continual` - conversion function from `torch.nn` modules to `co` modules. diff --git a/continual/__init__.py b/continual/__init__.py index b489e5e..0ede8d5 100644 --- a/continual/__init__.py +++ b/continual/__init__.py @@ -1,3 +1,4 @@ +from .closure import Add, Lambda, Multiply # noqa: F401 from .container import Parallel, Residual, Sequential # noqa: F401 from .conv import Conv1d, Conv2d, Conv3d # noqa: F401 from .convert import continual, forward_stepping # noqa: F401 diff --git a/continual/closure.py b/continual/closure.py new file mode 100644 index 0000000..d3dcafd --- /dev/null +++ b/continual/closure.py @@ -0,0 +1,51 @@ +from functools import partial +from typing import Callable, Union + +from torch import Tensor, nn + +from .module import CoModule + + +class Lambda(CoModule, nn.Module): + """Module wrapper for stateless functions""" + + def __init__(self, fn: Callable[[Tensor], Tensor]): + nn.Module.__init__(self) + assert callable(fn), "The pased function should be callable." + self.fn = fn + + def forward(self, input: Tensor) -> Tensor: + return self.fn(input) + + def forward_step(self, input: Tensor, update_state=True) -> Tensor: + x = input.unsqueeze(dim=2) + x = self.fn(x) + x = x.squeeze(dim=2) + return x + + def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor: + return self.fn(input) + + @property + def delay(self) -> int: + return 0 + + +def _multiply(x: Tensor, factor: Union[float, int, Tensor]): + return x * factor + + +def Multiply(factor) -> Lambda: + """Create Lambda with multiplication function""" + fn = partial(_multiply, factor=factor) + return Lambda(fn) + + +def _add(x: Tensor, constant: Union[float, int, Tensor]): + return x + constant + + +def Add(constant) -> Lambda: + """Create Lambda with addition function""" + fn = partial(_add, constant=constant) + return Lambda(fn) diff --git a/continual/conditional.py b/continual/conditional.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/continual/test_closure.py b/tests/continual/test_closure.py new file mode 100644 index 0000000..cb66c9c --- /dev/null +++ b/tests/continual/test_closure.py @@ -0,0 +1,65 @@ +import torch + +from continual.closure import Add, Lambda, Multiply + + +def test_add(): + x = torch.ones((1, 1, 2, 2)) + + # int + add_5 = Add(5) + assert torch.equal(add_5.forward(x), x + 5) + assert add_5.delay == 0 + + # float + add_pi = Add(3.14) + assert torch.equal(add_pi.forward_steps(x), x + 3.14) + + # Tensor + constants = torch.tensor([[[[2.0, 3.0]]]]) + add_constants = Add(constants) + assert torch.equal( + add_constants.forward_step(x[:, :, 0]), torch.tensor([[[3.0, 4.0]]]) + ) + + +def test_multiply(): + x = torch.ones((1, 1, 2, 2)) + + # int + mul_5 = Multiply(5) + assert torch.equal(mul_5.forward(x), x * 5) + assert mul_5.delay == 0 + + # float + mul_pi = Multiply(3.14) + assert torch.equal(mul_pi.forward_steps(x), x * 3.14) + + # Tensor + constants = torch.tensor([[[[2.0, 3.0]]]]) + mul_constants = Multiply(constants) + assert torch.equal( + mul_constants.forward_step(x[:, :, 0]), torch.tensor([[[2.0, 3.0]]]) + ) + + +def global_always42(x): + return torch.ones_like(x) * 42 + + +def test_lambda(): + x = torch.ones((1, 1, 2, 2)) + target = torch.ones_like(x) * 42 + + def local_always42(x): + return torch.ones_like(x) * 42 + + # Test if layer works in different scopes + # Global + assert torch.equal(target, Lambda(global_always42)(x)) + + # Local + assert torch.equal(target, Lambda(local_always42)(x)) + + # Anonymous + assert torch.equal(target, Lambda(lambda x: torch.ones_like(x) * 42)(x)) From c69e5655e342cc98abd58b008bc8f642b95086b2 Mon Sep 17 00:00:00 2001 From: LukasHedegaard Date: Thu, 26 Aug 2021 12:54:48 +0200 Subject: [PATCH 2/6] Add Unity --- continual/closure.py | 13 +++++++++++++ tests/continual/test_closure.py | 9 +++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/continual/closure.py b/continual/closure.py index d3dcafd..c2e339a 100644 --- a/continual/closure.py +++ b/continual/closure.py @@ -30,6 +30,10 @@ def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tens def delay(self) -> int: return 0 + @staticmethod + def build_from(fn: Callable[[Tensor], Tensor]) -> "Lambda": + return Lambda(fn) + def _multiply(x: Tensor, factor: Union[float, int, Tensor]): return x * factor @@ -49,3 +53,12 @@ def Add(constant) -> Lambda: """Create Lambda with addition function""" fn = partial(_add, constant=constant) return Lambda(fn) + + +def _unity(x: Tensor): + return x + + +def Unity() -> Lambda: + """Create Lambda with addition function""" + return Lambda(_unity) diff --git a/tests/continual/test_closure.py b/tests/continual/test_closure.py index cb66c9c..f88efd7 100644 --- a/tests/continual/test_closure.py +++ b/tests/continual/test_closure.py @@ -1,6 +1,6 @@ import torch -from continual.closure import Add, Lambda, Multiply +from continual.closure import Add, Lambda, Multiply, Unity def test_add(): @@ -62,4 +62,9 @@ def local_always42(x): assert torch.equal(target, Lambda(local_always42)(x)) # Anonymous - assert torch.equal(target, Lambda(lambda x: torch.ones_like(x) * 42)(x)) + assert torch.equal(target, Lambda.build_from(lambda x: torch.ones_like(x) * 42)(x)) + + +def test_unity(): + x = torch.ones((1, 1, 2, 2)) + assert torch.equal(x, Unity()(x)) From a53a947a73b0a5b9679797f0f27f02b927d4df6f Mon Sep 17 00:00:00 2001 From: LukasHedegaard Date: Thu, 26 Aug 2021 12:55:26 +0200 Subject: [PATCH 3/6] Add Conditional module --- continual/conditional.py | 0 continual/container.py | 66 +++++++++++++++++++++++++++++++ tests/continual/test_container.py | 52 ++++++++++++++++++++++++ 3 files changed, 118 insertions(+) delete mode 100644 continual/conditional.py diff --git a/continual/conditional.py b/continual/conditional.py deleted file mode 100644 index e69de29..0000000 diff --git a/continual/container.py b/continual/container.py index 3efbd5b..6358a33 100644 --- a/continual/container.py +++ b/continual/container.py @@ -326,3 +326,69 @@ def Residual( aggregation_fn=aggregation_fn, auto_delay=False, ) + + +class Conditional(FlattenableStateDict, CoModule, nn.Module): + """Module wrapper for conditional invocations at runtime""" + + def __init__( + self, + predicate: Callable[[CoModule, Tensor], bool], + on_true: CoModule, + on_false: CoModule = None, + ): + assert callable(predicate), "The pased function should be callable." + assert isinstance(on_true, CoModule), "on_true should be a CoModule." + assert ( + isinstance(on_false, CoModule) or on_false is None + ), "on_false should be a CoModule or None." + + nn.Module.__init__(self) + + self.predicate = predicate + + # Ensure modules have the same delay + self._delay = max(on_true.delay, getattr(on_false, "delay", 0)) + + self.add_module( + "0", + on_true + if on_true.delay == self._delay + else Sequential(Delay(self._delay - on_true.delay), on_true), + ) + + if on_false is not None: + self.add_module( + "1", + on_false + if on_false.delay == self._delay + else Sequential(Delay(self._delay - on_false.delay), on_false), + ) + + def forward(self, input: Tensor) -> Tensor: + if self.predicate(self, input): + return self._modules["0"].forward(input) + elif "1" in self._modules: + return self._modules["1"].forward(input) + else: + return input + + def forward_step(self, input: Tensor, update_state=True) -> Tensor: + if self.predicate(self, input): + return self._modules["0"].forward_step(input) + elif "1" in self._modules: + return self._modules["1"].forward_step(input) + else: + return input + + def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor: + if self.predicate(self, input): + return self._modules["0"].forward_steps(input) + elif "1" in self._modules: + return self._modules["1"].forward_steps(input) + else: + return input + + @property + def delay(self) -> int: + return self._delay diff --git a/tests/continual/test_container.py b/tests/continual/test_container.py index a48a155..3e055ff 100644 --- a/tests/continual/test_container.py +++ b/tests/continual/test_container.py @@ -272,3 +272,55 @@ def test_flat_state_dict(): assert torch.equal(nested[1].c1.bias, nested_new[1].c1.bias) assert True # Need to step down here to trigger context manager __exit__ + + +def test_conditional_only_first(): + x = torch.ones((1, 1, 3)) + + def is_training(module, *args): + return module.training + + mod = co.Conditional(is_training, co.Multiply(2)) + + mod.train() + assert torch.equal(mod.forward(x), x * 2) + assert torch.equal(mod.forward_steps(x), x * 2) + assert torch.equal(mod.forward_step(x[:, :, 0]), x[:, :, 0] * 2) + + mod.eval() + assert torch.equal(mod.forward(x), x) + assert torch.equal(mod.forward_steps(x), x) + assert torch.equal(mod.forward_step(x[:, :, 0]), x[:, :, 0]) + + +def test_conditional_both_cases(): + x = torch.ones((1, 1, 3)) + + def is_training(module, *args): + return module.training + + mod = co.Conditional(is_training, co.Multiply(2), co.Multiply(3)) + + mod.train() + assert torch.equal(mod.forward(x), x * 2) + assert torch.equal(mod.forward_steps(x), x * 2) + assert torch.equal(mod.forward_step(x[:, :, 0]), x[:, :, 0] * 2) + + mod.eval() + assert torch.equal(mod.forward(x), x * 3) + assert torch.equal(mod.forward_steps(x), x * 3) + assert torch.equal(mod.forward_step(x[:, :, 0]), x[:, :, 0] * 3) + + +def test_conditional_delay(): + # if_true.delay < if_false.delay + mod = co.Conditional(lambda a, b: True, co.Delay(2), co.Delay(3)) + assert mod.delay == 3 + assert mod._modules["0"].delay == 3 + assert mod._modules["1"].delay == 3 + + # if_true.delay > if_false.delay + mod = co.Conditional(lambda a, b: True, co.Delay(3), co.Delay(2)) + assert mod.delay == 3 + assert mod._modules["0"].delay == 3 + assert mod._modules["1"].delay == 3 From e70288dd3765370dc8687a6c698a25f7f3cb3d1d Mon Sep 17 00:00:00 2001 From: LukasHedegaard Date: Thu, 26 Aug 2021 12:55:42 +0200 Subject: [PATCH 4/6] Add modules to __init__ and convert --- continual/__init__.py | 4 ++-- continual/convert.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/continual/__init__.py b/continual/__init__.py index 0ede8d5..b241d15 100644 --- a/continual/__init__.py +++ b/continual/__init__.py @@ -1,5 +1,5 @@ -from .closure import Add, Lambda, Multiply # noqa: F401 -from .container import Parallel, Residual, Sequential # noqa: F401 +from .closure import Add, Lambda, Multiply, Unity # noqa: F401 +from .container import Conditional, Parallel, Residual, Sequential # noqa: F401 from .conv import Conv1d, Conv2d, Conv3d # noqa: F401 from .convert import continual, forward_stepping # noqa: F401 from .delay import Delay # noqa: F401 diff --git a/continual/convert.py b/continual/convert.py index d91c505..33c4235 100644 --- a/continual/convert.py +++ b/continual/convert.py @@ -1,10 +1,12 @@ """ Register modules with conversion system and 3rd-party libraries """ from functools import wraps +from types import FunctionType from typing import Callable, Type from torch import Tensor, nn +from .closure import Lambda from .container import Sequential from .conv import Conv1d, Conv2d, Conv3d from .logging import getLogger @@ -181,6 +183,9 @@ def continual(module: nn.Module) -> CoModule: # Container register(nn.Sequential, Sequential) +# Closure +register(FunctionType, Lambda) + # Register modules in `ptflops` try: From 970614fe7d839ae1fe35af28af8fd29528bd30b8 Mon Sep 17 00:00:00 2001 From: LukasHedegaard Date: Thu, 26 Aug 2021 12:55:58 +0200 Subject: [PATCH 5/6] Update README and CHANGELOG with new modules --- CHANGELOG.md | 2 ++ README.md | 2 ++ 2 files changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c06da7..bff5d4d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ From v1.0.0 and on, the project will adherence strictly to Semantic Versioning. - `co.Lambda` module. - `co.Add` module. - `co.Multiply` module. +- `co.Unity` module. +- `co.Conditional` module. ## [0.8.1] diff --git a/README.md b/README.md index bf5ab18..829e20c 100644 --- a/README.md +++ b/README.md @@ -201,12 +201,14 @@ Below is a list of the included modules and utilities included in the library: - `co.Sequential` - Sequential wrapper for modules. This module automatically performs conversions of torch.nn modules, which are safe during continual inference. These include all batch normalisation and activation function. - `co.Parallel` - Parallel wrapper for modules. - `co.Residual` - Residual wrapper for modules. + - `co.Conditional` - Conditionally checks whether to invoke a module at runtime. - `co.Delay` - Pure delay module (e.g. needed in residuals). - Functions - `co.Lambda` - Lambda module which wraps any function. - `co.Add` - Adds a constant value. - `co.Multiply` - Multiplies with a constant factor. + - `co.Unity` - Maps input to output without modification. - Converters From 3266ac4d9caa36460259045025ab44f8c9784be8 Mon Sep 17 00:00:00 2001 From: LukasHedegaard Date: Thu, 26 Aug 2021 12:56:38 +0200 Subject: [PATCH 6/6] Bump version to 0.9.0 --- CHANGELOG.md | 3 +++ setup.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bff5d4d..107236f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ From v1.0.0 and on, the project will adherence strictly to Semantic Versioning. ## [Unreleased] + + +## [0.9.0] ### Added - `co.Lambda` module. - `co.Add` module. diff --git a/setup.py b/setup.py index 3fb6aae..eb8f451 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def from_file(file_name: str = "requirements.txt", comment_char: str = "#"): setup( name="continual-inference", - version="0.8.1", + version="0.9.0", description="Building blocks for Continual Inference Networks in PyTorch", long_description=long_description(), long_description_content_type="text/markdown",