-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add new logic to enable stored checkpoint weights to be copied to new…
… history dimensions (#36) * Add new logic to enable stored checkpoint weights to be copied to new history dimensions * Refactor checkpoint adaptation logic to allow for more flexibility and different fn to adapt history * refactor ability to adapt max_history_size from a checkpoint to its own method * Add addiitonal test for multiple calls * Add copyright to the new test file * fill with zeroes instead of previous weights match previous weights device and dtype * manuall fix AuroraHighRes to match main * simplify weight copying logic * Fix pre-commit issues
- Loading branch information
Showing
2 changed files
with
113 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" | ||
|
||
import pytest | ||
import torch | ||
|
||
from aurora.model.aurora import AuroraSmall | ||
|
||
|
||
@pytest.fixture | ||
def model(request): | ||
return AuroraSmall(max_history_size=request.param) | ||
|
||
|
||
@pytest.fixture | ||
def checkpoint(): | ||
return { | ||
"encoder.surf_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)), | ||
"encoder.atmos_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)), | ||
} | ||
|
||
|
||
# check both history sizes which are divisible by 2 (original shape) and not | ||
@pytest.mark.parametrize("model", [4, 5], indirect=True) | ||
def test_adapt_checkpoint_max_history(model, checkpoint): | ||
# checkpoint starts with history dim, shape[2], as size 2 | ||
assert checkpoint["encoder.surf_token_embeds.weights.0"].shape[2] == 2 | ||
adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint) | ||
|
||
for name, weight in adapted_checkpoint.items(): | ||
assert weight.shape[2] == model.max_history_size | ||
for j in range(weight.shape[2]): | ||
if j >= checkpoint[name].shape[2]: | ||
assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :])) | ||
else: | ||
assert torch.equal( | ||
weight[:, :, j, :, :], | ||
checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :], | ||
) | ||
|
||
|
||
# check that assert is thrown when trying to load a larger checkpoint to a smaller history size | ||
@pytest.mark.parametrize("model", [1], indirect=True) | ||
def test_adapt_checkpoint_max_history_fail(model, checkpoint): | ||
with pytest.raises(AssertionError): | ||
model.adapt_checkpoint_max_history_size(checkpoint) | ||
|
||
|
||
# test adapting the checkpoint twice to ensure that the second time should not change the weights | ||
@pytest.mark.parametrize("model", [4], indirect=True) | ||
def test_adapt_checkpoint_max_history_twice(model, checkpoint): | ||
adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint) | ||
adapted_checkpoint = model.adapt_checkpoint_max_history_size(adapted_checkpoint) | ||
|
||
for name, weight in adapted_checkpoint.items(): | ||
assert weight.shape[2] == model.max_history_size | ||
for j in range(weight.shape[2]): | ||
if j >= checkpoint[name].shape[2]: | ||
assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :])) | ||
else: | ||
assert torch.equal( | ||
weight[:, :, j, :, :], | ||
checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :], | ||
) |