Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrss committed Dec 19, 2023
1 parent 81b4096 commit 425192c
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions src/cultionet/models/ltae.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class MultiHeadAttention(nn.Module):
def __init__(self, num_head: int, d_in: int, dropout: float = 0.1):
super(MultiHeadAttention, self).__init__()

self.num_head = num_head
d_k = d_in // num_head
scale = 1.0 / d_k**0.5

Expand All @@ -73,7 +74,7 @@ def __init__(self, num_head: int, d_in: int, dropout: float = 0.1):

def split(self, x: torch.Tensor) -> torch.Tensor:
return einops.rearrange(
x, 'b t (num_head k) -> num_head b t k', num_head=num_head
x, 'b t (num_head k) -> num_head b t k', num_head=self.num_head
)

def forward(
Expand Down Expand Up @@ -129,10 +130,7 @@ def __init__(
activation_type: str = "SiLU",
final_activation: Callable = Softmax(dim=1),
):
"""Lightweight Temporal Attention Encoder (L-TAE) for image time
series. Attention-based sequence encoding that maps a sequence of
images to a single feature map. A shared L-TAE is applied to all pixel
positions of the image sequence.
"""Transformer Self-Attention.
Args:
in_channels (int): Number of channels of the inputs.
Expand Down

0 comments on commit 425192c

Please sign in to comment.