Skip to content

Commit

Permalink
Expose metrics to choose the best model along a training run (#455)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Jan 29, 2025
1 parent 6419eec commit d358315
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 19 deletions.
5 changes: 5 additions & 0 deletions docs/src/architectures/nanopet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,8 @@ The hyperparameters for training are
to the RMSE
:param loss: The loss function to use, with the subfields described in the previous
section
:param best_model_metric: specifies the validation set metric to use to select the best
model, i.e. the model that will be saved as ``model.ckpt`` and ``model.pt`` both in
the current directory and in the checkpoint directory. The default is ``rmse_prod``,
i.e., the product of the RMSEs for each target. Other options are ``mae_prod`` and
``loss``.
5 changes: 5 additions & 0 deletions docs/src/architectures/soap-bpnn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ The parameters for training are
:param loss_weights: specifies the weights to be used in the loss for each target. The
weights should be a dictionary of floats, one for each target. All missing targets
are assigned a weight of 1.0.
:param best_model_metric: specifies the validation set metric to use to select the best
model, i.e. the model that will be saved as ``model.ckpt`` and ``model.pt`` both in
the current directory and in the checkpoint directory. The default is ``rmse_prod``,
i.e., the product of the RMSEs for each target. Other options are ``mae_prod`` and
``loss``.


References
Expand Down
1 change: 1 addition & 0 deletions src/metatrain/experimental/nanopet/default-hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ architecture:
fixed_composition_weights: {}
per_structure_targets: []
log_mae: False
best_model_metric: rmse_prod
loss:
type: mse
weights: {}
Expand Down
4 changes: 4 additions & 0 deletions src/metatrain/experimental/nanopet/schema-hypers.json
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@
"log_mae": {
"type": "boolean"
},
"best_model_metric": {
"type": "string",
"enum": ["rmse_prod", "mae_prod", "loss"]
},
"loss": {
"type": "object",
"properties": {
Expand Down
21 changes: 12 additions & 9 deletions src/metatrain/experimental/nanopet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ...utils.io import check_file_extension
from ...utils.logging import MetricLogger
from ...utils.loss import TensorMapDictLoss
from ...utils.metrics import MAEAccumulator, RMSEAccumulator
from ...utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric
from ...utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
Expand All @@ -38,7 +38,7 @@ def __init__(self, train_hypers):
self.optimizer_state_dict = None
self.scheduler_state_dict = None
self.epoch = None
self.best_loss = None
self.best_metric = None
self.best_model_state_dict = None
self.best_optimizer_state_dict = None

Expand Down Expand Up @@ -258,8 +258,8 @@ def systems_and_targets_to_dtype(
rotational_augmenter = RotationalAugmenter(train_targets)

# Train the model:
if self.best_loss is None:
self.best_loss = float("inf")
if self.best_metric is None:
self.best_metric = float("inf")
logger.info("Starting training")
epoch = start_epoch
for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]):
Expand Down Expand Up @@ -444,8 +444,11 @@ def systems_and_targets_to_dtype(
patience=self.hypers["scheduler_patience"],
)

if val_loss < self.best_loss:
self.best_loss = val_loss
val_metric = get_selected_metric(
finalized_val_info, self.hypers["best_model_metric"]
)
if val_metric < self.best_metric:
self.best_metric = val_metric
self.best_model_state_dict = copy.deepcopy(
(model.module if is_distributed else model).state_dict()
)
Expand Down Expand Up @@ -480,7 +483,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
"epoch": self.epoch,
"optimizer_state_dict": self.optimizer_state_dict,
"scheduler_state_dict": self.scheduler_state_dict,
"best_loss": self.best_loss,
"best_metric": self.best_metric,
"best_model_state_dict": self.best_model_state_dict,
"best_optimizer_state_dict": self.best_optimizer_state_dict,
}
Expand All @@ -497,7 +500,7 @@ def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":
epoch = checkpoint["epoch"]
optimizer_state_dict = checkpoint["optimizer_state_dict"]
scheduler_state_dict = checkpoint["scheduler_state_dict"]
best_loss = checkpoint["best_loss"]
best_metric = checkpoint["best_metric"]
best_model_state_dict = checkpoint["best_model_state_dict"]
best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"]

Expand All @@ -506,7 +509,7 @@ def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":
trainer.optimizer_state_dict = optimizer_state_dict
trainer.scheduler_state_dict = scheduler_state_dict
trainer.epoch = epoch
trainer.best_loss = best_loss
trainer.best_metric = best_metric
trainer.best_model_state_dict = best_model_state_dict
trainer.best_optimizer_state_dict = best_optimizer_state_dict

Expand Down
1 change: 1 addition & 0 deletions src/metatrain/experimental/soap_bpnn/default-hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ architecture:
fixed_composition_weights: {}
per_structure_targets: []
log_mae: False
best_model_metric: rmse_prod
loss:
type: mse
weights: {}
Expand Down
4 changes: 4 additions & 0 deletions src/metatrain/experimental/soap_bpnn/schema-hypers.json
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@
"log_mae": {
"type": "boolean"
},
"best_model_metric": {
"type": "string",
"enum": ["rmse_prod", "mae_prod", "loss"]
},
"loss": {
"type": "object",
"properties": {
Expand Down
21 changes: 12 additions & 9 deletions src/metatrain/experimental/soap_bpnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ...utils.io import check_file_extension
from ...utils.logging import MetricLogger
from ...utils.loss import TensorMapDictLoss
from ...utils.metrics import MAEAccumulator, RMSEAccumulator
from ...utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric
from ...utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
Expand All @@ -40,7 +40,7 @@ def __init__(self, train_hypers):
self.optimizer_state_dict = None
self.scheduler_state_dict = None
self.epoch = None
self.best_loss = None
self.best_metric = None
self.best_model_state_dict = None
self.best_optimizer_state_dict = None

Expand Down Expand Up @@ -252,8 +252,8 @@ def train(
start_epoch = 0 if self.epoch is None else self.epoch + 1

# Train the model:
if self.best_loss is None:
self.best_loss = float("inf")
if self.best_metric is None:
self.best_metric = float("inf")
logger.info("Starting training")
epoch = start_epoch
for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]):
Expand Down Expand Up @@ -430,8 +430,11 @@ def train(
patience=self.hypers["scheduler_patience"],
)

if val_loss < self.best_loss:
self.best_loss = val_loss
val_metric = get_selected_metric(
finalized_val_info, self.hypers["best_model_metric"]
)
if val_metric < self.best_metric:
self.best_metric = val_metric
self.best_model_state_dict = copy.deepcopy(
(model.module if is_distributed else model).state_dict()
)
Expand Down Expand Up @@ -466,7 +469,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
"epoch": self.epoch,
"optimizer_state_dict": self.optimizer_state_dict,
"scheduler_state_dict": self.scheduler_state_dict,
"best_loss": self.best_loss,
"best_metric": self.best_metric,
"best_model_state_dict": self.best_model_state_dict,
"best_optimizer_state_dict": self.best_optimizer_state_dict,
}
Expand All @@ -483,7 +486,7 @@ def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":
epoch = checkpoint["epoch"]
optimizer_state_dict = checkpoint["optimizer_state_dict"]
scheduler_state_dict = checkpoint["scheduler_state_dict"]
best_loss = checkpoint["best_loss"]
best_metric = checkpoint["best_metric"]
best_model_state_dict = checkpoint["best_model_state_dict"]
best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"]

Expand All @@ -492,7 +495,7 @@ def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":
trainer.optimizer_state_dict = optimizer_state_dict
trainer.scheduler_state_dict = scheduler_state_dict
trainer.epoch = epoch
trainer.best_loss = best_loss
trainer.best_metric = best_metric
trainer.best_model_state_dict = best_model_state_dict
trainer.best_optimizer_state_dict = best_optimizer_state_dict

Expand Down
33 changes: 33 additions & 0 deletions src/metatrain/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,36 @@ def finalize(
finalized_info[out_key] = value[0] / value[1]

return finalized_info


def get_selected_metric(metric_dict: Dict[str, float], selected_metric: str) -> float:
"""
Selects and/or calculates a (user-)selected metric from a dictionary of metrics.
This is useful when choosing the best model from a training run.
:param metric_dict: A dictionary of metrics, where the keys are the names of the
metrics and the values are the corresponding values.
:param selected_metric: The metric to return. This can be one of the following:
- "loss": return the loss value
- "rmse_prod": return the product of all RMSEs
- "mae_prod": return the product of all MAEs
"""
if selected_metric == "loss":
metric = metric_dict["loss"]
elif selected_metric == "rmse_prod":
metric = 1
for key in metric_dict:
if "RMSE" in key:
metric *= metric_dict[key]
elif selected_metric == "mae_prod":
metric = 1
for key in metric_dict:
if "MAE" in key:
metric *= metric_dict[key]
else:
raise ValueError(
f"Selected metric {selected_metric} not recognized. "
"Please select from 'loss', 'rmse_prod', or 'mae_prod'."
)
return metric
27 changes: 26 additions & 1 deletion tests/utils/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap

from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator
from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator, get_selected_metric


@pytest.fixture
Expand Down Expand Up @@ -81,3 +81,28 @@ def test_mae_accumulator(tensor_map_with_grad_1, tensor_map_with_grad_2):

assert "energy MAE (per atom)" in maes
assert "energy_gradient_gradients MAE" in maes


def test_get_selected_metric():
"""Tests the get_selected_metric function."""

metrics = {
"loss": 1,
"energy RMSE": 2,
"energy MAE": 3,
"mtt::target RMSE": 4,
"mtt::target MAE": 5,
}

selected_metric = "foo"
with pytest.raises(ValueError, match="Please select from"):
get_selected_metric(metrics, selected_metric)

selected_metric = "rmse_prod"
assert get_selected_metric(metrics, selected_metric) == 8

selected_metric = "mae_prod"
assert get_selected_metric(metrics, selected_metric) == 15

selected_metric = "loss"
assert get_selected_metric(metrics, selected_metric) == 1

0 comments on commit d358315

Please sign in to comment.