Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Jan 29, 2025
1 parent 3962929 commit 37ed487
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 26 deletions.
7 changes: 1 addition & 6 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,7 @@ def train_model(
# TRAIN MODEL #############
###########################

# logger.info("Calling trainer")
# from torch.profiler import profile, ProfilerActivity
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
# if True:
logger.info("Calling trainer")
trainer.train(
model=model,
dtype=dtype,
Expand All @@ -433,8 +430,6 @@ def train_model(
val_datasets=val_datasets,
checkpoint_dir=str(checkpoint_dir),
)
# print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=20))
# exit()

if not is_main_process():
return # only save and evaluate on the main process
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/experimental/phace/tests/test_equivariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_rotational_invariance():
torch.set_default_dtype(torch.float32) # change back


@pytest.mark.parametrize("o3_lambda", [0, 1, 2, 3, 4])
@pytest.mark.parametrize("o3_lambda", [0, 1, 2, 3])
@pytest.mark.parametrize("o3_sigma", [1])
def test_equivariance_rotations(o3_lambda, o3_sigma):
"""Tests that the model is rotationally equivariant when predicting
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_equivariance_rotations(o3_lambda, o3_sigma):


@pytest.mark.parametrize("dataset_path", [DATASET_PATH, DATASET_PATH_PERIODIC])
@pytest.mark.parametrize("o3_lambda", [0, 1, 2, 3, 4])
@pytest.mark.parametrize("o3_lambda", [0, 1, 2, 3])
@pytest.mark.parametrize("o3_sigma", [1])
def test_equivariance_inversion(dataset_path, o3_lambda, o3_sigma):
"""Tests that the model is equivariant with respect to inversions."""
Expand Down
8 changes: 4 additions & 4 deletions src/metatrain/experimental/phace/tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_output_last_layer_features():
]
assert features.values.shape == (
4,
32,
192,
)
assert features.properties.names == [
"properties",
Expand All @@ -178,7 +178,7 @@ def test_output_last_layer_features():
assert last_layer_features.values.shape == (
4,
1,
32,
192,
)
assert last_layer_features.properties.names == [
"properties",
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_output_last_layer_features():
]
assert features.values.shape == (
1,
32,
192,
)
assert features.properties.names == [
"properties",
Expand All @@ -220,7 +220,7 @@ def test_output_last_layer_features():
assert outputs["mtt::aux::energy_last_layer_features"].block().values.shape == (
1,
1,
32,
192,
)
assert outputs["mtt::aux::energy_last_layer_features"].block().properties.names == [
"properties",
Expand Down
28 changes: 14 additions & 14 deletions src/metatrain/experimental/phace/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ def test_regression_init():

expected_output = torch.tensor(
[
[-3.736138916016e01],
[6.381711363792e-02],
[1.800283813477e01],
[-2.970426513672e03],
[1.789148712158e01],
[0.002085668733],
[-0.003157143714],
[0.000328244379],
[0.004316798877],
[0.001980246045],
]
)

# if you need to change the hardcoded values:
torch.set_printoptions(precision=12)
print(output["mtt::U0"].block().values)
# torch.set_printoptions(precision=12)
# print(output["mtt::U0"].block().values)

torch.testing.assert_close(output["mtt::U0"].block().values, expected_output)

Expand Down Expand Up @@ -115,16 +115,16 @@ def test_regression_train():

expected_output = torch.tensor(
[
[2.120110988617],
[0.246357604861],
[0.113200485706],
[0.136439576745],
[0.023953542113],
[0.101170130074],
[0.038209509104],
[0.012803453952],
[0.151425197721],
[0.050753910094],
]
)

# if you need to change the hardcoded values:
torch.set_printoptions(precision=12)
print(output["mtt::U0"].block().values)
# torch.set_printoptions(precision=12)
# print(output["mtt::U0"].block().values)

torch.testing.assert_close(output["mtt::U0"].block().values, expected_output)

0 comments on commit 37ed487

Please sign in to comment.