diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 0cbbf43..9fae270 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -1,17 +1,21 @@ """Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" +import contextlib import dataclasses from datetime import timedelta from functools import partial import torch from huggingface_hub import hf_hub_download +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + apply_activation_checkpointing, +) from aurora.batch import Batch from aurora.model.decoder import Perceiver3DDecoder from aurora.model.encoder import Perceiver3DEncoder from aurora.model.lora import LoRAMode -from aurora.model.swin3d import Swin3DTransformerBackbone +from aurora.model.swin3d import BasicLayer3D, Swin3DTransformerBackbone __all__ = ["Aurora", "AuroraSmall", "AuroraHighRes"] @@ -47,6 +51,7 @@ def __init__( use_lora: bool = True, lora_steps: int = 40, lora_mode: LoRAMode = "single", + autocast: bool = False, ) -> None: """Construct an instance of the model. @@ -92,11 +97,14 @@ def __init__( lora_mode (str, optional): LoRA mode. `"single"` uses the same LoRA for all roll-out steps, and `"all"` uses a different LoRA for every roll-out step. Defaults to `"single"`. + autocast (bool, optional): Use `torch.autocast` to reduce memory usage. Defaults to + `False`. """ super().__init__() self.surf_vars = surf_vars self.atmos_vars = atmos_vars self.patch_size = patch_size + self.autocast = autocast self.encoder = Perceiver3DEncoder( surf_vars=surf_vars, @@ -181,12 +189,13 @@ def forward(self, batch: Batch) -> Batch: batch, lead_time=timedelta(hours=6), ) - x = self.backbone( - x, - lead_time=timedelta(hours=6), - patch_res=patch_res, - rollout_step=batch.metadata.rollout_step, - ) + with torch.autocast(device_type="cuda") if self.autocast else contextlib.nullcontext(): + x = self.backbone( + x, + lead_time=timedelta(hours=6), + patch_res=patch_res, + rollout_step=batch.metadata.rollout_step, + ) pred = self.decoder( x, batch, @@ -297,6 +306,13 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None: self.load_state_dict(d, strict=strict) + def configure_activation_checkpointing(self): + """Configure activation checkpointing. + + This is required in order to compute gradients without running out of memory. + """ + apply_activation_checkpointing(self, check_fn=lambda x: isinstance(x, BasicLayer3D)) + AuroraSmall = partial( Aurora, diff --git a/docs/finetuning.md b/docs/finetuning.md index 55f41ff..8bdda7b 100644 --- a/docs/finetuning.md +++ b/docs/finetuning.md @@ -10,6 +10,33 @@ model = Aurora(use_lora=False) # Model is not fine-tuned. model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt") ``` +## Computing Gradients + +To compute gradients, you will need an A100 with 80 GB of memory. +In addition, you will need to use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html) +and gradient checkpointing. +You can do this as follows: + +```python +from aurora import Aurora + +model = Aurora( + use_lora=False, # Model was not fine-tuned. + autocast=True, # Use AMP. +) +model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt") + +batch = ... # Load some data. + +model = model.cuda() +model.train() +model.configure_activation_checkpointing() + +pred = model.forward(batch) +loss = ... +loss.backward() +``` + ## Extending Aurora with New Variables Aurora can be extended with new variables by adjusting the keyword arguments `surf_vars`,