Skip to content

Commit

Permalink
Merge branch 'devel' into 4108
Browse files Browse the repository at this point in the history
  • Loading branch information
Chengqian-Zhang authored Sep 20, 2024
2 parents b2f62b7 + c084b20 commit 3fbc039
Show file tree
Hide file tree
Showing 16 changed files with 508 additions and 361 deletions.
5 changes: 5 additions & 0 deletions deepmd/pt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from deepmd.pt.cxx_op import (
ENABLE_CUSTOMIZED_OP,
)
from deepmd.utils.entry_point import (
load_entry_point,
)

load_entry_point("deepmd.pt")

__all__ = [
"ENABLE_CUSTOMIZED_OP",
Expand Down
5 changes: 0 additions & 5 deletions deepmd/pt/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.utils.entry_point import (
load_entry_point,
)

load_entry_point("deepmd.pt")
8 changes: 0 additions & 8 deletions deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,6 @@ def get_leading_dims(
return list(vshape[: (len(vshape) - len(vdef.shape))])


def get_atom_axis(
vdef: torch.Tensor,
):
"""Get the axis of atoms."""
atom_axis = -(len(vdef.shape) + 1)
return atom_axis


def take_deriv(
vv: torch.Tensor,
svv: torch.Tensor,
Expand Down
287 changes: 0 additions & 287 deletions deepmd/pt/utils/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,231 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
Union,
)

import torch

from deepmd.pt.utils import (
env,
)

log = logging.getLogger(__name__)


class Region3D:
def __init__(self, boxt):
"""Construct a simulation box."""
boxt = boxt.reshape([3, 3])
self.boxt = boxt # convert physical coordinates to internal ones
self.rec_boxt = torch.linalg.inv(
self.boxt
) # convert internal coordinates to physical ones

self.volume = torch.linalg.det(self.boxt) # compute the volume

# boxt = boxt.permute(1, 0)
c_yz = torch.cross(boxt[1], boxt[2])
self._h2yz = self.volume / torch.linalg.norm(c_yz)
c_zx = torch.cross(boxt[2], boxt[0])
self._h2zx = self.volume / torch.linalg.norm(c_zx)
c_xy = torch.cross(boxt[0], boxt[1])
self._h2xy = self.volume / torch.linalg.norm(c_xy)

def phys2inter(self, coord):
"""Convert physical coordinates to internal ones."""
return coord @ self.rec_boxt

def inter2phys(self, coord):
"""Convert internal coordinates to physical ones."""
return coord @ self.boxt

def get_face_distance(self):
"""Return face distinces to each surface of YZ, ZX, XY."""
return torch.stack([self._h2yz, self._h2zx, self._h2xy])


def normalize_coord(coord, region: Region3D, nloc: int):
"""Move outer atoms into region by mirror.
Args:
- coord: shape is [nloc*3]
"""
tmp_coord = coord.clone()
inter_cood = torch.remainder(region.phys2inter(tmp_coord), 1.0)
tmp_coord = region.inter2phys(inter_cood)
return tmp_coord


def compute_serial_cid(cell_offset, ncell):
"""Tell the sequential cell ID in its 3D space.
Args:
- cell_offset: shape is [3]
- ncell: shape is [3]
"""
cell_offset[:, 0] *= ncell[1] * ncell[2]
cell_offset[:, 1] *= ncell[2]
return cell_offset.sum(-1)


def compute_pbc_shift(cell_offset, ncell):
"""Tell shift count to move the atom into region."""
shift = torch.zeros_like(cell_offset)
shift = shift + (cell_offset < 0) * -(
torch.div(cell_offset, ncell, rounding_mode="floor")
)
shift = shift + (cell_offset >= ncell) * -(
torch.div((cell_offset - ncell), ncell, rounding_mode="floor") + 1
)
assert torch.all(cell_offset + shift * ncell >= 0)
assert torch.all(cell_offset + shift * ncell < ncell)
return shift


def build_inside_clist(coord, region: Region3D, ncell):
"""Build cell list on atoms inside region.
Args:
- coord: shape is [nloc*3]
- ncell: shape is [3]
"""
loc_ncell = int(torch.prod(ncell)) # num of local cells
nloc = coord.numel() // 3 # num of local atoms
inter_cell_size = 1.0 / ncell

inter_cood = region.phys2inter(coord.view(-1, 3))
cell_offset = torch.floor(inter_cood / inter_cell_size).to(torch.long)
# numerical error brought by conversion from phys to inter back and force
# may lead to negative value
cell_offset[cell_offset < 0] = 0
delta = cell_offset - ncell
a2c = compute_serial_cid(cell_offset, ncell) # cell id of atoms
arange = torch.arange(0, loc_ncell, 1) # pylint: disable=no-explicit-dtype,no-explicit-device
cellid = a2c == arange.unsqueeze(-1) # one hot cellid
c2a = cellid.nonzero()
lst = []
cnt = 0
bincount = torch.bincount(a2c, minlength=loc_ncell)
for i in range(loc_ncell):
n = bincount[i]
lst.append(c2a[cnt : cnt + n, 1])
cnt += n
return a2c, lst


def append_neighbors(coord, region: Region3D, atype, rcut: float):
"""Make ghost atoms who are valid neighbors.
Args:
- coord: shape is [nloc*3]
- atype: shape is [nloc]
"""
to_face = region.get_face_distance()

# compute num and size of local cells
ncell = torch.floor(to_face / rcut).to(torch.long)
ncell[ncell == 0] = 1
cell_size = to_face / ncell
ngcell = (
torch.floor(rcut / cell_size).to(torch.long) + 1
) # num of cells out of local, which contain ghost atoms

# add ghost atoms
a2c, c2a = build_inside_clist(coord, region, ncell)
xi = torch.arange(-ngcell[0], ncell[0] + ngcell[0], 1) # pylint: disable=no-explicit-dtype,no-explicit-device
yi = torch.arange(-ngcell[1], ncell[1] + ngcell[1], 1) # pylint: disable=no-explicit-dtype,no-explicit-device
zi = torch.arange(-ngcell[2], ncell[2] + ngcell[2], 1) # pylint: disable=no-explicit-dtype,no-explicit-device
xyz = xi.view(-1, 1, 1, 1) * torch.tensor([1, 0, 0], dtype=torch.long) # pylint: disable=no-explicit-device
xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor([0, 1, 0], dtype=torch.long) # pylint: disable=no-explicit-device
xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor([0, 0, 1], dtype=torch.long) # pylint: disable=no-explicit-device
xyz = xyz.view(-1, 3)
mask_a = (xyz >= 0).all(dim=-1)
mask_b = (xyz < ncell).all(dim=-1)
mask = ~torch.logical_and(mask_a, mask_b)
xyz = xyz[mask] # cell coord
shift = compute_pbc_shift(xyz, ncell)
coord_shift = region.inter2phys(shift.to(env.GLOBAL_PT_FLOAT_PRECISION))
mirrored = shift * ncell + xyz
cid = compute_serial_cid(mirrored, ncell)

n_atoms = coord.shape[0]
aid = [c2a[ci] + i * n_atoms for i, ci in enumerate(cid)]
aid = torch.cat(aid)
tmp = torch.div(aid, n_atoms, rounding_mode="trunc")
aid = aid % n_atoms
tmp_coord = coord[aid] - coord_shift[tmp]
tmp_atype = atype[aid]

# merge local and ghost atoms
merged_coord = torch.cat([coord, tmp_coord])
merged_coord_shift = torch.cat([torch.zeros_like(coord), coord_shift[tmp]])
merged_atype = torch.cat([atype, tmp_atype])
merged_mapping = torch.cat([torch.arange(atype.numel()), aid]) # pylint: disable=no-explicit-dtype,no-explicit-device
return merged_coord_shift, merged_atype, merged_mapping


def build_neighbor_list(
nloc: int, coord, atype, rcut: float, sec, mapping, type_split=True, min_check=False
):
"""For each atom inside region, build its neighbor list.
Args:
- coord: shape is [nall*3]
- atype: shape is [nall]
"""
nall = coord.numel() // 3
coord = coord.float()
nlist = [[] for _ in range(nloc)]
coord_l = coord.view(-1, 1, 3)[:nloc]
coord_r = coord.view(1, -1, 3)
distance = coord_l - coord_r
distance = torch.linalg.norm(distance, dim=-1)
DISTANCE_INF = distance.max().detach() + rcut
distance[:nloc, :nloc] += torch.eye(nloc, dtype=torch.bool) * DISTANCE_INF # pylint: disable=no-explicit-device
if min_check:
if distance.min().abs() < 1e-6:
raise RuntimeError("Atom dist too close!")
if not type_split:
sec = sec[-1:]
lst = []
nlist = torch.zeros((nloc, sec[-1].item())).long() - 1 # pylint: disable=no-explicit-dtype,no-explicit-device
nlist_loc = torch.zeros((nloc, sec[-1].item())).long() - 1 # pylint: disable=no-explicit-dtype,no-explicit-device
nlist_type = torch.zeros((nloc, sec[-1].item())).long() - 1 # pylint: disable=no-explicit-dtype,no-explicit-device
for i, nnei in enumerate(sec):
if i > 0:
nnei = nnei - sec[i - 1]
if not type_split:
tmp = distance
else:
mask = atype.unsqueeze(0) == i
tmp = distance + (~mask) * DISTANCE_INF
if tmp.shape[1] >= nnei:
_sorted, indices = torch.topk(tmp, nnei, dim=1, largest=False)
else:
# when nnei > nall
indices = torch.zeros((nloc, nnei)).long() - 1 # pylint: disable=no-explicit-dtype,no-explicit-device
_sorted = torch.ones((nloc, nnei)).long() * DISTANCE_INF # pylint: disable=no-explicit-dtype,no-explicit-device
_sorted_nnei, indices_nnei = torch.topk(
tmp, tmp.shape[1], dim=1, largest=False
)
_sorted[:, : tmp.shape[1]] = _sorted_nnei
indices[:, : tmp.shape[1]] = indices_nnei
mask = (_sorted < rcut).to(torch.long)
indices_loc = mapping[indices]
indices = indices * mask + -1 * (1 - mask) # -1 for padding
indices_loc = indices_loc * mask + -1 * (1 - mask) # -1 for padding
if i == 0:
start = 0
else:
start = sec[i - 1]
end = min(sec[i], start + indices.shape[1])
nlist[:, start:end] = indices[:, :nnei]
nlist_loc[:, start:end] = indices_loc[:, :nnei]
nlist_type[:, start:end] = atype[indices[:, :nnei]] * mask + -1 * (1 - mask)
return nlist, nlist_loc, nlist_type


def compute_smooth_weight(distance, rmin: float, rmax: float):
"""Compute smooth weight for descriptor elements."""
if rmin >= rmax:
Expand All @@ -236,70 +16,3 @@ def compute_smooth_weight(distance, rmin: float, rmax: float):
uu = (distance - rmin) / (rmax - rmin)
vv = uu * uu * uu * (-6 * uu * uu + 15 * uu - 10) + 1
return vv * mid_mask + min_mask


def make_env_mat(
coord,
atype,
region,
rcut: Union[float, list],
sec,
pbc=True,
type_split=True,
min_check=False,
):
"""Based on atom coordinates, return environment matrix.
Returns
-------
nlist: nlist, [nloc, nnei]
merged_coord_shift: shift on nall atoms, [nall, 3]
merged_mapping: mapping from nall index to nloc index, [nall]
"""
# move outer atoms into cell
hybrid = isinstance(rcut, list)
_rcut = rcut
if hybrid:
_rcut = max(rcut)
if pbc:
merged_coord_shift, merged_atype, merged_mapping = append_neighbors(
coord, region, atype, _rcut
)
merged_coord = coord[merged_mapping] - merged_coord_shift
if merged_coord.shape[0] <= coord.shape[0]:
log.warning("No ghost atom is added for system ")
else:
merged_coord_shift = torch.zeros_like(coord)
merged_atype = atype.clone()
merged_mapping = torch.arange(atype.numel()) # pylint: disable=no-explicit-dtype,no-explicit-device
merged_coord = coord.clone()

# build nlist
if not hybrid:
nlist, nlist_loc, nlist_type = build_neighbor_list(
coord.shape[0],
merged_coord,
merged_atype,
rcut,
sec,
merged_mapping,
type_split=type_split,
min_check=min_check,
)
else:
nlist, nlist_loc, nlist_type = [], [], []
for ii, single_rcut in enumerate(rcut):
nlist_tmp, nlist_loc_tmp, nlist_type_tmp = build_neighbor_list(
coord.shape[0],
merged_coord,
merged_atype,
single_rcut,
sec[ii],
merged_mapping,
type_split=type_split,
min_check=min_check,
)
nlist.append(nlist_tmp)
nlist_loc.append(nlist_loc_tmp)
nlist_type.append(nlist_type_tmp)
return nlist, nlist_loc, nlist_type, merged_coord_shift, merged_mapping
11 changes: 0 additions & 11 deletions deepmd/pt/utils/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,6 @@ def to_face_distance(
return dist.view(list(cshape[:-2]) + [3]) # noqa:RUF005


def _to_face_distance(cell):
volume = torch.linalg.det(cell)
c_yz = torch.cross(cell[1], cell[2])
_h2yz = volume / torch.linalg.norm(c_yz)
c_zx = torch.cross(cell[2], cell[0])
_h2zx = volume / torch.linalg.norm(c_zx)
c_xy = torch.cross(cell[0], cell[1])
_h2xy = volume / torch.linalg.norm(c_xy)
return torch.stack([_h2yz, _h2zx, _h2xy])


def b_to_face_distance(cell):
volume = torch.linalg.det(cell)
c_yz = torch.cross(cell[:, 1], cell[:, 2], dim=-1)
Expand Down
2 changes: 1 addition & 1 deletion doc/model/dpa2.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

The DPA-2 model implementation. See https://arxiv.org/abs/2312.15492 for more details.

Training example: `examples/water/dpa2/input_torch.json`.
Training example: `examples/water/dpa2/input_torch_medium.json`, see [README](../../examples/water/dpa2/README.md) for inputs in different levels.

## Data format

Expand Down
15 changes: 15 additions & 0 deletions examples/water/dpa2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
## Inputs for DPA-2 model

This directory contains the input files for training the DPA-2 model (currently supporting PyTorch backend only). Depending on your precision/efficiency requirements, we provide three different levels of model complexity:

- `input_torch_small.json`: Our smallest DPA-2 model, optimized for speed.
- `input_torch_medium.json` (Recommended): Our well-performing DPA-2 model, balancing efficiency and precision. This is a good starting point for most users.
- `input_torch_large.json`: Our most complex model with the highest precision, suitable for very intricate data structures.

For detailed differences in their configurations, please refer to the table below:

| Input | Repformer layers | Three-body embedding in Repinit | Pair-wise attention in Repformer | Tuned sub-structures in [#4089](https://github.com/deepmodeling/deepmd-kit/pull/4089) | Description |
| ------------------------- | ---------------- | ------------------------------- | -------------------------------- | ------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------- |
| `input_torch_small.json` | 3 |||| Smallest DPA-2 model, optimized for speed. |
| `input_torch_medium.json` | 6 |||| Recommended well-performing DPA-2 model, balancing efficiency and precision. |
| `input_torch_large.json` | 12 |||| Most complex model with the highest precision. |
Loading

0 comments on commit 3fbc039

Please sign in to comment.