Skip to content

Commit

Permalink
Merge pull request #695 from ACEsuit/fix-extract-invariant-bug
Browse files Browse the repository at this point in the history
extract_invariant now correctly extracts first layer when num_layers=1
  • Loading branch information
ilyes319 authored Nov 15, 2024
2 parents a65cd4f + 813c33b commit fbc62fa
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
4 changes: 2 additions & 2 deletions mace/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def _check_non_zero(std):

def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int):
out = []
for i in range(num_layers - 1):
out.append(x[:, :num_features])
for i in range(1, num_layers):
out.append(
x[
:,
Expand All @@ -247,7 +248,6 @@ def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max
* num_features,
]
)
out.append(x[:, -num_features:])
return torch.cat(out, dim=-1)


Expand Down
24 changes: 19 additions & 5 deletions tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,11 +468,18 @@ def test_calculator_energy_dipole(fitting_configs, trained_energy_dipole_model):

def test_calculator_descriptor(fitting_configs, trained_equivariant_model):
at = fitting_configs[2].copy()
at.calc = trained_equivariant_model

desc_invariant = at.calc.get_descriptors(at, invariants_only=True)
desc_single_layer = at.calc.get_descriptors(at, invariants_only=True, num_layers=1)
desc = at.calc.get_descriptors(at, invariants_only=False)
at_rotated = fitting_configs[2].copy()
at_rotated.rotate(90, "x")
calc = trained_equivariant_model

desc_invariant = calc.get_descriptors(at, invariants_only=True)
desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True)
desc_single_layer = calc.get_descriptors(at, invariants_only=True, num_layers=1)
desc_single_layer_rotated = calc.get_descriptors(
at_rotated, invariants_only=True, num_layers=1
)
desc = calc.get_descriptors(at, invariants_only=False)
desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False)

assert desc_invariant.shape[0] == 3
assert desc_invariant.shape[1] == 32
Expand All @@ -481,6 +488,13 @@ def test_calculator_descriptor(fitting_configs, trained_equivariant_model):
assert desc.shape[0] == 3
assert desc.shape[1] == 80

np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6)
np.testing.assert_allclose(desc_single_layer, desc_invariant[:, :16], atol=1e-6)
np.testing.assert_allclose(
desc_single_layer_rotated, desc_invariant[:, :16], atol=1e-6
)
assert not np.allclose(desc, desc_rotated, atol=1e-6)


def test_mace_mp(capsys: pytest.CaptureFixture):
mp_mace = mace_mp()
Expand Down

0 comments on commit fbc62fa

Please sign in to comment.