Skip to content

Commit

Permalink
Fix up more docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Aug 19, 2024
1 parent 42fec51 commit 47d3e7c
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 25 deletions.
43 changes: 28 additions & 15 deletions aurora/model/patchembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ def __init__(
norm_layer: Optional[nn.Module] = None,
flatten: bool = True,
) -> None:
"""Initialise.
Args:
max_vars (int): Maximum number of variables to embed.
patch_size (int): Patch size.
embed_dim (int): Embedding dimensionality.
history_size (int, optional): Number of history dimensions. Defaults to `1`.
norm_layer (torch.nn.Module, optional): Normalisation layer to be applied at the very
end. Defaults to no normalisation layer.
flatten (bool): At the end of the forward pass, flatten the two spatial dimensions
into a single dimension. See :meth:`LevelPatchEmbed.forward` for more details.
"""
super().__init__()

self.max_vars = max_vars
Expand All @@ -32,7 +44,8 @@ def __init__(
self.embed_dim = embed_dim

weight = torch.cat(
# (C_out, C_in, kT, kH, kW)
# Shape (C_out, C_in, T, H, W). `C_in = 1` here because we're embedding every variable
# separately.
[torch.empty(embed_dim, 1, *self.kernel_size) for _ in range(max_vars)],
dim=1,
)
Expand Down Expand Up @@ -61,33 +74,33 @@ def reset_parameters(self) -> None:
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)

def forward(self, x: torch.Tensor, vars: list[int]) -> torch.Tensor:
def forward(self, x: torch.Tensor, var_ids: list[int]) -> torch.Tensor:
"""Run the embedding.
Args:
x (:class:`torch.Tensor`): Tensor to embed of a shape of `[B, V, T, H, W]`.
vars (list[int]): A list of variable IDs. The length should be equal to `V`.
x (:class:`torch.Tensor`): Tensor to embed of a shape of `(B, V, T, H, W)`.
var_ids (list[int]): A list of variable IDs. The length should be equal to `V`.
Returns:
:class:`torch.Tensor`: Embedded tensor a shape of `[B, L, D]` if flattened,
where `L = H * W / P**2`. Otherwise, the shape is `[B, D, H', W']`.
:class:`torch.Tensor`: Embedded tensor a shape of `(B, L, D]) if flattened,
where `L = H * W / P^2`. Otherwise, the shape is `(B, D, H', W')`.
"""
B, V, T, H, W = x.shape
assert len(vars) == V, f"{V} != {len(vars)}"
assert self.kernel_size[0] >= T, f"{T} > {self.kernel_size[0]}"
assert H % self.kernel_size[1] == 0, f"{H} % {self.kernel_size[0]} != 0"
assert W % self.kernel_size[2] == 0, f"{W} % {self.kernel_size[1]} != 0"
assert max(vars) < self.max_vars, f"{max(vars)} >= {self.max_vars}"
assert min(vars) >= 0, f"{min(vars)} < 0"
assert len(set(vars)) == len(vars), f"{vars} contains duplicates"
assert len(var_ids) == V, f"{V} != {len(var_ids)}."
assert self.kernel_size[0] >= T, f"{T} > {self.kernel_size[0]}."
assert H % self.kernel_size[1] == 0, f"{H} % {self.kernel_size[0]} != 0."
assert W % self.kernel_size[2] == 0, f"{W} % {self.kernel_size[1]} != 0."
assert max(var_ids) < self.max_vars, f"{max(var_ids)} >= {self.max_vars}."
assert min(var_ids) >= 0, f"{min(var_ids)} < 0."
assert len(set(var_ids)) == len(var_ids), f"{var_ids} contains duplicates."

# Select the weights of the variables and history dimensions that are present in the batch.
weight = self.weight[:, vars, :T, ...] # [C_out, C_in, kT, kH, kW]
weight = self.weight[:, var_ids, :T, ...] # (C_out, C_in, T, H, W)
# Adjust the stride if history is smaller than maximum.
stride = (T,) + self.kernel_size[1:]

# (B, V, T, H, W) -> (B, D, 1, H / P, W / P)
# The convolution maps (B, V, T, H, W) to (B, D, 1, H/P, W/P)
proj = F.conv3d(x, weight, self.bias, stride=stride)
if self.flatten:
proj = proj.reshape(B, self.embed_dim, -1) # (B, D, L)
Expand Down
62 changes: 52 additions & 10 deletions aurora/model/perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,20 @@
import torch.nn.functional as F
from einops import rearrange

__all__ = ["MLP", "PerceiverResampler"]


class MLP(nn.Module):
def __init__(self, dim, hidden_features: int, dropout=0.0):
"""A simple one-hidden-layer MLP."""

def __init__(self, dim: int, hidden_features: int, dropout: float = 0.0) -> None:
"""Initialise.
Args:
dim (int): Input dimensionality.
hidden_features (int): Width of the hidden layer.
dropout (float, optional): Drop-out rate. Defaults to no drop-out.
"""
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_features),
Expand All @@ -72,20 +83,33 @@ def __init__(self, dim, hidden_features: int, dropout=0.0):
nn.Dropout(dropout),
)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Run the MLP."""
return self.net(x)


class PerceiverAttention(nn.Module):
"""Cross attention module from the Perceiver architecture."""

def __init__(
self, latent_dim: int, context_dim: int, head_dim: int = 64, num_heads: int = 8
self,
latent_dim: int,
context_dim: int,
head_dim: int = 64,
num_heads: int = 8,
) -> None:
"""Initialise.
Args:
latent_dim (int): Dimensionality of the latent features given as input.
context_dim (int): Dimensionality of the context features also given as input.
head_dim (int): Attention head dimensionality.
num_heads (int): Number of heads.
"""
super().__init__()
self.inner_dim = head_dim * num_heads
self.num_heads = num_heads
self.head_dim = head_dim
self.inner_dim = head_dim * num_heads

self.to_q = nn.Linear(latent_dim, self.inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, self.inner_dim * 2, bias=False)
Expand All @@ -105,8 +129,8 @@ def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
h = self.num_heads

q = self.to_q(latents) # (B, L1, D2) -> (B, L1, D)
k, v = self.to_kv(x).chunk(2, dim=-1) # (B, L2, D1) -> (B, L2, D) x 2
q = self.to_q(latents) # (B, L1, D2) to (B, L1, D)
k, v = self.to_kv(x).chunk(2, dim=-1) # (B, L2, D1) to twice (B, L2, D)
q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b h l d", h=h), (q, k, v))

out = F.scaled_dot_product_attention(q, k, v)
Expand All @@ -125,10 +149,26 @@ def __init__(
head_dim: int = 64,
num_heads: int = 16,
mlp_ratio: float = 4.0,
drop=0.0,
drop: float = 0.0,
residual_latent: bool = True,
ln_eps: float = 1e-5,
) -> None:
"""Initialise.
Args:
latent_dim (int): Dimensionality of the latent features given as input.
context_dim (int): Dimensionality of the context features also given as input.
depth (int, optional): Number of attention layers.
head_dim (int, optional): Attention head dimensionality. Defaults to `64`.
num_heads (int, optional): Number of heads. Defaults to `16`
mlp_ratio (float, optional): Rimensionality of the hidden layer divided by that of the
input for all MLPs. Defaults to `4.0`.
drop (float, optional): Drop-out rate. Defaults to no drop-out.
residual_latent (bool, optional): Use residual attention w.r.t. the latent features.
Defaults to `True`.
ln_eps (float, optional): Epsilon in the layer normalisation layers. Defaults to
`1e-5`.
"""
super().__init__()

self.residual_latent = residual_latent
Expand Down Expand Up @@ -165,9 +205,11 @@ def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
# We use post-res-norm like in Swin v2 and most Transformer architectures these days.
# This empirically works better than the pre-norm used in the original Perceiver.
attn_out = ln1(attn(latents, x))
# HuggingFace suggests using non-residual attention in Perceiver might
# work better when the semantics of the query and the output are different.
# https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/perceiver/modeling_perceiver.py#L398
# HuggingFace suggests using non-residual attention in Perceiver might work better when
# the semantics of the query and the output are different:
#
# https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/perceiver/modeling_perceiver.py#L398
#
latents = attn_out + latents if self.residual_latent else attn_out
latents = ln2(ff(latents)) + latents
return latents

0 comments on commit 47d3e7c

Please sign in to comment.