Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

APIv2: move all _autodetect_engine logic to the plugins #4709

Merged
merged 11 commits into from
Dec 22, 2020
15 changes: 3 additions & 12 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from ..core.dataarray import DataArray
from ..core.dataset import Dataset, _get_chunk, _maybe_chunk
from ..core.utils import close_on_error, is_grib_path, is_remote_uri
from ..core.utils import close_on_error, is_grib_path, is_remote_uri, read_magic_number
from .common import AbstractDataStore, ArrayWriter
from .locks import _get_scheduler

Expand Down Expand Up @@ -120,24 +120,15 @@ def _get_default_engine_netcdf():


def _get_engine_from_magic_number(filename_or_obj):
# check byte header to determine file type
if isinstance(filename_or_obj, bytes):
magic_number = filename_or_obj[:8]
else:
if filename_or_obj.tell() != 0:
raise ValueError(
"file-like object read/write pointer not at zero "
"please close and reopen, or use a context manager"
)
magic_number = filename_or_obj.read(8)
filename_or_obj.seek(0)
magic_number = read_magic_number(filename_or_obj)

if magic_number.startswith(b"CDF"):
engine = "scipy"
elif magic_number.startswith(b"\211HDF\r\n\032\n"):
engine = "h5netcdf"
else:
raise ValueError(
"cannot guess the engine, "
f"{magic_number} is not the signature of any supported file format "
"did you mean to pass a string for a path instead?"
)
Expand Down
9 changes: 2 additions & 7 deletions xarray/backends/apiv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@
from ..core.dataset import _get_chunk, _maybe_chunk
from ..core.utils import is_remote_uri
from . import plugins
from .api import (
_autodetect_engine,
_get_backend_cls,
_normalize_path,
_protect_dataset_variables_inplace,
)
from .api import _get_backend_cls, _normalize_path, _protect_dataset_variables_inplace


def _get_mtime(filename_or_obj):
Expand Down Expand Up @@ -248,7 +243,7 @@ def open_dataset(
filename_or_obj = _normalize_path(filename_or_obj)

if engine is None:
engine = _autodetect_engine(filename_or_obj)
engine = plugins.guess_engine(filename_or_obj)

engines = plugins.list_engines()
backend = _get_backend_cls(engine, engines=engines)
Expand Down
14 changes: 13 additions & 1 deletion xarray/backends/cfgrib_.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import numpy as np

from ..core import indexing
Expand Down Expand Up @@ -73,6 +75,14 @@ def get_encoding(self):
return encoding


def guess_can_open_cfgrib(store_spec):
try:
_, ext = os.path.splitext(store_spec)
except TypeError:
return False
return ext in {".grib", ".grib2", ".grb", ".grb2"}


def open_backend_dataset_cfgrib(
filename_or_obj,
*,
Expand Down Expand Up @@ -116,4 +126,6 @@ def open_backend_dataset_cfgrib(
return ds


cfgrib_backend = BackendEntrypoint(open_dataset=open_backend_dataset_cfgrib)
cfgrib_backend = BackendEntrypoint(
open_dataset=open_backend_dataset_cfgrib, guess_can_open=guess_can_open_cfgrib
)
37 changes: 24 additions & 13 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import functools
import io
import os
from distutils.version import LooseVersion

import numpy as np

from ..core import indexing
from ..core.utils import FrozenDict, is_remote_uri
from ..core.utils import FrozenDict, is_remote_uri, read_magic_number
from ..core.variable import Variable
from .common import WritableCFDataStore, find_root_and_group
from .file_manager import CachingFileManager, DummyFileManager
Expand Down Expand Up @@ -128,19 +130,12 @@ def open(
"can't open netCDF4/HDF5 as bytes "
"try passing a path or file-like object"
)
elif hasattr(filename, "tell"):
if filename.tell() != 0:
elif isinstance(filename, io.IOBase):
magic_number = read_magic_number(filename)
if not magic_number.startswith(b"\211HDF\r\n\032\n"):
raise ValueError(
"file-like object read/write pointer not at zero "
"please close and reopen, or use a context manager"
f"{magic_number} is not the signature of a valid netCDF file"
)
else:
magic_number = filename.read(8)
filename.seek(0)
if not magic_number.startswith(b"\211HDF\r\n\032\n"):
raise ValueError(
f"{magic_number} is not the signature of a valid netCDF file"
)

if format not in [None, "NETCDF4"]:
raise ValueError("invalid format for h5netcdf backend")
Expand Down Expand Up @@ -325,6 +320,20 @@ def close(self, **kwargs):
self._manager.close(**kwargs)


def guess_can_open_h5netcdf(store_spec):
try:
return read_magic_number(store_spec).startswith(b"\211HDF\r\n\032\n")
except TypeError:
pass

try:
_, ext = os.path.splitext(store_spec)
except TypeError:
return False

return ext in {".nc", ".nc4", ".cdf"}


def open_backend_dataset_h5netcdf(
filename_or_obj,
*,
Expand Down Expand Up @@ -364,4 +373,6 @@ def open_backend_dataset_h5netcdf(
return ds


h5netcdf_backend = BackendEntrypoint(open_dataset=open_backend_dataset_h5netcdf)
h5netcdf_backend = BackendEntrypoint(
open_dataset=open_backend_dataset_h5netcdf, guess_can_open=guess_can_open_h5netcdf
)
14 changes: 13 additions & 1 deletion xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,16 @@ def close(self, **kwargs):
self._manager.close(**kwargs)


def guess_can_open_netcdf4(store_spec):
if isinstance(store_spec, str) and is_remote_uri(store_spec):
return True
try:
_, ext = os.path.splitext(store_spec)
except TypeError:
return False
return ext in {".nc", ".nc4", ".cdf"}


def open_backend_dataset_netcdf4(
filename_or_obj,
mask_and_scale=True,
Expand Down Expand Up @@ -549,4 +559,6 @@ def open_backend_dataset_netcdf4(
return ds


netcdf4_backend = BackendEntrypoint(open_dataset=open_backend_dataset_netcdf4)
netcdf4_backend = BackendEntrypoint(
open_dataset=open_backend_dataset_netcdf4, guess_can_open=guess_can_open_netcdf4
)
24 changes: 22 additions & 2 deletions xarray/backends/plugins.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import inspect
import itertools
import logging
import warnings
from functools import lru_cache

import pkg_resources


class BackendEntrypoint:
__slots__ = ("open_dataset", "open_dataset_parameters")
__slots__ = ("guess_can_open", "open_dataset", "open_dataset_parameters")

def __init__(self, open_dataset, open_dataset_parameters=None):
def __init__(self, open_dataset, open_dataset_parameters=None, guess_can_open=None):
self.open_dataset = open_dataset
self.open_dataset_parameters = open_dataset_parameters
self.guess_can_open = guess_can_open


def remove_duplicates(backend_entrypoints):
Expand Down Expand Up @@ -76,3 +78,21 @@ def list_engines():
engines = create_engines_dict(backend_entrypoints)
set_missing_parameters(engines)
return engines


def guess_engine(store_spec):
engines = list_engines()

# use the pre-defined selection order for netCDF files
for engine in ["netcdf4", "h5netcdf", "scipy"]:
if engine in engines and engines[engine].guess_can_open(store_spec):
return engine

for engine, backend in engines.items():
try:
if backend.guess_can_open and backend.guess_can_open(store_spec):
return engine
except Exception:
logging.exception(f"{engine!r} fails while guessing")

raise ValueError("cannot guess the engine, try passing one explicitly")
10 changes: 8 additions & 2 deletions xarray/backends/pydap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ..core import indexing
from ..core.pycompat import integer_types
from ..core.utils import Frozen, FrozenDict, is_dict_like
from ..core.utils import Frozen, FrozenDict, is_dict_like, is_remote_uri
from ..core.variable import Variable
from .common import AbstractDataStore, BackendArray, robust_getitem
from .plugins import BackendEntrypoint
Expand Down Expand Up @@ -96,6 +96,10 @@ def get_dimensions(self):
return Frozen(self.ds.dimensions)


def guess_can_open_pydap(store_spec):
return isinstance(store_spec, str) and is_remote_uri(store_spec)


def open_backend_dataset_pydap(
filename_or_obj,
mask_and_scale=True,
Expand Down Expand Up @@ -126,4 +130,6 @@ def open_backend_dataset_pydap(
return ds


pydap_backend = BackendEntrypoint(open_dataset=open_backend_dataset_pydap)
pydap_backend = BackendEntrypoint(
open_dataset=open_backend_dataset_pydap, guess_can_open=guess_can_open_pydap
)
24 changes: 20 additions & 4 deletions xarray/backends/scipy_.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from io import BytesIO
import io
import os

import numpy as np

from ..core.indexing import NumpyIndexingAdapter
from ..core.utils import Frozen, FrozenDict
from ..core.utils import Frozen, FrozenDict, read_magic_number
from ..core.variable import Variable
from .common import BackendArray, WritableCFDataStore
from .file_manager import CachingFileManager, DummyFileManager
Expand Down Expand Up @@ -78,7 +79,7 @@ def _open_scipy_netcdf(filename, mode, mmap, version):

if isinstance(filename, bytes) and filename.startswith(b"CDF"):
# it's a NetCDF3 bytestring
filename = BytesIO(filename)
filename = io.BytesIO(filename)

try:
return scipy.io.netcdf_file(filename, mode=mode, mmap=mmap, version=version)
Expand Down Expand Up @@ -222,6 +223,19 @@ def close(self):
self._manager.close()


def guess_can_open_scipy(store_spec):
try:
return read_magic_number(store_spec).startswith(b"CDF")
except TypeError:
pass

try:
_, ext = os.path.splitext(store_spec)
except TypeError:
return False
return ext in {".nc", ".nc4", ".cdf", ".gz"}


def open_backend_dataset_scipy(
filename_or_obj,
mask_and_scale=True,
Expand Down Expand Up @@ -255,4 +269,6 @@ def open_backend_dataset_scipy(
return ds


scipy_backend = BackendEntrypoint(open_dataset=open_backend_dataset_scipy)
scipy_backend = BackendEntrypoint(
open_dataset=open_backend_dataset_scipy, guess_can_open=guess_can_open_scipy
)
9 changes: 8 additions & 1 deletion xarray/backends/store.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from .. import conventions
from ..core.dataset import Dataset
from ..core.utils import close_on_error
from .common import AbstractDataStore
from .plugins import BackendEntrypoint


def guess_can_open_store(store_spec):
return isinstance(store_spec, AbstractDataStore)


def open_backend_dataset_store(
store,
*,
Expand Down Expand Up @@ -40,4 +45,6 @@ def open_backend_dataset_store(
return ds


store_backend = BackendEntrypoint(open_dataset=open_backend_dataset_store)
store_backend = BackendEntrypoint(
open_dataset=open_backend_dataset_store, guess_can_open=guess_can_open_store
)
18 changes: 18 additions & 0 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
import contextlib
import functools
import io
import itertools
import os.path
import re
Expand Down Expand Up @@ -603,6 +604,23 @@ def is_remote_uri(path: str) -> bool:
return bool(re.search(r"^https?\://", path))


def read_magic_number(filename_or_obj, count=8):
# check byte header to determine file type
if isinstance(filename_or_obj, bytes):
magic_number = filename_or_obj[:count]
elif isinstance(filename_or_obj, io.IOBase):
if filename_or_obj.tell() != 0:
raise ValueError(
"file-like object read/write pointer not at the start of the file, "
"please close and reopen, or use a context manager"
)
magic_number = filename_or_obj.read(count)
filename_or_obj.seek(0)
else:
raise TypeError(f"cannot read the magic number form {type(filename_or_obj)}")
return magic_number


def is_grib_path(path: str) -> bool:
_, ext = os.path.splitext(path)
return ext in [".grib", ".grb", ".grib2", ".grb2"]
Expand Down
6 changes: 3 additions & 3 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2623,7 +2623,7 @@ def test_open_badbytes(self):
with raises_regex(ValueError, "HDF5 as bytes"):
with open_dataset(b"\211HDF\r\n\032\n", engine="h5netcdf"):
pass
with raises_regex(ValueError, "not the signature of any supported file"):
with raises_regex(ValueError, "cannot guess the engine"):
with open_dataset(b"garbage"):
pass
with raises_regex(ValueError, "can only read bytes"):
Expand All @@ -2636,7 +2636,7 @@ def test_open_badbytes(self):
def test_open_twice(self):
expected = create_test_data()
expected.attrs["foo"] = "bar"
with raises_regex(ValueError, "read/write pointer not at zero"):
with raises_regex(ValueError, "read/write pointer not at the start"):
with create_tmp_file() as tmp_file:
expected.to_netcdf(tmp_file, engine="h5netcdf")
with open(tmp_file, "rb") as f:
Expand Down Expand Up @@ -2669,7 +2669,7 @@ def test_open_fileobj(self):
open_dataset(f, engine="scipy")

f.seek(8)
with raises_regex(ValueError, "read/write pointer not at zero"):
with raises_regex(ValueError, "read/write pointer not at the start"):
open_dataset(f)


Expand Down