diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 097fb18..23c8ae4 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -13,6 +13,7 @@ jobs: fail-fast: false matrix: version: ["3.10", "3.11"] + environment: huggingface-access name: Test with Python ${{ matrix.version }} steps: @@ -25,6 +26,9 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install --upgrade --no-cache-dir -e '.[dev]' + - name: Log into HugginFace + run: | + huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }} - name: Run tests run: | pytest -v --cov=aurora --cov-report term-missing diff --git a/README.md b/README.md index a10e29c..4f31065 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ Install with `pip`: pip install microsoft-aurora ``` -Run an untrained small model on random data: +Run the pretrained small model on random data: ```python import torch @@ -66,6 +66,8 @@ from aurora import AuroraSmall, Batch, Metadata model = AuroraSmall() +model.load_checkpoint("wbruinsma/aurora", "aurora-0.25-small-pretrained.ckpt") + batch = Batch( surf_vars={k: torch.randn(1, 2, 16, 32) for k in ("2t", "10u", "10v", "msl")}, static_vars={k: torch.randn(1, 2, 16, 32) for k in ("lsm", "z", "slt")}, @@ -83,6 +85,12 @@ prediction = model.forward(batch) print(prediction.surf_vars["2t"]) ``` +Note that this will incur a 500 MB download +and you may need to authenticate with `huggingface-cli login`. + +See the [HuggingFace repository `wbruinsma/aurora`](https://huggingface.co/wbruinsma/aurora) +for an overview of which models are available. + ## Contributing See [`CONTRIBUTING.md`](CONTRIBUTING.md). @@ -148,6 +156,13 @@ First, install the repository in editable mode and setup `pre-commit`: make install ``` +Then configure the HuggingFace repository where the weights can be found and log into HuggingFace: + +```bash +export HUGGINGFACE_REPO=wbruinsma/aurora +huggingface-cli login +``` + To run the tests and print coverage, run ```bash diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index bee6a73..92b8951 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -3,7 +3,8 @@ from datetime import timedelta from functools import partial -from torch import nn +import torch +from huggingface_hub import hf_hub_download from aurora.batch import Batch from aurora.model.decoder import Perceiver3DDecoder @@ -16,7 +17,7 @@ """type: Tuple of variable names.""" -class Aurora(nn.Module): +class Aurora(torch.nn.Module): """The Aurora model. Defaults to to the 1.3 B parameter configuration. @@ -141,6 +142,18 @@ def forward(self, batch: Batch) -> Batch: return pred + def load_checkpoint(self, repo: str, name: str) -> None: + path = hf_hub_download(repo_id=repo, filename=name) + d = torch.load(path, map_location="cpu") + + # Rename keys to ensure compatibility. + for k, v in list(d.items()): + if k.startswith("net."): + del d[k] + d[k[4:]] = v + + self.load_state_dict(d, strict=True) + AuroraSmall = partial( Aurora, @@ -150,4 +163,5 @@ def forward(self, batch: Batch) -> Batch: decoder_num_heads=(16, 8, 4), embed_dim=256, num_heads=8, + use_lora=False, ) diff --git a/pyproject.toml b/pyproject.toml index 413b80f..7e84b0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "torch", "einops", "timm==0.6.13", + "huggingface-hub", ] [project.optional-dependencies] diff --git a/tests/test_model.py b/tests/test_model.py index 029a53a..c6278c8 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,5 +1,6 @@ """Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" +import os from datetime import datetime import torch @@ -10,6 +11,8 @@ def test_aurora_small(): model = AuroraSmall() + model.load_checkpoint(os.environ["HUGGINGFACE_REPO"], "aurora-0.25-small-pretrained.ckpt") + batch = Batch( surf_vars={k: torch.randn(1, 2, 16, 32) for k in ("2t", "10u", "10v", "msl")}, static_vars={k: torch.randn(1, 2, 16, 32) for k in ("lsm", "z", "slt")},