diff --git a/aurora/model/swin3d.py b/aurora/model/swin3d.py index 0ec3eda..99bf9ec 100644 --- a/aurora/model/swin3d.py +++ b/aurora/model/swin3d.py @@ -67,38 +67,49 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class WindowAttention(nn.Module): - """Window based multi-head self attention (W-MSA) module with relative position bias. + """Window-based multi-head self-attention (W-MSA). - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Defaults to - `True`. - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + It supports both shifted and non-shifted windows. """ def __init__( self, - dim, - window_size, - num_heads, - qkv_bias=True, - qk_scale=None, - attn_drop=0.0, - proj_drop=0.0, - lora_r=8, - lora_alpha=8, - lora_dropout=0.0, - lora_steps=40, + dim: int, + window_size: tuple[int, int, int], + num_heads: int, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + lora_r: int = 8, + lora_alpha: int = 8, + lora_dropout: float = 0.0, + lora_steps: int = 40, lora_mode: LoRAMode = "single", use_lora: bool = False, - ): + ) -> None: + """Initialise. + + Args: + dim (int): Number of input channels. + window_size (tuple[int, int, int]): The size of the windows. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If `True`, add a learnable bias to the query, key, dn value. + Defaults to `True`. + qk_scale (float, optional): If set, overrides the default query-key scale of + `1/sqrt(head_dim)`. + attn_drop (float, optional): Drop-out rate of attention weights. Default to `0.0`. + proj_drop (float, optional): Drop-out rate of the output. Default to `0.0`. + lora_r (int, optional): LoRA rank. Defaults to `8`. + lora_alpha (int, optional): LoRA alpha. Defaults to `8`. + lora_dropout (float, optional): LoRA drop-out rate. Defaults to `0.0`. + lora_steps (int, optional): Maximum number of LoRA roll-out steps. Defaults to `40`. + lora_mode (str, optional): Mode. `"single"` uses the same LoRA for all roll-out steps, + and `"all"` uses a different LoRA for every roll-out step. Defaults to `"single"`. + use_lora (bool, optional): Enable LoRA. By default, LoRA is disabled. + """ super().__init__() + self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads @@ -122,14 +133,16 @@ def __init__( self.lora_qkv = lambda *args, **kwargs: 0 # type: ignore def forward( - self, x: torch.Tensor, mask: torch.Tensor | None = None, rollout_step: int = 0 + self, + x: torch.Tensor, + mask: torch.Tensor | None = None, + rollout_step: int = 0, ) -> torch.Tensor: - """ - Runs the forward pass of the window-based multi-head self attention layer. + """Run the forward pass of the window-based multi-head self-attention layer. Args: x (torch.Tensor): Input features with shape of `(nW*B, N, C)`. - mask (torch.Tensor, optional): Attention mask of floating-points in the range + mask (torch.Tensor, optional): Attention mask of floating points in the range `[-inf, 0)` with shape of `(nW, ws, ws)`, where `nW` is the number of windows, and `ws` is the window size (i.e. total tokens inside the window). @@ -160,9 +173,9 @@ def extra_repr(self) -> str: def get_two_sidded_padding(H_padding: int, W_padding: int) -> tuple[int, int, int, int]: - """Returns the padding for the left, right, top, and bottom sides.""" - assert H_padding >= 0, f"H_padding ({H_padding}) must be >= 0" - assert W_padding >= 0, f"W_padding ({W_padding}) must be >= 0" + """Returns the padding for the left, right, top, and bottom sides, in exactly that order.""" + assert H_padding >= 0, f"H_padding ({H_padding}) must be >= 0." + assert W_padding >= 0, f"W_padding ({W_padding}) must be >= 0." if H_padding: padding_top = H_padding // 2 @@ -190,9 +203,9 @@ def window_partition_3d(x: torch.Tensor, ws: tuple[int, int, int]) -> torch.Tens torch.Tensor: Partitioning of shape `(num_windows*B, Wc, Wh, Ww, D)`. """ B, C, H, W, D = x.shape - assert C % ws[0] == 0, f"C ({C}) % window_size ({ws[0]}) must be 0" - assert H % ws[1] == 0, f"H ({H}) % window_size ({ws[1]}) must be 0" - assert W % ws[2] == 0, f"W ({W}) % window_size ({ws[2]}) must be 0" + assert C % ws[0] == 0, f"C ({C}) % window_size ({ws[0]}) must be 0." + assert H % ws[1] == 0, f"H ({H}) % window_size ({ws[1]}) must be 0." + assert W % ws[2] == 0, f"W ({W}) % window_size ({ws[2]}) must be 0." x = x.view(B, C // ws[0], ws[0], H // ws[1], ws[1], W // ws[2], ws[2], D) windows = rearrange(x, "B C1 Wc H1 Wh W1 Ww D -> (B C1 H1 W1) Wc Wh Ww D") @@ -204,7 +217,7 @@ def window_reverse_3d(windows: torch.Tensor, ws: tuple[int, int, int], C: int, H Args: windows (torch.Tensor): Partitioning of shape `(num_windows*B, Wc, Wh, Ww, D)`. - ws: (:obj:`tuple[int, int, int]`): The 3D window size + ws (tuple[int, int, int]): The 3D window size. C (int): Number of levels. H (int): Height of image. W (int): Width of image. @@ -212,9 +225,9 @@ def window_reverse_3d(windows: torch.Tensor, ws: tuple[int, int, int], C: int, H Returns: torch.Tensor: Unpartitioned input of shape `(B, C, H, W, D)`. """ - assert C % ws[0] == 0, f"D ({C}) % window_size ({ws[0]}) must be 0" - assert H % ws[1] == 0, f"H ({H}) % window_size ({ws[1]}) must be 0" - assert W % ws[2] == 0, f"W ({W}) % window_size ({ws[2]}) must be 0" + assert C % ws[0] == 0, f"D ({C}) % window_size ({ws[0]}) must be 0." + assert H % ws[1] == 0, f"H ({H}) % window_size ({ws[1]}) must be 0." + assert W % ws[2] == 0, f"W ({W}) % window_size ({ws[2]}) must be 0." C1, H1, W1 = C // ws[0], H // ws[1], W // ws[2] B = int(windows.shape[0] / (C1 * H1 * W1)) @@ -233,9 +246,12 @@ def window_reverse_3d(windows: torch.Tensor, ws: tuple[int, int, int], C: int, H def get_three_sidded_padding( - C_padding: int, H_padding: int, W_padding: int + C_padding: int, + H_padding: int, + W_padding: int, ) -> tuple[int, int, int, int, int, int]: - """Returns the padding for the left, right, top, bottom, front, and back sides.""" + """Returns the padding for the left, right, top, bottom, front, and back sides, in exactly that + order.""" assert C_padding >= 0, f"C_padding ({C_padding}) must be >= 0" if C_padding: @@ -271,10 +287,12 @@ def get_3d_merge_groups() -> list[tuple[int, int]]: """Returns the groups to be merged for the 3D case to obtain left-right connectivity.""" merge_groups_2d = [(1, 2), (4, 5), (7, 8)] merge_groups_3d = [] - for i_cslice in range(3): # i is the index of the `c_slices` + for i_c_slice in range(3): for grp1_2d, grp2_2d in merge_groups_2d: - # The 2D merge groups show up in each of the `c_slices` with an offset of 9. - offset = i_cslice * 9 # 9 = num_h_slices * num_w_slices + # The 2D merge groups show up in each of the `c_slices` with an offset of 9. 9 + # correspond to the total number of 2D merge groups. See + # :func:`compute_3d_shifted_window_mask`. + offset = i_c_slice * 9 grp1_3d, grp2_3d = grp1_2d + offset, grp2_2d + offset merge_groups_3d.append((grp1_3d, grp2_3d)) return merge_groups_3d @@ -291,33 +309,31 @@ def compute_3d_shifted_window_mask( dtype: torch.dtype = torch.bfloat16, warped: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: - """Computes the mask of each window for the shifted window attention. - - For a more detailed explanation of the algorithm used to compute the mask, - see the `compute_2d_shifted_window_mask` function in `swin_block.py`. - This function generalizes that function to 3D. + """Computes the mask of each window for the shifted-window attention. Args: C (int): Number of levels. H (int): Height of the image. W (int): Width of the image. - ws (tuple[int, int, int]): Window size of the form (Wc, Wh, Ww) - ss (tuple[int, int, int]): Shift size of the form (Sc, Sh, Sw) - dtype (torch.dtype): Data type of the mask. - warped (bool): If warped, we assume the left and right sides of the image are connected. + ws (tuple[int, int, int]): Window sizes of the form `(Wc, Wh, Ww)`. + ss (tuple[int, int, int]): Shift sizes of the form `(Sc, Sh, Sw)` + dtype (torch.dtype, optional): Data type of the mask. Defaults to `torch.bfloat16`. + warped (bool): If `True`,assume that the left and right sides of the image are connected. + Defaults to `True`. Returns: - attn_mask (torch.tensor): Attention mask for each window. Masked entries are -100 and - non-masked entries are 0. This matrix is added to the attention matrix before softmax. - img_mask (torch.tensor): Image mask splitting the input patches into groups. - Used for debugging purposes. + torch.Tensor: Attention mask for each window. Masked entries are -100 and non-masked + entries are 0. This matrix is added to the attention matrix before softmax. + torch.Tensor: Image mask splitting the input patches into groups. Used for debugging + purposes. """ - img_mask = torch.zeros((1, C, H, W, 1), device=device, dtype=dtype) # (1 C H W 1) + img_mask = torch.zeros((1, C, H, W, 1), device=device, dtype=dtype) c_slices = (slice(0, -ws[0]), slice(-ws[0], -ss[0]), slice(-ss[0], None)) h_slices = (slice(0, -ws[1]), slice(-ws[1], -ss[1]), slice(-ss[1], None)) w_slices = (slice(0, -ws[2]), slice(-ws[2], -ss[2]), slice(-ss[2], None)) - # Assign each patch to a communication group. + # Assign each patch to a communication group. The iteration order here must be consistent with + # the indices that :func:`get_3d_merge_groups` computes. cnt = 0 for c, h, w in itertools.product(c_slices, h_slices, w_slices): img_mask[:, c, h, w, :] = cnt @@ -327,7 +343,8 @@ def compute_3d_shifted_window_mask( for grp1, grp2 in get_3d_merge_groups(): img_mask = img_mask.masked_fill(img_mask == grp1, grp2) - # Pad to multiple of window size and assign padded patches to a separate group (cnt). + # Pad to multiple of window size and assign padded patches to a separate group (`cnt` is still + # unused). pad_size = (ws[0] - C % ws[0], ws[1] - H % ws[1], ws[2] - W % ws[2]) pad_size = (pad_size[0] % ws[0], pad_size[1] % ws[1], pad_size[2] % ws[2]) img_mask = pad_3d(img_mask, pad_size, value=cnt) @@ -342,45 +359,50 @@ def compute_3d_shifted_window_mask( class Swin3DTransformerBlock(nn.Module): - """3D Swin Transformer Block.""" + """3D Swin Transformer block.""" def __init__( self, - dim, - num_heads, + dim: int, + num_heads: int, time_dim: int, window_size: tuple[int, int, int] = (2, 7, 7), shift_size: tuple[int, int, int] = (0, 0, 0), - mlp_ratio=4.0, - qkv_bias=True, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - act_layer=nn.GELU, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer: type = nn.GELU, scale_bias: float = 0.0, - use_lora: bool = False, lora_steps: int = 40, lora_mode: LoRAMode = "single", - ): - """ + use_lora: bool = False, + ) -> None: + """Initialise. + Args: dim (int): Number of input channels. input_resolution (tuple[int, int]): Input resolution. num_heads (int): Number of attention heads. time_dim (int): Dimension of the lead time embedding. - window_size (tuple[int, int]): Window size. - shift_size (tuple[int, int]): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Defaults - to `True`. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - scale_bias (float): Scale bias to use for the AdaptiveLayerNorm. Default: 0 - use_lora (bool): If True, use LoRA. Default: False - lora_steps (int): Maximum number of LoRA steps to use for rollouts. Default: 40 - lora_mode (str): LoRA mode. Default: "single" + window_size (tuple[int, int, int]): Window size. Defaults to `(2, 7, 7)`. + shift_size (tuple[int, int, int]): Shift size for SW-MSA. Defaults to `(0, 0, 0)`. + mlp_ratio (float): Hidden layer dimensionality divided by that of the input for all + MLPs. Defaults to `4.0`. + qkv_bias (bool, optional): If `True,` add a learnable bias to each query, key, and + value. Defaults to `True`. + drop (float, optional): Drop-out rate. Defaults to `0.0`. + attn_drop (float, optional): Attention drop-out rate. Defaults to `0.0`. + drop_path (float, optional): Stochastic depth rate. Defaults to `0.0` + act_layer (type, optional): Activation function to use. Will be instantiated as + `act_layer()`. Defaults to `torch.nn.GELU`. + scale_bias (float, optional): Scale bias for + :class:`aurora.model.film.AdaptiveLayerNorm`. Defaults to `0`. + lora_steps (int, optional): Maximum number of LoRA roll-out steps. Defaults to `40`. + lora_mode (str, optional): Mode. `"single"` uses the same LoRA for all roll-out steps, + and `"all"` uses a different LoRA for every roll-out step. Defaults to `"single"`. + use_lora (bool): Enable LoRA. By default, LoRA is disabled. """ super().__init__() self.dim = dim @@ -418,13 +440,25 @@ def forward( c: torch.Tensor, res: tuple[int, int, int], rollout_step: int, - warped=True, - ): + warped: bool = True, + ) -> torch.Tensor: + """Run the block. + + Args: + x (torch.Tensor): Input tokens of shape `(B, L, D)`. + c (torch.Tensor): Conditioning context of shape `(B, D)`. + res (tuple[int, int, int]): Resolution of the input `x`. + rollout_step (int): Roll-out step. + warped (bool, optional): Connect the left and right sides. Defaults to `True`. + + Returns: + torch.Tensor: Output tokens. + """ C, H, W = res B, L, D = x.shape assert L == C * H * W, f"Wrong feature size: {L} vs {C}x{H}x{W}={C*H*W}" - # If window size is larger than input resolution, we don't partition windows + # If the window size is larger than the input resolution, we do not partition windows. ws, ss = maybe_adjust_windows(self.window_size, self.shift_size, res) shortcut = x @@ -448,10 +482,8 @@ def forward( x_windows = window_partition_3d(shifted_x, ws) # (nW*B, ws, ws, D) x_windows = x_windows.view(-1, ws[0] * ws[1] * ws[2], D) # (nW*B, ws*ws, D) - # W-MSA/SW-MSA. - attn_windows = self.attn( - x_windows, mask=attn_mask, rollout_step=rollout_step - ) # (nW*B, ws*ws, D) + # W-MSA/SW-MSA. Has shape (nW*B, ws*ws, D). + attn_windows = self.attn(x_windows, mask=attn_mask, rollout_step=rollout_step) # Merge the windows into the original input (patch) resolution. attn_windows = attn_windows.view(-1, ws[0], ws[1], ws[2], D) # (nW*B, Wc, Wh, Ww, D)