diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index f734137..0cbbf43 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -52,14 +52,11 @@ def __init__( Args: surf_vars (tuple[str, ...], optional): All surface-level variables supported by the - model. The model is sensitive to the order of `surf_vars`! Currently, adding - one more variable here causes the model to incorrectly load the static variables. - It is possible to hack around this. We are working on a more principled fix. Please - open an issue if this is a problem for you. + model. static_vars (tuple[str, ...], optional): All static variables supported by the - model. The model is sensitive to the order of `static_vars`! + model. atmos_vars (tuple[str, ...], optional): All atmospheric variables supported by the - model. The model is sensitive to the order of `atmos-vars`! + model. window_size (tuple[int, int, int], optional): Vertical height, height, and width of the window of the underlying Swin transformer. encoder_depths (tuple[int, ...], optional): Number of blocks in each encoder layer. @@ -240,12 +237,64 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None: d = torch.load(path, map_location=device, weights_only=True) - # Rename keys to ensure compatibility. + # You can safely ignore all cumbersome processing below. We modified the model after we + # trained it. The code below manually adapts the checkpoints, so the checkpoints are + # compatible with the new model. + + # Remove possibly prefix from the keys. for k, v in list(d.items()): if k.startswith("net."): del d[k] d[k[4:]] = v + # Convert the ID-based parametrisation to a name-based parametrisation. + + if "encoder.surf_token_embeds.weight" in d: + weight = d["encoder.surf_token_embeds.weight"] + del d["encoder.surf_token_embeds.weight"] + + assert weight.shape[1] == 4 + 3 + for i, name in enumerate(("2t", "10u", "10v", "msl", "lsm", "z", "slt")): + d[f"encoder.surf_token_embeds.weights.{name}"] = weight[:, [i]] + + if "encoder.atmos_token_embeds.weight" in d: + weight = d["encoder.atmos_token_embeds.weight"] + del d["encoder.atmos_token_embeds.weight"] + + assert weight.shape[1] == 5 + for i, name in enumerate(("z", "u", "v", "t", "q")): + d[f"encoder.atmos_token_embeds.weights.{name}"] = weight[:, [i]] + + if "decoder.surf_head.weight" in d: + weight = d["decoder.surf_head.weight"] + bias = d["decoder.surf_head.bias"] + del d["decoder.surf_head.weight"] + del d["decoder.surf_head.bias"] + + assert weight.shape[0] == 4 * self.patch_size**2 + assert bias.shape[0] == 4 * self.patch_size**2 + weight = weight.reshape(self.patch_size**2, 4, -1) + bias = bias.reshape(self.patch_size**2, 4) + + for i, name in enumerate(("2t", "10u", "10v", "msl")): + d[f"decoder.surf_heads.{name}.weight"] = weight[:, i] + d[f"decoder.surf_heads.{name}.bias"] = bias[:, i] + + if "decoder.atmos_head.weight" in d: + weight = d["decoder.atmos_head.weight"] + bias = d["decoder.atmos_head.bias"] + del d["decoder.atmos_head.weight"] + del d["decoder.atmos_head.bias"] + + assert weight.shape[0] == 5 * self.patch_size**2 + assert bias.shape[0] == 5 * self.patch_size**2 + weight = weight.reshape(self.patch_size**2, 5, -1) + bias = bias.reshape(self.patch_size**2, 5) + + for i, name in enumerate(("z", "u", "v", "t", "q")): + d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i] + d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i] + self.load_state_dict(d, strict=strict) diff --git a/aurora/model/decoder.py b/aurora/model/decoder.py index 1a3bc1c..e58679a 100644 --- a/aurora/model/decoder.py +++ b/aurora/model/decoder.py @@ -11,8 +11,6 @@ from aurora.model.perceiver import PerceiverResampler from aurora.model.util import ( check_lat_lon_dtype, - create_var_map, - get_ids_for_var_map, init_weights, unpatchify, ) @@ -60,8 +58,6 @@ def __init__( self.patch_size = patch_size self.surf_vars = surf_vars self.atmos_vars = atmos_vars - self.surf_var_map = create_var_map(surf_vars) - self.atmos_var_map = create_var_map(atmos_vars) self.embed_dim = embed_dim self.level_decoder = PerceiverResampler( @@ -76,8 +72,12 @@ def __init__( ln_eps=perceiver_ln_eps, ) - self.surf_head = nn.Linear(embed_dim, len(surf_vars) * patch_size**2) - self.atmos_head = nn.Linear(embed_dim, len(atmos_vars) * patch_size**2) + self.surf_heads = nn.ParameterDict( + {name: nn.Linear(embed_dim, patch_size**2) for name in surf_vars} + ) + self.atmos_heads = nn.ParameterDict( + {name: nn.Linear(embed_dim, patch_size**2) for name in atmos_vars} + ) self.atmos_levels_embed = nn.Linear(embed_dim, embed_dim) @@ -145,10 +145,10 @@ def forward( W=patch_res[2], ) - # Decode surface vars. - x_surf = self.surf_head(x[..., :1, :]) # (B, L, 1, V_S*p*p) - surf_var_ids = get_ids_for_var_map(surf_vars, self.surf_var_map, x_surf.device) - surf_preds = unpatchify(x_surf, len(self.surf_vars), H, W, self.patch_size)[:, surf_var_ids] + # Decode surface vars. Run the head for every surface-level variable. + x_surf = torch.stack([self.surf_heads[name](x[..., :1, :]) for name in surf_vars], dim=-1) + x_surf = x_surf.reshape(*x_surf.shape[:3], -1) # (B, L, 1, V_S*p*p) + surf_preds = unpatchify(x_surf, len(surf_vars), H, W, self.patch_size) surf_preds = surf_preds.squeeze(2) # (B, V_S, H, W) # Embed the atmospheric levels. @@ -162,10 +162,9 @@ def forward( x_atmos = self.deaggregate_levels(levels_embed, x[..., 1:, :]) # (B, L, C_A, D) # Decode the atmospheric vars. - x_atmos = self.atmos_head(x_atmos) # (B, L, C_A, V_A*p*p) - atmos_var_ids = get_ids_for_var_map(atmos_vars, self.atmos_var_map, x.device) - atmos_preds = unpatchify(x_atmos, len(self.atmos_vars), H, W, self.patch_size) - atmos_preds = atmos_preds[:, atmos_var_ids] + x_atmos = torch.stack([self.atmos_heads[name](x_atmos) for name in atmos_vars], dim=-1) + x_atmos = x_atmos.reshape(*x_atmos.shape[:3], -1) # (B, L, C_A, V_A*p*p) + atmos_preds = unpatchify(x_atmos, len(atmos_vars), H, W, self.patch_size) return Batch( {v: surf_preds[:, i] for i, v in enumerate(surf_vars)}, diff --git a/aurora/model/encoder.py b/aurora/model/encoder.py index 43edd8e..bcbdafa 100644 --- a/aurora/model/encoder.py +++ b/aurora/model/encoder.py @@ -19,8 +19,6 @@ from aurora.model.posencoding import pos_scale_enc from aurora.model.util import ( check_lat_lon_dtype, - create_var_map, - get_ids_for_var_map, init_weights, ) @@ -78,8 +76,6 @@ def __init__( # We treat the static variables as surface variables in the model. surf_vars = surf_vars + static_vars if static_vars is not None else surf_vars - self.surf_var_map = create_var_map(surf_vars) - self.atmos_var_map = create_var_map(atmos_vars) # Latent tokens assert latent_levels > 1, "At least two latent levels are required." @@ -102,10 +98,16 @@ def __init__( # Patch embeddings assert max_history_size > 0, "At least one history step is required." self.surf_token_embeds = LevelPatchEmbed( - len(surf_vars), patch_size, embed_dim, max_history_size + surf_vars, + patch_size, + embed_dim, + max_history_size, ) self.atmos_token_embeds = LevelPatchEmbed( - len(atmos_vars), patch_size, embed_dim, max_history_size + atmos_vars, + patch_size, + embed_dim, + max_history_size, ) # Learnable pressure level aggregation @@ -194,14 +196,12 @@ def forward(self, batch: Batch, lead_time: timedelta) -> torch.Tensor: # Patch embed the surface level. x_surf = rearrange(x_surf, "b t v h w -> b v t h w") - surf_ids = get_ids_for_var_map(surf_vars, self.surf_var_map, x_surf.device) - x_surf = self.surf_token_embeds(x_surf, surf_ids) # (B, L, D) + x_surf = self.surf_token_embeds(x_surf, surf_vars) # (B, L, D) dtype = x_surf.dtype # When using mixed precision, we need to keep track of the dtype. # Patch embed the atmospheric levels. - atmos_ids = get_ids_for_var_map(atmos_vars, self.atmos_var_map, x_atmos.device) x_atmos = rearrange(x_atmos, "b t v c h w -> (b c) v t h w") - x_atmos = self.atmos_token_embeds(x_atmos, atmos_ids) + x_atmos = self.atmos_token_embeds(x_atmos, atmos_vars) x_atmos = rearrange(x_atmos, "(b c) l d -> b c l d", b=B, c=C) # Add surface level encoding. This helps the model distinguish between surface and diff --git a/aurora/model/patchembed.py b/aurora/model/patchembed.py index f58f1e1..2589d08 100644 --- a/aurora/model/patchembed.py +++ b/aurora/model/patchembed.py @@ -17,7 +17,7 @@ class LevelPatchEmbed(nn.Module): def __init__( self, - max_vars: int, + var_names: tuple[str, ...], patch_size: int, embed_dim: int, history_size: int = 1, @@ -27,7 +27,7 @@ def __init__( """Initialise. Args: - max_vars (int): Maximum number of variables to embed. + var_names (tuple[str, ...]): Variables to embed. patch_size (int): Patch size. embed_dim (int): Embedding dimensionality. history_size (int, optional): Number of history dimensions. Defaults to `1`. @@ -38,18 +38,19 @@ def __init__( """ super().__init__() - self.max_vars = max_vars + self.var_names = var_names self.kernel_size = (history_size,) + to_2tuple(patch_size) self.flatten = flatten self.embed_dim = embed_dim - weight = torch.cat( - # Shape (C_out, C_in, T, H, W). `C_in = 1` here because we're embedding every variable - # separately. - [torch.empty(embed_dim, 1, *self.kernel_size) for _ in range(max_vars)], - dim=1, + self.weights = nn.ParameterDict( + { + # Shape (C_out, C_in, T, H, W). `C_in = 1` here because we're embedding every + # variable separately. + name: nn.Parameter(torch.empty(embed_dim, 1, *self.kernel_size)) + for name in var_names + } ) - self.weight = nn.Parameter(weight) self.bias = nn.Parameter(torch.empty(embed_dim)) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() @@ -63,23 +64,25 @@ def init_weights(self) -> None: # # https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573 # - nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + for weight in self.weights.values(): + nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) # The following initialisation is taken from # # https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv3d # - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(next(iter(self.weights.values()))) if fan_in != 0: bound = 1 / math.sqrt(fan_in) nn.init.uniform_(self.bias, -bound, bound) - def forward(self, x: torch.Tensor, var_ids: list[int]) -> torch.Tensor: + def forward(self, x: torch.Tensor, var_names: tuple[str, ...]) -> torch.Tensor: """Run the embedding. Args: x (:class:`torch.Tensor`): Tensor to embed of a shape of `(B, V, T, H, W)`. - var_ids (list[int]): A list of variable IDs. The length should be equal to `V`. + var_names (tuple[str, ...]): Names of the variables in `x`. The length should be equal + to `V`. Returns: :class:`torch.Tensor`: Embedded tensor a shape of `(B, L, D]) if flattened, @@ -87,16 +90,21 @@ def forward(self, x: torch.Tensor, var_ids: list[int]) -> torch.Tensor: """ B, V, T, H, W = x.shape - assert len(var_ids) == V, f"{V} != {len(var_ids)}." + assert len(var_names) == V, f"{V} != {len(var_names)}." assert self.kernel_size[0] >= T, f"{T} > {self.kernel_size[0]}." assert H % self.kernel_size[1] == 0, f"{H} % {self.kernel_size[0]} != 0." assert W % self.kernel_size[2] == 0, f"{W} % {self.kernel_size[1]} != 0." - assert max(var_ids) < self.max_vars, f"{max(var_ids)} >= {self.max_vars}." - assert min(var_ids) >= 0, f"{min(var_ids)} < 0." - assert len(set(var_ids)) == len(var_ids), f"{var_ids} contains duplicates." + assert len(set(var_names)) == len(var_names), f"{var_names} contains duplicates." # Select the weights of the variables and history dimensions that are present in the batch. - weight = self.weight[:, var_ids, :T, ...] # (C_out, C_in, T, H, W) + weight = torch.cat( + [ + # (C_out, C_in, T, H, W) + self.weights[name][:, :, :T, ...] + for name in var_names + ], + dim=1, + ) # Adjust the stride if history is smaller than maximum. stride = (T,) + self.kernel_size[1:] diff --git a/aurora/model/util.py b/aurora/model/util.py index 914a79b..a1b69c2 100644 --- a/aurora/model/util.py +++ b/aurora/model/util.py @@ -9,8 +9,6 @@ __all__ = [ "unpatchify", - "create_var_map", - "get_ids_for_var_map", "check_lat_lon_dtype", "maybe_adjust_windows", "init_weights", @@ -43,37 +41,6 @@ def unpatchify(x: torch.Tensor, V: int, H: int, W: int, P: int) -> torch.Tensor: return x -def create_var_map(variables: tuple[str, ...]) -> dict[str, int]: - """Create dictionary where the keys are variable names and values are unique IDs. - - Args: - variables (tuple[str, ...]): Variable strings. - - Returns: - dict[str, int]: Variable map dictionary. - """ - return {v: i for i, v in enumerate(variables)} - - -def get_ids_for_var_map( - variables: tuple, - var_maps: dict, - device: torch.cuda.device, -) -> torch.Tensor: - """Construct a tensor of variable IDs after retrieving those from a variable map created with - :func:`.create_var_map`. - - Args: - variables (tuples[str, ...]): Variables to retrieve the IDs for. - var_maps (dict[str, int]): Variable map constructed with :func:`.create_var_map`. - device (torch.cuda.device): Device. - - Returns: - torch.Tensor: Tensor of variable IDs found in `var_map`. - """ - return torch.tensor([var_maps[v] for v in variables], device=device) - - def check_lat_lon_dtype(lat: torch.Tensor, lon: torch.Tensor) -> None: """Assert that `lat` and `lon` are at least `float32`s.""" assert lat.dtype in [torch.float32, torch.float64], f"Latitude num. unstable: {lat.dtype}." diff --git a/docs/beware.md b/docs/beware.md index f42a485..f8235a1 100644 --- a/docs/beware.md +++ b/docs/beware.md @@ -42,14 +42,3 @@ If you changed the model and added or removed parameters, you need to set `stric loading a checkpoint `Aurora.load_checkpoint(..., strict=False)`. Importantly, enabling or disabling LoRA for a model that was trained respectively without or with LoRA changes the parameters! - -## Extending the Model with New Surface-Level Variables - -Whereas we have attempted to design a robust and flexible model, -inevitably some unfortunate design choices slipped through. - -A notable unfortunate design choice is that extending the model with a new surface-level -variable breaks compatibility with existing checkpoints. -It is possible to hack around this in a relatively simple way. -We are working on a more principled fix. -Please open an issue if this is a problem for you. diff --git a/docs/finetuning.md b/docs/finetuning.md index ce043a3..55f41ff 100644 --- a/docs/finetuning.md +++ b/docs/finetuning.md @@ -1,7 +1,7 @@ # Fine-Tuning -If you wish to fine-tune Aurora for you specific application, -you should use the pretrained version: +Generally, if you wish to fine-tune Aurora for a specific application, +you should build on the pretrained version: ```python from aurora import Aurora @@ -10,21 +10,49 @@ model = Aurora(use_lora=False) # Model is not fine-tuned. model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt") ``` -You are also free to extend the model for your particular use case. -In that case, it might be that you add or remove parameters. +## Extending Aurora with New Variables + +Aurora can be extended with new variables by adjusting the keyword arguments `surf_vars`, +`static_vars`, and `atmos_vars`. +When you add a new variable, you also need to set the normalisation statistics. + +```python +from aurora import Aurora +from aurora.normalisation import locations, scales + +model = Aurora( + use_lora=False, + surf_vars=("2t", "10u", "10v", "msl", "new_surf_var"), + static_vars=("lsm", "z", "slt", "new_static_var"), + atmos_vars=("z", "u", "v", "t", "q", "new_atmos_var"), +) +model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt") + +# Normalisation means: +locations["new_surf_var"] = 0.0 +locations["new_static_var"] = 0.0 +locations["new_atmos_var"] = 0.0 + +# Normalisation standard deviations: +scales["new_surf_var"] = 1.0 +scales["new_static_var"] = 1.0 +scales["new_atmos_var"] = 1.0 +``` + +## Other Model Extensions + +It is possible to extend to model in any way you like. +If you do this, you will likely add or remove parameters. Then `Aurora.load_checkpoint` will error, because the existing checkpoint now mismatches with the model's parameters. -Simply set `Aurora.load_checkpoint(..., strict=False)`: +Simply set `Aurora.load_checkpoint(..., strict=False)` to ignore the mismatches: ```python from aurora import Aurora - model = Aurora(...) ... # Modify `model`. model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt", strict=False) ``` - -More instructions coming soon! diff --git a/tests/test_model.py b/tests/test_model.py index 4e6cdb9..c3df873 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -144,3 +144,13 @@ def assert_approx_equality(v_out: np.ndarray, v_ref: np.ndarray, tol: float) -> np.testing.assert_allclose(pred.metadata.lat, test_output["metadata"]["lat"]) assert pred.metadata.atmos_levels == tuple(test_output["metadata"]["atmos_levels"]) assert pred.metadata.time == tuple(test_output["metadata"]["time"]) + + +def test_aurora_small_decoder_init() -> None: + model = AuroraSmall(use_lora=True) + + # Check that the decoder heads are properly initialised. The biases should be zero, but the + # weights shouldn't. + for layer in [*model.decoder.surf_heads.values(), *model.decoder.atmos_heads.values()]: + assert not torch.all(layer.weight == 0) + assert torch.all(layer.bias == 0)