diff --git a/source/tests/jax/jax2tf/test_nlist.py b/source/tests/jax/jax2tf/test_nlist.py index 049d948f84..28c5a0dfc8 100644 --- a/source/tests/jax/jax2tf/test_nlist.py +++ b/source/tests/jax/jax2tf/test_nlist.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import pytest +import unittest + import tensorflow as tf import tensorflow.experimental.numpy as tnp @@ -20,7 +21,7 @@ dtype = tnp.float64 -@pytest.mark.skipif( +@unittest.skipIf( not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", ) diff --git a/source/tests/jax/jax2tf/test_region.py b/source/tests/jax/jax2tf/test_region.py index f38b661401..a6baffcb33 100644 --- a/source/tests/jax/jax2tf/test_region.py +++ b/source/tests/jax/jax2tf/test_region.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import pytest +import unittest + import tensorflow as tf import tensorflow.experimental.numpy as tnp @@ -19,7 +20,7 @@ ) -@pytest.mark.skipif( +@unittest.skipIf( not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1", )