Skip to content

Commit

Permalink
First round of improvements for swin3d
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Aug 19, 2024
1 parent b4b7885 commit 08c8092
Showing 1 changed file with 123 additions and 91 deletions.
214 changes: 123 additions & 91 deletions aurora/model/swin3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,38 +67,49 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class WindowAttention(nn.Module):
"""Window based multi-head self attention (W-MSA) module with relative position bias.
"""Window-based multi-head self-attention (W-MSA).
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Defaults to
`True`.
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
It supports both shifted and non-shifted windows.
"""

def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
lora_r=8,
lora_alpha=8,
lora_dropout=0.0,
lora_steps=40,
dim: int,
window_size: tuple[int, int, int],
num_heads: int,
qkv_bias: bool = True,
qk_scale: Optional[float] = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
lora_r: int = 8,
lora_alpha: int = 8,
lora_dropout: float = 0.0,
lora_steps: int = 40,
lora_mode: LoRAMode = "single",
use_lora: bool = False,
):
) -> None:
"""Initialise.
Args:
dim (int): Number of input channels.
window_size (tuple[int, int, int]): The size of the windows.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If `True`, add a learnable bias to the query, key, dn value.
Defaults to `True`.
qk_scale (float, optional): If set, overrides the default query-key scale of
`1/sqrt(head_dim)`.
attn_drop (float, optional): Drop-out rate of attention weights. Default to `0.0`.
proj_drop (float, optional): Drop-out rate of the output. Default to `0.0`.
lora_r (int, optional): LoRA rank. Defaults to `8`.
lora_alpha (int, optional): LoRA alpha. Defaults to `8`.
lora_dropout (float, optional): LoRA drop-out rate. Defaults to `0.0`.
lora_steps (int, optional): Maximum number of LoRA roll-out steps. Defaults to `40`.
lora_mode (str, optional): Mode. `"single"` uses the same LoRA for all roll-out steps,
and `"all"` uses a different LoRA for every roll-out step. Defaults to `"single"`.
use_lora (bool, optional): Enable LoRA. By default, LoRA is disabled.
"""
super().__init__()

self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
Expand All @@ -122,14 +133,16 @@ def __init__(
self.lora_qkv = lambda *args, **kwargs: 0 # type: ignore

def forward(
self, x: torch.Tensor, mask: torch.Tensor | None = None, rollout_step: int = 0
self,
x: torch.Tensor,
mask: torch.Tensor | None = None,
rollout_step: int = 0,
) -> torch.Tensor:
"""
Runs the forward pass of the window-based multi-head self attention layer.
"""Run the forward pass of the window-based multi-head self-attention layer.
Args:
x (torch.Tensor): Input features with shape of `(nW*B, N, C)`.
mask (torch.Tensor, optional): Attention mask of floating-points in the range
mask (torch.Tensor, optional): Attention mask of floating points in the range
`[-inf, 0)` with shape of `(nW, ws, ws)`, where `nW` is the number of windows,
and `ws` is the window size (i.e. total tokens inside the window).
Expand Down Expand Up @@ -160,9 +173,9 @@ def extra_repr(self) -> str:


def get_two_sidded_padding(H_padding: int, W_padding: int) -> tuple[int, int, int, int]:
"""Returns the padding for the left, right, top, and bottom sides."""
assert H_padding >= 0, f"H_padding ({H_padding}) must be >= 0"
assert W_padding >= 0, f"W_padding ({W_padding}) must be >= 0"
"""Returns the padding for the left, right, top, and bottom sides, in exactly that order."""
assert H_padding >= 0, f"H_padding ({H_padding}) must be >= 0."
assert W_padding >= 0, f"W_padding ({W_padding}) must be >= 0."

if H_padding:
padding_top = H_padding // 2
Expand Down Expand Up @@ -190,9 +203,9 @@ def window_partition_3d(x: torch.Tensor, ws: tuple[int, int, int]) -> torch.Tens
torch.Tensor: Partitioning of shape `(num_windows*B, Wc, Wh, Ww, D)`.
"""
B, C, H, W, D = x.shape
assert C % ws[0] == 0, f"C ({C}) % window_size ({ws[0]}) must be 0"
assert H % ws[1] == 0, f"H ({H}) % window_size ({ws[1]}) must be 0"
assert W % ws[2] == 0, f"W ({W}) % window_size ({ws[2]}) must be 0"
assert C % ws[0] == 0, f"C ({C}) % window_size ({ws[0]}) must be 0."
assert H % ws[1] == 0, f"H ({H}) % window_size ({ws[1]}) must be 0."
assert W % ws[2] == 0, f"W ({W}) % window_size ({ws[2]}) must be 0."

x = x.view(B, C // ws[0], ws[0], H // ws[1], ws[1], W // ws[2], ws[2], D)
windows = rearrange(x, "B C1 Wc H1 Wh W1 Ww D -> (B C1 H1 W1) Wc Wh Ww D")
Expand All @@ -204,17 +217,17 @@ def window_reverse_3d(windows: torch.Tensor, ws: tuple[int, int, int], C: int, H
Args:
windows (torch.Tensor): Partitioning of shape `(num_windows*B, Wc, Wh, Ww, D)`.
ws: (:obj:`tuple[int, int, int]`): The 3D window size
ws (tuple[int, int, int]): The 3D window size.
C (int): Number of levels.
H (int): Height of image.
W (int): Width of image.
Returns:
torch.Tensor: Unpartitioned input of shape `(B, C, H, W, D)`.
"""
assert C % ws[0] == 0, f"D ({C}) % window_size ({ws[0]}) must be 0"
assert H % ws[1] == 0, f"H ({H}) % window_size ({ws[1]}) must be 0"
assert W % ws[2] == 0, f"W ({W}) % window_size ({ws[2]}) must be 0"
assert C % ws[0] == 0, f"D ({C}) % window_size ({ws[0]}) must be 0."
assert H % ws[1] == 0, f"H ({H}) % window_size ({ws[1]}) must be 0."
assert W % ws[2] == 0, f"W ({W}) % window_size ({ws[2]}) must be 0."

C1, H1, W1 = C // ws[0], H // ws[1], W // ws[2]
B = int(windows.shape[0] / (C1 * H1 * W1))
Expand All @@ -233,9 +246,12 @@ def window_reverse_3d(windows: torch.Tensor, ws: tuple[int, int, int], C: int, H


def get_three_sidded_padding(
C_padding: int, H_padding: int, W_padding: int
C_padding: int,
H_padding: int,
W_padding: int,
) -> tuple[int, int, int, int, int, int]:
"""Returns the padding for the left, right, top, bottom, front, and back sides."""
"""Returns the padding for the left, right, top, bottom, front, and back sides, in exactly that
order."""
assert C_padding >= 0, f"C_padding ({C_padding}) must be >= 0"

if C_padding:
Expand Down Expand Up @@ -271,10 +287,12 @@ def get_3d_merge_groups() -> list[tuple[int, int]]:
"""Returns the groups to be merged for the 3D case to obtain left-right connectivity."""
merge_groups_2d = [(1, 2), (4, 5), (7, 8)]
merge_groups_3d = []
for i_cslice in range(3): # i is the index of the `c_slices`
for i_c_slice in range(3):
for grp1_2d, grp2_2d in merge_groups_2d:
# The 2D merge groups show up in each of the `c_slices` with an offset of 9.
offset = i_cslice * 9 # 9 = num_h_slices * num_w_slices
# The 2D merge groups show up in each of the `c_slices` with an offset of 9. 9
# correspond to the total number of 2D merge groups. See
# :func:`compute_3d_shifted_window_mask`.
offset = i_c_slice * 9
grp1_3d, grp2_3d = grp1_2d + offset, grp2_2d + offset
merge_groups_3d.append((grp1_3d, grp2_3d))
return merge_groups_3d
Expand All @@ -291,33 +309,31 @@ def compute_3d_shifted_window_mask(
dtype: torch.dtype = torch.bfloat16,
warped: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Computes the mask of each window for the shifted window attention.
For a more detailed explanation of the algorithm used to compute the mask,
see the `compute_2d_shifted_window_mask` function in `swin_block.py`.
This function generalizes that function to 3D.
"""Computes the mask of each window for the shifted-window attention.
Args:
C (int): Number of levels.
H (int): Height of the image.
W (int): Width of the image.
ws (tuple[int, int, int]): Window size of the form (Wc, Wh, Ww)
ss (tuple[int, int, int]): Shift size of the form (Sc, Sh, Sw)
dtype (torch.dtype): Data type of the mask.
warped (bool): If warped, we assume the left and right sides of the image are connected.
ws (tuple[int, int, int]): Window sizes of the form `(Wc, Wh, Ww)`.
ss (tuple[int, int, int]): Shift sizes of the form `(Sc, Sh, Sw)`
dtype (torch.dtype, optional): Data type of the mask. Defaults to `torch.bfloat16`.
warped (bool): If `True`,assume that the left and right sides of the image are connected.
Defaults to `True`.
Returns:
attn_mask (torch.tensor): Attention mask for each window. Masked entries are -100 and
non-masked entries are 0. This matrix is added to the attention matrix before softmax.
img_mask (torch.tensor): Image mask splitting the input patches into groups.
Used for debugging purposes.
torch.Tensor: Attention mask for each window. Masked entries are -100 and non-masked
entries are 0. This matrix is added to the attention matrix before softmax.
torch.Tensor: Image mask splitting the input patches into groups. Used for debugging
purposes.
"""
img_mask = torch.zeros((1, C, H, W, 1), device=device, dtype=dtype) # (1 C H W 1)
img_mask = torch.zeros((1, C, H, W, 1), device=device, dtype=dtype)
c_slices = (slice(0, -ws[0]), slice(-ws[0], -ss[0]), slice(-ss[0], None))
h_slices = (slice(0, -ws[1]), slice(-ws[1], -ss[1]), slice(-ss[1], None))
w_slices = (slice(0, -ws[2]), slice(-ws[2], -ss[2]), slice(-ss[2], None))

# Assign each patch to a communication group.
# Assign each patch to a communication group. The iteration order here must be consistent with
# the indices that :func:`get_3d_merge_groups` computes.
cnt = 0
for c, h, w in itertools.product(c_slices, h_slices, w_slices):
img_mask[:, c, h, w, :] = cnt
Expand All @@ -327,7 +343,8 @@ def compute_3d_shifted_window_mask(
for grp1, grp2 in get_3d_merge_groups():
img_mask = img_mask.masked_fill(img_mask == grp1, grp2)

# Pad to multiple of window size and assign padded patches to a separate group (cnt).
# Pad to multiple of window size and assign padded patches to a separate group (`cnt` is still
# unused).
pad_size = (ws[0] - C % ws[0], ws[1] - H % ws[1], ws[2] - W % ws[2])
pad_size = (pad_size[0] % ws[0], pad_size[1] % ws[1], pad_size[2] % ws[2])
img_mask = pad_3d(img_mask, pad_size, value=cnt)
Expand All @@ -342,45 +359,50 @@ def compute_3d_shifted_window_mask(


class Swin3DTransformerBlock(nn.Module):
"""3D Swin Transformer Block."""
"""3D Swin Transformer block."""

def __init__(
self,
dim,
num_heads,
dim: int,
num_heads: int,
time_dim: int,
window_size: tuple[int, int, int] = (2, 7, 7),
shift_size: tuple[int, int, int] = (0, 0, 0),
mlp_ratio=4.0,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
act_layer: type = nn.GELU,
scale_bias: float = 0.0,
use_lora: bool = False,
lora_steps: int = 40,
lora_mode: LoRAMode = "single",
):
"""
use_lora: bool = False,
) -> None:
"""Initialise.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int, int]): Input resolution.
num_heads (int): Number of attention heads.
time_dim (int): Dimension of the lead time embedding.
window_size (tuple[int, int]): Window size.
shift_size (tuple[int, int]): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Defaults
to `True`.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
scale_bias (float): Scale bias to use for the AdaptiveLayerNorm. Default: 0
use_lora (bool): If True, use LoRA. Default: False
lora_steps (int): Maximum number of LoRA steps to use for rollouts. Default: 40
lora_mode (str): LoRA mode. Default: "single"
window_size (tuple[int, int, int]): Window size. Defaults to `(2, 7, 7)`.
shift_size (tuple[int, int, int]): Shift size for SW-MSA. Defaults to `(0, 0, 0)`.
mlp_ratio (float): Hidden layer dimensionality divided by that of the input for all
MLPs. Defaults to `4.0`.
qkv_bias (bool, optional): If `True,` add a learnable bias to each query, key, and
value. Defaults to `True`.
drop (float, optional): Drop-out rate. Defaults to `0.0`.
attn_drop (float, optional): Attention drop-out rate. Defaults to `0.0`.
drop_path (float, optional): Stochastic depth rate. Defaults to `0.0`
act_layer (type, optional): Activation function to use. Will be instantiated as
`act_layer()`. Defaults to `torch.nn.GELU`.
scale_bias (float, optional): Scale bias for
:class:`aurora.model.film.AdaptiveLayerNorm`. Defaults to `0`.
lora_steps (int, optional): Maximum number of LoRA roll-out steps. Defaults to `40`.
lora_mode (str, optional): Mode. `"single"` uses the same LoRA for all roll-out steps,
and `"all"` uses a different LoRA for every roll-out step. Defaults to `"single"`.
use_lora (bool): Enable LoRA. By default, LoRA is disabled.
"""
super().__init__()
self.dim = dim
Expand Down Expand Up @@ -418,13 +440,25 @@ def forward(
c: torch.Tensor,
res: tuple[int, int, int],
rollout_step: int,
warped=True,
):
warped: bool = True,
) -> torch.Tensor:
"""Run the block.
Args:
x (torch.Tensor): Input tokens of shape `(B, L, D)`.
c (torch.Tensor): Conditioning context of shape `(B, D)`.
res (tuple[int, int, int]): Resolution of the input `x`.
rollout_step (int): Roll-out step.
warped (bool, optional): Connect the left and right sides. Defaults to `True`.
Returns:
torch.Tensor: Output tokens.
"""
C, H, W = res
B, L, D = x.shape
assert L == C * H * W, f"Wrong feature size: {L} vs {C}x{H}x{W}={C*H*W}"

# If window size is larger than input resolution, we don't partition windows
# If the window size is larger than the input resolution, we do not partition windows.
ws, ss = maybe_adjust_windows(self.window_size, self.shift_size, res)

shortcut = x
Expand All @@ -448,10 +482,8 @@ def forward(
x_windows = window_partition_3d(shifted_x, ws) # (nW*B, ws, ws, D)
x_windows = x_windows.view(-1, ws[0] * ws[1] * ws[2], D) # (nW*B, ws*ws, D)

# W-MSA/SW-MSA.
attn_windows = self.attn(
x_windows, mask=attn_mask, rollout_step=rollout_step
) # (nW*B, ws*ws, D)
# W-MSA/SW-MSA. Has shape (nW*B, ws*ws, D).
attn_windows = self.attn(x_windows, mask=attn_mask, rollout_step=rollout_step)

# Merge the windows into the original input (patch) resolution.
attn_windows = attn_windows.view(-1, ws[0], ws[1], ws[2], D) # (nW*B, Wc, Wh, Ww, D)
Expand Down

0 comments on commit 08c8092

Please sign in to comment.