Skip to content

Commit

Permalink
Comment swin3d.MLP
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Aug 19, 2024
1 parent 7c31be6 commit b4b7885
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions aurora/model/swin3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import itertools
from datetime import timedelta
from functools import lru_cache
from typing import Optional

import torch
import torch.nn as nn
Expand All @@ -21,16 +22,32 @@
from aurora.model.lora import LoRAMode, LoRARollout
from aurora.model.util import init_weights, maybe_adjust_windows

__all__ = ["Swin3DTransformerBackbone"]


class MLP(nn.Module):
"""A one-hidden-layer MLP with dropout after the hidden layer and at the end."""

def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: type = nn.GELU,
drop: float = 0.0,
) -> None:
"""Initialise.
Args:
in_features (int): Input dimensionality.
hidden_features (int, optional): Hidden layer dimensionality. Defaults to the input
dimensionality.
out_features (int, optional): Output dimensionality. Defaults to the input
dimensionality.
act_layer (type, optional): Activation function to use. Will be instantiated as
`act_layer()`. Defaults to `torch.nn.GELU`.
drop (float, optional): Drop-out rate. Defaults to no drop-out.
"""
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
Expand All @@ -39,7 +56,8 @@ def __init__(
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Run the MLP."""
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
Expand All @@ -49,7 +67,8 @@ def forward(self, x):


class WindowAttention(nn.Module):
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
Expand Down

0 comments on commit b4b7885

Please sign in to comment.