Skip to content

Commit

Permalink
Rework ID-based parametrisation (#27)
Browse files Browse the repository at this point in the history
* Rework ID-based param

* Update docs

* Test decoder init

* Update docs/finetuning.md

Co-authored-by: Ana Lucic <[email protected]>

---------

Co-authored-by: Ana Lucic <[email protected]>
  • Loading branch information
wesselb and a-lucic authored Sep 11, 2024
1 parent 8c2f0c2 commit b65b87d
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 101 deletions.
63 changes: 56 additions & 7 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,11 @@ def __init__(
Args:
surf_vars (tuple[str, ...], optional): All surface-level variables supported by the
model. The model is sensitive to the order of `surf_vars`! Currently, adding
one more variable here causes the model to incorrectly load the static variables.
It is possible to hack around this. We are working on a more principled fix. Please
open an issue if this is a problem for you.
model.
static_vars (tuple[str, ...], optional): All static variables supported by the
model. The model is sensitive to the order of `static_vars`!
model.
atmos_vars (tuple[str, ...], optional): All atmospheric variables supported by the
model. The model is sensitive to the order of `atmos-vars`!
model.
window_size (tuple[int, int, int], optional): Vertical height, height, and width of the
window of the underlying Swin transformer.
encoder_depths (tuple[int, ...], optional): Number of blocks in each encoder layer.
Expand Down Expand Up @@ -240,12 +237,64 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None:

d = torch.load(path, map_location=device, weights_only=True)

# Rename keys to ensure compatibility.
# You can safely ignore all cumbersome processing below. We modified the model after we
# trained it. The code below manually adapts the checkpoints, so the checkpoints are
# compatible with the new model.

# Remove possibly prefix from the keys.
for k, v in list(d.items()):
if k.startswith("net."):
del d[k]
d[k[4:]] = v

# Convert the ID-based parametrisation to a name-based parametrisation.

if "encoder.surf_token_embeds.weight" in d:
weight = d["encoder.surf_token_embeds.weight"]
del d["encoder.surf_token_embeds.weight"]

assert weight.shape[1] == 4 + 3
for i, name in enumerate(("2t", "10u", "10v", "msl", "lsm", "z", "slt")):
d[f"encoder.surf_token_embeds.weights.{name}"] = weight[:, [i]]

if "encoder.atmos_token_embeds.weight" in d:
weight = d["encoder.atmos_token_embeds.weight"]
del d["encoder.atmos_token_embeds.weight"]

assert weight.shape[1] == 5
for i, name in enumerate(("z", "u", "v", "t", "q")):
d[f"encoder.atmos_token_embeds.weights.{name}"] = weight[:, [i]]

if "decoder.surf_head.weight" in d:
weight = d["decoder.surf_head.weight"]
bias = d["decoder.surf_head.bias"]
del d["decoder.surf_head.weight"]
del d["decoder.surf_head.bias"]

assert weight.shape[0] == 4 * self.patch_size**2
assert bias.shape[0] == 4 * self.patch_size**2
weight = weight.reshape(self.patch_size**2, 4, -1)
bias = bias.reshape(self.patch_size**2, 4)

for i, name in enumerate(("2t", "10u", "10v", "msl")):
d[f"decoder.surf_heads.{name}.weight"] = weight[:, i]
d[f"decoder.surf_heads.{name}.bias"] = bias[:, i]

if "decoder.atmos_head.weight" in d:
weight = d["decoder.atmos_head.weight"]
bias = d["decoder.atmos_head.bias"]
del d["decoder.atmos_head.weight"]
del d["decoder.atmos_head.bias"]

assert weight.shape[0] == 5 * self.patch_size**2
assert bias.shape[0] == 5 * self.patch_size**2
weight = weight.reshape(self.patch_size**2, 5, -1)
bias = bias.reshape(self.patch_size**2, 5)

for i, name in enumerate(("z", "u", "v", "t", "q")):
d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i]
d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i]

self.load_state_dict(d, strict=strict)


Expand Down
27 changes: 13 additions & 14 deletions aurora/model/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from aurora.model.perceiver import PerceiverResampler
from aurora.model.util import (
check_lat_lon_dtype,
create_var_map,
get_ids_for_var_map,
init_weights,
unpatchify,
)
Expand Down Expand Up @@ -60,8 +58,6 @@ def __init__(
self.patch_size = patch_size
self.surf_vars = surf_vars
self.atmos_vars = atmos_vars
self.surf_var_map = create_var_map(surf_vars)
self.atmos_var_map = create_var_map(atmos_vars)
self.embed_dim = embed_dim

self.level_decoder = PerceiverResampler(
Expand All @@ -76,8 +72,12 @@ def __init__(
ln_eps=perceiver_ln_eps,
)

self.surf_head = nn.Linear(embed_dim, len(surf_vars) * patch_size**2)
self.atmos_head = nn.Linear(embed_dim, len(atmos_vars) * patch_size**2)
self.surf_heads = nn.ParameterDict(
{name: nn.Linear(embed_dim, patch_size**2) for name in surf_vars}
)
self.atmos_heads = nn.ParameterDict(
{name: nn.Linear(embed_dim, patch_size**2) for name in atmos_vars}
)

self.atmos_levels_embed = nn.Linear(embed_dim, embed_dim)

Expand Down Expand Up @@ -145,10 +145,10 @@ def forward(
W=patch_res[2],
)

# Decode surface vars.
x_surf = self.surf_head(x[..., :1, :]) # (B, L, 1, V_S*p*p)
surf_var_ids = get_ids_for_var_map(surf_vars, self.surf_var_map, x_surf.device)
surf_preds = unpatchify(x_surf, len(self.surf_vars), H, W, self.patch_size)[:, surf_var_ids]
# Decode surface vars. Run the head for every surface-level variable.
x_surf = torch.stack([self.surf_heads[name](x[..., :1, :]) for name in surf_vars], dim=-1)
x_surf = x_surf.reshape(*x_surf.shape[:3], -1) # (B, L, 1, V_S*p*p)
surf_preds = unpatchify(x_surf, len(surf_vars), H, W, self.patch_size)
surf_preds = surf_preds.squeeze(2) # (B, V_S, H, W)

# Embed the atmospheric levels.
Expand All @@ -162,10 +162,9 @@ def forward(
x_atmos = self.deaggregate_levels(levels_embed, x[..., 1:, :]) # (B, L, C_A, D)

# Decode the atmospheric vars.
x_atmos = self.atmos_head(x_atmos) # (B, L, C_A, V_A*p*p)
atmos_var_ids = get_ids_for_var_map(atmos_vars, self.atmos_var_map, x.device)
atmos_preds = unpatchify(x_atmos, len(self.atmos_vars), H, W, self.patch_size)
atmos_preds = atmos_preds[:, atmos_var_ids]
x_atmos = torch.stack([self.atmos_heads[name](x_atmos) for name in atmos_vars], dim=-1)
x_atmos = x_atmos.reshape(*x_atmos.shape[:3], -1) # (B, L, C_A, V_A*p*p)
atmos_preds = unpatchify(x_atmos, len(atmos_vars), H, W, self.patch_size)

return Batch(
{v: surf_preds[:, i] for i, v in enumerate(surf_vars)},
Expand Down
20 changes: 10 additions & 10 deletions aurora/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from aurora.model.posencoding import pos_scale_enc
from aurora.model.util import (
check_lat_lon_dtype,
create_var_map,
get_ids_for_var_map,
init_weights,
)

Expand Down Expand Up @@ -78,8 +76,6 @@ def __init__(

# We treat the static variables as surface variables in the model.
surf_vars = surf_vars + static_vars if static_vars is not None else surf_vars
self.surf_var_map = create_var_map(surf_vars)
self.atmos_var_map = create_var_map(atmos_vars)

# Latent tokens
assert latent_levels > 1, "At least two latent levels are required."
Expand All @@ -102,10 +98,16 @@ def __init__(
# Patch embeddings
assert max_history_size > 0, "At least one history step is required."
self.surf_token_embeds = LevelPatchEmbed(
len(surf_vars), patch_size, embed_dim, max_history_size
surf_vars,
patch_size,
embed_dim,
max_history_size,
)
self.atmos_token_embeds = LevelPatchEmbed(
len(atmos_vars), patch_size, embed_dim, max_history_size
atmos_vars,
patch_size,
embed_dim,
max_history_size,
)

# Learnable pressure level aggregation
Expand Down Expand Up @@ -194,14 +196,12 @@ def forward(self, batch: Batch, lead_time: timedelta) -> torch.Tensor:

# Patch embed the surface level.
x_surf = rearrange(x_surf, "b t v h w -> b v t h w")
surf_ids = get_ids_for_var_map(surf_vars, self.surf_var_map, x_surf.device)
x_surf = self.surf_token_embeds(x_surf, surf_ids) # (B, L, D)
x_surf = self.surf_token_embeds(x_surf, surf_vars) # (B, L, D)
dtype = x_surf.dtype # When using mixed precision, we need to keep track of the dtype.

# Patch embed the atmospheric levels.
atmos_ids = get_ids_for_var_map(atmos_vars, self.atmos_var_map, x_atmos.device)
x_atmos = rearrange(x_atmos, "b t v c h w -> (b c) v t h w")
x_atmos = self.atmos_token_embeds(x_atmos, atmos_ids)
x_atmos = self.atmos_token_embeds(x_atmos, atmos_vars)
x_atmos = rearrange(x_atmos, "(b c) l d -> b c l d", b=B, c=C)

# Add surface level encoding. This helps the model distinguish between surface and
Expand Down
44 changes: 26 additions & 18 deletions aurora/model/patchembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LevelPatchEmbed(nn.Module):

def __init__(
self,
max_vars: int,
var_names: tuple[str, ...],
patch_size: int,
embed_dim: int,
history_size: int = 1,
Expand All @@ -27,7 +27,7 @@ def __init__(
"""Initialise.
Args:
max_vars (int): Maximum number of variables to embed.
var_names (tuple[str, ...]): Variables to embed.
patch_size (int): Patch size.
embed_dim (int): Embedding dimensionality.
history_size (int, optional): Number of history dimensions. Defaults to `1`.
Expand All @@ -38,18 +38,19 @@ def __init__(
"""
super().__init__()

self.max_vars = max_vars
self.var_names = var_names
self.kernel_size = (history_size,) + to_2tuple(patch_size)
self.flatten = flatten
self.embed_dim = embed_dim

weight = torch.cat(
# Shape (C_out, C_in, T, H, W). `C_in = 1` here because we're embedding every variable
# separately.
[torch.empty(embed_dim, 1, *self.kernel_size) for _ in range(max_vars)],
dim=1,
self.weights = nn.ParameterDict(
{
# Shape (C_out, C_in, T, H, W). `C_in = 1` here because we're embedding every
# variable separately.
name: nn.Parameter(torch.empty(embed_dim, 1, *self.kernel_size))
for name in var_names
}
)
self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(torch.empty(embed_dim))
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

Expand All @@ -63,40 +64,47 @@ def init_weights(self) -> None:
#
# https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573
#
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
for weight in self.weights.values():
nn.init.kaiming_uniform_(weight, a=math.sqrt(5))

# The following initialisation is taken from
#
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv3d
#
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(next(iter(self.weights.values())))
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)

def forward(self, x: torch.Tensor, var_ids: list[int]) -> torch.Tensor:
def forward(self, x: torch.Tensor, var_names: tuple[str, ...]) -> torch.Tensor:
"""Run the embedding.
Args:
x (:class:`torch.Tensor`): Tensor to embed of a shape of `(B, V, T, H, W)`.
var_ids (list[int]): A list of variable IDs. The length should be equal to `V`.
var_names (tuple[str, ...]): Names of the variables in `x`. The length should be equal
to `V`.
Returns:
:class:`torch.Tensor`: Embedded tensor a shape of `(B, L, D]) if flattened,
where `L = H * W / P^2`. Otherwise, the shape is `(B, D, H', W')`.
"""
B, V, T, H, W = x.shape
assert len(var_ids) == V, f"{V} != {len(var_ids)}."
assert len(var_names) == V, f"{V} != {len(var_names)}."
assert self.kernel_size[0] >= T, f"{T} > {self.kernel_size[0]}."
assert H % self.kernel_size[1] == 0, f"{H} % {self.kernel_size[0]} != 0."
assert W % self.kernel_size[2] == 0, f"{W} % {self.kernel_size[1]} != 0."
assert max(var_ids) < self.max_vars, f"{max(var_ids)} >= {self.max_vars}."
assert min(var_ids) >= 0, f"{min(var_ids)} < 0."
assert len(set(var_ids)) == len(var_ids), f"{var_ids} contains duplicates."
assert len(set(var_names)) == len(var_names), f"{var_names} contains duplicates."

# Select the weights of the variables and history dimensions that are present in the batch.
weight = self.weight[:, var_ids, :T, ...] # (C_out, C_in, T, H, W)
weight = torch.cat(
[
# (C_out, C_in, T, H, W)
self.weights[name][:, :, :T, ...]
for name in var_names
],
dim=1,
)
# Adjust the stride if history is smaller than maximum.
stride = (T,) + self.kernel_size[1:]

Expand Down
33 changes: 0 additions & 33 deletions aurora/model/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

__all__ = [
"unpatchify",
"create_var_map",
"get_ids_for_var_map",
"check_lat_lon_dtype",
"maybe_adjust_windows",
"init_weights",
Expand Down Expand Up @@ -43,37 +41,6 @@ def unpatchify(x: torch.Tensor, V: int, H: int, W: int, P: int) -> torch.Tensor:
return x


def create_var_map(variables: tuple[str, ...]) -> dict[str, int]:
"""Create dictionary where the keys are variable names and values are unique IDs.
Args:
variables (tuple[str, ...]): Variable strings.
Returns:
dict[str, int]: Variable map dictionary.
"""
return {v: i for i, v in enumerate(variables)}


def get_ids_for_var_map(
variables: tuple,
var_maps: dict,
device: torch.cuda.device,
) -> torch.Tensor:
"""Construct a tensor of variable IDs after retrieving those from a variable map created with
:func:`.create_var_map`.
Args:
variables (tuples[str, ...]): Variables to retrieve the IDs for.
var_maps (dict[str, int]): Variable map constructed with :func:`.create_var_map`.
device (torch.cuda.device): Device.
Returns:
torch.Tensor: Tensor of variable IDs found in `var_map`.
"""
return torch.tensor([var_maps[v] for v in variables], device=device)


def check_lat_lon_dtype(lat: torch.Tensor, lon: torch.Tensor) -> None:
"""Assert that `lat` and `lon` are at least `float32`s."""
assert lat.dtype in [torch.float32, torch.float64], f"Latitude num. unstable: {lat.dtype}."
Expand Down
11 changes: 0 additions & 11 deletions docs/beware.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,3 @@ If you changed the model and added or removed parameters, you need to set `stric
loading a checkpoint `Aurora.load_checkpoint(..., strict=False)`.
Importantly, enabling or disabling LoRA for a model that was trained respectively without or
with LoRA changes the parameters!

## Extending the Model with New Surface-Level Variables

Whereas we have attempted to design a robust and flexible model,
inevitably some unfortunate design choices slipped through.

A notable unfortunate design choice is that extending the model with a new surface-level
variable breaks compatibility with existing checkpoints.
It is possible to hack around this in a relatively simple way.
We are working on a more principled fix.
Please open an issue if this is a problem for you.
Loading

0 comments on commit b65b87d

Please sign in to comment.