Skip to content

Commit

Permalink
Merge pull request #21 from LukasHedegaard/develop
Browse files Browse the repository at this point in the history
Add closure and conditional modules
  • Loading branch information
LukasHedegaard authored Aug 26, 2021
2 parents 93692c9 + 3266ac4 commit 0200295
Show file tree
Hide file tree
Showing 9 changed files with 276 additions and 2 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ From v1.0.0 and on, the project will adherence strictly to Semantic Versioning.
## [Unreleased]


## [0.9.0]
### Added
- `co.Lambda` module.
- `co.Add` module.
- `co.Multiply` module.
- `co.Unity` module.
- `co.Conditional` module.


## [0.8.1]
### Fixed
- Bug in `forward_stepping`.
Expand Down
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,15 @@ Below is a list of the included modules and utilities included in the library:
- `co.Sequential` - Sequential wrapper for modules. This module automatically performs conversions of torch.nn modules, which are safe during continual inference. These include all batch normalisation and activation function.
- `co.Parallel` - Parallel wrapper for modules.
- `co.Residual` - Residual wrapper for modules.
- `co.Conditional` - Conditionally checks whether to invoke a module at runtime.
- `co.Delay` - Pure delay module (e.g. needed in residuals).

- Functions
- `co.Lambda` - Lambda module which wraps any function.
- `co.Add` - Adds a constant value.
- `co.Multiply` - Multiplies with a constant factor.
- `co.Unity` - Maps input to output without modification.

- Converters
<!-- - `co.Residual` - residual connection, which automatically adds delay if needed -->
- `co.continual` - conversion function from `torch.nn` modules to `co` modules.
Expand Down
3 changes: 2 additions & 1 deletion continual/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .container import Parallel, Residual, Sequential # noqa: F401
from .closure import Add, Lambda, Multiply, Unity # noqa: F401
from .container import Conditional, Parallel, Residual, Sequential # noqa: F401
from .conv import Conv1d, Conv2d, Conv3d # noqa: F401
from .convert import continual, forward_stepping # noqa: F401
from .delay import Delay # noqa: F401
Expand Down
64 changes: 64 additions & 0 deletions continual/closure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from functools import partial
from typing import Callable, Union

from torch import Tensor, nn

from .module import CoModule


class Lambda(CoModule, nn.Module):
"""Module wrapper for stateless functions"""

def __init__(self, fn: Callable[[Tensor], Tensor]):
nn.Module.__init__(self)
assert callable(fn), "The pased function should be callable."
self.fn = fn

def forward(self, input: Tensor) -> Tensor:
return self.fn(input)

def forward_step(self, input: Tensor, update_state=True) -> Tensor:
x = input.unsqueeze(dim=2)
x = self.fn(x)
x = x.squeeze(dim=2)
return x

def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor:
return self.fn(input)

@property
def delay(self) -> int:
return 0

@staticmethod
def build_from(fn: Callable[[Tensor], Tensor]) -> "Lambda":
return Lambda(fn)


def _multiply(x: Tensor, factor: Union[float, int, Tensor]):
return x * factor


def Multiply(factor) -> Lambda:
"""Create Lambda with multiplication function"""
fn = partial(_multiply, factor=factor)
return Lambda(fn)


def _add(x: Tensor, constant: Union[float, int, Tensor]):
return x + constant


def Add(constant) -> Lambda:
"""Create Lambda with addition function"""
fn = partial(_add, constant=constant)
return Lambda(fn)


def _unity(x: Tensor):
return x


def Unity() -> Lambda:
"""Create Lambda with addition function"""
return Lambda(_unity)
66 changes: 66 additions & 0 deletions continual/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,69 @@ def Residual(
aggregation_fn=aggregation_fn,
auto_delay=False,
)


class Conditional(FlattenableStateDict, CoModule, nn.Module):
"""Module wrapper for conditional invocations at runtime"""

def __init__(
self,
predicate: Callable[[CoModule, Tensor], bool],
on_true: CoModule,
on_false: CoModule = None,
):
assert callable(predicate), "The pased function should be callable."
assert isinstance(on_true, CoModule), "on_true should be a CoModule."
assert (
isinstance(on_false, CoModule) or on_false is None
), "on_false should be a CoModule or None."

nn.Module.__init__(self)

self.predicate = predicate

# Ensure modules have the same delay
self._delay = max(on_true.delay, getattr(on_false, "delay", 0))

self.add_module(
"0",
on_true
if on_true.delay == self._delay
else Sequential(Delay(self._delay - on_true.delay), on_true),
)

if on_false is not None:
self.add_module(
"1",
on_false
if on_false.delay == self._delay
else Sequential(Delay(self._delay - on_false.delay), on_false),
)

def forward(self, input: Tensor) -> Tensor:
if self.predicate(self, input):
return self._modules["0"].forward(input)
elif "1" in self._modules:
return self._modules["1"].forward(input)
else:
return input

def forward_step(self, input: Tensor, update_state=True) -> Tensor:
if self.predicate(self, input):
return self._modules["0"].forward_step(input)
elif "1" in self._modules:
return self._modules["1"].forward_step(input)
else:
return input

def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor:
if self.predicate(self, input):
return self._modules["0"].forward_steps(input)
elif "1" in self._modules:
return self._modules["1"].forward_steps(input)
else:
return input

@property
def delay(self) -> int:
return self._delay
5 changes: 5 additions & 0 deletions continual/convert.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
""" Register modules with conversion system and 3rd-party libraries """

from functools import wraps
from types import FunctionType
from typing import Callable, Type

from torch import Tensor, nn

from .closure import Lambda
from .container import Sequential
from .conv import Conv1d, Conv2d, Conv3d
from .logging import getLogger
Expand Down Expand Up @@ -181,6 +183,9 @@ def continual(module: nn.Module) -> CoModule:
# Container
register(nn.Sequential, Sequential)

# Closure
register(FunctionType, Lambda)


# Register modules in `ptflops`
try:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def from_file(file_name: str = "requirements.txt", comment_char: str = "#"):

setup(
name="continual-inference",
version="0.8.1",
version="0.9.0",
description="Building blocks for Continual Inference Networks in PyTorch",
long_description=long_description(),
long_description_content_type="text/markdown",
Expand Down
70 changes: 70 additions & 0 deletions tests/continual/test_closure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch

from continual.closure import Add, Lambda, Multiply, Unity


def test_add():
x = torch.ones((1, 1, 2, 2))

# int
add_5 = Add(5)
assert torch.equal(add_5.forward(x), x + 5)
assert add_5.delay == 0

# float
add_pi = Add(3.14)
assert torch.equal(add_pi.forward_steps(x), x + 3.14)

# Tensor
constants = torch.tensor([[[[2.0, 3.0]]]])
add_constants = Add(constants)
assert torch.equal(
add_constants.forward_step(x[:, :, 0]), torch.tensor([[[3.0, 4.0]]])
)


def test_multiply():
x = torch.ones((1, 1, 2, 2))

# int
mul_5 = Multiply(5)
assert torch.equal(mul_5.forward(x), x * 5)
assert mul_5.delay == 0

# float
mul_pi = Multiply(3.14)
assert torch.equal(mul_pi.forward_steps(x), x * 3.14)

# Tensor
constants = torch.tensor([[[[2.0, 3.0]]]])
mul_constants = Multiply(constants)
assert torch.equal(
mul_constants.forward_step(x[:, :, 0]), torch.tensor([[[2.0, 3.0]]])
)


def global_always42(x):
return torch.ones_like(x) * 42


def test_lambda():
x = torch.ones((1, 1, 2, 2))
target = torch.ones_like(x) * 42

def local_always42(x):
return torch.ones_like(x) * 42

# Test if layer works in different scopes
# Global
assert torch.equal(target, Lambda(global_always42)(x))

# Local
assert torch.equal(target, Lambda(local_always42)(x))

# Anonymous
assert torch.equal(target, Lambda.build_from(lambda x: torch.ones_like(x) * 42)(x))


def test_unity():
x = torch.ones((1, 1, 2, 2))
assert torch.equal(x, Unity()(x))
52 changes: 52 additions & 0 deletions tests/continual/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,55 @@ def test_flat_state_dict():
assert torch.equal(nested[1].c1.bias, nested_new[1].c1.bias)

assert True # Need to step down here to trigger context manager __exit__


def test_conditional_only_first():
x = torch.ones((1, 1, 3))

def is_training(module, *args):
return module.training

mod = co.Conditional(is_training, co.Multiply(2))

mod.train()
assert torch.equal(mod.forward(x), x * 2)
assert torch.equal(mod.forward_steps(x), x * 2)
assert torch.equal(mod.forward_step(x[:, :, 0]), x[:, :, 0] * 2)

mod.eval()
assert torch.equal(mod.forward(x), x)
assert torch.equal(mod.forward_steps(x), x)
assert torch.equal(mod.forward_step(x[:, :, 0]), x[:, :, 0])


def test_conditional_both_cases():
x = torch.ones((1, 1, 3))

def is_training(module, *args):
return module.training

mod = co.Conditional(is_training, co.Multiply(2), co.Multiply(3))

mod.train()
assert torch.equal(mod.forward(x), x * 2)
assert torch.equal(mod.forward_steps(x), x * 2)
assert torch.equal(mod.forward_step(x[:, :, 0]), x[:, :, 0] * 2)

mod.eval()
assert torch.equal(mod.forward(x), x * 3)
assert torch.equal(mod.forward_steps(x), x * 3)
assert torch.equal(mod.forward_step(x[:, :, 0]), x[:, :, 0] * 3)


def test_conditional_delay():
# if_true.delay < if_false.delay
mod = co.Conditional(lambda a, b: True, co.Delay(2), co.Delay(3))
assert mod.delay == 3
assert mod._modules["0"].delay == 3
assert mod._modules["1"].delay == 3

# if_true.delay > if_false.delay
mod = co.Conditional(lambda a, b: True, co.Delay(3), co.Delay(2))
assert mod.delay == 3
assert mod._modules["0"].delay == 3
assert mod._modules["1"].delay == 3

0 comments on commit 0200295

Please sign in to comment.