Skip to content

Commit

Permalink
Add top-level open_datatree function (TODO: deduplicate and clean up …
Browse files Browse the repository at this point in the history
…code)
  • Loading branch information
jthielen committed Jan 13, 2023
1 parent 39ba056 commit 77129f1
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 6 deletions.
2 changes: 2 additions & 0 deletions xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
load_dataset,
open_dataarray,
open_dataset,
open_datatree,
open_mfdataset,
save_mfdataset,
)
Expand Down Expand Up @@ -84,6 +85,7 @@
"ones_like",
"open_dataarray",
"open_dataset",
"open_datatree",
"open_mfdataset",
"open_rasterio",
"open_zarr",
Expand Down
221 changes: 221 additions & 0 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,45 @@ def _chunk_ds(
return backend_ds._replace(variables)


def _datatree_from_backend_datatree(
backend_dt,
filename_or_obj,
engine,
chunks,
cache,
overwrite_encoded_chunks,
inline_array,
**extra_tokens,
):
# TODO: deduplicate with _dataset_from_backend_dataset
if not isinstance(chunks, (int, dict)) and chunks not in {None, "auto"}:
raise ValueError(
f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}."
)

backend_dt.map_over_subtree_inplace(_protect_dataset_variables_inplace, cache=cache)
if chunks is None:
dt = backend_dt
else:
dt = backend_dt.map_over_subtree(
_chunk_ds,
filename_or_obj=filename_or_obj,
engine=engine,
chunks=chunks,
overwrite_encoded_chunks=overwrite_encoded_chunks,
inline_array=inline_array,
**extra_tokens,
)

dt.map_over_subtree_inplace((lambda ds: ds.set_close), backend_dt._close)

# Ensure source filename always stored in dataset object
if "source" not in dt.encoding and isinstance(filename_or_obj, (str, os.PathLike)):
dt.encoding["source"] = _normalize_path(filename_or_obj)

return dt


def _dataset_from_backend_dataset(
backend_ds,
filename_or_obj,
Expand Down Expand Up @@ -374,6 +413,188 @@ def _dataset_from_backend_dataset(
return ds


def open_datatree(
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
engine: T_Engine = None,
chunks: T_Chunks = None,
cache: bool | None = None,
decode_cf: bool | None = None,
mask_and_scale: bool | None = None,
decode_times: bool | None = None,
decode_timedelta: bool | None = None,
use_cftime: bool | None = None,
concat_characters: bool | None = None,
decode_coords: Literal["coordinates", "all"] | bool | None = None,
drop_variables: str | Iterable[str] | None = None,
inline_array: bool = False,
backend_kwargs: dict[str, Any] | None = None,
**kwargs,
) -> Dataset:
"""Open and decode a dataset from a file or file-like object.
Parameters
----------
filename_or_obj : str, Path, file-like or DataStore
Strings and Path objects are interpreted as a path to a netCDF file
or an OpenDAP URL and opened with python-netCDF4, unless the filename
ends with .gz, in which case the file is gunzipped and opened with
scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like
objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF).
engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", "cfgrib", \
"pseudonetcdf", "zarr", None}, installed backend \
or subclass of xarray.backends.BackendEntrypoint, optional
Engine to use when reading files. If not provided, the default engine
is chosen based on available dependencies, with a preference for
"netcdf4". A custom backend class (a subclass of ``BackendEntrypoint``)
can also be used.
chunks : int, dict, 'auto' or None, optional
If chunks is provided, it is used to load the new dataset into dask
arrays. ``chunks=-1`` loads the dataset with dask using a single
chunk for all arrays. ``chunks={}`` loads the dataset with dask using
engine preferred chunks if exposed by the backend, otherwise with
a single chunk for all arrays.
``chunks='auto'`` will use dask ``auto`` chunking taking into account the
engine preferred chunks. See dask chunking for more details.
cache : bool, optional
If True, cache data loaded from the underlying datastore in memory as
NumPy arrays when accessed to avoid reading from the underlying data-
store multiple times. Defaults to True unless you specify the `chunks`
argument to use dask, in which case it defaults to False. Does not
change the behavior of coordinates corresponding to dimensions, which
always load their data from disk into a ``pandas.Index``.
decode_cf : bool, optional
Whether to decode these variables, assuming they were saved according
to CF conventions.
mask_and_scale : bool, optional
If True, replace array values equal to `_FillValue` with NA and scale
values according to the formula `original_values * scale_factor +
add_offset`, where `_FillValue`, `scale_factor` and `add_offset` are
taken from variable attributes (if they exist). If the `_FillValue` or
`missing_value` attribute contains multiple values a warning will be
issued and all array values matching one of the multiple values will
be replaced by NA. mask_and_scale defaults to True except for the
pseudonetcdf backend. This keyword may not be supported by all the backends.
decode_times : bool, optional
If True, decode times encoded in the standard NetCDF datetime format
into datetime objects. Otherwise, leave them encoded as numbers.
This keyword may not be supported by all the backends.
decode_timedelta : bool, optional
If True, decode variables and coordinates with time units in
{"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"}
into timedelta objects. If False, leave them encoded as numbers.
If None (default), assume the same value of decode_time.
This keyword may not be supported by all the backends.
use_cftime: bool, optional
Only relevant if encoded dates come from a standard calendar
(e.g. "gregorian", "proleptic_gregorian", "standard", or not
specified). If None (default), attempt to decode times to
``np.datetime64[ns]`` objects; if this is not possible, decode times to
``cftime.datetime`` objects. If True, always decode times to
``cftime.datetime`` objects, regardless of whether or not they can be
represented using ``np.datetime64[ns]`` objects. If False, always
decode times to ``np.datetime64[ns]`` objects; if this is not possible
raise an error. This keyword may not be supported by all the backends.
concat_characters : bool, optional
If True, concatenate along the last dimension of character arrays to
form string arrays. Dimensions will only be concatenated over (and
removed) if they have no corresponding variable and if they are only
used as the last dimension of character arrays.
This keyword may not be supported by all the backends.
decode_coords : bool or {"coordinates", "all"}, optional
Controls which variables are set as coordinate variables:
- "coordinates" or True: Set variables referred to in the
``'coordinates'`` attribute of the datasets or individual variables
as coordinate variables.
- "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and
other attributes as coordinate variables.
drop_variables: str or iterable of str, optional
A variable or list of variables to exclude from being parsed from the
dataset. This may be useful to drop variables with problems or
inconsistent values.
inline_array: bool, default: False
How to include the array in the dask task graph.
By default(``inline_array=False``) the array is included in a task by
itself, and each chunk refers to that task by its key. With
``inline_array=True``, Dask will instead inline the array directly
in the values of the task graph. See :py:func:`dask.array.from_array`.
backend_kwargs: dict
Additional keyword arguments passed on to the engine open function,
equivalent to `**kwargs`.
**kwargs: dict
Additional keyword arguments passed on to the engine open function.
For example:
- 'group': path to the netCDF4 group in the given file to open given as
a str,supported by "netcdf4", "h5netcdf", "zarr".
- 'lock': resource lock to use when reading data from disk. Only
relevant when using dask or another form of parallelism. By default,
appropriate locks are chosen to safely read and write files with the
currently active dask scheduler. Supported by "netcdf4", "h5netcdf",
"scipy", "pynio", "pseudonetcdf", "cfgrib".
See engine open function for kwargs accepted by each specific engine.
Returns
-------
datatree : datatree.DataTree
The newly created datatree.
Notes
-----
``open_datatree`` opens the file with read-only access. When you modify
values of a Dataset, even one linked to files on disk, only the in-memory
copy you are manipulating in xarray is modified: the original file on disk
is never touched.
"""
# TODO deduplicate with open_dataset

if cache is None:
cache = chunks is None

if backend_kwargs is not None:
kwargs.update(backend_kwargs)

if engine is None:
engine = plugins.guess_engine(filename_or_obj)

backend = plugins.get_backend(engine)

decoders = _resolve_decoders_kwargs(
decode_cf,
open_backend_dataset_parameters=backend.open_dataset_parameters,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
decode_timedelta=decode_timedelta,
concat_characters=concat_characters,
use_cftime=use_cftime,
decode_coords=decode_coords,
)

overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None)
backend_dt = backend.open_datatree(
filename_or_obj,
drop_variables=drop_variables,
**decoders,
**kwargs,
)
dt = _datatree_from_backend_datatree(
backend_dt,
filename_or_obj,
engine,
chunks,
cache,
overwrite_encoded_chunks,
inline_array,
drop_variables=drop_variables,
**decoders,
**kwargs,
)
return dt


def open_dataset(
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
Expand Down
7 changes: 4 additions & 3 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,9 @@ def get_group_stores(self):
def select_group(self, group):
"""Return new NetCDF4DataStore for specified group of this NetCDF4DataStore."""
if group in self.ds.groups:
return self.__init__(
manager=self._manager, group=group, mode=self._mode, lock=self.lock, autoclose=self.autoclose
parent_group = self._group if self._group is not None else ''
return self.__class__(
manager=self._manager, group=f"{parent_group}{group}/", mode=self._mode, lock=self.lock, autoclose=self.autoclose
)
else:
raise KeyError(group)
Expand Down Expand Up @@ -654,7 +655,7 @@ def open_dataset(
autoclose=autoclose,
)

def open_dataset(
def open_datatree(
self,
filename_or_obj,
mask_and_scale=True,
Expand Down
7 changes: 4 additions & 3 deletions xarray/backends/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,17 @@ def _add_node(store, path, datasets):
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
ds.set_close(store.close) # TODO should this be on datatree? if so, need to add to datatree API
datasets[path] = ds

# Recursively add children to collector
for child_name, child_store in store.get_group_stores():
for child_name, child_store in store.get_group_stores().items():
datasets = _add_node(child_store, f"{path}{child_name}/", datasets)

return datasets

dt = DataTree.from_dict(_add_node(store, "/", {}))
dt.set_close(store.close)
datasets = _add_node(store, "/", {})
dt = DataTree.from_dict(datasets)

return dt

Expand Down

0 comments on commit 77129f1

Please sign in to comment.