Skip to content

Commit

Permalink
feat(jax): reformat nlist in the TF model
Browse files Browse the repository at this point in the history
Format the neighbor list in the TF model to convert the dynamic shape to the determined shape, so the TF model can accept the neighbor list with a dynamic shape.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 11, 2024
1 parent dcbf607 commit 1eecf10
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 2 deletions.
71 changes: 71 additions & 0 deletions deepmd/jax/jax2tf/format_nlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf
import tensorflow.experimental.numpy as tnp


@tf.function(autograph=True)
def format_nlist(
extended_coord: tnp.ndarray,
nlist: tnp.ndarray,
nsel: int,
rcut: float,
):
"""Format neighbor list.
If nnei == nsel, do nothing;
If nnei < nsel, pad -1;
If nnei > nsel, sort by distance and truncate.
Parameters
----------
extended_coord
The extended coordinates of the atoms.
shape: nf x nall x 3
nlist
The neighbor list.
shape: nf x nloc x nnei
nsel
The number of selected neighbors.
rcut
The cutoff radius.
Returns
-------
nlist
The formatted neighbor list.
shape: nf x nloc x nsel
"""
nlist_shape = tf.shape(nlist)
n_nf, n_nloc, n_nsel = nlist_shape[0], nlist_shape[1], nlist_shape[2]
extended_coord = extended_coord.reshape([n_nf, -1, 3])

Check warning on line 40 in deepmd/jax/jax2tf/format_nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/format_nlist.py#L38-L40

Added lines #L38 - L40 were not covered by tests

if n_nsel < nsel:

Check warning on line 42 in deepmd/jax/jax2tf/format_nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/format_nlist.py#L42

Added line #L42 was not covered by tests
# make a copy before revise
ret = tnp.concatenate(

Check warning on line 44 in deepmd/jax/jax2tf/format_nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/format_nlist.py#L44

Added line #L44 was not covered by tests
[
nlist,
tnp.full([n_nf, n_nloc, nsel - n_nsel], -1, dtype=nlist.dtype),
],
axis=-1,
)

elif n_nsel > nsel:

Check warning on line 52 in deepmd/jax/jax2tf/format_nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/format_nlist.py#L52

Added line #L52 was not covered by tests
# make a copy before revise
m_real_nei = nlist >= 0
ret = tnp.where(m_real_nei, nlist, 0)
coord0 = extended_coord[:, :n_nloc, :]
index = ret.reshape(n_nf, n_nloc * n_nsel, 1)
index = tnp.repeat(index, 3, axis=2)
coord1 = tnp.take_along_axis(extended_coord, index, axis=1)
coord1 = coord1.reshape(n_nf, n_nloc, n_nsel, 3)
rr2 = tnp.sum(tnp.square(coord0[:, :, None, :] - coord1), axis=-1)
rr2 = tnp.where(m_real_nei, rr2, float("inf"))
rr2, ret_mapping = tnp.sort(rr2, axis=-1), tnp.argsort(rr2, axis=-1)
ret = tnp.take_along_axis(ret, ret_mapping, axis=2)
ret = tnp.where(rr2 > rcut * rcut, -1, ret)
ret = ret[..., :nsel]

Check warning on line 66 in deepmd/jax/jax2tf/format_nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/format_nlist.py#L54-L66

Added lines #L54 - L66 were not covered by tests
else: # n_nsel == nsel:
ret = nlist

Check warning on line 68 in deepmd/jax/jax2tf/format_nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/format_nlist.py#L68

Added line #L68 was not covered by tests
# do a reshape any way; this will tell the xla the shape without any dynamic shape
ret = tnp.reshape(ret, [n_nf, n_nloc, nsel])
return ret

Check warning on line 71 in deepmd/jax/jax2tf/format_nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/format_nlist.py#L70-L71

Added lines #L70 - L71 were not covered by tests
9 changes: 7 additions & 2 deletions deepmd/jax/jax2tf/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
jax2tf,
)

from deepmd.jax.jax2tf.format_nlist import (
format_nlist,
)
from deepmd.jax.jax2tf.make_model import (
model_call_from_call_lower,
)
Expand Down Expand Up @@ -76,7 +79,7 @@ def call_lower_with_fixed_do_atomic_virial(
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
tf.TensorSpec([None, None, None], tf.int64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
Expand All @@ -85,6 +88,7 @@ def call_lower_with_fixed_do_atomic_virial(
def call_lower_without_atomic_virial(
coord, atype, nlist, mapping, fparam, aparam
):
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
return tf.cond(
tf.shape(coord)[1] == tf.shape(nlist)[1],
lambda: exported_whether_do_atomic_virial(
Expand All @@ -102,13 +106,14 @@ def call_lower_without_atomic_virial(
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
tf.TensorSpec([None, None, None], tf.int64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
],
)
def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
return tf.cond(
tf.shape(coord)[1] == tf.shape(nlist)[1],
lambda: exported_whether_do_atomic_virial(
Expand Down
91 changes: 91 additions & 0 deletions source/jax2tf_tests/test_format_nlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf
import tensorflow.experimental.numpy as tnp

from deepmd.jax.jax2tf.format_nlist import (
format_nlist,
)
from deepmd.jax.jax2tf.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
)

GLOBAL_SEED = 20241110


class TestFormatNlist(tf.test.TestCase):
def setUp(self):
self.nf = 3
self.nloc = 3
self.ns = 5 * 5 * 3
self.nall = self.ns * self.nloc
self.cell = tnp.array(
[[[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]]], dtype=tnp.float64
)
self.icoord = tnp.array(
[[[0.035, 0.062, 0.064], [0.085, 0.058, 0.021], [0.537, 0.553, 0.124]]],
dtype=tnp.float64,
)
self.atype = tnp.array([[1, 0, 1]], dtype=tnp.int32)
self.nsel = [10, 10]
self.rcut = 1.01

self.ecoord, self.eatype, mapping = extend_coord_with_ghosts(

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable mapping is not used.
self.icoord, self.atype, self.cell, self.rcut
)
self.nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut,
sum(self.nsel),
distinguish_types=False,
)

def test_format_nlist_equal(self):
nlist = format_nlist(self.ecoord, self.nlist, sum(self.nsel), self.rcut)
self.assertAllEqual(nlist, self.nlist)

def test_format_nlist_less(self):
nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut,
sum(self.nsel) - 5,
distinguish_types=False,
)
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
self.assertAllEqual(nlist, self.nlist)

def test_format_nlist_large(self):
nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut,
sum(self.nsel) + 5,
distinguish_types=False,
)
# random shuffle
shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2]))
nlist = tnp.take(nlist, shuffle_idx, axis=2)
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
# we only need to ensure the result is correct, no need to check the order
self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1))

def test_format_nlist_larger_rcut(self):
nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut * 2,
40,
distinguish_types=False,
)
# random shuffle
shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2]))
nlist = tnp.take(nlist, shuffle_idx, axis=2)
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
# we only need to ensure the result is correct, no need to check the order
self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1))

0 comments on commit 1eecf10

Please sign in to comment.