Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: numpy pairtab model #3212

Merged
merged 58 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
47cff4b
feat: add pair table model to pytorch
Jan 28, 2024
04b6f57
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2024
eb59d87
fix: typo
Jan 28, 2024
b7cbbd5
fix: typo
Jan 28, 2024
a1a76bb
Merge branch 'devel' into devel
anyangml Jan 28, 2024
84767f3
fix: update ruct extrapolation
Jan 28, 2024
8fee8fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2024
ff08515
fix: update allclose precision
Jan 28, 2024
f4b3720
Merge branch 'devel' into devel
anyangml Jan 29, 2024
451916e
Merge branch 'devel' into devel
anyangml Jan 29, 2024
0968eaa
Merge branch 'devel' into devel
anyangml Jan 29, 2024
6b0559e
Merge branch 'devel' into devel
anyangml Jan 29, 2024
8cbb98c
chore: refactor common method to PairTab
Jan 29, 2024
a08092c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2024
d3090b9
fix: update unit tests
Jan 29, 2024
daf2fc6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2024
399c278
fix: revert padding zero mask change
Jan 29, 2024
59abe43
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2024
8f1cdc8
Merge branch 'devel' into devel
anyangml Jan 30, 2024
88936cc
Merge branch 'devel' into devel
anyangml Jan 30, 2024
1c4ee0d
feat: redo extrapolation with cubic spline for smoothness
Jan 30, 2024
5793828
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2024
27f3559
Merge branch 'devel' into devel
anyangml Jan 30, 2024
92dec18
chore: refactor _make_data in PairTab
Jan 30, 2024
bc04359
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2024
4433035
chore: move file
Jan 30, 2024
2ba0318
Merge branch 'devel' into devel
anyangml Jan 30, 2024
f2c40e6
Merge branch 'devel' into devel
anyangml Jan 31, 2024
4851a0a
chore: refactor extrapolation code
Jan 31, 2024
ddbe7db
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2024
29d95db
Merge branch 'devel' into devel
anyangml Jan 31, 2024
365c20d
feat: add zbl weighted model
Feb 1, 2024
fb4ae7d
Merge branch 'deepmodeling:devel' into devel
anyangml Feb 1, 2024
e423e68
feat: add serialize and deserialize to pt pairtab
Feb 1, 2024
39da7ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 1, 2024
74002ed
chore: remove irrelevant files
Feb 1, 2024
0ce23f4
feat: add numpy version
Feb 1, 2024
5190903
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 1, 2024
f5541d1
fix: redo pairtab pt serialization
Feb 2, 2024
de3296a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 2, 2024
43fcea7
feat: test pairtabmodel numpy version
Feb 2, 2024
9716144
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 2, 2024
b09f43b
fix: precommit
Feb 2, 2024
05d2750
fix: precommit
Feb 2, 2024
263b6dc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 2, 2024
35232c2
chore: refactor code
Feb 2, 2024
c72bdb9
fix: at @variables key to serialize
Feb 2, 2024
f1e5d72
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 2, 2024
6950e98
fix: import
Feb 2, 2024
5083865
Merge remote-tracking branch 'upstream/devel' into feat/numpy_pairtab…
Feb 2, 2024
97ccfcc
Merge branch 'devel' into feat/numpy_pairtab_model
anyangml Feb 2, 2024
39da832
fix: rename method
Feb 2, 2024
1a59917
fix: import
anyangml Feb 2, 2024
a59a930
Merge branch 'devel' into feat/numpy_pairtab_model
anyangml Feb 3, 2024
97c6a88
Merge branch 'devel' into feat/numpy_pairtab_model
anyangml Feb 4, 2024
a6c637a
fix: change input output shape and move UTs
anyangml Feb 4, 2024
3b0faa4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2024
ec04b5d
fix: array type
anyangml Feb 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 296 additions & 0 deletions deepmd/dpmodel/model/pair_tab_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Optional,
Union,
)

import numpy as np

from deepmd.dpmodel.output_def import (
FittingOutputDef,
OutputVariableDef,
)
from deepmd.utils.pair_tab import (
PairTab,
)

from .base_atomic_model import (
BaseAtomicModel,
)


class PairTabModel(BaseAtomicModel):
"""Pairwise tabulation energy model.

This model can be used to tabulate the pairwise energy between atoms for either
short-range or long-range interactions, such as D3, LJ, ZBL, etc. It should not
be used alone, but rather as one submodel of a linear (sum) model, such as
DP+D3.

Do not put the model on the first model of a linear model, since the linear
model fetches the type map from the first model.

At this moment, the model does not smooth the energy at the cutoff radius, so
one needs to make sure the energy has been smoothed to zero.

Parameters
----------
tab_file : str
The path to the tabulation file.
rcut : float
The cutoff radius.
sel : int or list[int]
The maxmum number of atoms in the cut-off radius.
"""

def __init__(
self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs
):
super().__init__()
self.tab_file = tab_file
self.rcut = rcut

self.tab = PairTab(self.tab_file, rcut=rcut)

if self.tab_file is not None:
self.tab_info, self.tab_data = self.tab.get()
else:
self.tab_info, self.tab_data = None, None

if isinstance(sel, int):
self.sel = sel
elif isinstance(sel, list):
self.sel = sum(sel)

Check warning on line 65 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L64-L65

Added lines #L64 - L65 were not covered by tests
else:
raise TypeError("sel must be int or list[int]")

Check warning on line 67 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L67

Added line #L67 was not covered by tests

def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(

Check warning on line 70 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L70

Added line #L70 was not covered by tests
[
OutputVariableDef(
name="energy", shape=[1], reduciable=True, differentiable=True
)
]
)

def get_rcut(self) -> float:
return self.rcut

Check warning on line 79 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L79

Added line #L79 was not covered by tests

def get_sel(self) -> int:
return self.sel

Check warning on line 82 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L82

Added line #L82 was not covered by tests

def distinguish_types(self) -> bool:
# to match DPA1 and DPA2.
return False

Check warning on line 86 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L86

Added line #L86 was not covered by tests

def serialize(self) -> dict:
return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel}

@classmethod
def deserialize(cls, data) -> "PairTabModel":
rcut = data["rcut"]
sel = data["sel"]
tab = PairTab.deserialize(data["tab"])
tab_model = cls(None, rcut, sel)
tab_model.tab = tab
tab_model.tab_info = tab_model.tab.tab_info
tab_model.tab_data = tab_model.tab.tab_data
return tab_model

def forward_atomic(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[np.ndarray] = None,
do_atomic_virial: bool = False,
) -> Dict[str, np.ndarray]:
self.nframes, self.nloc, self.nnei = nlist.shape
extended_coord = extended_coord.reshape(self.nframes, -1, 3)

# this will mask all -1 in the nlist
masked_nlist = np.clip(nlist, 0, None)

atype = extended_atype[:, : self.nloc] # (nframes, nloc)
pairwise_dr = self._get_pairwise_dist(
extended_coord
) # (nframes, nall, nall, 3)
pairwise_rr = np.sqrt(
np.sum(np.power(pairwise_dr, 2), axis=-1)
) # (nframes, nall, nall)
self.tab_data = self.tab_data.reshape(
self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4
)

# (nframes, nloc, nnei)
j_type = extended_atype[
np.arange(extended_atype.shape[0])[:, None, None], masked_nlist
]

# slice rr to get (nframes, nloc, nnei)
rr = np.take_along_axis(pairwise_rr[:, : self.nloc, :], masked_nlist, 2)
raw_atomic_energy = self._pair_tabulated_inter(nlist, atype, j_type, rr)
atomic_energy = 0.5 * np.sum(
np.where(nlist != -1, raw_atomic_energy, np.zeros_like(raw_atomic_energy)),
axis=-1,
).reshape(self.nframes, self.nloc, 1)

return {"energy": atomic_energy}

def _pair_tabulated_inter(
self,
nlist: np.ndarray,
i_type: np.ndarray,
j_type: np.ndarray,
rr: np.ndarray,
) -> np.ndarray:
"""Pairwise tabulated energy.

Parameters
----------
nlist : np.ndarray
The unmasked neighbour list. (nframes, nloc)
i_type : np.ndarray
The integer representation of atom type for all local atoms for all frames. (nframes, nloc)
j_type : np.ndarray
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei)
rr : np.ndarray
The salar distance vector between two atoms. (nframes, nloc, nnei)

Returns
-------
np.ndarray
The masked atomic energy for all local atoms for all frames. (nframes, nloc, nnei)

Raises
------
Exception
If the distance is beyond the table.

Notes
-----
This function is used to calculate the pairwise energy between two atoms.
It uses a table containing cubic spline coefficients calculated in PairTab.
"""
rmin = self.tab_info[0]
hh = self.tab_info[1]
hi = 1.0 / hh

self.nspline = int(self.tab_info[2] + 0.1)

uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei)

# if nnei of atom 0 has -1 in the nlist, uu would be 0.
# this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms.
uu = np.where(nlist != -1, uu, self.nspline + 1)

if np.any(uu < 0):
raise Exception("coord go beyond table lower boundary")

Check warning on line 190 in deepmd/dpmodel/model/pair_tab_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/pair_tab_model.py#L190

Added line #L190 was not covered by tests

idx = uu.astype(int)

uu -= idx
table_coef = self._extract_spline_coefficient(
i_type, j_type, idx, self.tab_data, self.nspline
)
table_coef = table_coef.reshape(self.nframes, self.nloc, self.nnei, 4)
ener = self._calcualte_ener(table_coef, uu)
# here we need to overwrite energy to zero at rcut and beyond.
mask_beyond_rcut = rr >= self.rcut
# also overwrite values beyond extrapolation to zero
extrapolation_mask = rr >= self.tab.rmin + self.nspline * self.tab.hh
ener[mask_beyond_rcut] = 0
ener[extrapolation_mask] = 0

return ener

@staticmethod
def _get_pairwise_dist(coords: np.ndarray) -> np.ndarray:
"""Get pairwise distance `dr`.

Parameters
----------
coords : np.ndarray
The coordinate of the atoms shape of (nframes, nall, 3).

Returns
-------
np.ndarray
The pairwise distance between the atoms (nframes, nall, nall, 3).
"""
return np.expand_dims(coords, 2) - np.expand_dims(coords, 1)

@staticmethod
def _extract_spline_coefficient(
i_type: np.ndarray,
j_type: np.ndarray,
idx: np.ndarray,
tab_data: np.ndarray,
nspline: int,
) -> np.ndarray:
"""Extract the spline coefficient from the table.

Parameters
----------
i_type : np.ndarray
The integer representation of atom type for all local atoms for all frames. (nframes, nloc)
j_type : np.ndarray
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei)
idx : np.ndarray
The index of the spline coefficient. (nframes, nloc, nnei)
tab_data : np.ndarray
The table storing all the spline coefficient. (ntype, ntype, nspline, 4)
nspline : int
The number of splines in the table.

Returns
-------
np.ndarray
The spline coefficient. (nframes, nloc, nnei, 4), shape may be squeezed.
"""
# (nframes, nloc, nnei)
expanded_i_type = np.broadcast_to(
i_type[:, :, np.newaxis],
(i_type.shape[0], i_type.shape[1], j_type.shape[-1]),
)

# (nframes, nloc, nnei, nspline, 4)
expanded_tab_data = tab_data[expanded_i_type, j_type]

# (nframes, nloc, nnei, 1, 4)
expanded_idx = np.broadcast_to(
idx[..., np.newaxis, np.newaxis], (*idx.shape, 1, 4)
)
clipped_indices = np.clip(expanded_idx, 0, nspline - 1).astype(int)

# (nframes, nloc, nnei, 4)
final_coef = np.squeeze(
np.take_along_axis(expanded_tab_data, clipped_indices, 3)
)

# when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`.
final_coef[expanded_idx.squeeze() > nspline] = 0
return final_coef

@staticmethod
def _calcualte_ener(coef: np.ndarray, uu: np.ndarray) -> np.ndarray:
"""Calculate energy using spline coeeficients.

Parameters
----------
coef : np.ndarray
The spline coefficients. (nframes, nloc, nnei, 4)
uu : np.ndarray
The atom displancemnt used in interpolation and extrapolation (nframes, nloc, nnei)

Returns
-------
np.ndarray
The atomic energy for all local atoms for all frames. (nframes, nloc, nnei)
"""
a3, a2, a1, a0 = coef[..., 0], coef[..., 1], coef[..., 2], coef[..., 3]
etmp = (a3 * uu + a2) * uu + a1 # this should be elementwise operations.
ener = etmp * uu + a0 # this energy has the extrapolated value when rcut > rmax
return ener
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,19 @@ def __init__(
super().__init__()
self.tab_file = tab_file
self.rcut = rcut

self.tab = PairTab(self.tab_file, rcut=rcut)
self.ntypes = self.tab.ntypes

tab_info, tab_data = self.tab.get() # this returns -> Tuple[np.array, np.array]
self.tab_info = torch.from_numpy(tab_info)
self.tab_data = torch.from_numpy(tab_data)
# handle deserialization with no input file
if self.tab_file is not None:
(
tab_info,
tab_data,
) = self.tab.get() # this returns -> Tuple[np.array, np.array]
self.tab_info = torch.from_numpy(tab_info)
self.tab_data = torch.from_numpy(tab_data)
else:
self.tab_info = None
self.tab_data = None

# self.model_type = "ener"
# self.model_version = MODEL_VERSION ## this shoud be in the parent class
Expand Down Expand Up @@ -92,12 +98,18 @@ def distinguish_types(self) -> bool:
return False

def serialize(self) -> dict:
# place holder, implemantated in future PR
raise NotImplementedError

def deserialize(cls):
# place holder, implemantated in future PR
raise NotImplementedError
return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel}

@classmethod
def deserialize(cls, data) -> "PairTabModel":
rcut = data["rcut"]
sel = data["sel"]
tab = PairTab.deserialize(data["tab"])
tab_model = cls(None, rcut, sel)
tab_model.tab = tab
tab_model.tab_info = torch.from_numpy(tab_model.tab.tab_info)
tab_model.tab_data = torch.from_numpy(tab_model.tab.tab_data)
return tab_model

def forward_atomic(
self,
Expand All @@ -108,6 +120,7 @@ def forward_atomic(
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
self.nframes, self.nloc, self.nnei = nlist.shape
extended_coord = extended_coord.view(self.nframes, -1, 3)

# this will mask all -1 in the nlist
masked_nlist = torch.clamp(nlist, 0)
Expand All @@ -118,7 +131,7 @@ def forward_atomic(
) # (nframes, nall, nall, 3)
pairwise_rr = pairwise_dr.pow(2).sum(-1).sqrt() # (nframes, nall, nall)

self.tab_data = self.tab_data.reshape(
self.tab_data = self.tab_data.view(
self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4
)

Expand All @@ -139,7 +152,7 @@ def forward_atomic(
nlist != -1, raw_atomic_energy, torch.zeros_like(raw_atomic_energy)
),
dim=-1,
)
).unsqueeze(-1)

return {"energy": atomic_energy}

Expand Down Expand Up @@ -200,7 +213,7 @@ def _pair_tabulated_inter(
table_coef = self._extract_spline_coefficient(
i_type, j_type, idx, self.tab_data, self.nspline
)
table_coef = table_coef.reshape(self.nframes, self.nloc, self.nnei, 4)
table_coef = table_coef.view(self.nframes, self.nloc, self.nnei, 4)
ener = self._calcualte_ener(table_coef, uu)

# here we need to overwrite energy to zero at rcut and beyond.
Expand All @@ -219,12 +232,12 @@ def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor:
Parameters
----------
coords : torch.Tensor
The coordinate of the atoms shape of (nframes * nall * 3).
The coordinate of the atoms shape of (nframes, nall, 3).

Returns
-------
torch.Tensor
The pairwise distance between the atoms (nframes * nall * nall * 3).
The pairwise distance between the atoms (nframes, nall, nall, 3).

Examples
--------
Expand Down
Loading
Loading