Skip to content

Commit

Permalink
Suggestions from review
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Feb 2, 2025
1 parent da82cc8 commit 63d6ac3
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 17 deletions.
19 changes: 11 additions & 8 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,14 +422,17 @@ def train_model(
###########################

logger.info("Calling trainer")
trainer.train(
model=model,
dtype=dtype,
devices=devices,
train_datasets=train_datasets,
val_datasets=val_datasets,
checkpoint_dir=str(checkpoint_dir),
)
try:
trainer.train(
model=model,
dtype=dtype,
devices=devices,
train_datasets=train_datasets,
val_datasets=val_datasets,
checkpoint_dir=str(checkpoint_dir),
)
except Exception as e:
raise ArchitectureError(e)

if not is_main_process():
return # only save and evaluate on the main process
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/experimental/phace/default-hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ architecture:
name: experimental.phace

model:
nu_max: 3
max_correlation_order_per_layer: 3
num_message_passing_layers: 2
cutoff: 5.0
cutoff_width: 1.0
num_element_channels: 64
radial_basis:
mlp: true
E_max: 50.0
max_eigenvalue: 50.0
scale: 0.7
optimizable_lengthscales: false
nu_scaling: 0.1
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/phace/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
)
self.register_buffer("species_to_species_index", species_to_species_index)

self.nu_max = model_hypers["nu_max"]
self.nu_max = model_hypers["max_correlation_order_per_layer"]
self.num_message_passing_layers = model_hypers["num_message_passing_layers"]
if self.num_message_passing_layers < 1:
raise ValueError("Number of message-passing layers must be at least 1")
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/phace/modules/physical_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_physical_basis_spliner(E_max, r_cut, normalize):
for l in range(l_max + 1): # noqa: E741
n_max_l.append(np.where(E_nl[:, l] <= E_max)[0][-1] + 1)
if n_max_l[0] > n_max:
raise ValueError("n_max too large, try decreasing E_max")
raise ValueError("n_max too large, try decreasing max_eigenvalue")

def function_for_splining(n, l, x): # noqa: E741
ret = physical_basis.compute(n, l, x)
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/phace/modules/radial_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, hypers, all_species) -> None:
for species in all_species:
lengthscales[species] = np.log(hypers["scale"] * covalent_radii[species])
self.n_max_l, self.spliner = get_physical_basis_spliner(
hypers["E_max"], hypers["cutoff"], normalize=True
hypers["max_eigenvalue"], hypers["cutoff"], normalize=True
)
if hypers["optimizable_lengthscales"]:
self.lengthscales = torch.nn.Parameter(lengthscales)
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/experimental/phace/schema-hypers.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"model": {
"type": "object",
"properties": {
"nu_max": {
"max_correlation_order_per_layer": {
"type": "integer"
},
"num_message_passing_layers": {
Expand All @@ -30,7 +30,7 @@
"mlp": {
"type": "boolean"
},
"E_max": {
"max_eigenvalue": {
"type": "number"
},
"scale": {
Expand Down
3 changes: 1 addition & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,10 @@ commands =
pytest {posargs}

[testenv:phace-tests]
description = Run NanoPET tests with pytest
description = Run PhACE tests with pytest
passenv = *
deps =
pytest
spherical # for nanoPET spherical target
extras = phace
changedir = src/metatrain/experimental/phace/tests/
commands =
Expand Down

0 comments on commit 63d6ac3

Please sign in to comment.