Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated loss and eval metrics #896

Merged
merged 16 commits into from
Oct 30, 2024
34 changes: 34 additions & 0 deletions src/fairchem/core/common/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Registry:
# Mappings to respective classes.
"task_name_mapping": {},
"dataset_name_mapping": {},
"loss_name_mapping": {},
"model_name_mapping": {},
"logger_name_mapping": {},
"trainer_name_mapping": {},
Expand Down Expand Up @@ -109,6 +110,35 @@ def wrap(func: Callable[..., R]) -> Callable[..., R]:

return wrap

@classmethod
def register_loss(cls, name):
r"""Register a loss to registry with key 'name'

Args:
name: Key with which the loss will be registered.

Usage::

from fairchem.core.common.registry import registry
from torch import nn

@registry.register_loss("mae")
class MAELoss(nn.Module):
...

"""

def wrap(func):
from torch import nn

assert issubclass(
func, nn.Module
), "All loss must inherit torch.nn.Module class"
cls.mapping["loss_name_mapping"][name] = func
return func

return wrap

@classmethod
def register_model(cls, name: str):
r"""Register a model to registry with key 'name'
Expand Down Expand Up @@ -255,6 +285,10 @@ def get_task_class(cls, name: str):
def get_dataset_class(cls, name: str):
return cls.get_class(name, "dataset_name_mapping")

@classmethod
def get_loss_class(cls, name):
return cls.get_class(name, "loss_name_mapping")

@classmethod
def get_model_class(cls, name: str):
return cls.get_class(name, "model_name_mapping")
Expand Down
16 changes: 0 additions & 16 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@

import fairchem.core
from fairchem.core.common.registry import registry
from fairchem.core.modules.loss import AtomwiseL2Loss, L2MAELoss

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down Expand Up @@ -1433,21 +1432,6 @@ def update_config(base_config):
return config


def get_loss_module(loss_name):
if loss_name in ["l1", "mae"]:
loss_fn = nn.L1Loss()
elif loss_name == "mse":
loss_fn = nn.MSELoss()
elif loss_name == "l2mae":
loss_fn = L2MAELoss()
elif loss_name == "atomwisel2":
loss_fn = AtomwiseL2Loss()
else:
raise NotImplementedError(f"Unknown loss function name: {loss_name}")

return loss_fn


def load_model_and_weights_from_checkpoint(checkpoint_path: str) -> nn.Module:
if not os.path.isfile(checkpoint_path):
raise FileNotFoundError(
Expand Down
145 changes: 90 additions & 55 deletions src/fairchem/core/modules/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar
from functools import wraps
from typing import TYPE_CHECKING, Callable, ClassVar

import numpy as np
import torch
Expand All @@ -34,7 +35,7 @@
with the relevant metrics computed.
"""

NONE = slice(None)
NONE_SLICE = slice(None)


class Evaluator:
Expand Down Expand Up @@ -88,10 +89,9 @@ def eval(
self,
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
prev_metrics=None,
prev_metrics: dict | None = None,
):
if prev_metrics is None:
prev_metrics = {}
prev_metrics = prev_metrics or {}
metrics = prev_metrics

for target_property in self.target_metrics:
Expand Down Expand Up @@ -130,18 +130,98 @@ def update(self, key, stat, metrics):
return metrics


def metrics_dict(metric_fun: Callable) -> Callable:
"""Wrap up the return of a metrics function"""

@wraps(metric_fun)
def wrapped_metrics(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = None,
**kwargs,
) -> dict[str, torch.Tensor]:
error = metric_fun(prediction, target, key, **kwargs)
return {
"metric": torch.mean(error).item(),
"total": torch.sum(error).item(),
"numel": error.numel(),
}

return wrapped_metrics


@metrics_dict
def cosine_similarity(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE_SLICE,
):
# cast to float 32 to avoid 0/nan issues in fp16
# https://github.com/pytorch/pytorch/issues/69512
return torch.cosine_similarity(prediction[key].float(), target[key].float())


@metrics_dict
def mae(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE_SLICE,
) -> torch.Tensor:
return torch.abs(target[key] - prediction[key])


@metrics_dict
def mse(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE_SLICE,
) -> torch.Tensor:
return (target[key] - prediction[key]) ** 2


@metrics_dict
def per_atom_mae(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE_SLICE,
) -> torch.Tensor:
return torch.abs(target[key] - prediction[key]) / target["natoms"].unsqueeze(1)


@metrics_dict
def per_atom_mse(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE_SLICE,
) -> torch.Tensor:
return ((target[key] - prediction[key]) / target["natoms"].unsqueeze(1)) ** 2


@metrics_dict
def magnitude_error(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE_SLICE,
p: int = 2,
) -> torch.Tensor:
assert prediction[key].shape[1] > 1
return torch.abs(
torch.norm(prediction[key], p=p, dim=-1) - torch.norm(target[key], p=p, dim=-1)
)


def forcesx_mae(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE,
key: Hashable = NONE_SLICE,
):
return mae(prediction["forces"][:, 0], target["forces"][:, 0])


def forcesx_mse(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE,
key: Hashable = NONE_SLICE,
):
return mse(prediction["forces"][:, 0], target["forces"][:, 0])

Expand Down Expand Up @@ -289,57 +369,12 @@ def min_diff(
return np.matmul(fractional, cell)


def cosine_similarity(
def rmse(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE,
):
# cast to float 32 to avoid 0/nan issues in fp16
# https://github.com/pytorch/pytorch/issues/69512
error = torch.cosine_similarity(prediction[key].float(), target[key].float())
return {
"metric": torch.mean(error).item(),
"total": torch.sum(error).item(),
"numel": error.numel(),
}


def mae(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE,
) -> dict[str, float | int]:
error = torch.abs(target[key] - prediction[key])
return {
"metric": torch.mean(error).item(),
"total": torch.sum(error).item(),
"numel": error.numel(),
}


def mse(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE,
) -> dict[str, float | int]:
error = (target[key] - prediction[key]) ** 2
return {
"metric": torch.mean(error).item(),
"total": torch.sum(error).item(),
"numel": error.numel(),
}


def magnitude_error(
prediction: dict[str, torch.Tensor],
target: dict[str, torch.Tensor],
key: Hashable = NONE,
p: int = 2,
key: Hashable = None,
) -> dict[str, float | int]:
assert prediction[key].shape[1] > 1
error = torch.abs(
torch.norm(prediction[key], p=p, dim=-1) - torch.norm(target[key], p=p, dim=-1)
)
error = torch.sqrt(((target[key] - prediction[key]) ** 2).sum(dim=-1))
return {
"metric": torch.mean(error).item(),
"total": torch.sum(error).item(),
Expand Down
Loading