From 47d3e7c5f77c2b2123052a0481da43eb59c5fe14 Mon Sep 17 00:00:00 2001 From: Wessel Bruinsma Date: Mon, 19 Aug 2024 11:51:31 +0200 Subject: [PATCH] Fix up more docstrings --- aurora/model/patchembed.py | 43 +++++++++++++++++--------- aurora/model/perceiver.py | 62 ++++++++++++++++++++++++++++++++------ 2 files changed, 80 insertions(+), 25 deletions(-) diff --git a/aurora/model/patchembed.py b/aurora/model/patchembed.py index 328029b..2b19197 100644 --- a/aurora/model/patchembed.py +++ b/aurora/model/patchembed.py @@ -24,6 +24,18 @@ def __init__( norm_layer: Optional[nn.Module] = None, flatten: bool = True, ) -> None: + """Initialise. + + Args: + max_vars (int): Maximum number of variables to embed. + patch_size (int): Patch size. + embed_dim (int): Embedding dimensionality. + history_size (int, optional): Number of history dimensions. Defaults to `1`. + norm_layer (torch.nn.Module, optional): Normalisation layer to be applied at the very + end. Defaults to no normalisation layer. + flatten (bool): At the end of the forward pass, flatten the two spatial dimensions + into a single dimension. See :meth:`LevelPatchEmbed.forward` for more details. + """ super().__init__() self.max_vars = max_vars @@ -32,7 +44,8 @@ def __init__( self.embed_dim = embed_dim weight = torch.cat( - # (C_out, C_in, kT, kH, kW) + # Shape (C_out, C_in, T, H, W). `C_in = 1` here because we're embedding every variable + # separately. [torch.empty(embed_dim, 1, *self.kernel_size) for _ in range(max_vars)], dim=1, ) @@ -61,33 +74,33 @@ def reset_parameters(self) -> None: bound = 1 / math.sqrt(fan_in) nn.init.uniform_(self.bias, -bound, bound) - def forward(self, x: torch.Tensor, vars: list[int]) -> torch.Tensor: + def forward(self, x: torch.Tensor, var_ids: list[int]) -> torch.Tensor: """Run the embedding. Args: - x (:class:`torch.Tensor`): Tensor to embed of a shape of `[B, V, T, H, W]`. - vars (list[int]): A list of variable IDs. The length should be equal to `V`. + x (:class:`torch.Tensor`): Tensor to embed of a shape of `(B, V, T, H, W)`. + var_ids (list[int]): A list of variable IDs. The length should be equal to `V`. Returns: - :class:`torch.Tensor`: Embedded tensor a shape of `[B, L, D]` if flattened, - where `L = H * W / P**2`. Otherwise, the shape is `[B, D, H', W']`. + :class:`torch.Tensor`: Embedded tensor a shape of `(B, L, D]) if flattened, + where `L = H * W / P^2`. Otherwise, the shape is `(B, D, H', W')`. """ B, V, T, H, W = x.shape - assert len(vars) == V, f"{V} != {len(vars)}" - assert self.kernel_size[0] >= T, f"{T} > {self.kernel_size[0]}" - assert H % self.kernel_size[1] == 0, f"{H} % {self.kernel_size[0]} != 0" - assert W % self.kernel_size[2] == 0, f"{W} % {self.kernel_size[1]} != 0" - assert max(vars) < self.max_vars, f"{max(vars)} >= {self.max_vars}" - assert min(vars) >= 0, f"{min(vars)} < 0" - assert len(set(vars)) == len(vars), f"{vars} contains duplicates" + assert len(var_ids) == V, f"{V} != {len(var_ids)}." + assert self.kernel_size[0] >= T, f"{T} > {self.kernel_size[0]}." + assert H % self.kernel_size[1] == 0, f"{H} % {self.kernel_size[0]} != 0." + assert W % self.kernel_size[2] == 0, f"{W} % {self.kernel_size[1]} != 0." + assert max(var_ids) < self.max_vars, f"{max(var_ids)} >= {self.max_vars}." + assert min(var_ids) >= 0, f"{min(var_ids)} < 0." + assert len(set(var_ids)) == len(var_ids), f"{var_ids} contains duplicates." # Select the weights of the variables and history dimensions that are present in the batch. - weight = self.weight[:, vars, :T, ...] # [C_out, C_in, kT, kH, kW] + weight = self.weight[:, var_ids, :T, ...] # (C_out, C_in, T, H, W) # Adjust the stride if history is smaller than maximum. stride = (T,) + self.kernel_size[1:] - # (B, V, T, H, W) -> (B, D, 1, H / P, W / P) + # The convolution maps (B, V, T, H, W) to (B, D, 1, H/P, W/P) proj = F.conv3d(x, weight, self.bias, stride=stride) if self.flatten: proj = proj.reshape(B, self.embed_dim, -1) # (B, D, L) diff --git a/aurora/model/perceiver.py b/aurora/model/perceiver.py index bc8b49b..dbb4e92 100644 --- a/aurora/model/perceiver.py +++ b/aurora/model/perceiver.py @@ -61,9 +61,20 @@ import torch.nn.functional as F from einops import rearrange +__all__ = ["MLP", "PerceiverResampler"] + class MLP(nn.Module): - def __init__(self, dim, hidden_features: int, dropout=0.0): + """A simple one-hidden-layer MLP.""" + + def __init__(self, dim: int, hidden_features: int, dropout: float = 0.0) -> None: + """Initialise. + + Args: + dim (int): Input dimensionality. + hidden_features (int): Width of the hidden layer. + dropout (float, optional): Drop-out rate. Defaults to no drop-out. + """ super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_features), @@ -72,7 +83,8 @@ def __init__(self, dim, hidden_features: int, dropout=0.0): nn.Dropout(dropout), ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run the MLP.""" return self.net(x) @@ -80,12 +92,24 @@ class PerceiverAttention(nn.Module): """Cross attention module from the Perceiver architecture.""" def __init__( - self, latent_dim: int, context_dim: int, head_dim: int = 64, num_heads: int = 8 + self, + latent_dim: int, + context_dim: int, + head_dim: int = 64, + num_heads: int = 8, ) -> None: + """Initialise. + + Args: + latent_dim (int): Dimensionality of the latent features given as input. + context_dim (int): Dimensionality of the context features also given as input. + head_dim (int): Attention head dimensionality. + num_heads (int): Number of heads. + """ super().__init__() - self.inner_dim = head_dim * num_heads self.num_heads = num_heads self.head_dim = head_dim + self.inner_dim = head_dim * num_heads self.to_q = nn.Linear(latent_dim, self.inner_dim, bias=False) self.to_kv = nn.Linear(context_dim, self.inner_dim * 2, bias=False) @@ -105,8 +129,8 @@ def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """ h = self.num_heads - q = self.to_q(latents) # (B, L1, D2) -> (B, L1, D) - k, v = self.to_kv(x).chunk(2, dim=-1) # (B, L2, D1) -> (B, L2, D) x 2 + q = self.to_q(latents) # (B, L1, D2) to (B, L1, D) + k, v = self.to_kv(x).chunk(2, dim=-1) # (B, L2, D1) to twice (B, L2, D) q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b h l d", h=h), (q, k, v)) out = F.scaled_dot_product_attention(q, k, v) @@ -125,10 +149,26 @@ def __init__( head_dim: int = 64, num_heads: int = 16, mlp_ratio: float = 4.0, - drop=0.0, + drop: float = 0.0, residual_latent: bool = True, ln_eps: float = 1e-5, ) -> None: + """Initialise. + + Args: + latent_dim (int): Dimensionality of the latent features given as input. + context_dim (int): Dimensionality of the context features also given as input. + depth (int, optional): Number of attention layers. + head_dim (int, optional): Attention head dimensionality. Defaults to `64`. + num_heads (int, optional): Number of heads. Defaults to `16` + mlp_ratio (float, optional): Rimensionality of the hidden layer divided by that of the + input for all MLPs. Defaults to `4.0`. + drop (float, optional): Drop-out rate. Defaults to no drop-out. + residual_latent (bool, optional): Use residual attention w.r.t. the latent features. + Defaults to `True`. + ln_eps (float, optional): Epsilon in the layer normalisation layers. Defaults to + `1e-5`. + """ super().__init__() self.residual_latent = residual_latent @@ -165,9 +205,11 @@ def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor: # We use post-res-norm like in Swin v2 and most Transformer architectures these days. # This empirically works better than the pre-norm used in the original Perceiver. attn_out = ln1(attn(latents, x)) - # HuggingFace suggests using non-residual attention in Perceiver might - # work better when the semantics of the query and the output are different. - # https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/perceiver/modeling_perceiver.py#L398 + # HuggingFace suggests using non-residual attention in Perceiver might work better when + # the semantics of the query and the output are different: + # + # https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/perceiver/modeling_perceiver.py#L398 + # latents = attn_out + latents if self.residual_latent else attn_out latents = ln2(ff(latents)) + latents return latents