Skip to content

Commit

Permalink
Attempt inference mode
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Aug 6, 2024
1 parent 217175e commit 091dec6
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from aurora import AuroraSmall, Batch, Metadata

torch.use_deterministic_algorithms(True)


class SavedMetadata(TypedDict):
"""Type of metadata of a saved test batch."""
Expand Down Expand Up @@ -77,8 +79,8 @@ def test_aurora_small() -> None:

# Load the checkpoint and run the model.
model.load_checkpoint(os.environ["HUGGINGFACE_REPO"], "aurora-0.25-small-pretrained.ckpt")
with torch.no_grad():
torch.manual_seed(0) # Very important to seed! The test data was generated using this.
torch.manual_seed(0) # Very important to seed! The test data was generated using this.
with torch.inference_mode():
pred = model.forward(batch)

def assert_approx_equality(v_out, v_ref) -> None:
Expand Down

0 comments on commit 091dec6

Please sign in to comment.