Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #60

Merged
merged 3 commits into from
Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ From v1.0.0 and on, the project will adherence strictly to Semantic Versioning.

## Unpublished


## [1.1.2] - 2023-01-13

### Added
- `query_index` argument to `SingleOutputTransformerEncoderLayer`.

### Fixed
- `Residual` centred residual and `Delay` auto_delay forward_step.


## [1.1.1] - 2023-01-10

### Added
Expand Down
2 changes: 1 addition & 1 deletion continual/__about__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time

__version__ = "1.1.1"
__version__ = "1.1.2"
__author__ = "Lukas Hedegaard"
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
Expand Down
9 changes: 5 additions & 4 deletions continual/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def init_state(
) -> State:
padding = self._make_padding(first_output)
state_buffer = torch.stack([padding for _ in range(self.delay)], dim=0)
state_index = torch.tensor(-self.delay)
state_index = torch.tensor(
-2 * self.delay
if self.auto_shrink and isinstance(self.auto_shrink, bool)
else -self.delay
)
return state_buffer, state_index

def clean_state(self):
Expand Down Expand Up @@ -113,15 +117,12 @@ def forward_step(self, input: Tensor, update_state=True) -> Tensor:
return CoModule.forward_step(self, input, update_state)

def forward_steps(self, input: Tensor, pad_end=False, update_state=True) -> Tensor:
first_run = self.get_state() is None
if self._delay == 0:
return input

with temporary_parameter(self, "padding", (self.delay,)):
output = CoModule.forward_steps(self, input, pad_end, update_state)

if first_run and self.auto_shrink in {True, "centered"}:
output = output[:, :, self.delay :]
return output

def forward(self, input: Tensor) -> Tensor:
Expand Down
8 changes: 5 additions & 3 deletions continual/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def SingleOutputTransformerEncoderLayer(
dtype=None,
sequence_len: int = None,
single_output_forward=False,
query_index: int = -1,
):
"""Continual Single-output Transformer Encoder layer.

Expand Down Expand Up @@ -191,6 +192,7 @@ def SingleOutputTransformerEncoderLayer(
dtype: datatype of layer parameters. Defaults to None.
sequence_len: length of token-sequence to perform attention across. Defaults to None.
single_output_forward: whether to restrict the attention to the last token during forward. Defaults to False.
query_index: the sequence position index to compute the attention for.

Examples::

Expand Down Expand Up @@ -225,7 +227,7 @@ def SingleOutputTransformerEncoderLayer(
bias=True,
batch_first=True,
embed_dim_second=True,
query_index=-1,
query_index=query_index,
device=device,
dtype=dtype,
sequence_len=sequence_len,
Expand Down Expand Up @@ -462,7 +464,7 @@ def TransformerEncoderLayerFactory(

Examples::

encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8)
encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8, sequence_len=32)
transformer_encoder = co.TransformerEncoder(encoder_layer, num_layers=2)
src = torch.rand(10, 512, 32)
out = transformer_encoder(src)
Expand Down Expand Up @@ -527,7 +529,7 @@ class TransformerEncoder(Sequential):

Examples::

encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8)
encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8, sequence_len=32)
transformer_encoder = co.TransformerEncoder(encoder_layer, num_layers=2)
src = torch.rand(10, 512, 32)
out = transformer_encoder(src)
Expand Down
6 changes: 4 additions & 2 deletions tests/continual/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ def test_residual_shrink_centered():

# forward_steps
co_res.clean_state()
out_firsts = co_res.forward_steps(input[:, :, :-1], pad_end=False)
_ = co_res.forward_step(input[:, :, 0])
out_firsts = co_res.forward_steps(input[:, :, 1:-1], pad_end=False)
assert torch.allclose(out_firsts, target[:, :, :3])

# forward_step
Expand Down Expand Up @@ -312,7 +313,8 @@ def test_residual_shrink_lagging():

# forward_steps
co_res.clean_state()
out_firsts = co_res.forward_steps(input[:, :, :-1], pad_end=False)
_ = co_res.forward_step(input[:, :, 0])
out_firsts = co_res.forward_steps(input[:, :, 1:-1], pad_end=False)
assert torch.allclose(out_firsts, out_manual_res[:, :, :3])

# forward_step
Expand Down