Skip to content

Commit

Permalink
Merge pull request #20 from LukasHedegaard/develop
Browse files Browse the repository at this point in the history
Bug-fixes in forward_stepping and clean_state
  • Loading branch information
LukasHedegaard authored Aug 26, 2021
2 parents e6739cf + d0fa732 commit 93692c9
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 17 deletions.
20 changes: 14 additions & 6 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
# Changelog
All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html), with the exception that v0.X updates include backwards-incompatible API changes.
From v1.0.0 and on, the project will adherence strictly to Semantic Versioning.


## [Unreleased]


## [0.8.1]
### Fixed
- Bug in `forward_stepping`.
- Bug in `clean_state`.


## [0.8.0]
### Fixed
- Bugs in `forward_step(s)` with `update_state=False`
- Bugs in `forward_step(s)` with `update_state=False`.

### Changed
- `forward_steps` interface to always include `pad_end` argument.
- name of "interface.py" to "module.py".
- implementations of `forward_step(s)` to be consolidated in CoModule.
- Name of "interface.py" to "module.py".
- Implementations of `forward_step(s)` to be consolidated in CoModule.

### Removed
- `Padded` interface
- `Padded` interface.


## [0.7.0]
Expand Down
9 changes: 6 additions & 3 deletions continual/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,12 @@ def init_state(
return state_buffer, state_index, stride_index

def clean_state(self):
del self.state_buffer
del self.state_index
del self.stride_index
if hasattr(self, "state_buffer"):
del self.state_buffer
if hasattr(self, "state_index"):
del self.state_index
if hasattr(self, "stride_index"):
del self.stride_index

def get_state(self):
if (
Expand Down
11 changes: 9 additions & 2 deletions continual/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,21 @@ def call(x: Tensor, pad_end=False, update_state=True) -> Tensor:

return call

def clean_state(*args, **kwargs):
def dummy(*args, **kwargs):
... # pragma: no cover

@staticmethod
def build_from(mod): # pragma: no cover
return module.__class__()

module.forward = module.forward
module.forward_steps = forward_steps(module.forward)
module.forward_step = forward_step(module.forward)
module.delay = 0
module.clean_state = clean_state
module.get_state = dummy
module.set_state = dummy
module.clean_state = dummy
module.build_from = build_from

return module

Expand Down
6 changes: 4 additions & 2 deletions continual/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def init_state(
return state_buffer, state_index

def clean_state(self):
del self.state_buffer
del self.state_index
if hasattr(self, "state_buffer"):
del self.state_buffer
if hasattr(self, "state_index"):
del self.state_index

def get_state(self):
if (
Expand Down
9 changes: 6 additions & 3 deletions continual/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,12 @@ def init_state(
return state_buffer, state_index, stride_index

def clean_state(self):
del self.state_buffer
del self.state_index
del self.stride_index
if hasattr(self, "state_buffer"):
del self.state_buffer
if hasattr(self, "state_index"):
del self.state_index
if hasattr(self, "stride_index"):
del self.stride_index

def get_state(self):
if (
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def from_file(file_name: str = "requirements.txt", comment_char: str = "#"):

setup(
name="continual-inference",
version="0.8.0",
version="0.8.1",
description="Building blocks for Continual Inference Networks in PyTorch",
long_description=long_description(),
long_description_content_type="text/markdown",
Expand Down
1 change: 1 addition & 0 deletions tests/continual/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ def test_update_state_false():
padding_mode="zeros",
)
coconv = co.Conv3d.build_from(conv)
coconv.clean_state() # Nothing should happen

target = conv.forward(sample)

Expand Down
1 change: 1 addition & 0 deletions tests/continual/test_delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
def test_delay_3d():
sample = torch.normal(mean=torch.zeros(4 * 3 * 3)).reshape((1, 1, 4, 3, 3))
delay = Delay(delay=2, temporal_fill="zeros")
delay.clean_state() # Nothing should happen

ones = torch.ones_like(sample[:, :, 0])

Expand Down
1 change: 1 addition & 0 deletions tests/continual/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_AvgPool1d_padded():
target = pool(sample)

co_pool = AvgPool1d.build_from(pool)
co_pool.clean_state() # Nothing should happen

# forward
output2 = co_pool.forward(sample)
Expand Down

0 comments on commit 93692c9

Please sign in to comment.