diff --git a/README.md b/README.md index 9d3eee5..a10e29c 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,31 @@ Install with `pip`: pip install microsoft-aurora ``` -Example here. +Run an untrained small model on random data: + +```python +import torch + +from aurora import AuroraSmall, Batch, Metadata + +model = AuroraSmall() + +batch = Batch( + surf_vars={k: torch.randn(1, 2, 16, 32) for k in ("2t", "10u", "10v", "msl")}, + static_vars={k: torch.randn(1, 2, 16, 32) for k in ("lsm", "z", "slt")}, + atmos_vars={k: torch.randn(1, 2, 4, 16, 32) for k in ("z", "u", "v", "t", "q")}, + metadata=Metadata( + lat=torch.linspace(90, -90, 17)[:-1], # Cut off the south pole. + lon=torch.linspace(0, 360, 32 + 1)[:-1], + time=(datetime(2020, 6, 1, 12, 0),), + atmos_levels=(100, 250, 500, 850), + ), +) + +prediction = model.forward(batch) + +print(prediction.surf_vars["2t"]) +``` ## Contributing diff --git a/tests/test_model.py b/tests/test_model.py index dcbfaa4..029a53a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -11,17 +11,17 @@ def test_aurora_small(): model = AuroraSmall() batch = Batch( - {k: torch.randn(1, 2, 16, 32) for k in ("2t", "10u", "10v", "msl")}, - {k: torch.randn(1, 2, 16, 32) for k in ("lsm", "z", "slt")}, - {k: torch.randn(1, 2, 4, 16, 32) for k in ("z", "u", "v", "t", "q")}, - Metadata( - torch.linspace(90, -90, 17)[:-1], # Cut off the south pole. - torch.linspace(0, 360, 32 + 1)[:-1], - (datetime(2020, 6, 1, 12, 0),), - (100, 250, 500, 850), + surf_vars={k: torch.randn(1, 2, 16, 32) for k in ("2t", "10u", "10v", "msl")}, + static_vars={k: torch.randn(1, 2, 16, 32) for k in ("lsm", "z", "slt")}, + atmos_vars={k: torch.randn(1, 2, 4, 16, 32) for k in ("z", "u", "v", "t", "q")}, + metadata=Metadata( + lat=torch.linspace(90, -90, 17)[:-1], # Cut off the south pole. + lon=torch.linspace(0, 360, 32 + 1)[:-1], + time=(datetime(2020, 6, 1, 12, 0),), + atmos_levels=(100, 250, 500, 850), ), ) - pred = model.forward(batch) + prediction = model.forward(batch) - assert isinstance(pred, Batch) + assert isinstance(prediction, Batch)