Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): build nlist in the SavedModel & fix nopbc for StableHLO and SavedModel #4318

Merged
merged 32 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
d924013
checkpoint
njzjz Nov 6, 2024
373ea65
bugfix
njzjz Nov 6, 2024
933e4df
bugfix
njzjz Nov 6, 2024
94d2054
nopbc
njzjz Nov 6, 2024
84cb819
clean up default values
njzjz Nov 6, 2024
bd27d4f
add tests
njzjz Nov 7, 2024
a00aae8
skip the whole module
njzjz Nov 7, 2024
95c476b
fix tests on ci
njzjz Nov 7, 2024
8bef185
fix
njzjz Nov 7, 2024
5e10621
fix skip testing. I am still confused why it doesn't work
njzjz Nov 7, 2024
fc7b6b7
try to resolve OOM issue
njzjz Nov 7, 2024
9f9c174
Merge branch 'devel' into tf-call
njzjz Nov 8, 2024
31f9e87
limit threads during tests
njzjz Nov 8, 2024
050c20c
set xla flags before any imports
njzjz Nov 8, 2024
19f798c
set NPROC
njzjz Nov 8, 2024
befc0c7
I don;t understand why the tests fail randomly agter I add some tests...
njzjz Nov 8, 2024
89e8371
typo
njzjz Nov 8, 2024
8570079
release memory?
njzjz Nov 8, 2024
89041ed
--cov-append
njzjz Nov 8, 2024
b9ff755
try scope module
njzjz Nov 9, 2024
b1b8dd9
Revert "try scope module"
njzjz Nov 9, 2024
48dca4b
Revert "release memory?"
njzjz Nov 9, 2024
8d7a230
Revert "typo"
njzjz Nov 9, 2024
5762a31
Revert "I don;t understand why the tests fail randomly agter I add so…
njzjz Nov 9, 2024
ee831e0
Revert "set NPROC"
njzjz Nov 9, 2024
e6953c9
Revert "set xla flags before any imports"
njzjz Nov 9, 2024
c365cb6
Revert "limit threads during tests"
njzjz Nov 9, 2024
c94302d
Revert "try to resolve OOM issue"
njzjz Nov 9, 2024
b0a496c
move the time-comsuming test out of main test
njzjz Nov 9, 2024
4e42105
try
njzjz Nov 9, 2024
215efc6
try to move jax2tf_tests to a different directory
njzjz Nov 9, 2024
e4bac35
clean up
njzjz Nov 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
env:
NUM_WORKERS: 0
- name: Test TF2 eager mode
run: pytest --cov=deepmd source/tests/consistent/io/test_io.py --durations=0
run: pytest --cov=deepmd --cov-append source/tests/consistent/io/test_io.py source/jax2tf_tests --durations=0
env:
NUM_WORKERS: 0
DP_TEST_TF2_ONLY: 1
Expand Down
6 changes: 6 additions & 0 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def __init__(
stablehlo_atomic_virial=model_data["@variables"][
"stablehlo_atomic_virial"
].tobytes(),
stablehlo_no_ghost=model_data["@variables"][
"stablehlo_no_ghost"
].tobytes(),
stablehlo_atomic_virial_no_ghost=model_data["@variables"][
"stablehlo_atomic_virial_no_ghost"
].tobytes(),
model_def_script=model_data["model_def_script"],
**model_data["constants"],
)
Expand Down
3 changes: 3 additions & 0 deletions deepmd/jax/jax2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf
import tensorflow.experimental.numpy as tnp

if not tf.executing_eagerly():
# TF disallow temporary eager execution
Expand All @@ -9,3 +10,5 @@
"If you are converting a model between different backends, "
"considering converting to the `.dp` format first."
)

tnp.experimental_enable_numpy_behavior()
110 changes: 110 additions & 0 deletions deepmd/jax/jax2tf/make_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
)

import tensorflow as tf
import tensorflow.experimental.numpy as tnp

from deepmd.dpmodel.output_def import (
ModelOutputDef,
)
from deepmd.jax.jax2tf.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
)
from deepmd.jax.jax2tf.region import (
normalize_coord,
)
from deepmd.jax.jax2tf.transform_output import (
communicate_extended_output,
)


def model_call_from_call_lower(
*, # enforce keyword-only arguments
call_lower: Callable[
[
tnp.ndarray,
tnp.ndarray,
tnp.ndarray,
tnp.ndarray,
tnp.ndarray,
bool,
],
dict[str, tnp.ndarray],
],
rcut: float,
sel: list[int],
mixed_types: bool,
model_output_def: ModelOutputDef,
coord: tnp.ndarray,
atype: tnp.ndarray,
box: tnp.ndarray,
fparam: tnp.ndarray,
aparam: tnp.ndarray,
do_atomic_virial: bool = False,
):
"""Return model prediction from lower interface.

Parameters
----------
coord
The coordinates of the atoms.
shape: nf x (nloc x 3)
atype
The type of atoms. shape: nf x nloc
box
The simulation box. shape: nf x 9
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda
do_atomic_virial
If calculate the atomic virial.

Returns
-------
ret_dict
The result dict of type dict[str,tnp.ndarray].
The keys are defined by the `ModelOutputDef`.

"""
atype_shape = tf.shape(atype)
nframes, nloc = atype_shape[0], atype_shape[1]
cc, bb, fp, ap = coord, box, fparam, aparam
del coord, box, fparam, aparam
if tf.shape(bb)[-1] != 0:
coord_normalized = normalize_coord(

Check warning on line 78 in deepmd/jax/jax2tf/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/make_model.py#L73-L78

Added lines #L73 - L78 were not covered by tests
cc.reshape(nframes, nloc, 3),
bb.reshape(nframes, 3, 3),
)
else:
coord_normalized = cc
njzjz marked this conversation as resolved.
Show resolved Hide resolved
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(

Check warning on line 84 in deepmd/jax/jax2tf/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/make_model.py#L83-L84

Added lines #L83 - L84 were not covered by tests
njzjz marked this conversation as resolved.
Show resolved Hide resolved
coord_normalized, atype, bb, rcut
)
nlist = build_neighbor_list(

Check warning on line 87 in deepmd/jax/jax2tf/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/make_model.py#L87

Added line #L87 was not covered by tests
extended_coord,
extended_atype,
nloc,
rcut,
sel,
distinguish_types=not mixed_types,
)
extended_coord = extended_coord.reshape(nframes, -1, 3)
model_predict_lower = call_lower(

Check warning on line 96 in deepmd/jax/jax2tf/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/make_model.py#L95-L96

Added lines #L95 - L96 were not covered by tests
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fp,
aparam=ap,
)
model_predict = communicate_extended_output(

Check warning on line 104 in deepmd/jax/jax2tf/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/make_model.py#L104

Added line #L104 was not covered by tests
model_predict_lower,
model_output_def,
mapping,
do_atomic_virial=do_atomic_virial,
)
return model_predict

Check warning on line 110 in deepmd/jax/jax2tf/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/make_model.py#L110

Added line #L110 was not covered by tests
217 changes: 217 additions & 0 deletions deepmd/jax/jax2tf/nlist.py
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Union,
)

import tensorflow as tf
import tensorflow.experimental.numpy as tnp

from .region import (
to_face_distance,
)


## translated from torch implementation by chatgpt
def build_neighbor_list(
coord: tnp.ndarray,
atype: tnp.ndarray,
nloc: int,
rcut: float,
sel: Union[int, list[int]],
distinguish_types: bool = True,
) -> tnp.ndarray:
"""Build neighbor list for a single frame. keeps nsel neighbors.

Parameters
----------
coord : tnp.ndarray
exptended coordinates of shape [batch_size, nall x 3]
atype : tnp.ndarray
extended atomic types of shape [batch_size, nall]
type < 0 the atom is treat as virtual atoms.
nloc : int
number of local atoms.
rcut : float
cut-off radius
sel : int or list[int]
maximal number of neighbors (of each type).
if distinguish_types==True, nsel should be list and
the length of nsel should be equal to number of
types.
distinguish_types : bool
distinguish different types.

Returns
-------
neighbor_list : tnp.ndarray
Neighbor list of shape [batch_size, nloc, nsel], the neighbors
are stored in an ascending order. If the number of
neighbors is less than nsel, the positions are masked
with -1. The neighbor list of an atom looks like
|------ nsel ------|
xx xx xx xx -1 -1 -1
if distinguish_types==True and we have two types
|---- nsel[0] -----| |---- nsel[1] -----|
xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1
For virtual atoms all neighboring positions are filled with -1.

"""
batch_size = tf.shape(coord)[0]
coord = tnp.reshape(coord, (batch_size, -1))
nall = tf.shape(coord)[1] // 3
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
if tf.size(coord) > 0:
xmax = tnp.max(coord) + 2.0 * rcut
else:
xmax = tf.cast(2.0 * rcut, coord.dtype)

Check warning on line 67 in deepmd/jax/jax2tf/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/nlist.py#L67

Added line #L67 was not covered by tests
# nf x nall
is_vir = atype < 0
coord1 = tnp.where(
is_vir[:, :, None], xmax, tnp.reshape(coord, (batch_size, nall, 3))
)
coord1 = tnp.reshape(coord1, (batch_size, nall * 3))
if isinstance(sel, int):
sel = [sel]
nsel = sum(sel)
coord0 = coord1[:, : nloc * 3]
diff = (
tnp.reshape(coord1, [batch_size, -1, 3])[:, None, :, :]
- tnp.reshape(coord0, [batch_size, -1, 3])[:, :, None, :]
)
rr = tf.linalg.norm(diff, axis=-1)
# if central atom has two zero distances, sorting sometimes can not exclude itself
rr -= tf.eye(nloc, nall, dtype=diff.dtype)[tnp.newaxis, :, :]
nlist = tnp.argsort(rr, axis=-1)
rr = tnp.sort(rr, axis=-1)
rr = rr[:, :, 1:]
nlist = nlist[:, :, 1:]
nnei = tf.shape(rr)[2]
if nsel <= nnei:
rr = rr[:, :, :nsel]
nlist = nlist[:, :, :nsel]
else:
rr = tnp.concatenate(

Check warning on line 94 in deepmd/jax/jax2tf/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/nlist.py#L94

Added line #L94 was not covered by tests
[rr, tnp.ones([batch_size, nloc, nsel - nnei], dtype=rr.dtype) + rcut],
axis=-1,
)
nlist = tnp.concatenate(

Check warning on line 98 in deepmd/jax/jax2tf/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/nlist.py#L98

Added line #L98 was not covered by tests
[nlist, tnp.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype)],
axis=-1,
)
nlist = tnp.where(
tnp.logical_or((rr > rcut), is_vir[:, :nloc, None]),
tnp.full_like(nlist, -1),
nlist,
)

if distinguish_types:
return nlist_distinguish_types(nlist, atype, sel)
else:
return nlist


def nlist_distinguish_types(
nlist: tnp.ndarray,
atype: tnp.ndarray,
sel: list[int],
):
"""Given a nlist that does not distinguish atom types, return a nlist that
distinguish atom types.

"""
nloc = tf.shape(nlist)[1]
ret_nlist = []
tmp_atype = tnp.tile(atype[:, None, :], (1, nloc, 1))
mask = nlist == -1
tnlist_0 = tnp.where(mask, tnp.zeros_like(nlist), nlist)
tnlist = tnp.take_along_axis(tmp_atype, tnlist_0, axis=2)
tnlist = tnp.where(mask, tnp.full_like(tnlist, -1), tnlist)
for ii, ss in enumerate(sel):
pick_mask = tf.cast(tnlist == ii, tnp.int32)
sorted_indices = tnp.argsort(-pick_mask, kind="stable", axis=-1)
pick_mask_sorted = -tnp.sort(-pick_mask, axis=-1)
inlist = tnp.take_along_axis(nlist, sorted_indices, axis=2)
inlist = tnp.where(
~tf.cast(pick_mask_sorted, tf.bool), tnp.full_like(inlist, -1), inlist
)
ret_nlist.append(inlist[..., :ss])
ret = tf.concat(ret_nlist, axis=-1)
return ret


def tf_outer(a, b):
return tf.einsum("i,j->ij", a, b)


## translated from torch implementation by chatgpt
def extend_coord_with_ghosts(
coord: tnp.ndarray,
atype: tnp.ndarray,
cell: tnp.ndarray,
rcut: float,
):
"""Extend the coordinates of the atoms by appending peridoc images.
The number of images is large enough to ensure all the neighbors
within rcut are appended.

Parameters
----------
coord : tnp.ndarray
original coordinates of shape [-1, nloc*3].
atype : tnp.ndarray
atom type of shape [-1, nloc].
cell : tnp.ndarray
simulation cell tensor of shape [-1, 9].
rcut : float
the cutoff radius

Returns
-------
extended_coord: tnp.ndarray
extended coordinates of shape [-1, nall*3].
extended_atype: tnp.ndarray
extended atom type of shape [-1, nall].
index_mapping: tnp.ndarray
mapping extended index to the local index

"""
atype_shape = tf.shape(atype)
nf, nloc = atype_shape[0], atype_shape[1]
# int64 for index
aidx = tf.range(nloc, dtype=tnp.int64)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
aidx = tnp.tile(aidx[tnp.newaxis, :], (nf, 1))
if tf.shape(cell)[-1] == 0:
nall = nloc
extend_coord = coord
extend_atype = atype
extend_aidx = aidx

Check warning on line 188 in deepmd/jax/jax2tf/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/nlist.py#L185-L188

Added lines #L185 - L188 were not covered by tests
else:
coord = tnp.reshape(coord, (nf, nloc, 3))
cell = tnp.reshape(cell, (nf, 3, 3))
to_face = to_face_distance(cell)
nbuff = tf.cast(tnp.ceil(rcut / to_face), tnp.int64)
nbuff = tnp.max(nbuff, axis=0)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
xi = tf.range(-nbuff[0], nbuff[0] + 1, 1, dtype=tnp.int64)
yi = tf.range(-nbuff[1], nbuff[1] + 1, 1, dtype=tnp.int64)
zi = tf.range(-nbuff[2], nbuff[2] + 1, 1, dtype=tnp.int64)
xyz = tf_outer(xi, tnp.asarray([1, 0, 0]))[:, tnp.newaxis, tnp.newaxis, :]
xyz = xyz + tf_outer(yi, tnp.asarray([0, 1, 0]))[tnp.newaxis, :, tnp.newaxis, :]
xyz = xyz + tf_outer(zi, tnp.asarray([0, 0, 1]))[tnp.newaxis, tnp.newaxis, :, :]
xyz = tnp.reshape(xyz, (-1, 3))
xyz = tf.cast(xyz, coord.dtype)
shift_idx = tnp.take(xyz, tnp.argsort(tf.linalg.norm(xyz, axis=1)), axis=0)
ns = tf.shape(shift_idx)[0]
nall = ns * nloc
shift_vec = tnp.einsum("sd,fdk->fsk", shift_idx, cell)
# shift_vec = tnp.tensordot(shift_idx, cell, axes=([1], [1]))
# shift_vec = tnp.transpose(shift_vec, (1, 0, 2))
extend_coord = coord[:, None, :, :] + shift_vec[:, :, None, :]
extend_atype = tnp.tile(atype[:, :, tnp.newaxis], (1, ns, 1))
extend_aidx = tnp.tile(aidx[:, :, tnp.newaxis], (1, ns, 1))

return (
tnp.reshape(extend_coord, (nf, nall * 3)),
tnp.reshape(extend_atype, (nf, nall)),
tnp.reshape(extend_aidx, (nf, nall)),
)
Loading