Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): reformat nlist in the TF model #4336

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions deepmd/jax/jax2tf/format_nlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf
import tensorflow.experimental.numpy as tnp


@tf.function(autograph=True)
def format_nlist(
extended_coord: tnp.ndarray,
nlist: tnp.ndarray,
nsel: int,
rcut: float,
):
"""Format neighbor list.
If nnei == nsel, do nothing;
If nnei < nsel, pad -1;
If nnei > nsel, sort by distance and truncate.
Parameters
----------
extended_coord
The extended coordinates of the atoms.
shape: nf x nall x 3
nlist
The neighbor list.
shape: nf x nloc x nnei
nsel
The number of selected neighbors.
rcut
The cutoff radius.
Returns
-------
nlist
The formatted neighbor list.
shape: nf x nloc x nsel
"""
nlist_shape = tf.shape(nlist)
n_nf, n_nloc, n_nsel = nlist_shape[0], nlist_shape[1], nlist_shape[2]
extended_coord = extended_coord.reshape([n_nf, -1, 3])

Check warning on line 40 in deepmd/jax/jax2tf/format_nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/format_nlist.py#L38-L40

Added lines #L38 - L40 were not covered by tests
njzjz marked this conversation as resolved.
Show resolved Hide resolved

if n_nsel < nsel:

Check warning on line 42 in deepmd/jax/jax2tf/format_nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/format_nlist.py#L42

Added line #L42 was not covered by tests
# make a copy before revise
ret = tnp.concatenate(

Check warning on line 44 in deepmd/jax/jax2tf/format_nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/format_nlist.py#L44

Added line #L44 was not covered by tests
[
nlist,
tnp.full([n_nf, n_nloc, nsel - n_nsel], -1, dtype=nlist.dtype),
],
axis=-1,
)

elif n_nsel > nsel:

Check warning on line 52 in deepmd/jax/jax2tf/format_nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/format_nlist.py#L52

Added line #L52 was not covered by tests
# make a copy before revise
m_real_nei = nlist >= 0
ret = tnp.where(m_real_nei, nlist, 0)
coord0 = extended_coord[:, :n_nloc, :]
index = ret.reshape(n_nf, n_nloc * n_nsel, 1)
index = tnp.repeat(index, 3, axis=2)
coord1 = tnp.take_along_axis(extended_coord, index, axis=1)
coord1 = coord1.reshape(n_nf, n_nloc, n_nsel, 3)
rr2 = tnp.sum(tnp.square(coord0[:, :, None, :] - coord1), axis=-1)
rr2 = tnp.where(m_real_nei, rr2, float("inf"))
rr2, ret_mapping = tnp.sort(rr2, axis=-1), tnp.argsort(rr2, axis=-1)
ret = tnp.take_along_axis(ret, ret_mapping, axis=2)
ret = tnp.where(rr2 > rcut * rcut, -1, ret)
ret = ret[..., :nsel]

Check warning on line 66 in deepmd/jax/jax2tf/format_nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/format_nlist.py#L54-L66

Added lines #L54 - L66 were not covered by tests
else: # n_nsel == nsel:
ret = nlist

Check warning on line 68 in deepmd/jax/jax2tf/format_nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/format_nlist.py#L68

Added line #L68 was not covered by tests
# do a reshape any way; this will tell the xla the shape without any dynamic shape
ret = tnp.reshape(ret, [n_nf, n_nloc, nsel])
return ret

Check warning on line 71 in deepmd/jax/jax2tf/format_nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/jax2tf/format_nlist.py#L70-L71

Added lines #L70 - L71 were not covered by tests
9 changes: 7 additions & 2 deletions deepmd/jax/jax2tf/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
jax2tf,
)

from deepmd.jax.jax2tf.format_nlist import (
format_nlist,
)
from deepmd.jax.jax2tf.make_model import (
model_call_from_call_lower,
)
Expand Down Expand Up @@ -76,7 +79,7 @@ def call_lower_with_fixed_do_atomic_virial(
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
tf.TensorSpec([None, None, None], tf.int64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
Expand All @@ -85,6 +88,7 @@ def call_lower_with_fixed_do_atomic_virial(
def call_lower_without_atomic_virial(
coord, atype, nlist, mapping, fparam, aparam
):
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
return tf.cond(
tf.shape(coord)[1] == tf.shape(nlist)[1],
lambda: exported_whether_do_atomic_virial(
Expand All @@ -102,13 +106,14 @@ def call_lower_without_atomic_virial(
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
tf.TensorSpec([None, None, None], tf.int64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
],
)
def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
return tf.cond(
tf.shape(coord)[1] == tf.shape(nlist)[1],
lambda: exported_whether_do_atomic_virial(
Expand Down
91 changes: 91 additions & 0 deletions source/jax2tf_tests/test_format_nlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf
import tensorflow.experimental.numpy as tnp

from deepmd.jax.jax2tf.format_nlist import (
format_nlist,
)
from deepmd.jax.jax2tf.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
)

GLOBAL_SEED = 20241110


class TestFormatNlist(tf.test.TestCase):
def setUp(self):
self.nf = 3
self.nloc = 3
self.ns = 5 * 5 * 3
self.nall = self.ns * self.nloc
self.cell = tnp.array(
[[[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]]], dtype=tnp.float64
)
self.icoord = tnp.array(
[[[0.035, 0.062, 0.064], [0.085, 0.058, 0.021], [0.537, 0.553, 0.124]]],
dtype=tnp.float64,
)
self.atype = tnp.array([[1, 0, 1]], dtype=tnp.int32)
self.nsel = [10, 10]
self.rcut = 1.01

self.ecoord, self.eatype, mapping = extend_coord_with_ghosts(

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable mapping is not used.
self.icoord, self.atype, self.cell, self.rcut
)
self.nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut,
sum(self.nsel),
distinguish_types=False,
)

def test_format_nlist_equal(self):
nlist = format_nlist(self.ecoord, self.nlist, sum(self.nsel), self.rcut)
self.assertAllEqual(nlist, self.nlist)

def test_format_nlist_less(self):
nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut,
sum(self.nsel) - 5,
distinguish_types=False,
)
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
self.assertAllEqual(nlist, self.nlist)

def test_format_nlist_large(self):
nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut,
sum(self.nsel) + 5,
distinguish_types=False,
)
# random shuffle
shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2]))
nlist = tnp.take(nlist, shuffle_idx, axis=2)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
# we only need to ensure the result is correct, no need to check the order
self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1))

def test_format_nlist_larger_rcut(self):
nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut * 2,
40,
distinguish_types=False,
)
# random shuffle
shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2]))
nlist = tnp.take(nlist, shuffle_idx, axis=2)
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
# we only need to ensure the result is correct, no need to check the order
self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1))