Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precompute cdvae #84

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ Here we list, in alphabetical order, people who have contributed (not just code)
- Krzysztof Sadowski (@ksadowski13)
- Jonathan Schmidt (@JonathanSchmidt1)
- Matthew Spellings (@klarh)

- Bin Mu (@bmuaz)

If you have submitted a pull request with code contributions, please add your name above as part of your PR!
316 changes: 316 additions & 0 deletions examples/model_demos/cdvae/cdvae_precompute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
from typing import Dict, Union
from argparse import ArgumentParser, Namespace
from pathlib import Path
import os, sys, shutil, warnings, pickle
import lmdb, dgl, torch, numpy, time
from tqdm import tqdm
from torch_geometric.data import Data, Batch
from enum import Enum

from pymatgen.analysis.graphs import StructureGraph
from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure
from matsciml.datasets.utils import atomic_number_map
from pymatgen.analysis import local_env

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append("{}/../".format(dir_path))

from matsciml.datasets.utils import connect_db_read, write_lmdb_data

MAX_ATOMS = 25

CrystalNN = local_env.CrystalNN(
distance_cutoffs=None, x_diff_weight=-1, porous_adjustment=False
) # , search_cutoff=15.0)

Atomic_num_map_global = atomic_number_map()

def get_atomic_num(elements):
return [Atomic_num_map_global[element] for element in elements]

def get_atoms_from_atomic_numbers(atomic_numbers):
map_reversed = {atomic_num: element for element, atomic_num in Atomic_num_map_global.items()}
return [map_reversed[atomic_num] for atomic_num in atomic_numbers]

def get_distance_matrix(coords, lattice_vectors):
# Create a 3D array containing all pairwise differences
delta = coords[:, numpy.newaxis, :] - coords[numpy.newaxis, :, :]
# Calculate the distances using vectorized operations
distance_matrix = numpy.linalg.norm(numpy.dot(delta, lattice_vectors), axis=2)
return distance_matrix

def get_jimage(structure):
try:
crystal_graph = StructureGraph.with_local_env_strategy(structure, CrystalNN)
except ValueError:
return None
edge_indices, to_jimages = [], []
for i, j, to_jimage in crystal_graph.graph.edges(data="to_jimage"):
edge_indices.append([j, i])
to_jimages.append(to_jimage)
edge_indices.append([i, j])
to_jimages.append(tuple(-tj for tj in to_jimage))
return to_jimages, edge_indices

def get_lattice(lattice):
lattice_params = torch.FloatTensor(
lattice.abc + tuple(lattice.angles)
)
lattice_features = {
"lattice_params": lattice_params,
}
return lattice_features

def processing_data(structure, return_dict, y):
to_jimages, edge_indices = get_jimage(structure)
return_dict["to_jimages"] = torch.LongTensor(to_jimages)
return_dict["edge_index"] = torch.LongTensor(edge_indices).T
edge_index = return_dict["edge_index"] # torch.LongTensor([[0, 1], [1, 0]])
return_dict["lattice_features"] = get_lattice(structure.lattice)
lattice_params = return_dict["lattice_features"]["lattice_params"]
prop = torch.Tensor([y])

# atom_coords are fractional coordinates
# edge_index is incremented during batching
# https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html
data = Data(
frac_coords=torch.Tensor(return_dict["frac_coords"]),
atom_types=torch.LongTensor(return_dict["atomic_numbers"]),
lengths=torch.Tensor(lattice_params[:3]).view(1, -1),
angles=torch.Tensor(lattice_params[3:]).view(1, -1),
edge_index=edge_index, # shape (2, num_edges)
to_jimages=return_dict["to_jimages"],
num_atoms=len(return_dict["atomic_numbers"]),
num_bonds=edge_index.shape[1],
num_nodes=len(return_dict["atomic_numbers"]), # special attribute used for batching in pytorch geometric
y=prop.view(1, -1),
)
return data

def parse_structure_MP(item) -> None:
"""
The same as OG with the addition of jimages field
"""
return_dict = {}
structure = item.get("structure", None)
if structure is None:
raise ValueError(
"Structure not found in data - workflow needs a structure to use!"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many blocks in the parse_structure_XXX functions are repeated like this check and some dictionary value assignments - those can also be simplified

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, removed the redundant sections. Done.

coords = torch.from_numpy(structure.cart_coords).float()
return_dict["pos"] = coords[None, :] - coords[:, None]
return_dict["coords"] = coords
return_dict["frac_coords"] = structure.frac_coords
atom_numbers = torch.LongTensor(structure.atomic_numbers)
# keep atomic numbers for graph featurization
return_dict["atomic_numbers"] = torch.LongTensor(structure.atomic_numbers)
return_dict["num_particles"] = len(atom_numbers)
return_dict["distance_matrix"] = torch.from_numpy(structure.distance_matrix).float()
y = item.get("formation_energy_per_atom") or 1
data = processing_data(structure, return_dict, y)
return data

def parse_structure_NOMAD(item) -> None:
"""
The same as OG with the addition of jimages field
"""
return_dict = {}
structure = item["properties"]["structures"].get("structure_conventional", None)
if structure is None:
raise ValueError(
"Structure not found in data - workflow needs a structure to use!"
)
cartesian_coords = numpy.array(structure["cartesian_site_positions"]) * 1E10
lattice_vectors = Lattice(numpy.array(structure["lattice_vectors"]) * 1E10)
species = structure["species_at_sites"]
frac_coords = lattice_vectors.get_fractional_coords(cart_coords=cartesian_coords)
coords = torch.from_numpy(cartesian_coords)
return_dict["pos"] = coords[None, :] - coords[:, None]
return_dict["coords"] = coords
return_dict["frac_coords"] = frac_coords
num_particles = len(species)
atom_numbers = get_atomic_num(species)
# keep atomic numbers for graph featurization
return_dict["atomic_numbers"] = torch.LongTensor(atom_numbers)
return_dict["num_particles"] = num_particles
distance_matrix = get_distance_matrix(cartesian_coords, numpy.array(structure["lattice_vectors"]) * 1E10)
return_dict["distance_matrix"] = torch.from_numpy(distance_matrix)
y = (item["energies"]["total"]["value"] * 6.241509074461E+18) / num_particles #formation_energy_per_atom, eV
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for multiplying by such large numbers? When training a subsequent energy regression model, large numbers could lead into numerical instability

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@migalkin this is necessary a conversion from Joules to eV and from meters to angstroms.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can @bmuaz print some indicative values stored in the data before and after the conversion?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, as I mentioned yesterday, NOMAD used meter and Joules for their data. For example, you can check their original lattice data and total energy data. For example : "lattice_vectors": [ [2.91807e-10, 0, 0], [0, 2.91807e-10, 0], [0, 0, 2.91807e-10] ], "cartesian_site_positions": [ [0, 0, 0], [1.45904e-10, 1.45904e-10, 1.45904e-10] ], "total": { "value": -9.17469e-15 },

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, do we need to keep the exponent then? Let's check the absolute energy values in four datasets - if we are to train a single model on MP and NOMAD, and MP data has, for example, energy values as floats in range [-10, 10] and NOMAD in range [-1e5, 1e5], then the standard regressor scaler will get confused and treat MP values as almost-zeros

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also important to remember that we have a normalize_kwargs to help with these scaling issues. Here is an example. I often forget about this but it helps tremendously in stabilizing training.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The normalize_kwargs are per-task, so you can re-scale based on each task (and by extension, dataset).

I think the design question is, do we apply the unit conversion in the data_from_keys step, or do we save them (like is done here) converted? Personally I feel like I prefer the former, and just document what is being done as opposed to having no metadata associated with the precomputed sets.

structure = Structure(lattice_vectors, species, frac_coords)
data = processing_data(structure, return_dict, y)
return data

def parse_structure_OQMD(item) -> None:
"""
The same as OG with the addition of jimages field
"""
return_dict = {}
structure = item.get("cart_coords", None)
if structure is None:
raise ValueError(
"Structure not found in data - workflow needs a structure to use!"
)
cartesian_coords = numpy.array(structure)
lattice_vectors = Lattice(numpy.array(item["unit_cell"]))
species = [site.split(" ")[0] for site in item["sites"]]
frac_coords = lattice_vectors.get_fractional_coords(cart_coords=cartesian_coords)
coords = torch.from_numpy(cartesian_coords)
return_dict["pos"] = coords[None, :] - coords[:, None]
return_dict["coords"] = coords
return_dict["frac_coords"] = frac_coords
# keep atomic numbers for graph featurization
return_dict["atomic_numbers"] = torch.LongTensor(item["atomic_numbers"])
return_dict["num_particles"] = item["natoms"]
distance_matrix = get_distance_matrix(cartesian_coords, numpy.array(item["unit_cell"]))
return_dict["distance_matrix"] = torch.from_numpy(distance_matrix).float()
y = item["delta_e"]
structure = Structure(lattice_vectors, species, frac_coords)
data = processing_data(structure, return_dict, y)
return data

def parse_structure_Carolina(item) -> None:
"""
The same as OG with the addition of jimages field
"""
return_dict = {}
structure = item.get("cart_coords", None)
if structure is None:
raise ValueError(
"Structure not found in data - workflow needs a structure to use!"
)
# print(item["_cell_length_a"])
cartesian_coords = structure
a, b, c, alpha, beta, gamma = [
float(item["_cell_length_a"]),
float(item["_cell_length_b"]),
float(item["_cell_length_c"]),
float(item["_cell_angle_alpha"]),
float(item["_cell_angle_beta"]),
float(item["_cell_angle_gamma"])]
lattice_vectors = Lattice.from_parameters(a, b, c, alpha, beta, gamma)
species = get_atoms_from_atomic_numbers(item["atomic_numbers"])
frac_coords = lattice_vectors.get_fractional_coords(cart_coords=cartesian_coords)
coords = torch.from_numpy(cartesian_coords)
return_dict["pos"] = coords[None, :] - coords[:, None]
return_dict["coords"] = coords
return_dict["frac_coords"] = frac_coords
atom_numbers = torch.LongTensor(item["atomic_numbers"])
# keep atomic numbers for graph featurization
return_dict["atomic_numbers"] = atom_numbers
return_dict["num_particles"] = len(item["atomic_numbers"])
distance_matrix = get_distance_matrix(cartesian_coords, lattice_vectors._matrix)
return_dict["distance_matrix"] = torch.from_numpy(distance_matrix)
y = item["energy"]
structure = Structure(lattice_vectors, species, frac_coords)
data = processing_data(structure, return_dict, y)
return data


def check_num_atoms(num_atoms):
if num_atoms is not None:
if num_atoms > MAX_ATOMS:
return None
return True
else:
warnings.warn("One entry is skipped due to missing the number of atoms, which is needed by the workflow!")

def data_to_cdvae_MP(item):
num_atoms = len(item["structure"].atomic_numbers)
check_num_atoms(num_atoms)
if check_num_atoms:
pyg_data = parse_structure_MP(item)
return pyg_data

def data_to_cdvae_NOMAD(item):
num_atoms = len(item["properties"]["structures"]["structure_conventional"]["species_at_sites"])
check_num_atoms(num_atoms)
if check_num_atoms:
pyg_data = parse_structure_NOMAD(item)
return pyg_data

def data_to_cdvae_OQMD(item):
num_atoms = item["natoms"]
check_num_atoms(num_atoms)
if check_num_atoms:
pyg_data = parse_structure_OQMD(item)
return pyg_data

def data_to_cdvae_Carolina(item):
num_atoms = len(item["atomic_numbers"])
check_num_atoms(num_atoms)
if check_num_atoms:
pyg_data = parse_structure_Carolina(item)
return pyg_data
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some significant code repetition in those 4 functions who
(1) accept the same argument (item)
(2) return the same object (pyg_data)
(3) execute the same sequence of actions (check_num_atoms() and the if block)
which makes them perfect candidates for shrinking into 1 function with the main parse_structure_ function obtained from the function argument

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, condensed them to one function data_to_cdvae().


#############################################################################
def main(args: Namespace):
start_time = time.time()
print("Start")
input_path = Path(args.src_lmdb)
output_path = Path(args.output_folder)

if not input_path.exists():
raise FileNotFoundError(f"{input_path} could not be found.")
if args.dataset is None:
raise FileNotFoundError("Dataset name was not provided.")
if output_path.exists():
raise ValueError(
f"{output_path} already exists, please check its contents and remove the folder!"
)

os.makedirs(output_path, exist_ok=True)
db_paths = sorted(input_path.glob("*.lmdb"))
dataset_functions = {
"MP": data_to_cdvae_MP,
"NOMAD": data_to_cdvae_NOMAD,
"OQMD": data_to_cdvae_OQMD,
"Carolina": data_to_cdvae_Carolina,
}

# loop over individual LMDB files within the input set, if there are more than one
for path in db_paths:
target = output_path / path.name
pyg_env = connect_db_read(path)
# open the output file for writing
target_env = lmdb.open(
str(target),
subdir=False,
# The map_size setup is different for Windows vs. Linux:
# https://github.com/tensorpack/tensorpack/issues/1209
# https://stackoverflow.com/questions/33508305/lmdb-maximum-size-of-the-database-for-windows
map_size=1099511627776 * 2, ## 1099511627776 = 1TB, 1073741824 = 1 GB, 104857600 = 100 MB, 1048576 = 1 MB
meminit=False,
map_async=True,
)
with pyg_env.begin() as txn:
print(f"There are {txn.stat()['entries']} entries.")
keys = [key for key in txn.cursor().iternext(values=False)]
for i, key in enumerate(tqdm(keys)):
dataset_name = args.dataset
if key.decode("utf-8").isdigit():
crystal_data = pickle.loads(txn.get(key))
pyg_data = dataset_functions.get(dataset_name)(crystal_data)
if pyg_data is not None:
key = key.decode("utf-8")
write_lmdb_data(key, pyg_data, target_env)
else:
metadata = pickle.loads(txn.get(key))
key = key.decode("utf-8")
write_lmdb_data(key, metadata, target_env)

end_time = time.time()
running_time = end_time - start_time
minutes, seconds = divmod(running_time, 60)
print(f"Done! Program executed in {int(minutes)} minutes and {int(seconds)} seconds")


if __name__ == "__main__":
parser = ArgumentParser()
dataset = ["MP", "Carolina", "OQMD", "NOMAD"]
parser.add_argument("--src_lmdb", "-i", type=str, help="Folder containing the source LMDB files to be converted.")
parser.add_argument("--dataset", "-d", type=str, choices=dataset, help="Select one of the datasets.")
parser.add_argument("--output_folder", "-o", type=str, help="Path to a non-existing folder to save processed data to.")
args = parser.parse_args()
print(args)
main(args)
Binary file modified matsciml/datasets/oqmd/devset/data.lmdb
bmuaz marked this conversation as resolved.
Show resolved Hide resolved
Binary file not shown.
Binary file modified matsciml/datasets/oqmd/devset/data.lmdb-lock
Binary file not shown.
Loading