Skip to content

Commit

Permalink
Merge pull request #60 from LukasHedegaard/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
LukasHedegaard authored Jan 13, 2023
2 parents 8e71ffe + 6d39f3d commit 88fec55
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 10 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ From v1.0.0 and on, the project will adherence strictly to Semantic Versioning.

## Unpublished


## [1.1.2] - 2023-01-13

### Added
- `query_index` argument to `SingleOutputTransformerEncoderLayer`.

### Fixed
- `Residual` centred residual and `Delay` auto_delay forward_step.


## [1.1.1] - 2023-01-10

### Added
Expand Down
2 changes: 1 addition & 1 deletion continual/__about__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time

__version__ = "1.1.1"
__version__ = "1.1.2"
__author__ = "Lukas Hedegaard"
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
Expand Down
9 changes: 5 additions & 4 deletions continual/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def init_state(
) -> State:
padding = self._make_padding(first_output)
state_buffer = torch.stack([padding for _ in range(self.delay)], dim=0)
state_index = torch.tensor(-self.delay)
state_index = torch.tensor(
-2 * self.delay
if self.auto_shrink and isinstance(self.auto_shrink, bool)
else -self.delay
)
return state_buffer, state_index

def clean_state(self):
Expand Down Expand Up @@ -113,15 +117,12 @@ def forward_step(self, input: Tensor, update_state=True) -> Tensor:
return CoModule.forward_step(self, input, update_state)

def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor:
first_run = self.get_state() is None
if self._delay == 0:
return input

with temporary_parameter(self, "padding", (self.delay,)):
output = CoModule.forward_steps(self, input, pad_end, update_state)

if first_run and self.auto_shrink in {True, "centered"}:
output = output[:, :, self.delay :]
return output

def forward(self, input: Tensor) -> Tensor:
Expand Down
8 changes: 5 additions & 3 deletions continual/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def SingleOutputTransformerEncoderLayer(
dtype=None,
sequence_len: int = None,
single_output_forward=False,
query_index: int = -1,
):
"""Continual Single-output Transformer Encoder layer.
Expand Down Expand Up @@ -191,6 +192,7 @@ def SingleOutputTransformerEncoderLayer(
dtype: datatype of layer parameters. Defaults to None.
sequence_len: length of token-sequence to perform attention across. Defaults to None.
single_output_forward: whether to restrict the attention to the last token during forward. Defaults to False.
query_index: the sequence position index to compute the attention for.
Examples::
Expand Down Expand Up @@ -225,7 +227,7 @@ def SingleOutputTransformerEncoderLayer(
bias=True,
batch_first=True,
embed_dim_second=True,
query_index=-1,
query_index=query_index,
device=device,
dtype=dtype,
sequence_len=sequence_len,
Expand Down Expand Up @@ -462,7 +464,7 @@ def TransformerEncoderLayerFactory(
Examples::
encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8)
encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8, sequence_len=32)
transformer_encoder = co.TransformerEncoder(encoder_layer, num_layers=2)
src = torch.rand(10, 512, 32)
out = transformer_encoder(src)
Expand Down Expand Up @@ -527,7 +529,7 @@ class TransformerEncoder(Sequential):
Examples::
encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8)
encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8, sequence_len=32)
transformer_encoder = co.TransformerEncoder(encoder_layer, num_layers=2)
src = torch.rand(10, 512, 32)
out = transformer_encoder(src)
Expand Down
6 changes: 4 additions & 2 deletions tests/continual/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ def test_residual_shrink_centered():

# forward_steps
co_res.clean_state()
out_firsts = co_res.forward_steps(input[:, :, :-1], pad_end=False)
_ = co_res.forward_step(input[:, :, 0])
out_firsts = co_res.forward_steps(input[:, :, 1:-1], pad_end=False)
assert torch.allclose(out_firsts, target[:, :, :3])

# forward_step
Expand Down Expand Up @@ -312,7 +313,8 @@ def test_residual_shrink_lagging():

# forward_steps
co_res.clean_state()
out_firsts = co_res.forward_steps(input[:, :, :-1], pad_end=False)
_ = co_res.forward_step(input[:, :, 0])
out_firsts = co_res.forward_steps(input[:, :, 1:-1], pad_end=False)
assert torch.allclose(out_firsts, out_manual_res[:, :, :3])

# forward_step
Expand Down

0 comments on commit 88fec55

Please sign in to comment.