-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #21 from LukasHedegaard/develop
Add closure and conditional modules
- Loading branch information
Showing
9 changed files
with
276 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from functools import partial | ||
from typing import Callable, Union | ||
|
||
from torch import Tensor, nn | ||
|
||
from .module import CoModule | ||
|
||
|
||
class Lambda(CoModule, nn.Module): | ||
"""Module wrapper for stateless functions""" | ||
|
||
def __init__(self, fn: Callable[[Tensor], Tensor]): | ||
nn.Module.__init__(self) | ||
assert callable(fn), "The pased function should be callable." | ||
self.fn = fn | ||
|
||
def forward(self, input: Tensor) -> Tensor: | ||
return self.fn(input) | ||
|
||
def forward_step(self, input: Tensor, update_state=True) -> Tensor: | ||
x = input.unsqueeze(dim=2) | ||
x = self.fn(x) | ||
x = x.squeeze(dim=2) | ||
return x | ||
|
||
def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor: | ||
return self.fn(input) | ||
|
||
@property | ||
def delay(self) -> int: | ||
return 0 | ||
|
||
@staticmethod | ||
def build_from(fn: Callable[[Tensor], Tensor]) -> "Lambda": | ||
return Lambda(fn) | ||
|
||
|
||
def _multiply(x: Tensor, factor: Union[float, int, Tensor]): | ||
return x * factor | ||
|
||
|
||
def Multiply(factor) -> Lambda: | ||
"""Create Lambda with multiplication function""" | ||
fn = partial(_multiply, factor=factor) | ||
return Lambda(fn) | ||
|
||
|
||
def _add(x: Tensor, constant: Union[float, int, Tensor]): | ||
return x + constant | ||
|
||
|
||
def Add(constant) -> Lambda: | ||
"""Create Lambda with addition function""" | ||
fn = partial(_add, constant=constant) | ||
return Lambda(fn) | ||
|
||
|
||
def _unity(x: Tensor): | ||
return x | ||
|
||
|
||
def Unity() -> Lambda: | ||
"""Create Lambda with addition function""" | ||
return Lambda(_unity) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import torch | ||
|
||
from continual.closure import Add, Lambda, Multiply, Unity | ||
|
||
|
||
def test_add(): | ||
x = torch.ones((1, 1, 2, 2)) | ||
|
||
# int | ||
add_5 = Add(5) | ||
assert torch.equal(add_5.forward(x), x + 5) | ||
assert add_5.delay == 0 | ||
|
||
# float | ||
add_pi = Add(3.14) | ||
assert torch.equal(add_pi.forward_steps(x), x + 3.14) | ||
|
||
# Tensor | ||
constants = torch.tensor([[[[2.0, 3.0]]]]) | ||
add_constants = Add(constants) | ||
assert torch.equal( | ||
add_constants.forward_step(x[:, :, 0]), torch.tensor([[[3.0, 4.0]]]) | ||
) | ||
|
||
|
||
def test_multiply(): | ||
x = torch.ones((1, 1, 2, 2)) | ||
|
||
# int | ||
mul_5 = Multiply(5) | ||
assert torch.equal(mul_5.forward(x), x * 5) | ||
assert mul_5.delay == 0 | ||
|
||
# float | ||
mul_pi = Multiply(3.14) | ||
assert torch.equal(mul_pi.forward_steps(x), x * 3.14) | ||
|
||
# Tensor | ||
constants = torch.tensor([[[[2.0, 3.0]]]]) | ||
mul_constants = Multiply(constants) | ||
assert torch.equal( | ||
mul_constants.forward_step(x[:, :, 0]), torch.tensor([[[2.0, 3.0]]]) | ||
) | ||
|
||
|
||
def global_always42(x): | ||
return torch.ones_like(x) * 42 | ||
|
||
|
||
def test_lambda(): | ||
x = torch.ones((1, 1, 2, 2)) | ||
target = torch.ones_like(x) * 42 | ||
|
||
def local_always42(x): | ||
return torch.ones_like(x) * 42 | ||
|
||
# Test if layer works in different scopes | ||
# Global | ||
assert torch.equal(target, Lambda(global_always42)(x)) | ||
|
||
# Local | ||
assert torch.equal(target, Lambda(local_always42)(x)) | ||
|
||
# Anonymous | ||
assert torch.equal(target, Lambda.build_from(lambda x: torch.ones_like(x) * 42)(x)) | ||
|
||
|
||
def test_unity(): | ||
x = torch.ones((1, 1, 2, 2)) | ||
assert torch.equal(x, Unity()(x)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters