Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 9, 2024
1 parent 215efc6 commit e4bac35
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 33 deletions.
24 changes: 7 additions & 17 deletions source/jax2tf_tests/test_nlist.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,20 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import unittest

import tensorflow as tf
import tensorflow.experimental.numpy as tnp

from ...utils import (
DP_TEST_TF2_ONLY,
from deepmd.jax.jax2tf.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
)
from deepmd.jax.jax2tf.region import (
inter2phys,
)

if DP_TEST_TF2_ONLY:
from deepmd.jax.jax2tf.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
)
from deepmd.jax.jax2tf.region import (
inter2phys,
)

dtype = tnp.float64
dtype = tnp.float64


@unittest.skipIf(
not DP_TEST_TF2_ONLY,
reason="TF2 conflicts with TF1",
)
class TestNeighList(tf.test.TestCase):
def setUp(self):
self.nf = 3
Expand Down
20 changes: 4 additions & 16 deletions source/jax2tf_tests/test_region.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,17 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


import unittest

import tensorflow as tf
import tensorflow.experimental.numpy as tnp

from ...seed import (
GLOBAL_SEED,
)
from ...utils import (
DP_TEST_TF2_ONLY,
from deepmd.jax.jax2tf.region import (
inter2phys,
to_face_distance,
)

if DP_TEST_TF2_ONLY:
from deepmd.jax.jax2tf.region import (
inter2phys,
to_face_distance,
)
GLOBAL_SEED = 20241109


@unittest.skipIf(
not DP_TEST_TF2_ONLY,
reason="TF2 conflicts with TF1",
)
class TestRegion(tf.test.TestCase):
def setUp(self):
self.cell = tnp.array(
Expand Down

0 comments on commit e4bac35

Please sign in to comment.