Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: pairtab model pytorch #174

Closed
wants to merge 18 commits into from
262 changes: 262 additions & 0 deletions deepmd_pt/model/model/pair_tab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
from atomic_model import AtomicModel
from deepmd_utils.pair_tab import (
PairTab,
)
import torch
from torch import nn
from typing import Dict, List, Optional, Union

from deepmd_utils.model_format import FittingOutputDef, OutputVariableDef
from deepmd_pt.model.task import Fitting

class PairTabModel(nn.Module, AtomicModel):
"""Pairwise tabulation energy model.

This model can be used to tabulate the pairwise energy between atoms for either
short-range or long-range interactions, such as D3, LJ, ZBL, etc. It should not
be used alone, but rather as one submodel of a linear (sum) model, such as
DP+D3.

Do not put the model on the first model of a linear model, since the linear
model fetches the type map from the first model.

At this moment, the model does not smooth the energy at the cutoff radius, so
one needs to make sure the energy has been smoothed to zero.

Parameters
----------
tab_file : str
The path to the tabulation file.
rcut : float
The cutoff radius.
sel : int or list[int]
The maxmum number of atoms in the cut-off radius.
"""

def __init__(
self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs
):
super().__init__()
self.tab_file = tab_file
self.tab = PairTab(self.tab_file)
self.ntypes = self.tab.ntypes
self.rcut = rcut

tab_info, tab_data = self.tab.get() # this returns -> Tuple[np.array, np.array]
self.tab_info = torch.from_numpy(tab_info)
self.tab_data = torch.from_numpy(tab_data)

# self.model_type = "ener"
# self.model_version = MODEL_VERSION ## this shoud be in the parent class


if isinstance(sel, int):
self.sel = sel
elif isinstance(sel, list):
self.sel = sum(sel)
else:
raise TypeError("sel must be int or list[int]")

def get_fitting_net(self)->Fitting:
# this model has no fitting_net.
return

def get_fitting_output_def(self)->FittingOutputDef:
return FittingOutputDef(
[OutputVariableDef(
name="energy",
shape=[1],
reduciable=True,
differentiable=True
)]
)

def get_rcut(self)->float:
return self.rcut

def get_sel(self)->int:
return self.sel

def distinguish_types(self)->bool:
# to match DPA1 and DPA2.
return False

def forward_atomic(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:


nframes, nloc, nnei = nlist.shape

#this will mask all -1 in the nlist
masked_nlist = torch.clamp(nlist,0)

atype = extended_atype[:, :nloc] #(nframes, nloc)
pairwise_dr = self._get_pairwise_dist(extended_coord) # (nframes, nall, nall, 3)
pairwise_rr = pairwise_dr.pow(2).sum(-1).sqrt() # (nframes, nall, nall)

self.tab_data = self.tab_data.reshape(self.tab.ntypes,self.tab.ntypes,self.tab.nspline,4)

#to calculate the atomic_energy, we need 3 tensors, i_type, j_type, rr
#i_type : (nframes, nloc), this is atype.
#j_type : (nframes, nloc, nnei)
j_type = extended_atype[torch.arange(extended_atype.size(0))[:, None, None], masked_nlist]

#slice rr to get (nframes, nloc, nnei)
rr = torch.gather(pairwise_rr[:, :nloc, :],2, masked_nlist)

raw_atomic_energy = self._pair_tabulated_inter(atype, j_type, rr)

atomic_energy = 0.5 * torch.sum(torch.where(nlist != -1, raw_atomic_energy, torch.zeros_like(raw_atomic_energy)) ,dim=-1)

return {"atomic_energy": atomic_energy}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the key should be "energy", please check your output def.
you may want to use this decorator to ensure the correctness of your atomic model output

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the key. Not sure about the decorator, this atomic model has no Fitting


def _pair_tabulated_inter(self, nlist: torch.Tensor,i_type: torch.Tensor, j_type: torch.Tensor, rr: torch.Tensor) -> torch.Tensor:
"""Pairwise tabulated energy.

Parameters
----------
nlist : torch.Tensor
The unmasked neighbour list. (nframes, nloc)

i_type : torch.Tensor
The integer representation of atom type for all local atoms for all frames. (nframes, nloc)

j_type : torch.Tensor
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei)

rr : torch.Tensor
The salar distance vector between two atoms. (nframes, nloc, nnei)

Returns
-------
torch.Tensor
The masked atomic energy for all local atoms for all frames. (nframes, nloc, nnei)

Raises
------
Exception
If the distance is beyond the table.

Notes
-----
This function is used to calculate the pairwise energy between two atoms.
It uses a table containing cubic spline coefficients calculated in PairTab.
"""

rmin = self.tab_info[0]
hh = self.tab_info[1]
hi = 1. / hh

self.nspline = int(self.tab_info[2] + 0.1)

uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei)


# if nnei of atom 0 has -1 in the nlist, uu would be 0.
# this is to handel the nlist where the mask is set to 0.
# by replacing the values wiht nspline + 1, the energy contribution will be 0
uu = torch.where(nlist != -1, uu, self.nspline+1)

if torch.any(uu < 0):
raise Exception("coord go beyond table lower boundary")

idx = uu.to(torch.int)

uu -= idx


final_coef = self._extract_spline_coefficient(i_type, j_type, idx)

a3, a2, a1, a0 = torch.unbind(final_coef, dim=-1) # 4 * (nframes, nloc, nnei)

etmp = (a3 * uu + a2) * uu + a1 # this should be elementwise operations.
ener = etmp * uu + a0
return ener

@staticmethod
def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor:
"""Get pairwise distance `dr`.

Parameters
----------
coords : torch.Tensor
The coordinate of the atoms shape of (nframes * nall * 3).

Returns
-------
torch.Tensor
The pairwise distance between the atoms (nframes * nall * nall * 3).

Examples
--------
coords = torch.tensor([[
[0,0,0],
[1,3,5],
[2,4,6]
]])

dist = tensor([[
[[ 0, 0, 0],
[-1, -3, -5],
[-2, -4, -6]],

[[ 1, 3, 5],
[ 0, 0, 0],
[-1, -1, -1]],

[[ 2, 4, 6],
[ 1, 1, 1],
[ 0, 0, 0]]
]])
"""
return coords.unsqueeze(2) - coords.unsqueeze(1)

def _extract_spline_coefficient(self, i_type: torch.Tensor, j_type: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""Extract the spline coefficient from the table.

Parameters
----------
i_type : torch.Tensor
The integer representation of atom type for all local atoms for all frames. (nframes, nloc)

j_type : torch.Tensor
The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei)

idx : torch.Tensor
The index of the spline coefficient. (nframes, nloc, nnei)

Returns
-------
torch.Tensor
The spline coefficient. (nframes, nloc, nnei, 4)

Example
-------

"""

# (nframes, nloc, nnei)
expanded_i_type = i_type.unsqueeze(-1).expand(-1, -1, j_type.shape[-1])

# (nframes, nloc, nnei, nspline, 4)
expanded_tab_data = self.tab_data[expanded_i_type, j_type]

# (nframes, nloc, nnei, 1, 4)
expanded_idx = idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1,-1, -1, 4)

#handle the case where idx is beyond the number of splines
clipped_indices = torch.clamp(expanded_idx, 0, self.nspline - 1).to(torch.int64)

# (nframes, nloc, nnei, 4)
final_coef = torch.gather(expanded_tab_data, 3, clipped_indices).squeeze()

# when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`.
final_coef[expanded_idx.squeeze() >= self.nspline] = 0

return final_coef
74 changes: 74 additions & 0 deletions tests/test_pairtab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import unittest
import torch
import numpy as np
from unittest.mock import patch
from deepmd_pt.model.model.pair_tab import PairTabModel

class TestPairTab(unittest.TestCase):

@patch('numpy.loadtxt')
def setUp(self, mock_loadtxt) -> None:

file_path = 'dummy_path'
mock_loadtxt.return_value = np.array([
[0.005, 1. , 2. , 3. ],
[0.01 , 0.8 , 1.6 , 2.4 ],
[0.015, 0.5 , 1. , 1.5 ],
[0.02 , 0.25 , 0.4 , 0.75 ]])

self.model = PairTabModel(
tab_file = file_path,
rcut = 0.1,
sel = 2
)

self.extended_coord = torch.tensor([
[[0.01,0.01,0.01],
[0.01,0.02,0.01],
[0.01,0.01,0.02],
[0.02,0.01,0.01]],

[[0.01,0.01,0.01],
[0.01,0.02,0.01],
[0.01,0.01,0.02],
[0.05,0.01,0.01]],
])

# nframes=2, nall=4
self.extended_atype = torch.tensor([
[0,1,0,1],
[0,0,1,1]
])

# nframes=2, nloc=2, nnei=2
self.nlist = torch.tensor([
[[1,2],[0,2]],
[[1,2],[0,3]]
])

def test_without_mask(self):

result = self.model.forward_atomic(self.extended_coord, self.extended_atype,self.nlist)
expected_result = torch.tensor([[2.4000, 2.7085],
[2.4000, 0.8000]])

torch.testing.assert_allclose(result,expected_result)

def test_with_mask(self):

self.nlist = torch.tensor([
[[1,-1],[0,2]],
[[1,2],[0,3]]
])

result = self.model.forward_atomic(self.extended_coord, self.extended_atype,self.nlist)
expected_result = torch.tensor([[1.6000, 2.7085],
[2.4000, 0.8000]])

torch.testing.assert_allclose(result,expected_result)

def test_jit(self):
model = torch.jit.script(self.model)

if __name__ == '__main__':
unittest.main()
Loading