Skip to content

Commit

Permalink
Feat: numpy pairtab model (#3212)
Browse files Browse the repository at this point in the history
This PR is to provide backend independent implementation of PairTabModel
in `numpy`. Also the cross framework `serialization` and
`deserialization` are added.

---------

Co-authored-by: Anyang Peng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 4, 2024
1 parent ab2c551 commit 7db1fde
Show file tree
Hide file tree
Showing 7 changed files with 630 additions and 23 deletions.
296 changes: 296 additions & 0 deletions deepmd/dpmodel/model/pair_tab_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Optional,
Union,
)

import numpy as np

from deepmd.dpmodel.output_def import (
FittingOutputDef,
OutputVariableDef,
)
from deepmd.utils.pair_tab import (
PairTab,
)

from .base_atomic_model import (
BaseAtomicModel,
)


class PairTabModel(BaseAtomicModel):
"""Pairwise tabulation energy model.
This model can be used to tabulate the pairwise energy between atoms for either
short-range or long-range interactions, such as D3, LJ, ZBL, etc. It should not
be used alone, but rather as one submodel of a linear (sum) model, such as
DP+D3.
Do not put the model on the first model of a linear model, since the linear
model fetches the type map from the first model.
At this moment, the model does not smooth the energy at the cutoff radius, so
one needs to make sure the energy has been smoothed to zero.
Parameters
----------
tab_file : str
The path to the tabulation file.
rcut : float
The cutoff radius.
sel : int or list[int]
The maxmum number of atoms in the cut-off radius.
"""

def __init__(
self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs
):
super().__init__()
self.tab_file = tab_file
self.rcut = rcut

self.tab = PairTab(self.tab_file, rcut=rcut)

if self.tab_file is not None:
self.tab_info, self.tab_data = self.tab.get()
else:
self.tab_info, self.tab_data = None, None

if isinstance(sel, int):
self.sel = sel
elif isinstance(sel, list):
self.sel = sum(sel)
else:
raise TypeError("sel must be int or list[int]")

def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
OutputVariableDef(
name="energy", shape=[1], reduciable=True, differentiable=True
)
]
)

def get_rcut(self) -> float:
return self.rcut

def get_sel(self) -> int:
return self.sel

def distinguish_types(self) -> bool:
# to match DPA1 and DPA2.
return False

def serialize(self) -> dict:
return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel}

@classmethod
def deserialize(cls, data) -> "PairTabModel":
rcut = data["rcut"]
sel = data["sel"]
tab = PairTab.deserialize(data["tab"])
tab_model = cls(None, rcut, sel)
tab_model.tab = tab
tab_model.tab_info = tab_model.tab.tab_info
tab_model.tab_data = tab_model.tab.tab_data
return tab_model

def forward_atomic(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[np.ndarray] = None,
do_atomic_virial: bool = False,
) -> Dict[str, np.ndarray]:
self.nframes, self.nloc, self.nnei = nlist.shape
extended_coord = extended_coord.reshape(self.nframes, -1, 3)

# this will mask all -1 in the nlist
masked_nlist = np.clip(nlist, 0, None)

atype = extended_atype[:, : self.nloc] # (nframes, nloc)
pairwise_dr = self._get_pairwise_dist(
extended_coord
) # (nframes, nall, nall, 3)
pairwise_rr = np.sqrt(
np.sum(np.power(pairwise_dr, 2), axis=-1)
) # (nframes, nall, nall)
self.tab_data = self.tab_data.reshape(
self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4
)

# (nframes, nloc, nnei)
j_type = extended_atype[
np.arange(extended_atype.shape[0])[:, None, None], masked_nlist
]

# slice rr to get (nframes, nloc, nnei)
rr = np.take_along_axis(pairwise_rr[:, : self.nloc, :], masked_nlist, 2)
raw_atomic_energy = self._pair_tabulated_inter(nlist, atype, j_type, rr)
atomic_energy = 0.5 * np.sum(
np.where(nlist != -1, raw_atomic_energy, np.zeros_like(raw_atomic_energy)),
axis=-1,
).reshape(self.nframes, self.nloc, 1)

return {"energy": atomic_energy}

def _pair_tabulated_inter(
self,
nlist: np.ndarray,
i_type: np.ndarray,
j_type: np.ndarray,
rr: np.ndarray,
) -> np.ndarray:
"""Pairwise tabulated energy.
Parameters
----------
nlist : np.ndarray
The unmasked neighbour list. (nframes, nloc)
i_type : np.ndarray
The integer representation of atom type for all local atoms for all frames. (nframes, nloc)
j_type : np.ndarray
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei)
rr : np.ndarray
The salar distance vector between two atoms. (nframes, nloc, nnei)
Returns
-------
np.ndarray
The masked atomic energy for all local atoms for all frames. (nframes, nloc, nnei)
Raises
------
Exception
If the distance is beyond the table.
Notes
-----
This function is used to calculate the pairwise energy between two atoms.
It uses a table containing cubic spline coefficients calculated in PairTab.
"""
rmin = self.tab_info[0]
hh = self.tab_info[1]
hi = 1.0 / hh

self.nspline = int(self.tab_info[2] + 0.1)

uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei)

# if nnei of atom 0 has -1 in the nlist, uu would be 0.
# this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms.
uu = np.where(nlist != -1, uu, self.nspline + 1)

if np.any(uu < 0):
raise Exception("coord go beyond table lower boundary")

idx = uu.astype(int)

uu -= idx
table_coef = self._extract_spline_coefficient(
i_type, j_type, idx, self.tab_data, self.nspline
)
table_coef = table_coef.reshape(self.nframes, self.nloc, self.nnei, 4)
ener = self._calcualte_ener(table_coef, uu)
# here we need to overwrite energy to zero at rcut and beyond.
mask_beyond_rcut = rr >= self.rcut
# also overwrite values beyond extrapolation to zero
extrapolation_mask = rr >= self.tab.rmin + self.nspline * self.tab.hh
ener[mask_beyond_rcut] = 0
ener[extrapolation_mask] = 0

return ener

@staticmethod
def _get_pairwise_dist(coords: np.ndarray) -> np.ndarray:
"""Get pairwise distance `dr`.
Parameters
----------
coords : np.ndarray
The coordinate of the atoms shape of (nframes, nall, 3).
Returns
-------
np.ndarray
The pairwise distance between the atoms (nframes, nall, nall, 3).
"""
return np.expand_dims(coords, 2) - np.expand_dims(coords, 1)

@staticmethod
def _extract_spline_coefficient(
i_type: np.ndarray,
j_type: np.ndarray,
idx: np.ndarray,
tab_data: np.ndarray,
nspline: int,
) -> np.ndarray:
"""Extract the spline coefficient from the table.
Parameters
----------
i_type : np.ndarray
The integer representation of atom type for all local atoms for all frames. (nframes, nloc)
j_type : np.ndarray
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei)
idx : np.ndarray
The index of the spline coefficient. (nframes, nloc, nnei)
tab_data : np.ndarray
The table storing all the spline coefficient. (ntype, ntype, nspline, 4)
nspline : int
The number of splines in the table.
Returns
-------
np.ndarray
The spline coefficient. (nframes, nloc, nnei, 4), shape may be squeezed.
"""
# (nframes, nloc, nnei)
expanded_i_type = np.broadcast_to(
i_type[:, :, np.newaxis],
(i_type.shape[0], i_type.shape[1], j_type.shape[-1]),
)

# (nframes, nloc, nnei, nspline, 4)
expanded_tab_data = tab_data[expanded_i_type, j_type]

# (nframes, nloc, nnei, 1, 4)
expanded_idx = np.broadcast_to(
idx[..., np.newaxis, np.newaxis], (*idx.shape, 1, 4)
)
clipped_indices = np.clip(expanded_idx, 0, nspline - 1).astype(int)

# (nframes, nloc, nnei, 4)
final_coef = np.squeeze(
np.take_along_axis(expanded_tab_data, clipped_indices, 3)
)

# when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`.
final_coef[expanded_idx.squeeze() > nspline] = 0
return final_coef

@staticmethod
def _calcualte_ener(coef: np.ndarray, uu: np.ndarray) -> np.ndarray:
"""Calculate energy using spline coeeficients.
Parameters
----------
coef : np.ndarray
The spline coefficients. (nframes, nloc, nnei, 4)
uu : np.ndarray
The atom displancemnt used in interpolation and extrapolation (nframes, nloc, nnei)
Returns
-------
np.ndarray
The atomic energy for all local atoms for all frames. (nframes, nloc, nnei)
"""
a3, a2, a1, a0 = coef[..., 0], coef[..., 1], coef[..., 2], coef[..., 3]
etmp = (a3 * uu + a2) * uu + a1 # this should be elementwise operations.
ener = etmp * uu + a0 # this energy has the extrapolated value when rcut > rmax
return ener
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,19 @@ def __init__(
super().__init__()
self.tab_file = tab_file
self.rcut = rcut

self.tab = PairTab(self.tab_file, rcut=rcut)
self.ntypes = self.tab.ntypes

tab_info, tab_data = self.tab.get() # this returns -> Tuple[np.array, np.array]
self.tab_info = torch.from_numpy(tab_info)
self.tab_data = torch.from_numpy(tab_data)
# handle deserialization with no input file
if self.tab_file is not None:
(
tab_info,
tab_data,
) = self.tab.get() # this returns -> Tuple[np.array, np.array]
self.tab_info = torch.from_numpy(tab_info)
self.tab_data = torch.from_numpy(tab_data)
else:
self.tab_info = None
self.tab_data = None

# self.model_type = "ener"
# self.model_version = MODEL_VERSION ## this shoud be in the parent class
Expand Down Expand Up @@ -92,12 +98,18 @@ def distinguish_types(self) -> bool:
return False

def serialize(self) -> dict:
# place holder, implemantated in future PR
raise NotImplementedError

def deserialize(cls):
# place holder, implemantated in future PR
raise NotImplementedError
return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel}

@classmethod
def deserialize(cls, data) -> "PairTabModel":
rcut = data["rcut"]
sel = data["sel"]
tab = PairTab.deserialize(data["tab"])
tab_model = cls(None, rcut, sel)
tab_model.tab = tab
tab_model.tab_info = torch.from_numpy(tab_model.tab.tab_info)
tab_model.tab_data = torch.from_numpy(tab_model.tab.tab_data)
return tab_model

def forward_atomic(
self,
Expand All @@ -108,6 +120,7 @@ def forward_atomic(
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
self.nframes, self.nloc, self.nnei = nlist.shape
extended_coord = extended_coord.view(self.nframes, -1, 3)

# this will mask all -1 in the nlist
masked_nlist = torch.clamp(nlist, 0)
Expand All @@ -118,7 +131,7 @@ def forward_atomic(
) # (nframes, nall, nall, 3)
pairwise_rr = pairwise_dr.pow(2).sum(-1).sqrt() # (nframes, nall, nall)

self.tab_data = self.tab_data.reshape(
self.tab_data = self.tab_data.view(
self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4
)

Expand All @@ -139,7 +152,7 @@ def forward_atomic(
nlist != -1, raw_atomic_energy, torch.zeros_like(raw_atomic_energy)
),
dim=-1,
)
).unsqueeze(-1)

return {"energy": atomic_energy}

Expand Down Expand Up @@ -200,7 +213,7 @@ def _pair_tabulated_inter(
table_coef = self._extract_spline_coefficient(
i_type, j_type, idx, self.tab_data, self.nspline
)
table_coef = table_coef.reshape(self.nframes, self.nloc, self.nnei, 4)
table_coef = table_coef.view(self.nframes, self.nloc, self.nnei, 4)
ener = self._calcualte_ener(table_coef, uu)

# here we need to overwrite energy to zero at rcut and beyond.
Expand All @@ -219,12 +232,12 @@ def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor:
Parameters
----------
coords : torch.Tensor
The coordinate of the atoms shape of (nframes * nall * 3).
The coordinate of the atoms shape of (nframes, nall, 3).
Returns
-------
torch.Tensor
The pairwise distance between the atoms (nframes * nall * nall * 3).
The pairwise distance between the atoms (nframes, nall, nall, 3).
Examples
--------
Expand Down
Loading

0 comments on commit 7db1fde

Please sign in to comment.