diff --git a/intake_xarray/netcdf.py b/intake_xarray/netcdf.py index 36c64bf..3786881 100644 --- a/intake_xarray/netcdf.py +++ b/intake_xarray/netcdf.py @@ -1,5 +1,10 @@ # -*- coding: utf-8 -*- -import xarray as xr +from distutils.version import LooseVersion +try: + import xarray as xr + XARRAY_VERSION = LooseVersion(xr.__version__) +except ImportError: + XARRAY_VERSION = None from intake.source.base import PatternMixin from intake.source.utils import reverse_format from .base import DataSourceMixin @@ -10,38 +15,47 @@ class NetCDFSource(DataSourceMixin, PatternMixin): Parameters ---------- - urlpath: str + urlpath : str Path to source file. May include glob "*" characters, format pattern strings, or list. Some examples: - - ``{{ CATALOG_DIR }}data/air.nc`` - - ``{{ CATALOG_DIR }}data/*.nc`` - - ``{{ CATALOG_DIR }}data/air_{year}.nc`` - chunks: int or dict + - ``{{ CATALOG_DIR }}/data/air.nc`` + - ``{{ CATALOG_DIR }}/data/*.nc`` + - ``{{ CATALOG_DIR }}/data/air_{year}.nc`` + chunks : int or dict, optional Chunks is used to load the new dataset into dask arrays. ``chunks={}`` loads the dataset with dask using a single chunk for all arrays. - path_as_pattern: bool or str, optional + concat_dim : str, optional + Name of dimension along which to concatenate the files. Can + be new or pre-existing. Default is 'concat_dim'. + path_as_pattern : bool or str, optional Whether to treat the path as a pattern (ie. ``data_{field}.nc``) and create new coodinates in the output corresponding to pattern fields. If str, is treated as pattern to match on. Default is True. """ name = 'netcdf' - def __init__(self, urlpath, chunks, xarray_kwargs=None, metadata=None, + def __init__(self, urlpath, chunks=None, concat_dim='concat_dim', + xarray_kwargs=None, metadata=None, path_as_pattern=True, **kwargs): self.path_as_pattern = path_as_pattern self.urlpath = urlpath self.chunks = chunks + self.concat_dim = concat_dim self._kwargs = xarray_kwargs or kwargs self._ds = None super(NetCDFSource, self).__init__(metadata=metadata) def _open_dataset(self): + if not XARRAY_VERSION: + raise ImportError("xarray not available") url = self.urlpath kwargs = self._kwargs if "*" in url or isinstance(url, list): _open_dataset = xr.open_mfdataset + if 'concat_dim' not in kwargs.keys(): + kwargs.update(concat_dim=self.concat_dim) if self.pattern: kwargs.update(preprocess=self._add_path_to_ds) else: @@ -52,6 +66,12 @@ def _open_dataset(self): def _add_path_to_ds(self, ds): """Adding path info to a coord for a particular file """ + if not (XARRAY_VERSION > '0.11.1'): + raise ImportError("Your version of xarray is '{}'. " + "The insurance that source path is available on output of " + "open_dataset was added in 0.11.2, so " + "pattern urlpaths are not supported.".format(XARRAY_VERSION)) + var = next(var for var in ds) new_coords = reverse_format(self.pattern, ds[var].encoding['source']) return ds.assign_coords(**new_coords) diff --git a/tests/util.py b/tests/conftest.py similarity index 87% rename from tests/util.py rename to tests/conftest.py index b59c906..9d044a9 100644 --- a/tests/util.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -import os +import posixpath import pytest import shutil import tempfile @@ -11,11 +11,11 @@ TEST_DATA_DIR = 'tests/data' TEST_DATA = 'example_1.nc' -TEST_URLPATH = os.path.join(TEST_DATA_DIR, TEST_DATA) +TEST_URLPATH = posixpath.join(TEST_DATA_DIR, TEST_DATA) @pytest.fixture -def cdf_source(): +def netcdf_source(): return NetCDFSource(TEST_URLPATH, {}) diff --git a/tests/data/example_2.nc b/tests/data/example_2.nc new file mode 100644 index 0000000..5775622 Binary files /dev/null and b/tests/data/example_2.nc differ diff --git a/tests/test_catalog.py b/tests/test_catalog.py index 5a76f4a..38b031e 100644 --- a/tests/test_catalog.py +++ b/tests/test_catalog.py @@ -4,7 +4,6 @@ import pytest from intake import open_catalog -from .util import dataset # noqa @pytest.fixture diff --git a/tests/test_intake_xarray.py b/tests/test_intake_xarray.py index a0d9f44..277f9bf 100644 --- a/tests/test_intake_xarray.py +++ b/tests/test_intake_xarray.py @@ -7,12 +7,10 @@ here = os.path.dirname(__file__) -from .util import TEST_URLPATH, cdf_source, zarr_source, dataset # noqa - -@pytest.mark.parametrize('source', ['cdf', 'zarr']) -def test_discover(source, cdf_source, zarr_source, dataset): - source = {'cdf': cdf_source, 'zarr': zarr_source}[source] +@pytest.mark.parametrize('source', ['netcdf', 'zarr']) +def test_discover(source, netcdf_source, zarr_source, dataset): + source = {'netcdf': netcdf_source, 'zarr': zarr_source}[source] r = source.discover() assert r['datashape'] is None @@ -25,9 +23,9 @@ def test_discover(source, cdf_source, zarr_source, dataset): assert set(source.metadata['coords']) == set(dataset.coords.keys()) -@pytest.mark.parametrize('source', ['cdf', 'zarr']) -def test_read(source, cdf_source, zarr_source, dataset): - source = {'cdf': cdf_source, 'zarr': zarr_source}[source] +@pytest.mark.parametrize('source', ['netcdf', 'zarr']) +def test_read(source, netcdf_source, zarr_source, dataset): + source = {'netcdf': netcdf_source, 'zarr': zarr_source}[source] ds = source.read_chunked() assert ds.temp.chunks @@ -38,8 +36,8 @@ def test_read(source, cdf_source, zarr_source, dataset): assert np.all(ds.rh == dataset.rh) -def test_read_partition_cdf(cdf_source): - source = cdf_source +def test_read_partition_netcdf(netcdf_source): + source = netcdf_source with pytest.raises(TypeError): source.read_partition(None) out = source.read_partition(('temp', 0, 0, 0, 0)) @@ -48,6 +46,28 @@ def test_read_partition_cdf(cdf_source): assert np.all(out == expected) +def test_read_list_of_netcdf_files(): + from intake_xarray.netcdf import NetCDFSource + source = NetCDFSource([ + os.path.join(here, 'data', 'example_1.nc'), + os.path.join(here, 'data', 'example_2.nc'), + ]) + d = source.to_dask() + assert d.dims == {'lat': 5, 'lon': 10, 'level': 4, 'time': 1, + 'concat_dim': 2} + + +def test_read_glob_pattern_of_netcdf_files(): + from intake_xarray.netcdf import NetCDFSource + + source = NetCDFSource(os.path.join(here, 'data', 'example_{num: d}.nc'), + concat_dim='num') + d = source.to_dask() + assert d.dims == {'lat': 5, 'lon': 10, 'level': 4, 'time': 1, + 'num': 2} + assert (d.num.data == np.array([1, 2])).all() + + def test_read_partition_zarr(zarr_source): source = zarr_source with pytest.raises(TypeError): @@ -57,9 +77,9 @@ def test_read_partition_zarr(zarr_source): assert np.all(out == expected) -@pytest.mark.parametrize('source', ['cdf', 'zarr']) -def test_to_dask(source, cdf_source, zarr_source, dataset): - source = {'cdf': cdf_source, 'zarr': zarr_source}[source] +@pytest.mark.parametrize('source', ['netcdf', 'zarr']) +def test_to_dask(source, netcdf_source, zarr_source, dataset): + source = {'netcdf': netcdf_source, 'zarr': zarr_source}[source] ds = source.to_dask() assert ds.dims == dataset.dims