Skip to content

Commit

Permalink
Clean up posencoding
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Aug 19, 2024
1 parent 47d3e7c commit 7c31be6
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 71 deletions.
8 changes: 4 additions & 4 deletions aurora/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from aurora.model.patchembed import LevelPatchEmbed
from aurora.model.perceiver import MLP, PerceiverResampler
from aurora.model.posencoding import get_2d_patched_lat_lon_encode
from aurora.model.posencoding import pos_scale_enc
from aurora.model.util import (
check_lat_lon_dtype,
create_var_map,
Expand Down Expand Up @@ -223,20 +223,20 @@ def forward(self, batch: Batch, lead_time: timedelta) -> torch.Tensor:
x = torch.cat((x_surf.unsqueeze(1), x_atmos), dim=1)

# Add position and scale embeddings to the 3D tensor.
pos_encode, scale_encode = get_2d_patched_lat_lon_encode(
pos_encode, scale_encode = pos_scale_enc(
self.embed_dim,
lat,
lon,
self.patch_size,
pos_expansion=pos_expansion,
scale_expansion=scale_expansion,
)
# Encodings are (L, D)
# Encodings are (L, D).
pos_encode = self.pos_embed(pos_encode[None, None, :].to(dtype=dtype))
scale_encode = self.scale_embed(scale_encode[None, None, :].to(dtype=dtype))
x = x + pos_encode + scale_encode

# Flatten the tokens
# Flatten the tokens.
x = x.reshape(B, -1, self.embed_dim) # (B, C + 1, L, D) to (B, L', D)

# Add lead time embedding.
Expand Down
146 changes: 79 additions & 67 deletions aurora/model/posencoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,39 @@

from aurora.model.fourier import FourierExpansion

__all__ = ["pos_scale_enc"]

def get_root_area_on_sphere(
lat_min: torch.Tensor, lon_min: torch.Tensor, lat_max: torch.Tensor, lon_max: torch.Tensor

def patch_root_area(
lat_min: torch.Tensor,
lon_min: torch.Tensor,
lat_max: torch.Tensor,
lon_max: torch.Tensor,
) -> torch.Tensor:
"""Calculate the root area of rectangular grid. Latitude and longitude values are used as
inputs. The root is taken to return units of km, and thus stay scalable between different
"""For a rectangular patch on a sphere, compute the square root of the area of the patch in
units km^2. The root is taken to return units of km, and thus stay scalable between different
resolutions.
Args:
lat_min (torch.Tensor): Latitude of first point.
lon_min (torch.Tensor): Longitude of first point.
lat_max (torch.Tensor): Latitude of second point.
lon_max (torch.Tensor): Longitude of second point.
lat_min (torch.Tensor): Minimum latitutes of patches.
lon_min (torch.Tensor): Minimum longitudes of patches.
lat_max (torch.Tensor): Maximum latitudes of patches.
lon_max (torch.Tensor): Maximum longitudes of patches.
Returns:
torch.Tensor: Tensor of root area on grid.
torch.Tensor: Square root of the area.
"""
# Calculate area of latitude (phi) - longitude (theta) grid using the formula:
# R**2 * pi * (sin(phi_1) - sin(phi_2)) *(theta_1 - theta_2)
# https://www.johndcook.com/blog/2023/02/21/sphere-grid-area/
assert (lat_max > lat_min).all(), f"lat_max - lat_min: {torch.min(lat_max - lat_min)}"
assert (lon_max > lon_min).all(), f"lon_max - lon_min: {torch.min(lon_max - lon_min)}"
# Calculate area of latitude-longitude grid using the following formula. Phis are latitudes
# and thetas are longitudes.
#
# area = R**2 * pi * (sin(phi_1) - sin(phi_2)) * (theta_1 - theta_2)
#
# Taken from
#
# https://www.johndcook.com/blog/2023/02/21/sphere-grid-area/
#
assert (lat_max > lat_min).all(), f"lat_max - lat_min: {torch.min(lat_max - lat_min)}."
assert (lon_max > lon_min).all(), f"lon_max - lon_min: {torch.min(lon_max - lon_min)}."
assert (abs(lat_max) <= 90.0).all() and (abs(lat_min) <= 90.0).all()
assert (lon_max <= 360.0).all() and (lon_min <= 360.0).all()
assert (lon_max >= 0.0).all() and (lon_min >= 0.0).all()
Expand All @@ -47,71 +58,66 @@ def get_root_area_on_sphere(
return torch.sqrt(area)


def get_2d_patched_lat_lon_from_grid(
def pos_scale_enc_grid(
encode_dim: int,
grid: torch.Tensor,
patch_dims: tuple,
pos_expansion: FourierExpansion,
scale_expansion: FourierExpansion,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculates 2D patched position encoding from grid. For each patch the mean latitute and
longitude values are calculated.
"""Compute the position and scale encoding for a latitude-longitude grid.
Requires batch dimensions in the input and returns a batch dimension.
Args:
encode_dim (int): Output encoding dimension `D`.
grid (torch.Tensor): Latitude-longitude grid of dimensions `(B, 2, H, W)`
patch_dims (tuple): Patch dimensions. Different x- and y-values are supported.
pos_expansion (:class:`.FourierExpansion`): Fourier expansion for the latitudes and
longitudes.
scale_expansion (:class:`.FourierExpansion`): Fourier expansion for the patch areas.
encode_dim (int): Output encoding dimension `D`. Must be a multiple of four: splits
across latitudes and longitudes and across sines and cosines.
grid (torch.Tensor): Latitude-longitude grid of dimensions `(B, 2, H, W)`. `grid[:, 0]`
should be the latitudes of `grid[:, 1]` should be the longitudes.
patch_dims (tuple): Patch dimensions. Different x-values and y-values are supported.
pos_expansion (:class:`aurora.model.fourier.FourierExpansion`): Fourier expansion for the
latitudes and longitudes.
scale_expansion (:class:`aurora.model.fourier.FourierExpansion`): Fourier expansion for the
patch areas.
Returns:
tuple[torch.Tensor, torch.Tensor]: Returns positional encoding tensor and scale tensor of
shape `(B, H/patch[0] * W/patch[1], D)`.
tuple[torch.Tensor, torch.Tensor]: Positional encoding and scale encoding of shape
`(B, H/patch[0] * W/patch[1], D)`.
"""
# encode_dim has to be % 4 (lat-lon split, sin-cosine split)
assert encode_dim % 4 == 0
assert grid.dim() == 4

# Take the 2D pooled values of the mesh - this is the same as subsequent 1D pooling over x and
# y axis.
# Take the 2D pooled values of the mesh. This is the same as subsequent 1D pooling over the
# x-axis and then ove the y-axis.
grid_h = F.avg_pool2d(grid[:, 0], patch_dims)
grid_w = F.avg_pool2d(grid[:, 1], patch_dims)

# get min and max values for x and y coordinates to calculate the diagonal of each patch
# Compute the square root of the area of the patches specified by the latitude-longitude
# grid.
grid_lat_max = F.max_pool2d(grid[:, 0], patch_dims)
grid_lat_min = -F.max_pool2d(-grid[:, 0], patch_dims)
grid_lon_max = F.max_pool2d(grid[:, 1], patch_dims)
grid_lon_min = -F.max_pool2d(-grid[:, 1], patch_dims)
root_area_on_sphere = get_root_area_on_sphere(
grid_lat_min, grid_lon_min, grid_lat_max, grid_lon_max
)
root_area = patch_root_area(grid_lat_min, grid_lon_min, grid_lat_max, grid_lon_max)

# Use half of dimensions for the latitudes of the midpoints of the patches and the other
# half for the longitudes. Before computing the encodings, flatten over the spatial dimensions.
B = grid_h.shape[0]
encode_h = pos_expansion(grid_h.reshape(B, -1), encode_dim // 2) # (B, L, D/2)
encode_w = pos_expansion(grid_w.reshape(B, -1), encode_dim // 2) # (B, L, D/2)
pos_encode = torch.cat((encode_h, encode_w), axis=-1) # (B, L, D)

# No need to split things up for the scale encoding.
scale_encode = scale_expansion(root_area.reshape(B, -1), encode_dim) # (B, L, D)

# use half of dimensions to encode grid_h
# (B, H, W) -> (B, H*W)
encode_h = pos_expansion(
grid_h.reshape(grid_h.shape[0], -1), encode_dim // 2
) # (B, H*W/patch**2, D/2)
# (B, H, W) -> (B, H*W)
encode_w = pos_expansion(
grid_w.reshape(grid_w.shape[0], -1), encode_dim // 2
) # (B, H*W/patch**2, D/2)

# use all dimensions to encode scale
# (B, H, W) -> (B, H*W)
scale_encode = scale_expansion(
root_area_on_sphere.reshape(root_area_on_sphere.shape[0], -1), encode_dim
) # (B, H*W/patch**2, D)

pos_encode = torch.cat((encode_h, encode_w), axis=-1) # (B, H*W/patch**2, D)
return pos_encode, scale_encode


def get_lat_lon_grid(lat: torch.Tensor, lon: torch.Tensor) -> torch.Tensor:
"""Return meshgrid of latitude and longitude coordinates.
def lat_lon_meshgrid(lat: torch.Tensor, lon: torch.Tensor) -> torch.Tensor:
"""Construct a meshgrid of latitude and longitude coordinates.
`torch.meshgrid(*tensors, indexing='xy')` gives the same behavior as calling
`numpy.meshgrid(*arrays, indexing='ij')`::
`torch.meshgrid(*tensors, indexing="xy")` gives the same behavior as calling
`numpy.meshgrid(*arrays, indexing="ij")`::
lat = torch.tensor([1, 2, 3])
lon = torch.tensor([4, 5, 6])
Expand All @@ -120,22 +126,23 @@ def get_lat_lon_grid(lat: torch.Tensor, lon: torch.Tensor) -> torch.Tensor:
grid_y = tensor([[4, 4, 4], [5, 5, ,5], [6, 6, 6]])
Args:
lat (torch.Tensor): 1D tensor of latitude values
lon (torch.Tensor): 1D tensor of longitude values
lat (torch.Tensor): Vector of latitudes.
lon (torch.Tensor): Vector of longitudes.
Returns:
torch.Tensor: Meshgrid of shape `(2, lat.shape, lon.shape)`
torch.Tensor: Meshgrid of shape `(2, len(lat), len(lon))`.
"""
assert lat.dim() == 1
assert lon.dim() == 1

grid = torch.meshgrid(lat, lon, indexing="xy")
grid = torch.stack(grid, axis=0)
grid = grid.permute(0, 2, 1)

return grid


def get_2d_patched_lat_lon_encode(
def pos_scale_enc(
encode_dim: int,
lat: torch.Tensor,
lon: torch.Tensor,
Expand All @@ -145,20 +152,25 @@ def get_2d_patched_lat_lon_encode(
) -> torch.Tensor:
"""Positional encoding of latitude-longitude data.
Does not support batch dimensions in the input and does not return batch dimensions either.
Args:
encode_dim (int): Output encoding dimension `D`.
lat (torch.Tensor): Tensor of latitude values `H`.
lon (torch.Tensor): Tensor of longitude values `W`.
patch_dims (Union[list, tuple]): Patch dimensions. Different x- and y-values are supported.
pos_expansion (:class:`.FourierExpansion`): Fourier expansion for the latitudes and
longitudes.
scale_expansion (:class:`.FourierExpansion`): Fourier expansion for the patch areas.
lat (torch.Tensor): Latitudes, `H`. Can be either a vector or a matrix.
lon (torch.Tensor): Longitudes, `W`. Can be either a vector or a matrix.
patch_dims (Union[list, tuple]): Patch dimensions. Different x-values and y-values are
supported.
pos_expansion (:class:`aurora.model.fourier.FourierExpansion`): Fourier expansion for the
latitudes and longitudes.
scale_expansion (:class:`aurora.model.fourier.FourierExpansion`): Fourier expansion for the
patch areas.
Returns:
torch.Tensor: Returns positional encoding tensor of shape `(H/patch[0] * W/patch[1], D)`.
tuple[torch.Tensor, torch.Tensor]: Positional encoding and scale encoding of shape
`(H/patch[0] * W/patch[1], D)`.
"""
if lat.dim() == lon.dim() == 1:
grid = get_lat_lon_grid(lat, lon)
grid = lat_lon_meshgrid(lat, lon)
elif lat.dim() == lon.dim() == 2:
grid = torch.stack((lat, lon), dim=0)
else:
Expand All @@ -169,12 +181,12 @@ def get_2d_patched_lat_lon_encode(

grid = grid[None] # Add batch dimension.

pos_encode, scale_encode = get_2d_patched_lat_lon_from_grid(
pos_encoding, scale_encoding = pos_scale_enc_grid(
encode_dim,
grid,
to_2tuple(patch_dims),
pos_expansion=pos_expansion,
scale_expansion=scale_expansion,
)

return pos_encode.squeeze(0), scale_encode.squeeze(0) # Return without batch dimension.
return pos_encoding.squeeze(0), scale_encoding.squeeze(0) # Return without batch dimension.

0 comments on commit 7c31be6

Please sign in to comment.