From b8a0999801fe81345c7accf487c84998779c9ae3 Mon Sep 17 00:00:00 2001 From: Wessel Date: Thu, 29 Aug 2024 15:07:54 +0200 Subject: [PATCH] Export and fix high-res model (#19) --- aurora/__init__.py | 3 ++- aurora/model/aurora.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/aurora/__init__.py b/aurora/__init__.py index df060c6..995f9ef 100644 --- a/aurora/__init__.py +++ b/aurora/__init__.py @@ -1,11 +1,12 @@ """Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" from aurora.batch import Batch, Metadata -from aurora.model.aurora import Aurora, AuroraSmall +from aurora.model.aurora import Aurora, AuroraHighRes, AuroraSmall from aurora.rollout import rollout __all__ = [ "Aurora", + "AuroraHighRes", "AuroraSmall", "Batch", "Metadata", diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 82f37dd..b08c387 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -253,6 +253,7 @@ def load_checkpoint(self, repo: str, name: str, strict: bool = True) -> None: AuroraHighRes = partial( Aurora, + patch_size=10, encoder_depths=(6, 8, 8), decoder_depths=(8, 8, 6), )