Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: support dpa2 model parallel inference #3657

Merged
merged 104 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
ae0f799
init
CaRoLZhangxy Apr 7, 2024
96c9309
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 7, 2024
bd1927f
init
CaRoLZhangxy Apr 8, 2024
8350372
fix
CaRoLZhangxy Apr 8, 2024
28ae599
finish
CaRoLZhangxy Apr 8, 2024
1afd8fc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2024
7f6632a
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 15, 2024
29d1bec
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 17, 2024
2a7db1e
use google cuda define
CaRoLZhangxy Apr 17, 2024
6af0d63
update forward api
CaRoLZhangxy Apr 17, 2024
3020781
remove frozen model
CaRoLZhangxy Apr 17, 2024
420868f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
c779828
Merge branch 'dis' of https://github.com/CaRoLZhangxy/deepmd-kit into…
CaRoLZhangxy Apr 17, 2024
7591dd3
be able to compile without mpi
CaRoLZhangxy Apr 17, 2024
3d0f14d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
14a5fed
type to fix mpich
CaRoLZhangxy Apr 17, 2024
313a4b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
a39aff3
remove unused code
CaRoLZhangxy Apr 17, 2024
31a4f0d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
cbb2916
update model
CaRoLZhangxy Apr 18, 2024
05686e8
upload smaller model
CaRoLZhangxy Apr 18, 2024
5dcf5f0
hack to resolve border_op problem
njzjz Apr 18, 2024
fd9177a
update dpa model
CaRoLZhangxy Apr 19, 2024
37989c7
use gpu memcpy
CaRoLZhangxy Apr 19, 2024
a13934b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
06208d2
update ut data
CaRoLZhangxy Apr 19, 2024
48b9833
Merge branch 'dis' of https://github.com/CaRoLZhangxy/deepmd-kit into…
CaRoLZhangxy Apr 19, 2024
761c1c8
update dpa model
CaRoLZhangxy Apr 19, 2024
0ed6116
update ut data
CaRoLZhangxy Apr 19, 2024
6df987c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
305245c
update ut data
CaRoLZhangxy Apr 19, 2024
8e5e41c
rollback ut and only apply new api to dpa2 model
CaRoLZhangxy Apr 19, 2024
44e0e6a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
7bc66be
update ut data
CaRoLZhangxy Apr 19, 2024
3da55f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
50c0f46
add comments
CaRoLZhangxy Apr 19, 2024
38bcdd6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
b49d91d
add ut file
CaRoLZhangxy Apr 19, 2024
46911c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
66303a0
fix bug
CaRoLZhangxy Apr 19, 2024
c9bc208
Merge branch 'dis' of https://github.com/CaRoLZhangxy/deepmd-kit into…
CaRoLZhangxy Apr 19, 2024
dca1202
fix type bug
CaRoLZhangxy Apr 19, 2024
60605a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
048c2af
try to fix mpich compile error
CaRoLZhangxy Apr 19, 2024
1ff60b2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
b7a19cf
fix ut data
CaRoLZhangxy Apr 19, 2024
6eb03f4
low requirement at float
CaRoLZhangxy Apr 19, 2024
3dcb4ba
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 19, 2024
0b485f9
skip no balance test
CaRoLZhangxy Apr 19, 2024
7dc5815
update ut data
CaRoLZhangxy Apr 21, 2024
303644f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 21, 2024
7867e13
update lmp test data
CaRoLZhangxy Apr 21, 2024
e0a08f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 21, 2024
993f05a
Update source/op/pt/comm.cc
CaRoLZhangxy Apr 22, 2024
535ade4
Update source/op/pt/comm.cc
CaRoLZhangxy Apr 22, 2024
f4b4481
Update source/op/pt/comm.cc
CaRoLZhangxy Apr 22, 2024
ffbc4db
Update source/op/pt/comm.cc
CaRoLZhangxy Apr 22, 2024
acf841d
Update source/op/pt/comm.cc
CaRoLZhangxy Apr 22, 2024
9473606
throw error when compiled with mpi without cuda support
CaRoLZhangxy Apr 22, 2024
bc02345
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2024
1952784
support mpich
CaRoLZhangxy Apr 22, 2024
fc2d61b
include errors.h
CaRoLZhangxy Apr 22, 2024
67b68aa
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 22, 2024
bc5f092
apply memcpy when cuda-aware = 0
CaRoLZhangxy Apr 23, 2024
e534e99
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
05cdd92
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 23, 2024
ff17514
Merge branch 'dis' of https://github.com/CaRoLZhangxy/deepmd-kit into…
CaRoLZhangxy Apr 23, 2024
8b23ebb
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 24, 2024
14b43aa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2024
ffd89cd
fix no cuda error
CaRoLZhangxy Apr 24, 2024
09cc940
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2024
e406e6a
fix compile error
CaRoLZhangxy Apr 24, 2024
3959f19
print log.lammps to screen in test_cuda if failed
njzjz Apr 25, 2024
521aa28
skip dpa test on cuda ,add todo and fix codeql
CaRoLZhangxy Apr 25, 2024
1704230
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 25, 2024
a06a49c
make pre-commit.ci pass
njzjz Apr 25, 2024
63123de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2024
f356ca5
add doc
CaRoLZhangxy Apr 25, 2024
5baefcc
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 26, 2024
51125b1
add doc
CaRoLZhangxy Apr 26, 2024
b09e857
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2024
32ba778
try runs-on t4
CaRoLZhangxy Apr 26, 2024
f3b55b4
Update deepmd/pt/model/descriptor/repformers.py
CaRoLZhangxy Apr 26, 2024
92ffb35
rename
CaRoLZhangxy Apr 26, 2024
8ed31c8
Merge branch 'dis' of https://github.com/CaRoLZhangxy/deepmd-kit into…
CaRoLZhangxy Apr 26, 2024
0269b81
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2024
124c2d6
fix
CaRoLZhangxy Apr 26, 2024
189fc6b
run c++ test only
CaRoLZhangxy Apr 26, 2024
0adc34c
deal with mpi not init
CaRoLZhangxy Apr 26, 2024
d368ccc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2024
26b63e0
fix doc format
CaRoLZhangxy Apr 26, 2024
0c76246
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2024
86757bf
try to fix
CaRoLZhangxy Apr 26, 2024
ecdef3e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2024
47242af
fix
CaRoLZhangxy Apr 26, 2024
d051e4b
Merge branch 'dis' of https://github.com/CaRoLZhangxy/deepmd-kit into…
CaRoLZhangxy Apr 26, 2024
8de1785
init mpi_init = 0
CaRoLZhangxy Apr 26, 2024
17b35e0
add world_size
CaRoLZhangxy Apr 26, 2024
74b21e4
add low version support
CaRoLZhangxy Apr 26, 2024
f784505
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2024
5fe1fd1
fix error
CaRoLZhangxy Apr 26, 2024
33c2798
add &
CaRoLZhangxy Apr 26, 2024
273a446
reset test.yml
CaRoLZhangxy Apr 26, 2024
beba142
add doc str in python
CaRoLZhangxy Apr 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,14 @@ def train(FLAGS):

def freeze(FLAGS):
model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model)
if '"type": "dpa2"' in model.model_def_script:
CaRoLZhangxy marked this conversation as resolved.
Show resolved Hide resolved
extra_files = {"type": "dpa2"}
else:
extra_files = {"type": "else"}
torch.jit.save(
model,
FLAGS.output,
{},
extra_files,
)


Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def forward_common_atomic(
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
CaRoLZhangxy marked this conversation as resolved.
Show resolved Hide resolved
) -> Dict[str, torch.Tensor]:
"""Common interface for atomic inference.

Expand Down Expand Up @@ -234,6 +235,7 @@ def forward_common_atomic(
mapping=mapping,
fparam=fparam,
aparam=aparam,
comm_dict=comm_dict,
)
ret_dict = self.apply_out_stat(ret_dict, atype)

Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def forward_atomic(
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
"""Return atomic prediction.

Expand Down Expand Up @@ -163,6 +164,7 @@ def forward_atomic(
extended_atype,
nlist,
mapping=mapping,
comm_dict=comm_dict,
)
assert descriptor is not None
# energy, force
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def forward_atomic(
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
"""Return atomic prediction.

Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def forward_atomic(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
nframes, nloc, nnei = nlist.shape
extended_coord = extended_coord.view(nframes, -1, 3)
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Dict,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -219,6 +220,7 @@ def forward(
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Compute the descriptor.

Expand Down
17 changes: 11 additions & 6 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Dict,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -395,6 +396,7 @@ def forward(
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Compute the descriptor.

Expand Down Expand Up @@ -450,11 +452,13 @@ def forward(
# linear to change shape
g1 = self.g1_shape_tranform(g1)
# mapping g1
assert mapping is not None
mapping_ext = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, g1.shape[-1])
)
g1_ext = torch.gather(g1, 1, mapping_ext)
if comm_dict is None:
assert mapping is not None
mapping_ext = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, g1.shape[-1])
)
g1_ext = torch.gather(g1, 1, mapping_ext)
g1 = g1_ext
CaRoLZhangxy marked this conversation as resolved.
Show resolved Hide resolved
# repformer
g1, g2, h2, rot_mat, sw = self.repformers(
nlist_dict[
Expand All @@ -464,8 +468,9 @@ def forward(
],
extended_coord,
extended_atype,
g1_ext,
g1,
mapping,
comm_dict,
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def forward(
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Calculate decoded embedding for each atom.

Expand Down
48 changes: 41 additions & 7 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,10 @@ def forward(
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
assert mapping is not None
assert extended_atype_embd is not None
if comm_dict is None:
assert extended_atype_embd is not None
CaRoLZhangxy marked this conversation as resolved.
Show resolved Hide resolved
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
atype = extended_atype[:, :nloc]
Expand All @@ -257,9 +258,13 @@ def forward(
sw = sw.masked_fill(~nlist_mask, 0.0)

# [nframes, nloc, tebd_dim]
atype_embd = extended_atype_embd[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim]

if comm_dict is None:
assert isinstance(extended_atype_embd, torch.Tensor) # for jit
atype_embd = extended_atype_embd[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim]
else:
atype_embd = extended_atype_embd
assert isinstance(atype_embd, torch.Tensor) # for jit
g1 = self.act(atype_embd)
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
if not self.direct_dist:
Expand All @@ -275,11 +280,40 @@ def forward(
# if the a neighbor is real or not is indicated by nlist_mask
nlist[nlist == -1] = 0
# nb x nall x ng1
mapping = mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim)
if comm_dict is None:
assert mapping is not None
mapping = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim)
)
for idx, ll in enumerate(self.layers):
# g1: nb x nloc x ng1
# g1_ext: nb x nall x ng1
g1_ext = torch.gather(g1, 1, mapping)
if comm_dict is None:
assert mapping is not None
g1_ext = torch.gather(g1, 1, mapping)
else:
n_padding = nall - nloc
g1 = torch.nn.functional.pad(
g1.squeeze(0), (0, 0, 0, n_padding), value=0.0
)
assert "send_list" in comm_dict
assert "send_proc" in comm_dict
assert "recv_proc" in comm_dict
assert "send_num" in comm_dict
assert "recv_num" in comm_dict
assert "communicator" in comm_dict
ret = torch.ops.deepmd.border_op(
comm_dict["send_list"],
comm_dict["send_proc"],
comm_dict["recv_proc"],
comm_dict["send_num"],
comm_dict["recv_num"],
g1,
comm_dict["communicator"],
torch.tensor(nloc),
torch.tensor(nall - nloc),
CaRoLZhangxy marked this conversation as resolved.
Show resolved Hide resolved
)
g1_ext = ret[0].unsqueeze(0)
g1, g2, h2 = ll.forward(
g1_ext,
g2,
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def forward(
atype_ext: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Compute the descriptor.

Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def forward_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
model_ret = self.forward_common_lower(
extended_coord,
Expand All @@ -92,6 +93,7 @@ def forward_lower(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
)
if self.get_fitting_net() is not None:
model_predict = {}
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def forward_common_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Return model prediction. Lower interface that takes
extended atomic coordinates and types, nlist, and mapping
Expand Down Expand Up @@ -254,6 +255,7 @@ def forward_common_lower(
mapping=mapping,
fparam=fp,
aparam=ap,
comm_dict=comm_dict,
)
model_predict = fit_output_to_model_output(
atomic_ret,
Expand Down
13 changes: 13 additions & 0 deletions source/api_c/include/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ extern DP_Nlist* DP_NewNlist(int inum_,
int* numneigh_,
int** firstneigh_);

extern DP_Nlist* DP_NewNlist_comm(int inum_,
int* ilist_,
int* numneigh_,
int** firstneigh_,
int nswap,
int* sendnum,
int* recvnum,
int* firstrecv,
int** sendlist,
int* sendproc,
int* recvproc,
void* world);

CaRoLZhangxy marked this conversation as resolved.
Show resolved Hide resolved
/**
* @brief Delete a neighbor list.
*
Expand Down
30 changes: 30 additions & 0 deletions source/api_c/include/deepmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,36 @@
nl(DP_NewNlist(inum_, ilist_, numneigh_, firstneigh_)) {
DP_CHECK_OK(DP_NlistCheckOK, nl);
};
InputNlist(int inum_,
int *ilist_,
int *numneigh_,
int **firstneigh_,
int nswap,
int *sendnum,
int *recvnum,
int *firstrecv,
int **sendlist,
int *sendproc,
int *recvproc,
void *world)
: inum(inum_),
ilist(ilist_),
numneigh(numneigh_),
firstneigh(firstneigh_),
nl(DP_NewNlist_comm(inum_,
ilist_,
numneigh_,
firstneigh_,
nswap,
sendnum,
recvnum,
firstrecv,
sendlist,
sendproc,
recvproc,
world)){
// DP_CHECK_OK(DP_NlistCheckOK, nl);
Fixed Show fixed Hide fixed
};
~InputNlist() { DP_DeleteNlist(nl); };
/// @brief C API neighbor list.
DP_Nlist *nl;
Expand Down
19 changes: 18 additions & 1 deletion source/api_c/src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,24 @@ DP_Nlist* DP_NewNlist(int inum_,
deepmd::InputNlist nl(inum_, ilist_, numneigh_, firstneigh_);
DP_Nlist* new_nl = new DP_Nlist(nl); return new_nl;)
}

DP_Nlist* DP_NewNlist_comm(int inum_,
int* ilist_,
int* numneigh_,
int** firstneigh_,
int nswap,
int* sendnum,
int* recvnum,
int* firstrecv,
int** sendlist,
int* sendproc,
int* recvproc,
void* world) {
deepmd::InputNlist nl(inum_, ilist_, numneigh_, firstneigh_, nswap, sendnum,
recvnum, firstrecv, sendlist, sendproc, recvproc,
world);
DP_Nlist* new_nl = new DP_Nlist(nl);
return new_nl;
}
void DP_DeleteNlist(DP_Nlist* nl) { delete nl; }

DP_DeepPot::DP_DeepPot() {}
Expand Down
2 changes: 2 additions & 0 deletions source/api_cc/include/DeepPotPT.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,10 @@ class DeepPotPT : public DeepPotBase {
NeighborListData nlist_data;
int max_num_neighbors;
int gpu_id;
int model_type; // 1:dpa2 model 0:others
bool gpu_enabled;
at::Tensor firstneigh_tensor;
torch::Dict<std::string, torch::Tensor> comm_dict;
};

} // namespace deepmd
Loading
Loading