diff --git a/src/fairchem/core/common/distutils.py b/src/fairchem/core/common/distutils.py index 8989840641..f6bf88ccaf 100644 --- a/src/fairchem/core/common/distutils.py +++ b/src/fairchem/core/common/distutils.py @@ -10,7 +10,8 @@ import logging import os import subprocess -from typing import TypeVar +from datetime import timedelta +from typing import Any, TypeVar import torch import torch.distributed as dist @@ -27,6 +28,7 @@ def os_environ_get_or_throw(x: str) -> str: def setup(config) -> None: + timeout = timedelta(minutes=config.get("timeout", 30)) if config["submit"]: node_list = os.environ.get("SLURM_STEP_NODELIST") if node_list is None: @@ -72,6 +74,7 @@ def setup(config) -> None: init_method=config["init_method"], world_size=config["world_size"], rank=config["rank"], + timeout=timeout, ) except subprocess.CalledProcessError as e: # scontrol failed raise e @@ -95,10 +98,11 @@ def setup(config) -> None: rank=world_rank, world_size=world_size, init_method="env://", + timeout=timeout, ) else: config["local_rank"] = int(os.environ.get("LOCAL_RANK", config["local_rank"])) - dist.init_process_group(backend=config.get("backend", "nccl")) + dist.init_process_group(backend=config.get("backend", "nccl"), timeout=timeout) def cleanup() -> None: @@ -135,6 +139,14 @@ def broadcast( dist.broadcast(tensor, src, group, async_op) +def broadcast_object_list( + object_list: list[Any], src: int, group=dist.group.WORLD, device: str | None = None +) -> None: + if get_world_size() == 1: + return + dist.broadcast_object_list(object_list, src, group, device) + + def all_reduce( data, group=dist.group.WORLD, average: bool = False, device=None ) -> torch.Tensor: diff --git a/src/fairchem/core/datasets/ase_datasets.py b/src/fairchem/core/datasets/ase_datasets.py index 15c22322db..08618c9f25 100644 --- a/src/fairchem/core/datasets/ase_datasets.py +++ b/src/fairchem/core/datasets/ase_datasets.py @@ -13,7 +13,7 @@ import os import warnings from abc import ABC, abstractmethod -from functools import cache, reduce +from functools import cache from glob import glob from pathlib import Path from typing import Any, Callable @@ -467,13 +467,14 @@ class AseDBDataset(AseAtomsDataset): def _load_dataset_get_ids(self, config: dict) -> list[int]: if isinstance(config["src"], list): - if os.path.isdir(config["src"][0]): - filepaths = reduce( - lambda x, y: x + y, - (glob(f"{path}/*") for path in config["src"]), - ) - else: - filepaths = config["src"] + filepaths = [] + for path in config["src"]: + if os.path.isdir(path): + filepaths.extend(glob(f"{path}/*")) + elif os.path.isfile(path): + filepaths.append(path) + else: + raise RuntimeError(f"Error reading dataset in {path}!") elif os.path.isfile(config["src"]): filepaths = [config["src"]] elif os.path.isdir(config["src"]): diff --git a/src/fairchem/core/modules/normalization/__init__.py b/src/fairchem/core/modules/normalization/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/fairchem/core/modules/normalization/_load_utils.py b/src/fairchem/core/modules/normalization/_load_utils.py new file mode 100644 index 0000000000..0825886db9 --- /dev/null +++ b/src/fairchem/core/modules/normalization/_load_utils.py @@ -0,0 +1,113 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Callable + +import torch + +from fairchem.core.common.utils import save_checkpoint + +if TYPE_CHECKING: + from pathlib import Path + + from torch.nn import Module + from torch.utils.data import Dataset + + +def _load_check_duplicates(config: dict, name: str) -> dict[str, torch.nn.Module]: + """Attempt to load a single file with normalizers/element references and check config for duplicate targets. + + Args: + config: configuration dictionary + name: Name of module to use for logging + + Returns: + dictionary of normalizer or element reference modules + """ + modules = {} + if "file" in config: + modules = torch.load(config["file"]) + logging.info(f"Loaded {name} for the following targets: {list(modules.keys())}") + # make sure that element-refs are not specified both as fit and file + fit_targets = config["fit"]["targets"] if "fit" in config else [] + duplicates = list( + filter( + lambda x: x in fit_targets, + list(config) + list(modules.keys()), + ) + ) + if len(duplicates) > 0: + logging.warning( + f"{name} values for the following targets {duplicates} have been specified to be fit and also read" + f" from a file. The files read from file will be used instead of fitting." + ) + duplicates = list(filter(lambda x: x in modules, config)) + if len(duplicates) > 0: + logging.warning( + f"Duplicate {name} values for the following targets {duplicates} where specified in the file " + f"{config['file']} and an explicitly set file. The normalization values read from " + f"{config['file']} will be used." + ) + return modules + + +def _load_from_config( + config: dict, + name: str, + fit_fun: Callable[[list[str], Dataset, Any, ...], dict[str, Module]], + create_fun: Callable[[str | Path], Module], + dataset: Dataset, + checkpoint_dir: str | Path | None = None, + **fit_kwargs, +) -> dict[str, torch.nn.Module]: + """Load or fit normalizers or element references from config + + If a fit is done, a fitted key with value true is added to the config to avoid re-fitting + once a checkpoint has been saved. + + Args: + config: configuration dictionary + name: Name of module to use for logging + fit_fun: Function to fit modules + create_fun: Function to create a module from file + checkpoint_dir: directory to save modules. If not given, modules won't be saved. + + Returns: + dictionary of normalizer or element reference modules + + """ + modules = _load_check_duplicates(config, name) + for target in config: + if target == "fit" and not config["fit"].get("fitted", False): + # remove values for output targets that have already been read from files + targets = [ + target for target in config["fit"]["targets"] if target not in modules + ] + fit_kwargs.update( + {k: v for k, v in config["fit"].items() if k != "targets"} + ) + modules.update(fit_fun(targets=targets, dataset=dataset, **fit_kwargs)) + config["fit"]["fitted"] = True + # if a single file for all outputs is not provided, + # then check if a single file is provided for a specific output + elif target != "file": + modules[target] = create_fun(**config[target]) + # save the linear references for possible subsequent use + if checkpoint_dir is not None: + path = save_checkpoint( + modules, + checkpoint_dir, + f"{name}.pt", + ) + logging.info( + f"{name} checkpoint for targets {list(modules.keys())} have been saved to: {path}" + ) + + return modules diff --git a/src/fairchem/core/modules/normalization/element_references.py b/src/fairchem/core/modules/normalization/element_references.py new file mode 100644 index 0000000000..e41dbe588c --- /dev/null +++ b/src/fairchem/core/modules/normalization/element_references.py @@ -0,0 +1,290 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +from functools import partial +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from fairchem.core.datasets import data_list_collater + +from ._load_utils import _load_from_config + +if TYPE_CHECKING: + from torch_geometric.data import Batch + + +class LinearReferences(nn.Module): + """Represents an elemental linear references model for a target property. + + In an elemental reference associates a value with each chemical element present in the dataset. + Elemental references define a chemical composition model, i.e. a rough approximation of a target + property (energy) using elemental references is done by summing the elemental references multiplied + by the number of times the corresponding element is present. + + Elemental references energies can be taken as: + - the energy of a chemical species in its elemental state + (i.e. lowest energy polymorph of single element crystal structures for solids) + - fitting a linear model to a dataset, where the features are the counts of each element in each data point. + see the function fit_linear references below for details + + Training GNNs to predict the difference between DFT and the predictions of a chemical composition + model represent a useful normalization scheme that can improve model accuracy. See for example the + "Alternative reference scheme" section of the OC22 manuscript: https://arxiv.org/pdf/2206.08917 + """ + + def __init__( + self, + element_references: torch.Tensor | None = None, + max_num_elements: int = 118, + ): + """ + Args: + element_references (Tensor): tensor with linear reference values + max_num_elements (int): max number of elements - 118 is a stretch + metrics (dict): dictionary with accuracy metrics in predicting values for structures used in fitting. + """ + super().__init__() + self.register_buffer( + name="element_references", + tensor=element_references + if element_references is not None + else torch.zeros(max_num_elements + 1), + ) + + def _apply_refs( + self, target: torch.Tensor, batch: Batch, sign: int, reshaped: bool = True + ) -> torch.Tensor: + """Apply references batch-wise""" + indices = batch.atomic_numbers.to( + dtype=torch.int, device=self.element_references.device + ) + elemrefs = self.element_references[indices].to(dtype=target.dtype) + # this option should not exist, all tensors should have compatible shapes in dataset and trainer outputs + if reshaped: + elemrefs = elemrefs.view(batch.natoms.sum(), -1) + + return target.index_add(0, batch.batch, elemrefs, alpha=sign) + + @torch.autocast(device_type="cuda", enabled=False) + def dereference( + self, target: torch.Tensor, batch: Batch, reshaped: bool = True + ) -> torch.Tensor: + """Remove linear references""" + return self._apply_refs(target, batch, -1, reshaped=reshaped) + + @torch.autocast(device_type="cuda", enabled=False) + def forward( + self, target: torch.Tensor, batch: Batch, reshaped: bool = True + ) -> torch.Tensor: + """Add linear references""" + return self._apply_refs(target, batch, 1, reshaped=reshaped) + + +def create_element_references( + file: str | Path | None = None, + state_dict: dict | None = None, +) -> LinearReferences: + """Create an element reference module. + + Args: + type (str): type of reference (only linear implemented) + file (str or Path): path to pt or npz file + state_dict (dict): a state dict of a element reference module + + Returns: + LinearReference + """ + if file is not None and state_dict is not None: + logging.warning( + "Both a file and a state_dict for element references was given." + "The references will be read from the file and the provided state_dict will be ignored." + ) + + # path takes priority if given + if file is not None: + extension = Path(file).suffix + if extension == ".pt": + # try to load a pt file + state_dict = torch.load(file) + elif extension == ".npz": + state_dict = {} + with np.load(file) as values: + # legacy linref files + if "coeff" in values: + state_dict["element_references"] = torch.tensor(values["coeff"]) + else: + state_dict["element_references"] = torch.tensor( + values["element_references"] + ) + else: + raise RuntimeError( + f"Element references file with extension '{extension}' is not supported." + ) + + if "element_references" not in state_dict: + raise RuntimeError("Unable to load linear element references!") + + return LinearReferences(element_references=state_dict["element_references"]) + + +@torch.no_grad() +def fit_linear_references( + targets: list[str], + dataset: Dataset, + batch_size: int, + num_batches: int | None = None, + num_workers: int = 0, + max_num_elements: int = 118, + log_metrics: bool = True, + use_numpy: bool = True, + driver: str | None = None, + shuffle: bool = True, + seed: int = 0, +) -> dict[str, LinearReferences]: + """Fit a set linear references for a list of targets using a given number of batches. + + Args: + targets: list of target names + dataset: data set to fit linear references with + batch_size: size of batch + num_batches: number of batches to use in fit. If not given will use all batches + num_workers: number of workers to use in data loader. + Note setting num_workers > 1 leads to finicky multiprocessing issues when using this function + in distributed mode. The issue has to do with pickling the functions in load_references_from_config + see function below... + max_num_elements: max number of elements in dataset. If not given will use an ambitious value of 118 + log_metrics: if true will compute MAE, RMSE and R2 score of fit and log. + use_numpy: use numpy.linalg.lstsq instead of torch. This tends to give better solutions. + driver: backend used to solve linear system. See torch.linalg.lstsq docs. Ignored if use_numpy=True + shuffle: whether to shuffle when loading the dataset + seed: random seed used to shuffle the sampler if shuffle=True + + Returns: + dict of fitted LinearReferences objects + """ + data_loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=partial(data_list_collater, otf_graph=True), + num_workers=num_workers, + persistent_workers=num_workers > 0, + generator=torch.Generator().manual_seed(seed), + ) + + num_batches = num_batches if num_batches is not None else len(data_loader) + if num_batches > len(data_loader): + logging.warning( + f"The given num_batches {num_batches} is larger than total batches of size {batch_size} in dataset. " + f"num_batches will be ignored and the whole dataset will be used." + ) + num_batches = len(data_loader) + + max_num_elements += 1 # + 1 since H starts at index 1 + # solving linear system happens on CPU, which allows handling poorly conditioned and + # rank deficient matrices, unlike torch lstsq on GPU + composition_matrix = torch.zeros( + num_batches * batch_size, + max_num_elements, + ) + + target_vectors = { + target: torch.zeros(num_batches * batch_size) for target in targets + } + + logging.info( + f"Fitting linear references using {num_batches * batch_size} samples in {num_batches} " + f"batches of size {batch_size}." + ) + for i, batch in tqdm( + enumerate(data_loader), total=num_batches, desc="Fitting linear references" + ): + if i == 0: + assert all( + len(batch[target].squeeze().shape) == 1 for target in targets + ), "element references can only be used for scalar targets" + elif i == num_batches: + break + + next_batch_size = len(batch) if i == len(data_loader) - 1 else batch_size + for target in targets: + target_vectors[target][ + i * batch_size : i * batch_size + next_batch_size + ] = batch[target].to(torch.float64) + for j, data in enumerate(batch.to_data_list()): + composition_matrix[i * batch_size + j] = torch.bincount( + data.atomic_numbers.int(), + minlength=max_num_elements, + ).to(torch.float64) + + # reduce the composition matrix to only features that are non-zero to improve rank + mask = composition_matrix.sum(axis=0) != 0.0 + reduced_composition_matrix = composition_matrix[:, mask] + elementrefs = {} + + for target in targets: + coeffs = torch.zeros(max_num_elements) + + if use_numpy: + solution = torch.tensor( + np.linalg.lstsq( + reduced_composition_matrix.numpy(), + target_vectors[target].numpy(), + rcond=None, + )[0] + ) + else: + lstsq = torch.linalg.lstsq( + reduced_composition_matrix, target_vectors[target], driver=driver + ) + solution = lstsq.solution + + coeffs[mask] = solution + elementrefs[target] = LinearReferences(coeffs) + + if log_metrics is True: + y = target_vectors[target] + y_pred = torch.matmul(reduced_composition_matrix, solution) + y_mean = target_vectors[target].mean() + N = len(target_vectors[target]) + ss_res = ((y - y_pred) ** 2).sum() + ss_tot = ((y - y_mean) ** 2).sum() + mae = (abs(y - y_pred)).sum() / N + rmse = (((y - y_pred) ** 2).sum() / N).sqrt() + r2 = 1 - (ss_res / ss_tot) + logging.info( + f"Training accuracy metrics for fitted linear element references: mae={mae}, rmse={rmse}, r2 score={r2}" + ) + + return elementrefs + + +def load_references_from_config( + config: dict[str, Any], + dataset: Dataset, + seed: int = 0, + checkpoint_dir: str | Path | None = None, +) -> dict[str, LinearReferences]: + """Create a dictionary with element references from a config.""" + return _load_from_config( + config, + "element_references", + fit_linear_references, + create_element_references, + dataset, + checkpoint_dir, + seed=seed, + ) diff --git a/src/fairchem/core/modules/normalization/normalizer.py b/src/fairchem/core/modules/normalization/normalizer.py new file mode 100644 index 0000000000..f16db7d398 --- /dev/null +++ b/src/fairchem/core/modules/normalization/normalizer.py @@ -0,0 +1,290 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +import warnings +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from fairchem.core.datasets import data_list_collater + +from ._load_utils import _load_from_config + +if TYPE_CHECKING: + from collections.abc import Mapping + + from fairchem.core.modules.normalization.element_references import LinearReferences + + +class Normalizer(nn.Module): + """Normalize/denormalize a tensor and optionally add a atom reference offset.""" + + def __init__( + self, + mean: float | torch.Tensor = 0.0, + rmsd: float | torch.Tensor = 1.0, + ): + """tensor is taken as a sample to calculate the mean and rmsd""" + super().__init__() + + if isinstance(mean, float): + mean = torch.tensor(mean) + if isinstance(rmsd, float): + rmsd = torch.tensor(rmsd) + + self.register_buffer(name="mean", tensor=mean) + self.register_buffer(name="rmsd", tensor=rmsd) + + @torch.autocast(device_type="cuda", enabled=False) + def norm(self, tensor: torch.Tensor) -> torch.Tensor: + return (tensor - self.mean) / self.rmsd + + @torch.autocast(device_type="cuda", enabled=False) + def denorm(self, normed_tensor: torch.Tensor) -> torch.Tensor: + return normed_tensor * self.rmsd + self.mean + + def forward(self, normed_tensor: torch.Tensor) -> torch.Tensor: + return self.denorm(normed_tensor) + + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False + ): + # check if state dict is legacy state dicts + if "std" in state_dict: + state_dict = { + "mean": torch.tensor(state_dict["mean"]), + "rmsd": state_dict["std"], + } + + return super().load_state_dict(state_dict, strict=strict, assign=assign) + + +def create_normalizer( + file: str | Path | None = None, + state_dict: dict | None = None, + tensor: torch.Tensor | None = None, + mean: float | torch.Tensor | None = None, + rmsd: float | torch.Tensor | None = None, + stdev: float | torch.Tensor | None = None, +) -> Normalizer: + """Build a target data normalizers with optional atom ref + + Only one of file, state_dict, tensor, or (mean and rmsd) will be used to create a normalizer. + If more than one set of inputs are given priority will be given following the order in which they are listed above. + + Args: + file (str or Path): path to pt or npz file. + state_dict (dict): a state dict for Normalizer module + tensor (Tensor): a tensor with target values used to compute mean and std + mean (float | Tensor): mean of target data + rmsd (float | Tensor): rmsd of target data, rmsd from mean = stdev, rmsd from 0 = rms + stdev: standard deviation (deprecated, use rmsd instead) + + Returns: + Normalizer + """ + if stdev is not None: + warnings.warn( + "Use of 'stdev' is deprecated, use 'rmsd' instead", DeprecationWarning + ) + if rmsd is not None: + logging.warning( + "Both 'stdev' and 'rmsd' values where given to create a normalizer, rmsd values will be used." + ) + + # old configs called it stdev, using this in the function signature reduces overhead code elsewhere + if stdev is not None and rmsd is None: + rmsd = stdev + + # path takes priority if given + if file is not None: + if state_dict is not None or tensor is not None or mean is not None: + logging.warning( + "A file to a normalizer has been given. Normalization values will be read from it, and all other inputs" + " will be ignored." + ) + extension = Path(file).suffix + if extension == ".pt": + # try to load a pt file + state_dict = torch.load(file) + elif extension == ".npz": + # try to load an NPZ file + values = np.load(file) + mean = values.get("mean") + rmsd = values.get("rmsd") or values.get("std") # legacy files + tensor = None # set to None since values read from file are prioritized + else: + raise RuntimeError( + f"Normalizer file with extension '{extension}' is not supported." + ) + + # state dict is second priority + if state_dict is not None: + if tensor is not None or mean is not None: + logging.warning( + "The state_dict provided will be used to set normalization values. All other inputs will be ignored." + ) + normalizer = Normalizer() + normalizer.load_state_dict(state_dict) + return normalizer + + # if not then read target value tensor + if tensor is not None: + if mean is not None: + logging.warning( + "Normalization values will be computed from input tensor, all other inputs will be ignored." + ) + mean = torch.mean(tensor) + rmsd = torch.std(tensor) + elif mean is not None and rmsd is not None: + if not isinstance(mean, torch.Tensor): + mean = torch.tensor(mean) + if not isinstance(rmsd, torch.Tensor): + rmsd = torch.tensor(rmsd) + + # if mean and rmsd are still None than raise an error + if mean is None or rmsd is None: + raise ValueError( + "Incorrect inputs. One of the following sets of inputs must be given: ", + "a file path to a .pt or .npz file, or mean and rmsd values, or a tensor of target values", + ) + + return Normalizer(mean=mean, rmsd=rmsd) + + +@torch.no_grad() +def fit_normalizers( + targets: list[str], + dataset: Dataset, + batch_size: int, + override_values: dict[str, dict[str, float]] | None = None, + rmsd_correction: int | None = None, + element_references: dict | None = None, + num_batches: int | None = None, + num_workers: int = 0, + shuffle: bool = True, + seed: int = 0, +) -> dict[str, Normalizer]: + """Estimate mean and rmsd from data to create normalizers + + Args: + targets: list of target names + dataset: data set to fit linear references with + batch_size: size of batch + override_values: dictionary with target names and values to override. i.e. {"forces": {"mean": 0.0}} will set + the forces mean to zero. + rmsd_correction: correction to use when computing mean in std/rmsd. See docs for torch.std. + If not given, will always use 0 when mean == 0, and 1 otherwise. + element_references: + num_batches: number of batches to use in fit. If not given will use all batches + num_workers: number of workers to use in data loader + Note setting num_workers > 1 leads to finicky multiprocessing issues when using this function + in distributed mode. The issue has to do with pickling the functions in load_normalizers_from_config + see function below... + shuffle: whether to shuffle when loading the dataset + seed: random seed used to shuffle the sampler if shuffle=True + + Returns: + dict of normalizer objects + """ + data_loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=partial(data_list_collater, otf_graph=True), + num_workers=num_workers, + persistent_workers=num_workers > 0, + generator=torch.Generator().manual_seed(seed), + ) + + num_batches = num_batches if num_batches is not None else len(data_loader) + if num_batches > len(data_loader): + logging.warning( + f"The given num_batches {num_batches} is larger than total batches of size {batch_size} in dataset. " + f"num_batches will be ignored and the whole dataset will be used." + ) + num_batches = len(data_loader) + + element_references = element_references or {} + target_vectors = defaultdict(list) + + logging.info( + f"Estimating mean and rmsd for normalization using {num_batches * batch_size} samples in {num_batches} batches " + f"of size {batch_size}." + ) + for i, batch in tqdm( + enumerate(data_loader), total=num_batches, desc="Estimating mean and rmsd" + ): + if i == num_batches: + break + + for target in targets: + target_vector = batch[target] + if target in element_references: + target_vector = element_references[target].dereference( + target_vector, batch, reshaped=False + ) + target_vectors[target].append(target_vector) + + normalizers = {} + for target in targets: + target_vector = torch.cat(target_vectors[target], dim=0) + values = {"mean": target_vector.mean()} + if target in override_values: + for name, val in override_values[target].items(): + values[name] = torch.tensor(val) + # calculate root mean square deviation + if "rmsd" not in values: + if rmsd_correction is None: + rmsd_correction = 0 if values["mean"] == 0.0 else 1 + values["rmsd"] = ( + ((target_vector - values["mean"]) ** 2).sum() + / max(len(target_vector) - rmsd_correction, 1) + ).sqrt() + normalizers[target] = create_normalizer(**values) + + return normalizers + + +def load_normalizers_from_config( + config: dict[str, Any], + dataset: Dataset, + seed: int = 0, + checkpoint_dir: str | Path | None = None, + element_references: dict[str, LinearReferences] | None = None, +) -> dict[str, Normalizer]: + """Create a dictionary with element references from a config.""" + # edit the config slightly to extract override args + if "fit" in config: + override_values = { + target: vals + for target, vals in config["fit"]["targets"].items() + if isinstance(vals, dict) + } + config["fit"]["override_values"] = override_values + config["fit"]["targets"] = list(config["fit"]["targets"].keys()) + + return _load_from_config( + config, + "normalizers", + fit_normalizers, + create_normalizer, + dataset, + checkpoint_dir, + seed=seed, + element_references=element_references, + ) diff --git a/src/fairchem/core/modules/normalizer.py b/src/fairchem/core/modules/normalizer.py deleted file mode 100644 index 75f34e83f4..0000000000 --- a/src/fairchem/core/modules/normalizer.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -Copyright (c) Meta, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations - -import torch - - -class Normalizer: - """Normalize a Tensor and restore it later.""" - - def __init__( - self, - tensor: torch.Tensor | None = None, - mean=None, - std=None, - device=None, - ) -> None: - """tensor is taken as a sample to calculate the mean and std""" - if tensor is None and mean is None: - return - - if device is None: - device = "cpu" - - self.mean: torch.Tensor - self.std: torch.Tensor - if tensor is not None: - self.mean = torch.mean(tensor, dim=0).to(device) - self.std = torch.std(tensor, dim=0).to(device) - return - - if mean is not None and std is not None: - self.mean = torch.tensor(mean).to(device) - self.std = torch.tensor(std).to(device) - - def to(self, device) -> None: - self.mean = self.mean.to(device) - self.std = self.std.to(device) - - def norm(self, tensor: torch.Tensor) -> torch.Tensor: - return (tensor - self.mean) / self.std - - def denorm(self, normed_tensor: torch.Tensor) -> torch.Tensor: - return normed_tensor * self.std + self.mean - - def state_dict(self): - return {"mean": self.mean, "std": self.std} - - def load_state_dict(self, state_dict) -> None: - self.mean = state_dict["mean"].to(self.mean.device) - self.std = state_dict["std"].to(self.mean.device) diff --git a/src/fairchem/core/modules/transforms.py b/src/fairchem/core/modules/transforms.py index 3a86be468c..52675fd28f 100644 --- a/src/fairchem/core/modules/transforms.py +++ b/src/fairchem/core/modules/transforms.py @@ -19,10 +19,12 @@ def __call__(self, data_object): return data_object for transform_fn in self.config: - # TODO: Normalization information used in the trainers. Ignore here - # for now. - if transform_fn == "normalizer": + # TODO: Normalization information used in the trainers. Ignore here for now + # TODO: if we dont use them here, these should not be defined as "transforms" in the config + # TODO: add them as another entry under dataset, maybe "standardize"? + if transform_fn in ("normalizer", "element_references"): continue + data_object = eval(transform_fn)(data_object, self.config[transform_fn]) return data_object diff --git a/src/fairchem/core/scripts/fit_normalizers.py b/src/fairchem/core/scripts/fit_normalizers.py new file mode 100644 index 0000000000..0cfa2f2db5 --- /dev/null +++ b/src/fairchem/core/scripts/fit_normalizers.py @@ -0,0 +1,119 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import argparse +import logging +from pathlib import Path + +from fairchem.core.common.registry import registry +from fairchem.core.common.utils import load_config, save_checkpoint +from fairchem.core.modules.normalization.element_references import ( + create_element_references, +) +from fairchem.core.modules.normalization.normalizer import fit_normalizers + + +def fit_norms( + config: dict, + output_path: str | Path, + linref_file: str | Path | None = None, + linref_target: str = "energy", +) -> None: + """Fit dataset mean and std using the standard config + + Args: + config: config + output_path: output path + linref_file: path to fitted linear references. IF these are used in training they must be used to compute mean/std + linref_target: target using linear references, basically always energy. + """ + output_path = Path(output_path).resolve() + elementrefs = ( + {linref_target: create_element_references(linref_file)} + if linref_file is not None + else {} + ) + + try: + # load the training dataset + train_dataset = registry.get_dataset_class( + config["dataset"]["train"].get("format", "lmdb") + )(config["dataset"]["train"]) + except KeyError as err: + raise ValueError("Train dataset is not specified in config!") from err + + try: + norm_config = config["dataset"]["train"]["transforms"]["normalizer"]["fit"] + except KeyError as err: + raise ValueError( + "The provided config does not specify a 'fit' block for 'normalizer'!" + ) from err + + targets = list(norm_config["targets"].keys()) + override_values = { + target: vals + for target, vals in norm_config["targets"].items() + if isinstance(vals, dict) + } + + normalizers = fit_normalizers( + targets=targets, + override_values=override_values, + element_references=elementrefs, + dataset=train_dataset, + batch_size=norm_config.get("batch_size", 32), + num_batches=norm_config.get("num_batches"), + num_workers=config.get("optim", {}).get("num_workers", 16), + ) + path = save_checkpoint( + normalizers, + output_path, + "normalizers.pt", + ) + logging.info(f"normalizers have been saved to {path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + required=True, + type=Path, + help="Path to configuration yaml file", + ) + parser.add_argument( + "--out-path", + default=".", + type=str, + help="Output path to save normalizers", + ) + parser.add_argument( + "--linref-path", + type=str, + help="Path to linear references used.", + ) + parser.add_argument( + "--linref-target", + default="energy", + type=str, + help="target for which linear references are used.", + ) + args = parser.parse_args() + config, dup_warning, dup_error = load_config(args.config) + + if len(dup_warning) > 0: + logging.warning( + f"The following keys in the given config have duplicates: {dup_warning}." + ) + if len(dup_error) > 0: + raise RuntimeError( + f"The following include entries in the config have duplicates: {dup_error}" + ) + + fit_norms(config, args.out_path, args.linref_path) diff --git a/src/fairchem/core/scripts/fit_references.py b/src/fairchem/core/scripts/fit_references.py new file mode 100644 index 0000000000..f7f0c84dd7 --- /dev/null +++ b/src/fairchem/core/scripts/fit_references.py @@ -0,0 +1,91 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import argparse +import logging +from pathlib import Path + +from fairchem.core.common.registry import registry +from fairchem.core.common.utils import load_config, save_checkpoint +from fairchem.core.modules.normalization.element_references import fit_linear_references + + +def fit_linref(config: dict, output_path: str | Path) -> None: + """Fit linear references using the standard config + + Args: + config: config + output_path: output path + """ + # load the training dataset + output_path = Path(output_path).resolve() + + try: + # load the training dataset + train_dataset = registry.get_dataset_class( + config["dataset"]["train"].get("format", "lmdb") + )(config["dataset"]["train"]) + except KeyError as err: + raise ValueError("Train dataset is not specified in config!") from err + + try: + elementref_config = config["dataset"]["train"]["transforms"][ + "element_references" + ]["fit"] + except KeyError as err: + raise ValueError( + "The provided config does not specify a 'fit' block for 'element_refereces'!" + ) from err + + element_refs = fit_linear_references( + targets=elementref_config["targets"], + dataset=train_dataset, + batch_size=elementref_config.get("batch_size", 32), + num_batches=elementref_config.get("num_batches"), + num_workers=config.get("optim", {}).get("num_workers", 16), + max_num_elements=elementref_config.get("max_num_elements", 118), + driver=elementref_config.get("driver", None), + ) + + for target, references in element_refs.items(): + path = save_checkpoint( + references.state_dict(), + output_path, + f"{target}_linref.pt", + ) + logging.info(f"{target} linear references have been saved to: {path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + required=True, + type=Path, + help="Path to configuration yaml file", + ) + parser.add_argument( + "--out-path", + default=".", + type=str, + help="Output path to save linear references", + ) + args = parser.parse_args() + config, dup_warning, dup_error = load_config(args.config) + + if len(dup_warning) > 0: + logging.warning( + f"The following keys in the given config have duplicates: {dup_warning}." + ) + if len(dup_error) > 0: + raise RuntimeError( + f"The following include entries in the config have duplicates: {dup_error}" + ) + + fit_linref(config, args.out_path) diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index c21409863e..40c7e65de6 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -43,7 +43,11 @@ from fairchem.core.modules.evaluator import Evaluator from fairchem.core.modules.exponential_moving_average import ExponentialMovingAverage from fairchem.core.modules.loss import DDPLoss -from fairchem.core.modules.normalizer import Normalizer +from fairchem.core.modules.normalization.element_references import ( + LinearReferences, + load_references_from_config, +) +from fairchem.core.modules.normalization.normalizer import load_normalizers_from_config from fairchem.core.modules.scaling.compat import load_scales_compat from fairchem.core.modules.scaling.util import ensure_fitted from fairchem.core.modules.scheduler import LRScheduler @@ -185,6 +189,11 @@ def __init__( if distutils.is_master(): logging.info(yaml.dump(self.config, default_flow_style=False)) + self.elementrefs = {} + self.normalizers = {} + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None self.load() @abstractmethod @@ -208,6 +217,7 @@ def load(self) -> None: self.load_seed_from_config() self.load_logger() self.load_datasets() + self.load_references_and_normalizers() self.load_task() self.load_model() self.load_loss() @@ -395,20 +405,68 @@ def convert_settings_to_split_settings(config, split_name): self.relax_sampler, ) - def load_task(self): - # Normalizer for the dataset. - + def load_references_and_normalizers(self): + """Load or create element references and normalizers from config""" # Is it troublesome that we assume any normalizer info is in train? What if there is no # training dataset? What happens if we just specify a test - normalizer = self.config["dataset"].get("transforms", {}).get("normalizer", {}) - self.normalizers = {} - if normalizer: - for target in normalizer: - self.normalizers[target] = Normalizer( - mean=normalizer[target].get("mean", 0), - std=normalizer[target].get("stdev", 1), + + elementref_config = ( + self.config["dataset"].get("transforms", {}).get("element_references") + ) + norms_config = self.config["dataset"].get("transforms", {}).get("normalizer") + elementrefs, normalizers = {}, {} + if distutils.is_master(): + if elementref_config is not None: + # put them in a list to allow broadcasting python objects + elementrefs = load_references_from_config( + elementref_config, + dataset=self.train_dataset, + seed=self.config["cmd"]["seed"], + checkpoint_dir=self.config["cmd"]["checkpoint_dir"] + if not self.is_debug + else None, + ) + + if norms_config is not None: + normalizers = load_normalizers_from_config( + norms_config, + dataset=self.train_dataset, + seed=self.config["cmd"]["seed"], + checkpoint_dir=self.config["cmd"]["checkpoint_dir"] + if not self.is_debug + else None, + element_references=elementrefs, ) + # log out the values that will be used. + for output, normalizer in normalizers.items(): + logging.info( + f"Normalization values for output {output}: mean={normalizer.mean.item()}, rmsd={normalizer.rmsd.item()}." + ) + + # put them in a list to broadcast them + elementrefs, normalizers = [elementrefs], [normalizers] + distutils.broadcast_object_list( + object_list=elementrefs, src=0, device=self.device + ) + distutils.broadcast_object_list( + object_list=normalizers, src=0, device=self.device + ) + # make sure element refs and normalizers are on this device + self.elementrefs.update( + { + output: elementref.to(self.device) + for output, elementref in elementrefs[0].items() + } + ) + self.normalizers.update( + { + output: normalizer.to(self.device) + for output, normalizer in normalizers[0].items() + } + ) + + def load_task(self): self.output_targets = {} for target_name in self.config["outputs"]: self.output_targets[target_name] = self.config["outputs"][target_name] @@ -425,15 +483,15 @@ def load_task(self): ][target_name].get("level", "system") if "train_on_free_atoms" not in self.output_targets[subtarget]: self.output_targets[subtarget]["train_on_free_atoms"] = ( - self.config["outputs"][target_name].get( - "train_on_free_atoms", True - ) + self.config[ + "outputs" + ][target_name].get("train_on_free_atoms", True) ) if "eval_on_free_atoms" not in self.output_targets[subtarget]: self.output_targets[subtarget]["eval_on_free_atoms"] = ( - self.config["outputs"][target_name].get( - "eval_on_free_atoms", True - ) + self.config[ + "outputs" + ][target_name].get("eval_on_free_atoms", True) ) # TODO: Assert that all targets, loss fn, metrics defined are consistent @@ -550,9 +608,20 @@ def load_checkpoint( target_key = key if target_key in self.normalizers: - self.normalizers[target_key].load_state_dict( + mkeys = self.normalizers[target_key].load_state_dict( checkpoint["normalizers"][key] ) + assert len(mkeys.missing_keys) == 0 + assert len(mkeys.unexpected_keys) == 0 + + for key, state_dict in checkpoint.get("elementrefs", {}).items(): + elementrefs = LinearReferences( + max_num_elements=len(state_dict["element_references"]) - 1 + ) + mkeys = elementrefs.load_state_dict(state_dict) + self.elementrefs[key] = elementrefs + assert len(mkeys.missing_keys) == 0 + assert len(mkeys.unexpected_keys) == 0 if self.scaler and checkpoint["amp"]: self.scaler.load_state_dict(checkpoint["amp"]) @@ -649,32 +718,40 @@ def save( training_state: bool = True, ) -> str | None: if not self.is_debug and distutils.is_master(): + state = { + "state_dict": self.model.state_dict(), + "normalizers": { + key: value.state_dict() for key, value in self.normalizers.items() + }, + "elementrefs": { + key: value.state_dict() for key, value in self.elementrefs.items() + }, + "config": self.config, + "val_metrics": metrics, + "amp": self.scaler.state_dict() if self.scaler else None, + } if training_state: - return save_checkpoint( + state.update( { "epoch": self.epoch, "step": self.step, - "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "scheduler": ( self.scheduler.scheduler.state_dict() if self.scheduler.scheduler_type != "Null" else None ), - "normalizers": { - key: value.state_dict() - for key, value in self.normalizers.items() - }, "config": self.config, - "val_metrics": metrics, "ema": self.ema.state_dict() if self.ema else None, - "amp": self.scaler.state_dict() if self.scaler else None, "best_val_metric": self.best_val_metric, "primary_metric": self.evaluation_metrics.get( "primary_metric", self.evaluator.task_primary_metric[self.name], ), }, + ) + ckpt_path = save_checkpoint( + state, checkpoint_dir=self.config["cmd"]["checkpoint_dir"], checkpoint_file=checkpoint_file, ) @@ -683,22 +760,13 @@ def save( self.ema.store() self.ema.copy_to() ckpt_path = save_checkpoint( - { - "state_dict": self.model.state_dict(), - "normalizers": { - key: value.state_dict() - for key, value in self.normalizers.items() - }, - "config": self.config, - "val_metrics": metrics, - "amp": self.scaler.state_dict() if self.scaler else None, - }, + state, checkpoint_dir=self.config["cmd"]["checkpoint_dir"], checkpoint_file=checkpoint_file, ) if self.ema: self.ema.restore() - return ckpt_path + return ckpt_path return None def update_best( diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 72c005893d..26269c6da4 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -11,6 +11,7 @@ import os from collections import defaultdict from itertools import chain +from typing import TYPE_CHECKING import numpy as np import torch @@ -25,6 +26,9 @@ from fairchem.core.modules.scaling.util import ensure_fitted from fairchem.core.trainers.base_trainer import BaseTrainer +if TYPE_CHECKING: + from torch_geometric.data import Batch + @registry.register_trainer("ocp") @registry.register_trainer("energy") @@ -148,7 +152,6 @@ def train(self, disable_eval_tqdm: bool = False) -> None: # Get a batch. batch = next(train_loader_iter) - # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) @@ -227,10 +230,21 @@ def train(self, disable_eval_tqdm: bool = False) -> None: if checkpoint_every == -1: self.save(checkpoint_file="checkpoint.pt", training_state=True) + def _denorm_preds(self, target_key: str, prediction: torch.Tensor, batch: Batch): + """Convert model output from a batch into raw prediction by denormalizing and adding references""" + # denorm the outputs + if target_key in self.normalizers: + prediction = self.normalizers[target_key](prediction) + + # add element references + if target_key in self.elementrefs: + prediction = self.elementrefs[target_key](prediction, batch) + + return prediction + def _forward(self, batch): out = self.model(batch.to(self.device)) - ### TODO: Move into BaseModel in OCP 2.0 outputs = {} batch_size = batch.natoms.numel() num_atoms_in_batch = batch.natoms.sum() @@ -254,10 +268,7 @@ def _forward(self, batch): for subtarget_key in self.output_targets[target_key]["decomposition"]: irreps = self.output_targets[subtarget_key]["irrep_dim"] - _pred = out[subtarget_key] - - if self.normalizers.get(subtarget_key, False): - _pred = self.normalizers[subtarget_key].denorm(_pred) + _pred = self._denorm_preds(subtarget_key, out[subtarget_key], batch) ## Fill in the corresponding irreps prediction ## Reshape irrep prediction to (batch_size, irrep_dim) @@ -278,7 +289,6 @@ def _forward(self, batch): pred = pred.view(num_atoms_in_batch, -1) else: pred = pred.view(batch_size, -1) - outputs[target_key] = pred return outputs @@ -307,8 +317,6 @@ def _compute_loss(self, out, batch): natoms = natoms[mask] num_atoms_in_batch = natoms.numel() - if self.normalizers.get(target_name, False): - target = self.normalizers[target_name].norm(target) ### reshape accordingly: num_atoms_in_batch, -1 or num_systems_in_batch, -1 if self.output_targets[target_name]["level"] == "atom": @@ -316,6 +324,14 @@ def _compute_loss(self, out, batch): else: target = target.view(batch_size, -1) + # to keep the loss coefficient weights balanced we remove linear references + # subtract element references from target data + if target_name in self.elementrefs: + target = self.elementrefs[target_name].dereference(target, batch) + # normalize the targets data + if target_name in self.normalizers: + target = self.normalizers[target_name].norm(target) + mult = loss_info["coefficient"] loss.append( mult @@ -373,11 +389,8 @@ def _compute_metrics(self, out, batch, evaluator, metrics=None): else: target = target.view(batch_size, -1) + out[target_name] = self._denorm_preds(target_name, out[target_name], batch) targets[target_name] = target - if self.normalizers.get(target_name, False): - out[target_name] = self.normalizers[target_name].denorm( - out[target_name] - ) targets["natoms"] = natoms out["natoms"] = natoms @@ -385,7 +398,7 @@ def _compute_metrics(self, out, batch, evaluator, metrics=None): return evaluator.eval(out, targets, prev_metrics=metrics) # Takes in a new data source and generates predictions on it. - @torch.no_grad() + @torch.no_grad def predict( self, data_loader, @@ -419,7 +432,7 @@ def predict( predictions = defaultdict(list) - for _i, batch in tqdm( + for _, batch in tqdm( enumerate(data_loader), total=len(data_loader), position=rank, @@ -430,9 +443,7 @@ def predict( out = self._forward(batch) for target_key in self.config["outputs"]: - pred = out[target_key] - if self.normalizers.get(target_key, False): - pred = self.normalizers[target_key].denorm(pred) + pred = self._denorm_preds(target_key, out[target_key], batch) if per_image: ### Save outputs in desired precision, default float16 @@ -449,7 +460,8 @@ def predict( else: dtype = torch.float16 - pred = pred.cpu().detach().to(dtype) + pred = pred.detach().cpu().to(dtype) + ### Split predictions into per-image predictions if self.config["outputs"][target_key]["level"] == "atom": batch_natoms = batch.natoms @@ -510,6 +522,7 @@ def predict( return predictions + @torch.no_grad def run_relaxations(self, split="val"): ensure_fitted(self._unwrapped_model) @@ -642,9 +655,7 @@ def run_relaxations(self, split="val"): ) gather_results["chunk_idx"] = np.cumsum( [gather_results["chunk_idx"][i] for i in idx] - )[ - :-1 - ] # np.split does not need last idx, assumes n-1:end + )[:-1] # np.split does not need last idx, assumes n-1:end full_path = os.path.join( self.config["cmd"]["results_dir"], "relaxed_positions.npz" diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 1584becd45..54055d0c3b 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -7,19 +7,20 @@ from pathlib import Path import numpy as np +import numpy.testing as npt import pytest import yaml +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + +from fairchem.core._cli import Runner +from fairchem.core.common.flags import flags from fairchem.core.common.test_utils import ( PGConfig, init_env_rank_and_launch_test, spawn_multi_process, ) -from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes -from tensorboard.backend.event_processing.event_accumulator import EventAccumulator - -from fairchem.core._cli import Runner -from fairchem.core.common.flags import flags from fairchem.core.common.utils import build_config, setup_logging +from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes setup_logging() @@ -66,21 +67,56 @@ def tutorial_val_src(tutorial_dataset_path): return tutorial_dataset_path / "s2ef/val_20" -def oc20_lmdb_train_and_val_from_paths(train_src, val_src, test_src=None): +def oc20_lmdb_train_and_val_from_paths( + train_src, val_src, test_src=None, otf_norms=False +): datasets = {} if train_src is not None: datasets["train"] = { "src": train_src, - "normalize_labels": True, - "target_mean": -0.7554450631141663, - "target_std": 2.887317180633545, - "grad_target_mean": 0.0, - "grad_target_std": 2.887317180633545, + "format": "lmdb", + "key_mapping": {"y": "energy", "force": "forces"}, } + if otf_norms is True: + datasets["train"].update( + { + "transforms": { + "element_references": { + "fit": { + "targets": ["energy"], + "batch_size": 4, + "num_batches": 10, + "driver": "gelsd", + } + }, + "normalizer": { + "fit": { + "targets": {"energy": None, "forces": {"mean": 0.0}}, + "batch_size": 4, + "num_batches": 10, + } + }, + } + } + ) + else: + datasets["train"].update( + { + "transforms": { + "normalizer": { + "energy": { + "mean": -0.7554450631141663, + "stdev": 2.887317180633545, + }, + "forces": {"mean": 0.0, "stdev": 2.887317180633545}, + } + } + } + ) if val_src is not None: - datasets["val"] = {"src": val_src} + datasets["val"] = {"src": val_src, "format": "lmdb"} if test_src is not None: - datasets["test"] = {"src": test_src} + datasets["test"] = {"src": test_src, "format": "lmdb"} return datasets @@ -124,7 +160,6 @@ def _run_main( yaml_config["backend"] = "gloo" with open(str(config_yaml), "w") as yaml_file: yaml.dump(yaml_config, yaml_file) - run_args = { "run_dir": rundir, "logdir": f"{rundir}/logs", @@ -168,11 +203,6 @@ def _run_main( ) -@pytest.fixture(scope="class") -def torch_tempdir(tmpdir_factory): - return tmpdir_factory.mktemp("torch_tempdir") - - """ These tests are intended to be as quick as possible and test only that the network is runnable and outputs training+validation to tensorboard output These should catch errors such as shape mismatches or otherways to code wise break a network @@ -180,12 +210,7 @@ def torch_tempdir(tmpdir_factory): class TestSmoke: - def smoke_test_train( - self, - model_name, - input_yaml, - tutorial_val_src, - ): + def smoke_test_train(self, input_yaml, tutorial_val_src, otf_norms=False): with tempfile.TemporaryDirectory() as tempdirname: # first train a very simple model, checkpoint train_rundir = Path(tempdirname) / "train" @@ -201,6 +226,7 @@ def smoke_test_train( train_src=str(tutorial_val_src), val_src=str(tutorial_val_src), test_src=str(tutorial_val_src), + otf_norms=otf_norms, ), }, save_checkpoint_to=checkpoint_path, @@ -222,6 +248,7 @@ def smoke_test_train( train_src=str(tutorial_val_src), val_src=str(tutorial_val_src), test_src=str(tutorial_val_src), + otf_norms=otf_norms, ), }, update_run_args_with={ @@ -231,42 +258,65 @@ def smoke_test_train( save_predictions_to=predictions_filename, ) + if otf_norms is True: + norm_path = glob.glob( + str(train_rundir / "checkpoints" / "*" / "normalizers.pt") + ) + assert len(norm_path) == 1 + assert os.path.isfile(norm_path[0]) + ref_path = glob.glob( + str(train_rundir / "checkpoints" / "*" / "element_references.pt") + ) + assert len(ref_path) == 1 + assert os.path.isfile(ref_path[0]) + # verify predictions from train and predict are identical energy_from_train = np.load(training_predictions_filename)["energy"] energy_from_checkpoint = np.load(predictions_filename)["energy"] - assert np.isclose(energy_from_train, energy_from_checkpoint).all() + npt.assert_allclose( + energy_from_train, energy_from_checkpoint, rtol=1e-6, atol=1e-6 + ) + # not all models are tested with otf normalization estimation + # only gemnet_oc, escn, equiformer, and their hydra versions @pytest.mark.parametrize( - "model_name", + ("model_name", "otf_norms"), [ - pytest.param("schnet", id="schnet"), - pytest.param("scn", id="scn"), - pytest.param("gemnet_dt", id="gemnet_dt"), - pytest.param("gemnet_dt_hydra", id="gemnet_dt_hydra"), - pytest.param("gemnet_dt_hydra_grad", id="gemnet_dt_hydra_grad"), - pytest.param("gemnet_oc", id="gemnet_oc"), - pytest.param("gemnet_oc_hydra", id="gemnet_oc_hydra"), - pytest.param("gemnet_oc_hydra_grad", id="gemnet_oc_hydra_grad"), - pytest.param("dimenet++", id="dimenet++"), - pytest.param("dimenet++_hydra", id="dimenet++_hydra"), - pytest.param("painn", id="painn"), - pytest.param("painn_hydra", id="painn_hydra"), - pytest.param("escn", id="escn"), - pytest.param("escn_hydra", id="escn_hydra"), - pytest.param("equiformer_v2", id="equiformer_v2"), - pytest.param("equiformer_v2_hydra", id="equiformer_v2_hydra"), + ("schnet", False), + ("scn", False), + ("gemnet_dt", False), + ("gemnet_dt_hydra", False), + ("gemnet_dt_hydra_grad", False), + ("gemnet_oc", False), + ("gemnet_oc", True), + ("gemnet_oc_hydra", False), + ("gemnet_oc_hydra", True), + ("gemnet_oc_hydra_grad", False), + ("dimenet++", False), + ("dimenet++_hydra", False), + ("painn", False), + ("painn_hydra", False), + ("escn", False), + ("escn", True), + ("escn_hydra", False), + ("escn_hydra", True), + ("equiformer_v2", False), + ("equiformer_v2", True), + ("equiformer_v2_hydra", False), + ("equiformer_v2_hydra", True), ], ) def test_train_and_predict( self, model_name, + otf_norms, configs, tutorial_val_src, ): self.smoke_test_train( - model_name=model_name, input_yaml=configs[model_name], tutorial_val_src=tutorial_val_src, + otf_norms=otf_norms, ) @pytest.mark.parametrize( @@ -307,7 +357,6 @@ def test_ddp(self, world_size, ddp, configs, tutorial_val_src, torch_determinist def test_balanced_batch_sampler_ddp( self, world_size, ddp, configs, tutorial_val_src, torch_deterministic ): - # make dataset metadata parser = get_lmdb_sizes_parser() args, override_args = parser.parse_known_args( diff --git a/tests/core/models/test_configs/test_equiformerv2.yml b/tests/core/models/test_configs/test_equiformerv2.yml index 54d5e61c95..8c5c200fdf 100644 --- a/tests/core/models/test_configs/test_equiformerv2.yml +++ b/tests/core/models/test_configs/test_equiformerv2.yml @@ -1,6 +1,53 @@ +trainer: forces + +logger: + name: tensorboard +outputs: + energy: + shape: 1 + level: system + forces: + irrep_dim: 1 + level: atom + train_on_free_atoms: True + eval_on_free_atoms: True + +loss_functions: + - energy: + fn: mae + coefficient: 2 + - forces: + fn: l2mae + coefficient: 100 + +evaluation_metrics: + metrics: + energy: + - mae + forces: + - mae + - cosine_similarity + - magnitude_error + misc: + - energy_forces_within_threshold + primary_metric: forces_mae -trainer: forces +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 20 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae model: name: equiformer_v2 @@ -45,47 +92,3 @@ model: proj_drop: 0.0 weight_init: 'normal' # ['uniform', 'normal'] - -dataset: - train: - src: tutorial_dset/s2ef/train_100/ - normalize_labels: True - target_mean: -0.7554450631141663 - target_std: 2.887317180633545 - grad_target_mean: 0.0 - grad_target_std: 2.887317180633545 - val: - format: lmdb - src: tutorial_dset/s2ef/val_20/ - -logger: - name: tensorboard - -task: - dataset: lmdb - type: regression - metric: mae - primary_metric: forces_mae - labels: - - potential energy - grad_input: atomic forces - train_on_free_atoms: True - eval_on_free_atoms: True - prediction_dtype: float32 - - -optim: - batch_size: 5 - eval_batch_size: 2 - num_workers: 0 - lr_initial: 0.0025 - optimizer: AdamW - optimizer_params: {"amsgrad": True,weight_decay: 0.0} - eval_every: 190 - max_epochs: 50 - force_coefficient: 20 - scheduler: "Null" - energy_coefficient: 1 - clip_grad_norm: 20 - loss_energy: mae - loss_force: l2mae diff --git a/tests/core/models/test_configs/test_escn.yml b/tests/core/models/test_configs/test_escn.yml index 5148e409e5..5848587cdd 100644 --- a/tests/core/models/test_configs/test_escn.yml +++ b/tests/core/models/test_configs/test_escn.yml @@ -1,31 +1,37 @@ trainer: forces -dataset: - train: - src: tutorial_dset/s2ef/train_100/ - normalize_labels: True - target_mean: -0.7554450631141663 - target_std: 2.887317180633545 - grad_target_mean: 0.0 - grad_target_std: 2.887317180633545 - val: - format: lmdb - src: tutorial_dset/s2ef/val_20/ - logger: name: tensorboard -task: - dataset: lmdb - type: regression - metric: mae +outputs: + energy: + shape: 1 + level: system + forces: + irrep_dim: 1 + level: atom + train_on_free_atoms: True + eval_on_free_atoms: True + +loss_functions: + - energy: + fn: mae + coefficient: 2 + - forces: + fn: l2mae + coefficient: 100 + +evaluation_metrics: + metrics: + energy: + - mae + forces: + - mae + - cosine_similarity + - magnitude_error + misc: + - energy_forces_within_threshold primary_metric: forces_mae - labels: - - potential energy - grad_input: atomic forces - train_on_free_atoms: True - eval_on_free_atoms: True - prediction_dtype: float32 model: name: escn diff --git a/tests/core/models/test_configs/test_gemnet_oc.yml b/tests/core/models/test_configs/test_gemnet_oc.yml index a720583608..f1c0d01c3a 100644 --- a/tests/core/models/test_configs/test_gemnet_oc.yml +++ b/tests/core/models/test_configs/test_gemnet_oc.yml @@ -1,34 +1,37 @@ - - - trainer: forces -dataset: - train: - src: tutorial_dset/s2ef/train_100/ - normalize_labels: True - target_mean: -0.7554450631141663 - target_std: 2.887317180633545 - grad_target_mean: 0.0 - grad_target_std: 2.887317180633545 - val: - format: lmdb - src: tutorial_dset/s2ef/val_20/ - logger: name: tensorboard -task: - dataset: lmdb - type: regression - metric: mae +outputs: + energy: + shape: 1 + level: system + forces: + irrep_dim: 1 + level: atom + train_on_free_atoms: True + eval_on_free_atoms: True + +loss_functions: + - energy: + fn: mae + coefficient: 2 + - forces: + fn: l2mae + coefficient: 100 + +evaluation_metrics: + metrics: + energy: + - mae + forces: + - mae + - cosine_similarity + - magnitude_error + misc: + - energy_forces_within_threshold primary_metric: forces_mae - labels: - - potential energy - grad_input: atomic forces - train_on_free_atoms: True - eval_on_free_atoms: True - prediction_dtype: float32 model: name: gemnet_oc diff --git a/tests/core/modules/conftest.py b/tests/core/modules/conftest.py new file mode 100644 index 0000000000..1b1e4ab7e6 --- /dev/null +++ b/tests/core/modules/conftest.py @@ -0,0 +1,48 @@ +from itertools import product +from random import choice +import pytest +import numpy as np +from pymatgen.core.periodic_table import Element +from pymatgen.core import Structure + +from fairchem.core.datasets import LMDBDatabase, AseDBDataset + + +@pytest.fixture(scope="session") +def dummy_element_refs(): + # create some dummy elemental energies from ionic radii (ignore deuterium and tritium included in pmg) + return np.concatenate( + [[0], [e.average_ionic_radius for e in Element if e.name not in ("D", "T")]] + ) + + +@pytest.fixture(scope="session") +def max_num_elements(dummy_element_refs): + return len(dummy_element_refs) - 1 + + +@pytest.fixture(scope="session") +def dummy_binary_dataset(tmpdir_factory, dummy_element_refs): + # a dummy dataset with binaries with energy that depends on composition only plus noise + all_binaries = list(product(list(Element), repeat=2)) + rng = np.random.default_rng(seed=0) + + tmpdir = tmpdir_factory.mktemp("dataset") + with LMDBDatabase(tmpdir / "dummy.aselmdb") as db: + for _ in range(1000): + elements = choice(all_binaries) + structure = Structure.from_prototype("cscl", species=elements, a=2.0) + energy = ( + sum(e.average_ionic_radius for e in elements) + + 0.05 * rng.random() * dummy_element_refs.mean() + ) + atoms = structure.to_ase_atoms() + db.write(atoms, data={"energy": energy, "forces": rng.random((2, 3))}) + + dataset = AseDBDataset( + config={ + "src": str(tmpdir / "dummy.aselmdb"), + "a2g_args": {"r_data_keys": ["energy", "forces"]}, + } + ) + return dataset diff --git a/tests/core/modules/test_element_references.py b/tests/core/modules/test_element_references.py new file mode 100644 index 0000000000..62928b623c --- /dev/null +++ b/tests/core/modules/test_element_references.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import numpy as np +import numpy.testing as npt +import pytest +import torch + +from fairchem.core.datasets import data_list_collater +from fairchem.core.modules.normalization.element_references import ( + LinearReferences, + create_element_references, + fit_linear_references, +) + + +@pytest.fixture(scope="session", params=(True, False)) +def element_refs(dummy_binary_dataset, max_num_elements, request): + return fit_linear_references( + ["energy"], + dataset=dummy_binary_dataset, + batch_size=16, + shuffle=False, + max_num_elements=max_num_elements, + seed=0, + use_numpy=request.param, + ) + + +def test_apply_linear_references( + element_refs, dummy_binary_dataset, dummy_element_refs +): + max_noise = 0.05 * dummy_element_refs.mean() + + # check that removing element refs keeps only values within max noise + batch = data_list_collater(list(dummy_binary_dataset), otf_graph=True) + energy = batch.energy.clone().view(len(batch), -1) + deref_energy = element_refs["energy"].dereference(energy, batch) + assert all(deref_energy <= max_noise) + + # and check that we recover the total energy from applying references + ref_energy = element_refs["energy"](deref_energy, batch) + assert torch.allclose(ref_energy, energy) + + +def test_create_element_references(element_refs, tmp_path): + # test from state dict + sdict = element_refs["energy"].state_dict() + + refs = create_element_references(state_dict=sdict) + assert isinstance(refs, LinearReferences) + assert torch.allclose( + element_refs["energy"].element_references, refs.element_references + ) + + # test from saved stated dict + torch.save(sdict, tmp_path / "linref.pt") + refs = create_element_references(file=tmp_path / "linref.pt") + assert isinstance(refs, LinearReferences) + assert torch.allclose( + element_refs["energy"].element_references, refs.element_references + ) + + # from a legacy numpy npz file + np.savez( + tmp_path / "linref.npz", coeff=element_refs["energy"].element_references.numpy() + ) + refs = create_element_references(file=tmp_path / "linref.npz") + assert isinstance(refs, LinearReferences) + assert torch.allclose( + element_refs["energy"].element_references, refs.element_references + ) + + # from a numpy npz file + np.savez( + tmp_path / "linref.npz", + element_references=element_refs["energy"].element_references.numpy(), + ) + + refs = create_element_references(file=tmp_path / "linref.npz") + assert isinstance(refs, LinearReferences) + assert torch.allclose( + element_refs["energy"].element_references, refs.element_references + ) + + +def test_fit_linear_references( + element_refs, dummy_binary_dataset, max_num_elements, dummy_element_refs +): + # create the composition matrix + energy = np.array([d.energy for d in dummy_binary_dataset]) + cmatrix = np.vstack( + [ + np.bincount(d.atomic_numbers.int().numpy(), minlength=max_num_elements + 1) + for d in dummy_binary_dataset + ] + ) + mask = cmatrix.sum(axis=0) != 0.0 + + # fit using numpy + element_refs_np = np.zeros(max_num_elements + 1) + element_refs_np[mask] = np.linalg.lstsq(cmatrix[:, mask], energy, rcond=None)[0] + + # length is max_num_elements + 1, since H starts at 1 + assert len(element_refs["energy"].element_references) == max_num_elements + 1 + # first element is dummy, should always be zero + assert element_refs["energy"].element_references[0] == 0.0 + # elements not present should be zero + npt.assert_allclose(element_refs["energy"].element_references.numpy()[~mask], 0.0) + # torch fit vs numpy fit + npt.assert_allclose( + element_refs_np, element_refs["energy"].element_references.numpy(), atol=1e-5 + ) + # close enough to ground truth w/out noise + npt.assert_allclose( + dummy_element_refs[mask], + element_refs["energy"].element_references.numpy()[mask], + atol=5e-2, + ) + + +def test_fit_seed_no_seed(dummy_binary_dataset, max_num_elements): + refs_seed = fit_linear_references( + ["energy"], + dataset=dummy_binary_dataset, + batch_size=16, + num_batches=len(dummy_binary_dataset) // 16 - 2, + shuffle=True, + max_num_elements=max_num_elements, + seed=0, + ) + refs_seed1 = fit_linear_references( + ["energy"], + dataset=dummy_binary_dataset, + batch_size=16, + num_batches=len(dummy_binary_dataset) // 16 - 2, + shuffle=True, + max_num_elements=max_num_elements, + seed=0, + ) + refs_noseed = fit_linear_references( + ["energy"], + dataset=dummy_binary_dataset, + batch_size=16, + num_batches=len(dummy_binary_dataset) // 16 - 2, + shuffle=True, + max_num_elements=max_num_elements, + seed=1, + ) + + assert torch.allclose( + refs_seed["energy"].element_references, + refs_seed1["energy"].element_references, + atol=1e-6, + ) + assert not torch.allclose( + refs_seed["energy"].element_references, + refs_noseed["energy"].element_references, + atol=1e-6, + ) diff --git a/tests/core/modules/test_normalizer.py b/tests/core/modules/test_normalizer.py new file mode 100644 index 0000000000..b0d4a44040 --- /dev/null +++ b/tests/core/modules/test_normalizer.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import numpy as np +import pytest +import torch + +from fairchem.core.datasets import data_list_collater +from fairchem.core.modules.normalization.normalizer import ( + Normalizer, + create_normalizer, + fit_normalizers, +) + + +@pytest.fixture(scope="session") +def normalizers(dummy_binary_dataset): + return fit_normalizers( + ["energy", "forces"], + override_values={"forces": {"mean": 0.0}}, + dataset=dummy_binary_dataset, + batch_size=16, + shuffle=False, + ) + + +def test_norm_denorm(normalizers, dummy_binary_dataset, dummy_element_refs): + batch = data_list_collater(list(dummy_binary_dataset), otf_graph=True) + # test norm and denorm + for target, normalizer in normalizers.items(): + normed = normalizer.norm(batch[target]) + assert torch.allclose( + (batch[target] - normalizer.mean) / normalizer.rmsd, normed + ) + assert torch.allclose( + normalizer.rmsd * normed + normalizer.mean, normalizer(normed) + ) + + +def test_create_normalizers(normalizers, dummy_binary_dataset, tmp_path): + # test that forces mean was overriden + assert normalizers["forces"].mean.item() == 0.0 + + # test from state dict + sdict = normalizers["energy"].state_dict() + + norm = create_normalizer(state_dict=sdict) + assert isinstance(norm, Normalizer) + assert norm.state_dict() == sdict + + # test from saved stated dict + torch.save(sdict, tmp_path / "norm.pt") + norm = create_normalizer(file=tmp_path / "norm.pt") + assert isinstance(norm, Normalizer) + assert norm.state_dict() == sdict + + # from a legacy numpy npz file + np.savez( + tmp_path / "norm.npz", + mean=normalizers["energy"].mean.numpy(), + std=normalizers["energy"].rmsd.numpy(), + ) + norm = create_normalizer(file=tmp_path / "norm.npz") + assert isinstance(norm, Normalizer) + assert norm.state_dict() == sdict + + # from a new npz file + np.savez( + tmp_path / "norm.npz", + mean=normalizers["energy"].mean.numpy(), + rmsd=normalizers["energy"].rmsd.numpy(), + ) + norm = create_normalizer(file=tmp_path / "norm.npz") + assert isinstance(norm, Normalizer) + assert norm.state_dict() == sdict + + # from tensor directly + batch = data_list_collater(list(dummy_binary_dataset), otf_graph=True) + norm = create_normalizer(tensor=batch.energy) + assert isinstance(norm, Normalizer) + # assert norm.state_dict() == sdict + # not sure why the above fails + new_sdict = norm.state_dict() + for key in sdict: + assert torch.allclose(new_sdict[key], sdict[key]) + + # passing values directly + norm = create_normalizer( + mean=batch.energy.mean().item(), rmsd=batch.energy.std().item() + ) + assert isinstance(norm, Normalizer) + # assert norm.state_dict() == sdict + new_sdict = norm.state_dict() + for key in sdict: + assert torch.allclose(new_sdict[key], sdict[key]) + + # bad construction + with pytest.raises(ValueError): + create_normalizer(mean=1.0)