Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(OTF) Normalization and element references #715

Merged
merged 94 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 92 commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
4f0d91a
denorm targets in _forward only
lbluque May 20, 2024
c5e997b
linear reference class
lbluque May 20, 2024
03a3f66
atomref in normalizer
lbluque May 21, 2024
57174cc
raise input error
lbluque May 21, 2024
80d71c4
clean up normalizer interface
lbluque May 21, 2024
2219d2c
add element refs
lbluque May 22, 2024
390a19e
add element refs correctly
lbluque May 22, 2024
fb99a52
ruff
lbluque May 22, 2024
bc6b864
fix save_checkpoint
lbluque May 22, 2024
2a7804f
reference and dereference
lbluque May 23, 2024
c2914f4
2xnorm linref trainer add
lbluque May 23, 2024
8e4f491
clean-up
lbluque May 23, 2024
578a73f
otf linear reference fit
lbluque May 24, 2024
7607591
fix tensor device
lbluque May 24, 2024
64eb32d
otf element references and normalizers
lbluque May 24, 2024
8de9a30
use only present elements when fitting
lbluque May 24, 2024
caad844
lint
lbluque May 24, 2024
75e72b1
_forward norm and derefd values
lbluque May 24, 2024
e944a06
Merge branch 'main' into norms-and-refs
lbluque Jun 20, 2024
ad36406
fix list of paths in src
lbluque Jun 25, 2024
27b9e7f
total mean and std
lbluque Jun 25, 2024
0ca227a
fitted flag to avoid refitting normalizers/references on rerun
lbluque Jun 25, 2024
d2af7c9
allow passing lstsq driver
lbluque Jun 25, 2024
2913e12
Merge branch 'main' into norms-and-refs
lbluque Jun 25, 2024
4295330
element ref unit tests
lbluque Jun 26, 2024
75f3a51
remove superfluous type
lbluque Jun 26, 2024
d7b4a98
lint fix
lbluque Jun 26, 2024
c26362f
Merge branch 'main' of https://github.com/FAIR-Chem/fairchem into nor…
lbluque Jun 26, 2024
029e1db
allow setting batch_size explicitly
lbluque Jun 26, 2024
143d7a6
test applying element refs
lbluque Jun 26, 2024
d6b5925
normalizer tests
lbluque Jun 26, 2024
c6fbf80
increase distributed timeout
lbluque Jun 27, 2024
6c1d3c3
save normalizers and linear refs in otf_fit
lbluque Jun 27, 2024
8b39de1
remove debug code
lbluque Jun 27, 2024
e82898b
fix removing refs
lbluque Jun 28, 2024
afe944f
swap otf_fit for fit, and save all normalizers in one file
lbluque Jun 28, 2024
8d21eb2
log loading and saving normalizers
lbluque Jun 28, 2024
e5af2de
fit references and normalizer scripts
lbluque Jun 28, 2024
9f4d32d
Merge branch 'main' into norms-and-refs
lbluque Jun 28, 2024
072c8a1
Merge branch 'norms-and-refs' of https://github.com/FAIR-Chem/fairche…
lbluque Jun 28, 2024
8072bc6
lint fixes
lbluque Jun 28, 2024
90c2cc7
allow absent optim key in config
lbluque Jul 8, 2024
c4aaceb
Merge branch 'main' into norms-and-refs
lbluque Jul 8, 2024
1bad0d2
Merge branch 'norms-and-refs' of https://github.com/FAIR-Chem/fairche…
lbluque Jul 8, 2024
410f711
Merge branch 'main' into norms-and-refs
lbluque Jul 10, 2024
d27b555
Merge branch 'norms-and-refs' of https://github.com/FAIR-Chem/fairche…
lbluque Jul 11, 2024
ca4f9ce
lin-ref description
lbluque Jul 11, 2024
62a4ff2
read files based on extension
lbluque Jul 11, 2024
ea17989
pass seed
lbluque Jul 11, 2024
1fcc354
rename dataset fixture
lbluque Jul 11, 2024
ebc5c87
check if file is none
lbluque Jul 11, 2024
1c457f7
pass generator correctly
lbluque Jul 11, 2024
200f62a
separate method for norms and refs
lbluque Jul 11, 2024
0c7b2e6
add normalizer code back
lbluque Jul 11, 2024
ec1b25e
fix Generator construction
lbluque Jul 11, 2024
2e7fa0a
import order
lbluque Jul 11, 2024
2e71534
log warnings if multiple inputs are passed
lbluque Jul 11, 2024
8ec9e54
raise Error if duplicate references or norms are set
lbluque Jul 11, 2024
2b2bd30
use len batch
lbluque Jul 11, 2024
92ac259
assert element reference targets are scalar
lbluque Jul 11, 2024
e0f921c
fix name and rename method
lbluque Jul 11, 2024
4ec71e5
load and save norms and refs using same logic
lbluque Jul 11, 2024
6356d6b
fix creating normalizer
lbluque Jul 12, 2024
efb6c1a
Merge branch 'main' of https://github.com/FAIR-Chem/fairchem into nor…
lbluque Jul 12, 2024
30db452
remove print statements
lbluque Jul 12, 2024
5743a59
adding new notebook for using fairchem models with NEBs without CatTS…
brookwander Jul 16, 2024
18a5fe8
warn instead of error when duplicate norm/ref target names
lbluque Jul 16, 2024
7a2f3c9
allow timeout to be read from config
lbluque Jul 16, 2024
661efea
test seed noseed ref fits
lbluque Jul 16, 2024
079d042
merge upstream
lbluque Jul 16, 2024
462257e
lotsa refactoring
lbluque Jul 16, 2024
8b1288d
lotsa fixing
lbluque Jul 17, 2024
fe555a8
more fixing...
lbluque Jul 17, 2024
074521f
num_workers zero to prevent mp issues
lbluque Jul 17, 2024
17ae426
add otf norms smoke test and fixes
lbluque Jul 18, 2024
6c3b20f
allow overriding normalization fit values
lbluque Jul 19, 2024
8d6e306
update tests
lbluque Jul 19, 2024
0d976d0
fix normalizer loading
lbluque Jul 19, 2024
f27aa3e
Merge branch 'main' into norms-and-refs
mshuaibii Jul 19, 2024
7b8af9e
use rmsd instead of only stdev
lbluque Jul 19, 2024
91364a3
fix tests
lbluque Jul 19, 2024
e6c2252
correct rmsd calc and fix loading
lbluque Jul 19, 2024
02ebac2
clean up norm loading and log values
lbluque Jul 19, 2024
72a5f50
logg linear reference metrics
lbluque Jul 19, 2024
b82b0dc
load element references state dict
lbluque Jul 19, 2024
c2896a5
fix loading and tests
lbluque Jul 20, 2024
f7bae74
fix imports in scripts
lbluque Jul 22, 2024
a8ebdbd
fix test?
lbluque Jul 22, 2024
71d19da
fix test
lbluque Jul 23, 2024
6875b93
use numpy as default to fit references
lbluque Jul 23, 2024
998c401
minor fixes
lbluque Aug 1, 2024
86b65e9
Merge branch 'main' into norms-and-refs
lbluque Aug 2, 2024
b4f3096
merge upstream
lbluque Aug 2, 2024
f35ce82
rm torch_tempdir fixture
lbluque Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/fairchem/core/common/distutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import logging
import os
import subprocess
from typing import TypeVar
from datetime import timedelta
from typing import Any, TypeVar

import torch
import torch.distributed as dist
Expand All @@ -27,6 +28,7 @@ def os_environ_get_or_throw(x: str) -> str:


def setup(config) -> None:
timeout = timedelta(minutes=config.get("timeout", 30))
if config["submit"]:
node_list = os.environ.get("SLURM_STEP_NODELIST")
if node_list is None:
Expand Down Expand Up @@ -72,6 +74,7 @@ def setup(config) -> None:
init_method=config["init_method"],
world_size=config["world_size"],
rank=config["rank"],
timeout=timeout,
)
except subprocess.CalledProcessError as e: # scontrol failed
raise e
Expand All @@ -95,10 +98,11 @@ def setup(config) -> None:
rank=world_rank,
world_size=world_size,
init_method="env://",
timeout=timeout,
)
else:
config["local_rank"] = int(os.environ.get("LOCAL_RANK", config["local_rank"]))
dist.init_process_group(backend="nccl")
dist.init_process_group(backend="nccl", timeout=timeout)


def cleanup() -> None:
Expand Down Expand Up @@ -135,6 +139,14 @@ def broadcast(
dist.broadcast(tensor, src, group, async_op)


def broadcast_object_list(
object_list: list[Any], src: int, group=dist.group.WORLD, device: str | None = None
) -> None:
if get_world_size() == 1:
return
dist.broadcast_object_list(object_list, src, group, device)


def all_reduce(
data, group=dist.group.WORLD, average: bool = False, device=None
) -> torch.Tensor:
Expand Down
17 changes: 9 additions & 8 deletions src/fairchem/core/datasets/ase_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import os
import warnings
from abc import ABC, abstractmethod
from functools import cache, reduce
from functools import cache
from glob import glob
from pathlib import Path
from typing import Any, Callable
Expand Down Expand Up @@ -469,13 +469,14 @@ class AseDBDataset(AseAtomsDataset):

def _load_dataset_get_ids(self, config: dict) -> list[int]:
if isinstance(config["src"], list):
if os.path.isdir(config["src"][0]):
filepaths = reduce(
lambda x, y: x + y,
(glob(f"{path}/*") for path in config["src"]),
)
else:
filepaths = config["src"]
filepaths = []
for path in config["src"]:
if os.path.isdir(path):
filepaths.extend(glob(f"{path}/*"))
elif os.path.isfile(path):
filepaths.append(path)
else:
raise RuntimeError(f"Error reading dataset in {path}!")
elif os.path.isfile(config["src"]):
filepaths = [config["src"]]
elif os.path.isdir(config["src"]):
Expand Down
Empty file.
113 changes: 113 additions & 0 deletions src/fairchem/core/modules/normalization/_load_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
Copyright (c) Meta, Inc. and its affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Callable

import torch

from fairchem.core.common.utils import save_checkpoint

if TYPE_CHECKING:
from pathlib import Path

from torch.nn import Module
from torch.utils.data import Dataset


def _load_check_duplicates(config: dict, name: str) -> dict[str, torch.nn.Module]:
"""Attempt to load a single file with normalizers/element references and check config for duplicate targets.

Args:
config: configuration dictionary
name: Name of module to use for logging

Returns:
dictionary of normalizer or element reference modules
"""
modules = {}
if "file" in config:
modules = torch.load(config["file"])
logging.info(f"Loaded {name} for the following targets: {list(modules.keys())}")
# make sure that element-refs are not specified both as fit and file
fit_targets = config["fit"]["targets"] if "fit" in config else []
duplicates = list(
filter(
lambda x: x in fit_targets,
list(config) + list(modules.keys()),
)
)
if len(duplicates) > 0:
logging.warning(
f"{name} values for the following targets {duplicates} have been specified to be fit and also read"
f" from a file. The files read from file will be used instead of fitting."
)
duplicates = list(filter(lambda x: x in modules, config))
if len(duplicates) > 0:
logging.warning(
f"Duplicate {name} values for the following targets {duplicates} where specified in the file "
f"{config['file']} and an explicitly set file. The normalization values read from "
f"{config['file']} will be used."
)
return modules


def _load_from_config(
config: dict,
name: str,
fit_fun: Callable[[list[str], Dataset, Any, ...], dict[str, Module]],
create_fun: Callable[[str | Path], Module],
dataset: Dataset,
checkpoint_dir: str | Path | None = None,
**fit_kwargs,
) -> dict[str, torch.nn.Module]:
"""Load or fit normalizers or element references from config

If a fit is done, a fitted key with value true is added to the config to avoid re-fitting
once a checkpoint has been saved.

Args:
config: configuration dictionary
name: Name of module to use for logging
fit_fun: Function to fit modules
create_fun: Function to create a module from file
checkpoint_dir: directory to save modules. If not given, modules won't be saved.

Returns:
dictionary of normalizer or element reference modules

"""
modules = _load_check_duplicates(config, name)
for target in config:
if target == "fit" and not config["fit"].get("fitted", False):
# remove values for output targets that have already been read from files
targets = [
target for target in config["fit"]["targets"] if target not in modules
]
fit_kwargs.update(
{k: v for k, v in config["fit"].items() if k != "targets"}
)
modules.update(fit_fun(targets=targets, dataset=dataset, **fit_kwargs))
config["fit"]["fitted"] = True
# if a single file for all outputs is not provided,
# then check if a single file is provided for a specific output
elif target != "file":
modules[target] = create_fun(**config[target])
# save the linear references for possible subsequent use
if checkpoint_dir is not None:
path = save_checkpoint(
modules,
checkpoint_dir,
f"{name}.pt",
)
logging.info(
f"{name} checkpoint for targets {list(modules.keys())} have been saved to: {path}"
)

return modules
Loading