diff --git a/aurora/__init__.py b/aurora/__init__.py index df060c6..995f9ef 100644 --- a/aurora/__init__.py +++ b/aurora/__init__.py @@ -1,11 +1,12 @@ """Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" from aurora.batch import Batch, Metadata -from aurora.model.aurora import Aurora, AuroraSmall +from aurora.model.aurora import Aurora, AuroraHighRes, AuroraSmall from aurora.rollout import rollout __all__ = [ "Aurora", + "AuroraHighRes", "AuroraSmall", "Batch", "Metadata", diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 82f37dd..b08c387 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -253,6 +253,7 @@ def load_checkpoint(self, repo: str, name: str, strict: bool = True) -> None: AuroraHighRes = partial( Aurora, + patch_size=10, encoder_depths=(6, 8, 8), decoder_depths=(8, 8, 6), )