diff --git a/src/zarr/core/buffer/core.py b/src/zarr/core/buffer/core.py index 0c6d966db9..ba629befa1 100644 --- a/src/zarr/core/buffer/core.py +++ b/src/zarr/core/buffer/core.py @@ -462,8 +462,12 @@ def __len__(self) -> int: def __repr__(self) -> str: return f"" - def all_equal(self, other: Any) -> bool: - return bool((self._data == other).all()) + def all_equal(self, other: Any, equal_nan: bool = True) -> bool: + """Compare to `other` using np.array_equal.""" + # use array_equal to obtain equal_nan=True functionality + data, other = np.broadcast_arrays(self._data, other) + result = np.array_equal(self._data, other, equal_nan=equal_nan) + return result def fill(self, value: Any) -> None: self._data.fill(value) diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index 11be51682c..b7beb63b1c 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -138,6 +138,25 @@ def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str assert arr.fill_value.dtype == arr.dtype +@pytest.mark.parametrize("store", ["memory"], indirect=True) +async def test_array_v3_nan_fill_value(store: MemoryStore) -> None: + shape = (10,) + arr = Array.create( + store=store, + shape=shape, + dtype=np.float64, + zarr_format=3, + chunk_shape=shape, + fill_value=np.nan, + ) + arr[:] = np.nan + + assert np.isnan(arr.fill_value) + assert arr.fill_value.dtype == arr.dtype + # all fill value chunk is an empty chunk, and should not be written + assert len([a async for a in store.list_prefix("/")]) == 0 + + @pytest.mark.parametrize("store", ("local",), indirect=["store"]) @pytest.mark.parametrize("zarr_format", (2, 3)) async def test_serializable_async_array( diff --git a/tests/v3/test_properties.py b/tests/v3/test_properties.py index 7a085c03b7..a78e9207bd 100644 --- a/tests/v3/test_properties.py +++ b/tests/v3/test_properties.py @@ -18,11 +18,6 @@ def test_roundtrip(data: st.DataObject) -> None: @given(data=st.data()) -# The filter warning here is to silence an occasional warning in NDBuffer.all_equal -# See https://github.com/zarr-developers/zarr-python/pull/2118#issuecomment-2310280899 -# Uncomment the next line to reproduce the original failure. -# @reproduce_failure('6.111.2', b'AXicY2FgZGRAB/8/ndR2z7nkDZEDADWpBL4=') -@pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_basic_indexing(data: st.DataObject) -> None: zarray = data.draw(arrays()) nparray = zarray[:] @@ -37,11 +32,6 @@ def test_basic_indexing(data: st.DataObject) -> None: @given(data=st.data()) -# The filter warning here is to silence an occasional warning in NDBuffer.all_equal -# See https://github.com/zarr-developers/zarr-python/pull/2118#issuecomment-2310280899 -# Uncomment the next line to reproduce the original failure. -# @reproduce_failure('6.111.2', b'AXicY2FgZGRAB/8/eLmF7qr/C5EDADZUBRM=') -@pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_vindex(data: st.DataObject) -> None: zarray = data.draw(arrays()) nparray = zarray[:]