Skip to content

Commit

Permalink
Merge pull request #18 from LukasHedegaard/develop
Browse files Browse the repository at this point in the history
Update forward_steps IF with pad_end, bug fixes, and implementation restructure
  • Loading branch information
LukasHedegaard authored Aug 24, 2021
2 parents 31d47e4 + b113e8b commit e6739cf
Show file tree
Hide file tree
Showing 17 changed files with 349 additions and 269 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pythonpackage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
matrix:
os: [macos-latest, ubuntu-latest, windows-latest]
python-version: [3.6, 3.8]
python-version: [3.6, 3.9]

steps:
- uses: actions/checkout@v2
Expand Down
16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,23 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## [Unreleased]


## [0.8.0]
### Fixed
- Bugs in `forward_step(s)` with `update_state=False`

### Changed
- `forward_steps` interface to always include `pad_end` argument.
- name of "interface.py" to "module.py".
- implementations of `forward_step(s)` to be consolidated in CoModule.

### Removed
- `Padded` interface


## [0.7.0]
### Added
- Independent state_dict and load_state_dict functions.
Expand All @@ -20,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Continual interface and conversion to support both class and module.
- Replicate padding in `co._ConvNd`


## [0.6.1]
### Changed
- `co.Residual` modules to be unnamed. This allows the module state dicts to be flattened.
Expand Down
2 changes: 1 addition & 1 deletion continual/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .conv import Conv1d, Conv2d, Conv3d # noqa: F401
from .convert import continual, forward_stepping # noqa: F401
from .delay import Delay # noqa: F401
from .interface import CoModule, PaddingMode, TensorPlaceholder # noqa: F401
from .module import CoModule, PaddingMode, TensorPlaceholder # noqa: F401
from .pooling import ( # noqa: F401
AdaptiveAvgPool2d,
AdaptiveAvgPool3d,
Expand Down
30 changes: 11 additions & 19 deletions continual/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import Tensor, nn

from .delay import Delay
from .interface import CoModule, Padded, PaddingMode, TensorPlaceholder
from .module import CoModule, PaddingMode, TensorPlaceholder
from .utils import load_state_dict, state_dict

__all__ = ["Sequential", "Parallel", "Residual"]
Expand Down Expand Up @@ -39,7 +39,7 @@ def load_state_dict(
return load_state_dict(self, state_dict, strict, flatten)


class Sequential(FlattenableStateDict, nn.Sequential, Padded, CoModule):
class Sequential(FlattenableStateDict, nn.Sequential, CoModule):
"""Continual Sequential module
This module is a drop-in replacement for `torch.nn.Sequential`
Expand All @@ -65,17 +65,14 @@ def forward_step(self, input, update_state=True):
for module in self:
input = module.forward_step(input, update_state=update_state)
if not isinstance(input, Tensor):
return TensorPlaceholder() # We can't infer output shape
return TensorPlaceholder()
return input

def forward_steps(self, input: Tensor, pad_end=False, update_state=True):
for module in self:
if isinstance(module, Padded):
input = module.forward_steps(
input, pad_end=pad_end, update_state=update_state
)
else:
input = module.forward_steps(input, update_state=update_state)
if not isinstance(input, Tensor) or len(input) == 0:
return TensorPlaceholder() # pragma: no cover
input = module.forward_steps(input, pad_end, update_state)

return input

Expand Down Expand Up @@ -159,13 +156,13 @@ def nonempty(fn: AggregationFunc) -> AggregationFunc:
@wraps(fn)
def wrapped(inputs: Sequence[Tensor]) -> Tensor:
if any(len(inp) == 0 for inp in inputs):
return TensorPlaceholder(inputs[0].shape)
return TensorPlaceholder(inputs[0].shape) # pragma: no cover
return fn(inputs)

return wrapped


class Parallel(FlattenableStateDict, nn.Sequential, Padded, CoModule):
class Parallel(FlattenableStateDict, nn.Sequential, CoModule):
"""Continual parallel container.
Args:
Expand Down Expand Up @@ -262,7 +259,7 @@ def forward_step(self, input: Tensor, update_state=True) -> Tensor:
# Try to infer shape
shape = tuple()
for o in outs:
if isinstance(o, Tensor):
if isinstance(o, Tensor): # pragma: no cover
shape = o.shape
break
return TensorPlaceholder(shape)
Expand All @@ -271,12 +268,7 @@ def forward_step(self, input: Tensor, update_state=True) -> Tensor:
def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor:
outs = []
for m in self:
if isinstance(m, Padded):
outs.append(
m.forward_steps(input, pad_end=pad_end, update_state=update_state)
)
else:
outs.append(m.forward_steps(input, update_state=update_state))
outs.append(m.forward_steps(input, pad_end, update_state))

return self.aggregation_fn(outs)

Expand All @@ -303,7 +295,7 @@ def delay(self) -> int:

@property
def stride(self) -> int:
return getattr(next(iter(self)), "stride", 1)
return int_from(getattr(next(iter(self)), "stride", 1))

@property
def padding(self) -> int:
Expand Down
50 changes: 10 additions & 40 deletions continual/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
_triple,
)

from .interface import CoModule, Padded, PaddingMode, TensorPlaceholder
from .logging import getLogger
from .module import CoModule, PaddingMode, TensorPlaceholder

logger = getLogger(__name__)

Expand All @@ -29,7 +29,7 @@
]


class _ConvCoNd(_ConvNd, Padded, CoModule):
class _ConvCoNd(_ConvNd, CoModule):
def __init__(
self,
ConvClass: torch.nn.Module,
Expand Down Expand Up @@ -136,10 +136,13 @@ def get_state(self):
):
return (self.state_buffer, self.state_index, self.stride_index)

def set_state(self, state: State):
self.state_buffer, self.state_index, self.stride_index = state

def _forward_step(self, input: Tensor, prev_state: State) -> Tuple[Tensor, State]:
assert (
len(input.shape) == self._input_len - 1
), f"A tensor of shape {(*self.input_shape_desciption[:2], *self.input_shape_desciption[3:])} should be passed as input."
), f"A tensor of shape {(*self.input_shape_desciption[:2], *self.input_shape_desciption[3:])} should be passed as input but got {input.shape}"

# e.g. B, C -> B, C, 1
x = input.unsqueeze(2)
Expand Down Expand Up @@ -198,45 +201,12 @@ def _forward_step(self, input: Tensor, prev_state: State) -> Tuple[Tensor, State

return x_out, (next_buffer, next_index, next_stride_index)

def forward_step(self, input: Tensor, update_state=True) -> Tensor:
output, (state_buffer, state_index, stride_index) = self._forward_step(
input, self.get_state()
)
if update_state:
self.state_buffer = state_buffer
self.state_index = state_index
self.stride_index = stride_index
return output

def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor:
assert (
len(input.shape) == self._input_len
), f"A tensor of shape {self.input_shape_desciption} should be passed as input."

outs = []

for t in range(input.shape[2]):
o = self.forward_step(input[:, :, t], update_state=update_state)
if isinstance(o, Tensor):
outs.append(o)

if pad_end:
# Don't save state for the end-padding
(tmp_buffer, tmp_index, tmp_stride_index) = self.get_state()
for t, i in enumerate(
[self.make_padding(input[:, :, -1]) for _ in range(self.padding[0])]
):
o, (tmp_buffer, tmp_index, tmp_stride_index) = self._forward_step(
i, (tmp_buffer, tmp_index, tmp_stride_index)
)
if isinstance(o, Tensor):
outs.append(o)

if len(outs) > 0:
outs = torch.stack(outs, dim=2)
else:
outs = torch.tensor([]) # pragma: no cover
return outs
), f"A tensor of shape {self.input_shape_desciption} should be passed as input but got {input.shape}."

return CoModule.forward_steps(self, input, pad_end, update_state)

def forward(self, input: Tensor) -> Tensor:
"""Performs a full forward computation exactly as the regular layer would.
Expand All @@ -250,7 +220,7 @@ def forward(self, input: Tensor) -> Tensor:
"""
assert (
len(input.shape) == self._input_len
), f"A tensor of shape {self.input_shape_desciption} should be passed as input."
), f"A tensor of shape {self.input_shape_desciption} should be passed as input but got {input.shape}."
output = self._ConvClass._conv_forward(self, input, self.weight, self.bias)

return output
Expand Down
18 changes: 9 additions & 9 deletions continual/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from .container import Sequential
from .conv import Conv1d, Conv2d, Conv3d
from .interface import CoModule
from .logging import getLogger
from .module import CoModule
from .pooling import (
AdaptiveAvgPool2d,
AdaptiveAvgPool3d,
Expand Down Expand Up @@ -41,32 +41,32 @@ def forward_stepping(module: nn.Module, dim: int = 2):
dim (int, optional): The dimension to unsqueeze during `forward_step`. Defaults to 2.
"""

def unsqueezed(func: Callable[[Tensor], Tensor]):
def forward_step(func: Callable[[Tensor], Tensor]):
@wraps(func)
def call(x: Tensor) -> Tensor:
def call(x: Tensor, update_state=True) -> Tensor:
x = x.unsqueeze(dim)
x = func(x)
x = x.squeeze(dim)
return x

return call

def with_dummy_args(func: Callable[[Tensor], Tensor]):
def forward_steps(func: Callable[[Tensor], Tensor]):
@wraps(func)
def call(x: Tensor, update_state=True) -> Tensor:
def call(x: Tensor, pad_end=False, update_state=True) -> Tensor:
x = func(x)
return x

return call

def dummy(self):
def clean_state(*args, **kwargs):
... # pragma: no cover

module.forward = module.forward
module.forward_steps = with_dummy_args(module.forward)
module.forward_step = with_dummy_args(unsqueezed(module.forward))
module.forward_steps = forward_steps(module.forward)
module.forward_step = forward_step(module.forward)
module.delay = 0
module.clean_state = dummy
module.clean_state = clean_state

return module

Expand Down
56 changes: 17 additions & 39 deletions continual/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import torch
from torch import Tensor

from .interface import CoModule, Padded, PaddingMode, TensorPlaceholder
from .module import CoModule, PaddingMode, TensorPlaceholder
from .utils import temporary_parameter

State = Tuple[Tensor, int]

__all__ = ["Delay"]


class Delay(torch.nn.Module, Padded, CoModule):
class Delay(torch.nn.Module, CoModule):
"""Continual delay modules
This module only introduces a delay in the continual modes, i.e. on `forward_step` and `forward_steps`.
Expand Down Expand Up @@ -54,20 +55,9 @@ def get_state(self):
and self.state_buffer is not None
):
return (self.state_buffer, self.state_index)
else:
return None

def forward_step(self, input: Tensor, update_state=True) -> Tensor:
if self._delay == 0:
return input

output, (state_buffer, state_index) = self._forward_step(
input, self.get_state()
)
if update_state:
self.state_buffer = state_buffer
self.state_index = state_index
return output
def set_state(self, state: State):
self.state_buffer, self.state_index = state

def _forward_step(self, input: Tensor, prev_state: State) -> Tuple[Tensor, State]:
if prev_state is None:
Expand All @@ -89,32 +79,20 @@ def _forward_step(self, input: Tensor, prev_state: State) -> Tuple[Tensor, State

return output, (buffer, new_index)

def forward_step(self, input: Tensor, update_state=True) -> Tensor:
if self._delay == 0:
return input

return CoModule.forward_step(self, input, update_state)

def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor:
outs = []
for t in range(input.shape[2]):
o = self.forward_step(input[:, :, t], update_state=update_state)
if isinstance(o, Tensor):
outs.append(o)

if pad_end:
# Empty out delay values, but don't save state for the end-padding
(tmp_buffer, tmp_index) = self.get_state()
tmp_buffer = tmp_buffer.clone()
for t, i in enumerate(
[self.make_padding(input[:, :, -1]) for _ in range(self.delay)]
):
o, (tmp_buffer, tmp_index) = self._forward_step(
i, (tmp_buffer, tmp_index)
)
if isinstance(o, Tensor):
outs.append(o)

if len(outs) > 0:
outs = torch.stack(outs, dim=2)
else:
outs = torch.tensor([]) # pragma: no cover
if self._delay == 0:
return input

return outs
with temporary_parameter(self, "padding", (self.delay,)):
output = CoModule.forward_steps(self, input, pad_end, update_state)

return output

def forward(self, input: Tensor) -> Tensor:
# No delay during regular forward
Expand Down
Loading

0 comments on commit e6739cf

Please sign in to comment.