Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize lazy backend indexing a little more #10078

Merged
merged 6 commits into from
Feb 28, 2025
Merged

Conversation

dcherian
Copy link
Contributor

@dcherian dcherian commented Feb 26, 2025

This helps us support reading from Zarr to GPU memory: https://zarr.readthedocs.io/en/stable/user-guide/gpu.html

Confirmed working by @TomAugspurger @weiji14 @negin513 @kafitzgerald !

@dcherian dcherian changed the title [WIP] cupy zarr fixes Generalize lazy indexing a little more. Feb 26, 2025
@dcherian dcherian marked this pull request as ready for review February 26, 2025 23:02
@dcherian dcherian changed the title Generalize lazy indexing a little more. Generalize lazy backend indexing a little more Feb 26, 2025
@TomAugspurger
Copy link
Contributor

Thanks @dcherian! Here's an example running this:

$ uv run --with pooch --with scipy --with ipython --with zarr --with cupy-cuda12x --with "xarray @ git+https://github.com/dcherian/xarray@fix-cupy" ipython
Installed 41 packages in 43ms
Python 3.12.8 (main, Jan 14 2025, 22:49:14) [Clang 19.1.6 ]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.32.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import xarray as xr
   ...: import zarr
   ...: 
   ...: xr.tutorial.open_dataset("air_temperature").to_zarr(
   ...:     "test.zarr", zarr_format=3, mode="w"
   ...: )
   ...: with zarr.config.enable_gpu():
   ...:     ds = xr.open_dataset("test.zarr", engine="zarr", decode_times=False)
   ...:     print(type(ds.air.data))
   ...:     ds.air.data.mean()
   ...: 
<ipython-input-1-870251b7bb27>:4: SerializationWarning: saving variable None with floating point data as an integer dtype without any _FillValue to use for NaNs
  xr.tutorial.open_dataset("air_temperature").to_zarr(
/raid/toaugspurger/uv/archive-v0/yduzEAGBqcU7fURbvFkfK/lib/python3.12/site-packages/zarr/api/asynchronous.py:203: UserWarning: Consolidated metadata is currently not part in the Zarr format 3 specification. It may not be supported by other zarr implementations and may change in the future.
  warnings.warn(
<class 'cupy.ndarray'>

Note that with the default decode_times=True, we do still get a TypeError:

In [4]: with zarr.config.enable_gpu():
   ...:     ds = xr.open_dataset("test.zarr", engine="zarr")
   ...:     print(type(ds.air.data))
   ...:     ds.air.data.mean()
   ...: 
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File /raid/toaugspurger/uv/archive-v0/yduzEAGBqcU7fURbvFkfK/lib/python3.12/site-packages/xarray/conventions.py:401, in decode_cf_variables(variables, attributes, concat_characters, mask_and_scale, decode_times, decode_coords, drop_variables, use_cftime, decode_timedelta)
    400 try:
--> 401     new_vars[k] = decode_cf_variable(
    402         k,
    403         v,
    404         concat_characters=_item_or_default(concat_characters, k, True),
    405         mask_and_scale=_item_or_default(mask_and_scale, k, True),
    406         decode_times=_item_or_default(decode_times, k, True),
    407         stack_char_dim=stack_char_dim,
    408         use_cftime=_item_or_default(use_cftime, k, None),
    409         decode_timedelta=_item_or_default(decode_timedelta, k, None),
    410     )
    411 except Exception as e:

File /raid/toaugspurger/uv/archive-v0/yduzEAGBqcU7fURbvFkfK/lib/python3.12/site-packages/xarray/conventions.py:234, in decode_cf_variable(name, var, concat_characters, mask_and_scale, decode_times, decode_endianness, stack_char_dim, use_cftime, decode_timedelta)
    225             raise TypeError(
    226                 "Usage of 'use_cftime' as a kwarg is not allowed "
    227                 "if a 'CFDatetimeCoder' instance is passed to "
   (...)
    232                 "    ds = xr.open_dataset(decode_times=time_coder)\n",
    233             )
--> 234     var = decode_times.decode(var, name=name)
    236 if decode_endianness and not var.dtype.isnative:

File /raid/toaugspurger/uv/archive-v0/yduzEAGBqcU7fURbvFkfK/lib/python3.12/site-packages/xarray/coding/times.py:1335, in CFDatetimeCoder.decode(self, variable, name)
   1334 calendar = pop_to(attrs, encoding, "calendar")
-> 1335 dtype = _decode_cf_datetime_dtype(
   1336     data, units, calendar, self.use_cftime, self.time_unit
   1337 )
   1338 transform = partial(
   1339     decode_cf_datetime,
   1340     units=units,
   (...)
   1343     time_unit=self.time_unit,
   1344 )

File /raid/toaugspurger/uv/archive-v0/yduzEAGBqcU7fURbvFkfK/lib/python3.12/site-packages/xarray/coding/times.py:313, in _decode_cf_datetime_dtype(data, units, calendar, use_cftime, time_unit)
    311 values = indexing.ImplicitToExplicitIndexingAdapter(indexing.as_indexable(data))
    312 example_value = np.concatenate(
--> 313     [first_n_items(values, 1) or [0], last_item(values) or [0]]
    314 )
    316 try:

File /raid/toaugspurger/uv/archive-v0/yduzEAGBqcU7fURbvFkfK/lib/python3.12/site-packages/xarray/core/formatting.py:97, in first_n_items(array, n_desired)
     96     array = array._data
---> 97 return np.ravel(to_duck_array(array))[:n_desired]

File /raid/toaugspurger/uv/archive-v0/yduzEAGBqcU7fURbvFkfK/lib/python3.12/site-packages/xarray/namedarray/pycompat.py:138, in to_duck_array(data, **kwargs)
    137 else:
--> 138     return np.asarray(data)

File /raid/toaugspurger/uv/archive-v0/yduzEAGBqcU7fURbvFkfK/lib/python3.12/site-packages/xarray/core/indexing.py:574, in ImplicitToExplicitIndexingAdapter.__array__(self, dtype, copy)
    573 if Version(np.__version__) >= Version("2.0.0"):
--> 574     return np.asarray(self.get_duck_array(), dtype=dtype, copy=copy)
    575 else:

File cupy/_core/core.pyx:1481, in cupy._core.core._ndarray_base.__array__()

TypeError: Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Cell In[4], line 2
      1 with zarr.config.enable_gpu():
----> 2     ds = xr.open_dataset("test.zarr", engine="zarr")
      3     print(type(ds.air.data))
      4     ds.air.data.mean()

File /raid/toaugspurger/uv/archive-v0/yduzEAGBqcU7fURbvFkfK/lib/python3.12/site-packages/xarray/backends/api.py:685, in open_dataset(filename_or_obj, engine, chunks, cache, decode_cf, mask_and_scale, decode_times, decode_timedelta, use_cftime, concat_characters, decode_coords, drop_variables, inline_array, chunked_array_type, from_array_kwargs, backend_kwargs, **kwargs)
    673 decoders = _resolve_decoders_kwargs(
    674     decode_cf,
    675     open_backend_dataset_parameters=backend.open_dataset_parameters,
   (...)
    681     decode_coords=decode_coords,
    682 )
    684 overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None)
--> 685 backend_ds = backend.open_dataset(
    686     filename_or_obj,
    687     drop_variables=drop_variables,
    688     **decoders,
    689     **kwargs,
    690 )
    691 ds = _dataset_from_backend_dataset(
    692     backend_ds,
    693     filename_or_obj,
   (...)
    703     **kwargs,
    704 )
    705 return ds

File /raid/toaugspurger/uv/archive-v0/yduzEAGBqcU7fURbvFkfK/lib/python3.12/site-packages/xarray/backends/zarr.py:1580, in ZarrBackendEntrypoint.open_dataset(self, filename_or_obj, mask_and_scale, decode_times, concat_characters, decode_coords, drop_variables, use_cftime, decode_timedelta, group, mode, synchronizer, consolidated, chunk_store, storage_options, zarr_version, zarr_format, store, engine, use_zarr_fill_value_as_mask, cache_members)
   1578 store_entrypoint = StoreBackendEntrypoint()
   1579 with close_on_error(store):
-> 1580     ds = store_entrypoint.open_dataset(
   1581         store,
   1582         mask_and_scale=mask_and_scale,
   1583         decode_times=decode_times,
   1584         concat_characters=concat_characters,
   1585         decode_coords=decode_coords,
   1586         drop_variables=drop_variables,
   1587         use_cftime=use_cftime,
   1588         decode_timedelta=decode_timedelta,
   1589     )
   1590 return ds

File /raid/toaugspurger/uv/archive-v0/yduzEAGBqcU7fURbvFkfK/lib/python3.12/site-packages/xarray/backends/store.py:47, in StoreBackendEntrypoint.open_dataset(self, filename_or_obj, mask_and_scale, decode_times, concat_characters, decode_coords, drop_variables, use_cftime, decode_timedelta)
     44 vars, attrs = filename_or_obj.load()
     45 encoding = filename_or_obj.get_encoding()
---> 47 vars, attrs, coord_names = conventions.decode_cf_variables(
     48     vars,
     49     attrs,
     50     mask_and_scale=mask_and_scale,
     51     decode_times=decode_times,
     52     concat_characters=concat_characters,
     53     decode_coords=decode_coords,
     54     drop_variables=drop_variables,
     55     use_cftime=use_cftime,
     56     decode_timedelta=decode_timedelta,
     57 )
     59 ds = Dataset(vars, attrs=attrs)
     60 ds = ds.set_coords(coord_names.intersection(vars))

File /raid/toaugspurger/uv/archive-v0/yduzEAGBqcU7fURbvFkfK/lib/python3.12/site-packages/xarray/conventions.py:412, in decode_cf_variables(variables, attributes, concat_characters, mask_and_scale, decode_times, decode_coords, drop_variables, use_cftime, decode_timedelta)
    401     new_vars[k] = decode_cf_variable(
    402         k,
    403         v,
   (...)
    409         decode_timedelta=_item_or_default(decode_timedelta, k, None),
    410     )
    411 except Exception as e:
--> 412     raise type(e)(f"Failed to decode variable {k!r}: {e}") from e
    413 if decode_coords in [True, "coordinates", "all"]:
    414     var_attrs = new_vars[k].attrs

TypeError: Failed to decode variable 'time': Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.

I believe that something like #10079 is an option for fixing that, by instructing zarr to always read coordinates into host memory.

@dcherian dcherian merged commit 0184702 into pydata:main Feb 28, 2025
37 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
plan to merge Final call for comments
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants