Skip to content

Commit

Permalink
Remove unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Aug 9, 2024
1 parent 765fccd commit e5af5fd
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 947 deletions.
71 changes: 0 additions & 71 deletions aurora/model/patchembed.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

import math
from collections.abc import Iterable

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.layers.helpers import to_2tuple
from timm.models.vision_transformer import trunc_normal_


class LevelPatchEmbed(nn.Module):
Expand Down Expand Up @@ -92,71 +89,3 @@ def forward(self, x: torch.Tensor, vars: list[int]) -> torch.Tensor:

x = self.norm(proj)
return x


class StableGroupedVarPatchEmbed(nn.Module):
def __init__(
self,
max_vars: int,
patch_size: int,
embed_dim: int,
norm_layer: nn.Module = None,
return_flatten: bool = True,
):
super().__init__()
self.max_vars = max_vars
self.patch_size = to_2tuple(patch_size)
self.embed_dim = embed_dim
self.return_flatten = return_flatten

self.proj = nn.ModuleList(
[
nn.Conv2d(
1,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bool(norm_layer),
)
for _ in range(max_vars)
]
)

if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = nn.Identity()

self.apply(self._init_weights)

def _init_weights(self, m):
"""Initialize conv layers and layer norm."""
if isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def forward(self, x: torch.Tensor, vars: Iterable[int]):
"""Forward fucntion
Args:
x (torch.Tensor): a shape of [BT, V, L, C] tensor
vars (list[int], optional): a list of variable ID
Returns:
proj (torch.Tensor): a shape of [BT V L' C] tensor
"""
proj = []
for i, var in enumerate(vars):
proj.append(self.proj[var](x[:, i : i + 1]))
proj = torch.stack(proj, dim=1) # BT, V, C, H, W

if self.return_flatten:
proj = rearrange(proj, "b v c h w -> b v (h w) c")

proj = self.norm(proj)

return proj
70 changes: 0 additions & 70 deletions aurora/model/posencoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,6 @@
from aurora.model.fourier import FourierExpansion


def get_great_circle_distance(
lat_min: torch.Tensor, lon_min: torch.Tensor, lat_max: torch.Tensor, lon_max: torch.Tensor
) -> torch.Tensor:
"""Calculate the great-circle distance between two points on a sphere via the Haversine formula.
Latitude and longitude values are used as inputs.
Args:
lat_min (torch.Tensor): Latitude of first point.
lon_min (torch.Tensor): Longitude of first point.
lat_max (torch.Tensor): Latitude of second point.
lon_max (torch.Tensor): Longitude of second point.
Returns:
torch.Tensor: Tensor of great-circle distance between pairs of points multiplied by the
radius of the earth.
"""
delta_lat = torch.deg2rad(lat_min) - torch.deg2rad(lat_max)
delta_lon = torch.deg2rad(lon_min) - torch.deg2rad(lon_max)
# "Haversine" formula where the radius is the radius of the earth = 6371km.
# https://en.wikipedia.org/wiki/Haversine_formula
great_circle_dist = (
2
* 6371
* torch.asin(
torch.sqrt(
torch.sin(delta_lat / 2) ** 2
+ torch.cos(torch.deg2rad(lat_min))
* torch.cos(torch.deg2rad(lat_max))
* torch.sin(delta_lon / 2) ** 2
)
)
)
return great_circle_dist


def get_root_area_on_sphere(
lat_min: torch.Tensor, lon_min: torch.Tensor, lat_max: torch.Tensor, lon_max: torch.Tensor
) -> torch.Tensor:
Expand Down Expand Up @@ -221,38 +186,3 @@ def get_2d_patched_lat_lon_encode(
)

return pos_encode.squeeze(0), scale_encode.squeeze(0) # Return without batch dimension.


def get_flexible_2d_patched_lat_lon_encode(
encode_dim: int,
lat: torch.Tensor,
lon: torch.Tensor,
patch_dims: int | list | tuple,
pos_expansion: FourierExpansion,
scale_expansion: FourierExpansion,
) -> torch.Tensor:
"""Positional encoding of latitude-longitude data that works for non-regular data such as HRRR.
Args:
encode_dim (int): Output encoding dimension `D`.
lat (torch.Tensor): Tensor of latitude values `(B, H, W)`.
lon (torch.Tensor): Tensor of longitude values `(B, H, W)`.
patch (Union[list, tuple]): Patch dimensions. Different x- and y-values are supported.
pos_expansion (:class:`.FourierExpansion`): Fourier expansion for the latitudes and
longitudes.
scale_expansion (:class:`.FourierExpansion`): Fourier expansion for the patch areas.
Returns:
torch.Tensor: Returns positional encoding tensor of shape `(B, H/patch[0] * W/patch[1], D)`.
"""

grid = torch.cat((lat[:, None, ...], lon[:, None, ...]), dim=1)
pos_encode, scale_encode = get_2d_patched_lat_lon_from_grid(
encode_dim,
grid,
to_2tuple(patch_dims),
pos_expansion=pos_expansion,
scale_expansion=scale_expansion,
)

return pos_encode, scale_encode
Loading

0 comments on commit e5af5fd

Please sign in to comment.