From 2d95075ee55c97c963729ed036b43a250ee871df Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Fri, 23 Feb 2024 09:24:09 -0500 Subject: [PATCH] Promote `isclose` arguments to inexact dtype unless extended (fixes #19935). --- jax/_src/numpy/lax_numpy.py | 64 +++++++++++++++++++------------------ tests/lax_numpy_test.py | 4 +++ 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 1bbce5b7273c..8069f896b8d5 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -940,39 +940,41 @@ def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike equal_nan: bool = False) -> Array: a, b = util.promote_args("isclose", a, b) dtype = _dtype(a) - if issubdtype(dtype, inexact): - if issubdtype(dtype, complexfloating): - dtype = util._complex_elem_type(dtype) - rtol = lax.convert_element_type(rtol, dtype) - atol = lax.convert_element_type(atol, dtype) - out = lax.le( - lax.abs(lax.sub(a, b)), - lax.add(atol, lax.mul(rtol, lax.abs(b)))) - # This corrects the comparisons for infinite and nan values - a_inf = ufuncs.isinf(a) - b_inf = ufuncs.isinf(b) - any_inf = ufuncs.logical_or(a_inf, b_inf) - both_inf = ufuncs.logical_and(a_inf, b_inf) - # Make all elements where either a or b are infinite to False - out = ufuncs.logical_and(out, ufuncs.logical_not(any_inf)) - # Make all elements where both a or b are the same inf to True - same_value = lax.eq(a, b) - same_inf = ufuncs.logical_and(both_inf, same_value) - out = ufuncs.logical_or(out, same_inf) - - # Make all elements where either a or b is NaN to False - a_nan = ufuncs.isnan(a) - b_nan = ufuncs.isnan(b) - any_nan = ufuncs.logical_or(a_nan, b_nan) - out = ufuncs.logical_and(out, ufuncs.logical_not(any_nan)) - if equal_nan: - # Make all elements where both a and b is NaN to True - both_nan = ufuncs.logical_and(a_nan, b_nan) - out = ufuncs.logical_or(out, both_nan) - return out - else: + if dtypes.issubdtype(dtype, dtypes.extended): return lax.eq(a, b) + a, b = util.promote_args_inexact("isclose", a, b) + dtype = _dtype(a) + if issubdtype(dtype, complexfloating): + dtype = util._complex_elem_type(dtype) + rtol = lax.convert_element_type(rtol, dtype) + atol = lax.convert_element_type(atol, dtype) + out = lax.le( + lax.abs(lax.sub(a, b)), + lax.add(atol, lax.mul(rtol, lax.abs(b)))) + # This corrects the comparisons for infinite and nan values + a_inf = ufuncs.isinf(a) + b_inf = ufuncs.isinf(b) + any_inf = ufuncs.logical_or(a_inf, b_inf) + both_inf = ufuncs.logical_and(a_inf, b_inf) + # Make all elements where either a or b are infinite to False + out = ufuncs.logical_and(out, ufuncs.logical_not(any_inf)) + # Make all elements where both a or b are the same inf to True + same_value = lax.eq(a, b) + same_inf = ufuncs.logical_and(both_inf, same_value) + out = ufuncs.logical_or(out, same_inf) + + # Make all elements where either a or b is NaN to False + a_nan = ufuncs.isnan(a) + b_nan = ufuncs.isnan(b) + any_nan = ufuncs.logical_or(a_nan, b_nan) + out = ufuncs.logical_and(out, ufuncs.logical_not(any_nan)) + if equal_nan: + # Make all elements where both a and b is NaN to True + both_nan = ufuncs.logical_and(a_nan, b_nan) + out = ufuncs.logical_or(out, both_nan) + return out + def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, left: ArrayLike | str | None = None, diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 586467db8502..1fd10ddc9749 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3532,6 +3532,10 @@ def testIsClose(self): self.assertTrue(jnp.all(jnp.equal(result_np, result_jax))) self.assertTrue(jnp.all(jnp.equal(result_np, result_jit))) + self.assertEqual(np.isclose(6, 10, rtol=0.5), jnp.isclose(6, 10, rtol=0.5)) + key = jax.random.key(0) + self.assertTrue(jnp.isclose(key, key)) + @jtu.sample_product( x=[1, [1], [1, 1 + 1E-4], [1, np.nan]], y=[1, [1], [1, 1 + 1E-4], [1, np.nan]],