Skip to content

Commit

Permalink
Add new logic to enable stored checkpoint weights to be copied to new…
Browse files Browse the repository at this point in the history
… history dimensions (#36)

* Add new logic to enable stored checkpoint weights to be copied to new history dimensions

* Refactor checkpoint adaptation logic to allow for more flexibility and different fn to adapt history

* refactor ability to adapt max_history_size from a checkpoint to its own method

* Add addiitonal test for multiple calls

* Add copyright to the new test file

* fill with zeroes instead of previous weights

match previous weights device and dtype

* manuall fix AuroraHighRes to match main

* simplify weight copying logic

* Fix pre-commit issues
  • Loading branch information
scottcha authored Sep 23, 2024
1 parent 83727c1 commit c694fca
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 3 deletions.
53 changes: 50 additions & 3 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from datetime import timedelta
from functools import partial
from typing import Optional
from typing import Any, Optional

import torch
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -112,6 +112,7 @@ def __init__(
self.patch_size = patch_size
self.surf_stats = surf_stats or dict()
self.autocast = autocast
self.max_history_size = max_history_size

if self.surf_stats:
warnings.warn(
Expand Down Expand Up @@ -268,8 +269,7 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None:
del d[k]
d[k[4:]] = v

# Convert the ID-based parametrisation to a name-based parametrisation.

# Convert the ID-based parametrization to a name-based parametrization.
if "encoder.surf_token_embeds.weight" in d:
weight = d["encoder.surf_token_embeds.weight"]
del d["encoder.surf_token_embeds.weight"]
Expand Down Expand Up @@ -316,8 +316,55 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None:
d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i]
d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i]

# check if history size is compatible and adjust weights if necessary
if self.max_history_size > d["encoder.surf_token_embeds.weights.2t"].shape[2]:
d = self.adapt_checkpoint_max_history_size(d)
elif self.max_history_size < d["encoder.surf_token_embeds.weights.2t"].shape[2]:
raise AssertionError(f"Cannot load checkpoint with max_history_size \
{d['encoder.surf_token_embeds.weights.2t'].shape[2]} \
into model with max_history_size {self.max_history_size}")

self.load_state_dict(d, strict=strict)

def adapt_checkpoint_max_history_size(self, checkpoint) -> Any:
"""Adapt a checkpoint with smaller max_history_size to a model with a larger
max_history_size than the current model.
If a checkpoint was trained with a larger max_history_size than the current model,
this function will assert fail to prevent loading the checkpoint. This is to
prevent loading a checkpoint which will likely cause the checkpoint to degrade is
performance.
This implementation copies weights from the checkpoint to the model and fills 0
for the new history width dimension.
"""
# Find all weights with prefix "encoder.surf_token_embeds.weights."
for name, weight in list(checkpoint.items()):
if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith(
"encoder.atmos_token_embeds.weights."
):
# This shouldn't get called with current logic but leaving here for future proofing
# and in cases where its called outside current context
assert (
weight.shape[2] <= self.max_history_size
), f"Cannot load checkpoint with max_history_size {weight.shape[2]} \
into model with max_history_size {self.max_history_size} for weight {name}"

# Initialize the new weight tensor
new_weight = torch.zeros(
(weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4]),
device=weight.device,
dtype=weight.dtype,
)

# Copy the existing weights to the new tensor by duplicating the histories provided
# into any new history dimensions
for j in range(weight.shape[2]):
# only fill existing weights, others are zeros
new_weight[:, :, j, :, :] = weight[:, :, j, :, :]
checkpoint[name] = new_weight
return checkpoint

def configure_activation_checkpointing(self):
"""Configure activation checkpointing.
Expand Down
63 changes: 63 additions & 0 deletions tests/test_checkpoint_adaptation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

import pytest
import torch

from aurora.model.aurora import AuroraSmall


@pytest.fixture
def model(request):
return AuroraSmall(max_history_size=request.param)


@pytest.fixture
def checkpoint():
return {
"encoder.surf_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)),
"encoder.atmos_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)),
}


# check both history sizes which are divisible by 2 (original shape) and not
@pytest.mark.parametrize("model", [4, 5], indirect=True)
def test_adapt_checkpoint_max_history(model, checkpoint):
# checkpoint starts with history dim, shape[2], as size 2
assert checkpoint["encoder.surf_token_embeds.weights.0"].shape[2] == 2
adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint)

for name, weight in adapted_checkpoint.items():
assert weight.shape[2] == model.max_history_size
for j in range(weight.shape[2]):
if j >= checkpoint[name].shape[2]:
assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :]))
else:
assert torch.equal(
weight[:, :, j, :, :],
checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :],
)


# check that assert is thrown when trying to load a larger checkpoint to a smaller history size
@pytest.mark.parametrize("model", [1], indirect=True)
def test_adapt_checkpoint_max_history_fail(model, checkpoint):
with pytest.raises(AssertionError):
model.adapt_checkpoint_max_history_size(checkpoint)


# test adapting the checkpoint twice to ensure that the second time should not change the weights
@pytest.mark.parametrize("model", [4], indirect=True)
def test_adapt_checkpoint_max_history_twice(model, checkpoint):
adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint)
adapted_checkpoint = model.adapt_checkpoint_max_history_size(adapted_checkpoint)

for name, weight in adapted_checkpoint.items():
assert weight.shape[2] == model.max_history_size
for j in range(weight.shape[2]):
if j >= checkpoint[name].shape[2]:
assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :]))
else:
assert torch.equal(
weight[:, :, j, :, :],
checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :],
)

0 comments on commit c694fca

Please sign in to comment.