diff --git a/tests/core/core_lift_test.py b/tests/core/core_lift_test.py index 276d78e988..5ff7e3e696 100644 --- a/tests/core/core_lift_test.py +++ b/tests/core/core_lift_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import operator - import jax import numpy as np from absl.testing import absltest @@ -170,12 +168,11 @@ def body_fn(scope, c): ) self.assertEqual(vars['state']['acc'], x) self.assertEqual(c, 2 * x) - np.testing.assert_array_equal( + self.assertEqual( vars['state']['rng_params'][0], vars['state']['rng_params'][1] ) with jax.debug_key_reuse(False): - np.testing.assert_array_compare( - operator.__ne__, + self.assertNotEqual( vars['state']['rng_loop'][0], vars['state']['rng_loop'][1], )