Skip to content

Commit

Permalink
allow refreshing of backends
Browse files Browse the repository at this point in the history
  • Loading branch information
headtr1ck committed Feb 12, 2023
1 parent 5f766b1 commit 576692b
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 26 deletions.
5 changes: 3 additions & 2 deletions xarray/backends/cfgrib_.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
BACKEND_ENTRYPOINTS,
AbstractDataStore,
BackendArray,
BackendEntrypoint,
_InternalBackendEntrypoint,
_normalize_path,
)
from xarray.backends.locks import SerializableLock, ensure_lock
Expand Down Expand Up @@ -90,7 +90,8 @@ def get_encoding(self):
return {"unlimited_dims": {k for k, v in dims.items() if v is None}}


class CfgribfBackendEntrypoint(BackendEntrypoint):
class CfgribfBackendEntrypoint(_InternalBackendEntrypoint):
_module_name = "cfgrib"
available = module_available("cfgrib")

def guess_can_open(self, filename_or_obj):
Expand Down
29 changes: 27 additions & 2 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from xarray.conventions import cf_encoder
from xarray.core import indexing
from xarray.core.pycompat import is_duck_dask_array
from xarray.core.utils import FrozenDict, NdimSizeLenMixin, is_remote_uri
from xarray.core.utils import (
FrozenDict,
NdimSizeLenMixin,
is_remote_uri,
module_available,
)

if TYPE_CHECKING:
from io import BufferedIOBase
Expand Down Expand Up @@ -428,4 +433,24 @@ def guess_can_open(
return False


BACKEND_ENTRYPOINTS: dict[str, type[BackendEntrypoint]] = {}
class _InternalBackendEntrypoint:
"""
Wrapper class for BackendEntrypoints that ship with xarray.
Additional attributes
----------
_module_name : str
Name of the module that is required to enable the backend.
"""

_module_name: ClassVar[str]

@classmethod
def _set_availability(cls) -> None:
"""Resets the backends availability."""
cls.available = module_available(cls._module_name)


BACKEND_ENTRYPOINTS: dict[str, type[_InternalBackendEntrypoint]] = {}
5 changes: 3 additions & 2 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from xarray.backends.common import (
BACKEND_ENTRYPOINTS,
BackendEntrypoint,
WritableCFDataStore,
_InternalBackendEntrypoint,
_normalize_path,
find_root_and_group,
)
Expand Down Expand Up @@ -343,7 +343,7 @@ def close(self, **kwargs):
self._manager.close(**kwargs)


class H5netcdfBackendEntrypoint(BackendEntrypoint):
class H5netcdfBackendEntrypoint(_InternalBackendEntrypoint):
"""
Backend for netCDF files based on the h5netcdf package.
Expand All @@ -365,6 +365,7 @@ class H5netcdfBackendEntrypoint(BackendEntrypoint):
backends.ScipyBackendEntrypoint
"""

_module_name = "h5netcdf"
available = module_available("h5netcdf")
description = (
"Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using h5netcdf in Xarray"
Expand Down
5 changes: 3 additions & 2 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from xarray.backends.common import (
BACKEND_ENTRYPOINTS,
BackendArray,
BackendEntrypoint,
WritableCFDataStore,
_InternalBackendEntrypoint,
_normalize_path,
find_root_and_group,
robust_getitem,
Expand Down Expand Up @@ -513,7 +513,7 @@ def close(self, **kwargs):
self._manager.close(**kwargs)


class NetCDF4BackendEntrypoint(BackendEntrypoint):
class NetCDF4BackendEntrypoint(_InternalBackendEntrypoint):
"""
Backend for netCDF files based on the netCDF4 package.
Expand All @@ -535,6 +535,7 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint):
backends.ScipyBackendEntrypoint
"""

_module_name = "netCDF4"
available = module_available("netCDF4")
description = (
"Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using netCDF4 in Xarray"
Expand Down
28 changes: 20 additions & 8 deletions xarray/backends/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@
import sys
import warnings
from importlib.metadata import entry_points
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Callable

from xarray.backends.common import BACKEND_ENTRYPOINTS, BackendEntrypoint

if TYPE_CHECKING:
import os
from importlib.metadata import EntryPoint, EntryPoints
from io import BufferedIOBase

from xarray.backends.common import AbstractDataStore

STANDARD_BACKENDS_ORDER = ["netcdf4", "h5netcdf", "scipy"]


def remove_duplicates(entrypoints):
def remove_duplicates(entrypoints: EntryPoints) -> list[EntryPoint]:
# sort and group entrypoints by name
entrypoints = sorted(entrypoints, key=lambda ep: ep.name)
entrypoints_grouped = itertools.groupby(entrypoints, key=lambda ep: ep.name)
Expand All @@ -42,7 +43,7 @@ def remove_duplicates(entrypoints):
return unique_entrypoints


def detect_parameters(open_dataset):
def detect_parameters(open_dataset: Callable) -> tuple[str, ...]:
signature = inspect.signature(open_dataset)
parameters = signature.parameters
parameters_list = []
Expand All @@ -60,7 +61,9 @@ def detect_parameters(open_dataset):
return tuple(parameters_list)


def backends_dict_from_pkg(entrypoints):
def backends_dict_from_pkg(
entrypoints: list[EntryPoint],
) -> dict[str, BackendEntrypoint]:
backend_entrypoints = {}
for entrypoint in entrypoints:
name = entrypoint.name
Expand All @@ -72,14 +75,16 @@ def backends_dict_from_pkg(entrypoints):
return backend_entrypoints


def set_missing_parameters(backend_entrypoints):
for name, backend in backend_entrypoints.items():
def set_missing_parameters(backend_entrypoints: dict[str, BackendEntrypoint]):
for _, backend in backend_entrypoints.items():
if backend.open_dataset_parameters is None:
open_dataset = backend.open_dataset
backend.open_dataset_parameters = detect_parameters(open_dataset)


def sort_backends(backend_entrypoints):
def sort_backends(
backend_entrypoints: dict[str, BackendEntrypoint]
) -> dict[str, BackendEntrypoint]:
ordered_backends_entrypoints = {}
for be_name in STANDARD_BACKENDS_ORDER:
if backend_entrypoints.get(be_name, None) is not None:
Expand All @@ -90,7 +95,7 @@ def sort_backends(backend_entrypoints):
return ordered_backends_entrypoints


def build_engines(entrypoints) -> dict[str, BackendEntrypoint]:
def build_engines(entrypoints: EntryPoints) -> dict[str, BackendEntrypoint]:
backend_entrypoints = {}
for backend_name, backend in BACKEND_ENTRYPOINTS.items():
if backend.available:
Expand Down Expand Up @@ -126,6 +131,13 @@ def list_engines() -> dict[str, BackendEntrypoint]:
return build_engines(entrypoints)


def refresh_engines() -> None:
"""Refreshes the backend engines based on installed packages."""
list_engines.cache_clear()
for backend_entrypoint in BACKEND_ENTRYPOINTS.values():
backend_entrypoint._set_availability()


def guess_engine(
store_spec: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
):
Expand Down
5 changes: 3 additions & 2 deletions xarray/backends/pseudonetcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
BACKEND_ENTRYPOINTS,
AbstractDataStore,
BackendArray,
BackendEntrypoint,
_InternalBackendEntrypoint,
_normalize_path,
)
from xarray.backends.file_manager import CachingFileManager
Expand Down Expand Up @@ -96,7 +96,7 @@ def close(self):
self._manager.close()


class PseudoNetCDFBackendEntrypoint(BackendEntrypoint):
class PseudoNetCDFBackendEntrypoint(_InternalBackendEntrypoint):
"""
Backend for netCDF-like data formats in the air quality field
based on the PseudoNetCDF package.
Expand All @@ -121,6 +121,7 @@ class PseudoNetCDFBackendEntrypoint(BackendEntrypoint):
backends.PseudoNetCDFDataStore
"""

_module_name = "PseudoNetCDF"
available = module_available("PseudoNetCDF")
description = (
"Open many atmospheric science data formats using PseudoNetCDF in Xarray"
Expand Down
5 changes: 3 additions & 2 deletions xarray/backends/pydap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
BACKEND_ENTRYPOINTS,
AbstractDataStore,
BackendArray,
BackendEntrypoint,
_InternalBackendEntrypoint,
robust_getitem,
)
from xarray.backends.store import StoreBackendEntrypoint
Expand Down Expand Up @@ -138,7 +138,7 @@ def get_dimensions(self):
return Frozen(self.ds.dimensions)


class PydapBackendEntrypoint(BackendEntrypoint):
class PydapBackendEntrypoint(_InternalBackendEntrypoint):
"""
Backend for steaming datasets over the internet using
the Data Access Protocol, also known as DODS or OPeNDAP
Expand All @@ -154,6 +154,7 @@ class PydapBackendEntrypoint(BackendEntrypoint):
backends.PydapDataStore
"""

_module_name = "pydap"
available = module_available("pydap")
description = "Open remote datasets via OPeNDAP using pydap in Xarray"
url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.PydapBackendEntrypoint.html"
Expand Down
5 changes: 3 additions & 2 deletions xarray/backends/pynio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
BACKEND_ENTRYPOINTS,
AbstractDataStore,
BackendArray,
BackendEntrypoint,
_InternalBackendEntrypoint,
_normalize_path,
)
from xarray.backends.file_manager import CachingFileManager
Expand Down Expand Up @@ -107,7 +107,7 @@ def close(self):
self._manager.close()


class PynioBackendEntrypoint(BackendEntrypoint):
class PynioBackendEntrypoint(_InternalBackendEntrypoint):
"""
PyNIO backend
Expand All @@ -117,6 +117,7 @@ class PynioBackendEntrypoint(BackendEntrypoint):
https://github.com/pydata/xarray/issues/4491 for more information
"""

_module_name = "Nio"
available = module_available("Nio")

def open_dataset(
Expand Down
5 changes: 3 additions & 2 deletions xarray/backends/scipy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from xarray.backends.common import (
BACKEND_ENTRYPOINTS,
BackendArray,
BackendEntrypoint,
WritableCFDataStore,
_InternalBackendEntrypoint,
_normalize_path,
)
from xarray.backends.file_manager import CachingFileManager, DummyFileManager
Expand Down Expand Up @@ -240,7 +240,7 @@ def close(self):
self._manager.close()


class ScipyBackendEntrypoint(BackendEntrypoint):
class ScipyBackendEntrypoint(_InternalBackendEntrypoint):
"""
Backend for netCDF files based on the scipy package.
Expand All @@ -261,6 +261,7 @@ class ScipyBackendEntrypoint(BackendEntrypoint):
backends.H5netcdfBackendEntrypoint
"""

_module_name = "scipy"
available = module_available("scipy")
description = "Open netCDF files (.nc, .nc4, .cdf and .gz) using scipy in Xarray"
url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.ScipyBackendEntrypoint.html"
Expand Down
5 changes: 3 additions & 2 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
BACKEND_ENTRYPOINTS,
AbstractWritableDataStore,
BackendArray,
BackendEntrypoint,
_encode_variable_name,
_InternalBackendEntrypoint,
_normalize_path,
)
from xarray.backends.store import StoreBackendEntrypoint
Expand Down Expand Up @@ -845,7 +845,7 @@ def open_zarr(
return ds


class ZarrBackendEntrypoint(BackendEntrypoint):
class ZarrBackendEntrypoint(_InternalBackendEntrypoint):
"""
Backend for ".zarr" files based on the zarr package.
Expand All @@ -857,6 +857,7 @@ class ZarrBackendEntrypoint(BackendEntrypoint):
backends.ZarrStore
"""

_module_name = "zarr"
available = module_available("zarr")
description = "Open zarr files (.zarr) using zarr in Xarray"
url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.ZarrBackendEntrypoint.html"
Expand Down

0 comments on commit 576692b

Please sign in to comment.