From c6977f18341e3872ba69360045ee3ef8212eb1bf Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 8 May 2018 03:25:39 +0100 Subject: [PATCH 01/61] h5netcdf new API support (#1915) * Ignore dask scratch area * Public API for HDF5 support * Remove save_mfdataset_hdf5 * Replace to_hdf5 with to_netcdf(engine='h5netcdf-ng') * h5netcdf-ng -> h5netcdf-new * Trivial fixes * Functional implementation * stickler fixes * Reimplement as extra params for h5netcdf * Cosmetic tweaks * Bugfixes * More robust mixed-style encoding handling * Crash on mismatched encoding if check_encoding=True * Test check_encoding * stickler fix * Use parentheses instead of explicit continuation with \ --- .gitignore | 1 + doc/whats-new.rst | 5 +++ xarray/backends/api.py | 1 + xarray/backends/h5netcdf_.py | 73 ++++++++++++++++++++++---------- xarray/backends/netCDF4_.py | 7 +++- xarray/core/dataarray.py | 3 +- xarray/core/dataset.py | 7 ++++ xarray/tests/test_backends.py | 78 +++++++++++++++++++++++++++++++++++ 8 files changed, 149 insertions(+), 26 deletions(-) diff --git a/.gitignore b/.gitignore index 70458f00648..92e488ed616 100644 --- a/.gitignore +++ b/.gitignore @@ -39,6 +39,7 @@ nosetests.xml .tags* .testmon* .pytest_cache +dask-worker-space/ # asv environments .asv diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8006c658e01..d614a23d0fc 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,6 +36,11 @@ Enhancements - Support writing lists of strings as netCDF attributes (:issue:`2044`). By `Dan Nowacki `_. +- :py:meth:`~xarray.Dataset.to_netcdf(engine='h5netcdf')` now accepts h5py + encoding settings ``compression`` and ``compression_opts``, along with the + NetCDF4-Python style settings ``gzip=True`` and ``complevel``. + This allows using any compression plugin installed in hdf5, e.g. LZF + (:issue:`1536`). By `Guido Imperiale `_. - :py:meth:`~xarray.dot` on dask-backed data will now call :func:`dask.array.einsum`. This greatly boosts speed and allows chunking on the core dims. The function now requires dask >= 0.17.3 to work on dask-backed data diff --git a/xarray/backends/api.py b/xarray/backends/api.py index da4ef537a3a..b8cfa3c926a 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -741,6 +741,7 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, Engine to use when writing netCDF files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4' if writing to a file on disk. + See `Dataset.to_netcdf` for additional information. Examples -------- diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 7beda03308e..d34fa2d9267 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -45,21 +45,21 @@ def _read_attributes(h5netcdf_var): # to ensure conventions decoding works properly on Python 3, decode all # bytes attributes to strings attrs = OrderedDict() - for k in h5netcdf_var.ncattrs(): - v = h5netcdf_var.getncattr(k) + for k, v in h5netcdf_var.attrs.items(): if k not in ['_FillValue', 'missing_value']: v = maybe_decode_bytes(v) attrs[k] = v return attrs -_extract_h5nc_encoding = functools.partial(_extract_nc4_variable_encoding, - lsd_okay=False, backend='h5netcdf') +_extract_h5nc_encoding = functools.partial( + _extract_nc4_variable_encoding, + lsd_okay=False, h5py_okay=True, backend='h5netcdf') def _open_h5netcdf_group(filename, mode, group): - import h5netcdf.legacyapi - ds = h5netcdf.legacyapi.Dataset(filename, mode=mode) + import h5netcdf + ds = h5netcdf.File(filename, mode=mode) with close_on_error(ds): return _nc4_group(ds, group, mode) @@ -96,10 +96,19 @@ def open_store_variable(self, name, var): attrs = _read_attributes(var) # netCDF4 specific encoding - encoding = dict(var.filters()) - chunking = var.chunking() - encoding['chunksizes'] = chunking \ - if chunking != 'contiguous' else None + encoding = { + 'chunksizes': var.chunks, + 'fletcher32': var.fletcher32, + 'shuffle': var.shuffle, + } + # Convert h5py-style compression options to NetCDF4-Python + # style, if possible + if var.compression == 'gzip': + encoding['zlib'] = True + encoding['complevel'] = var.compression_opts + elif var.compression is not None: + encoding['compression'] = var.compression + encoding['compression_opts'] = var.compression_opts # save source so __repr__ can detect if it's local or not encoding['source'] = self._filename @@ -130,14 +139,14 @@ def get_encoding(self): def set_dimension(self, name, length, is_unlimited=False): with self.ensure_open(autoclose=False): if is_unlimited: - self.ds.createDimension(name, size=None) + self.ds.dimensions[name] = None self.ds.resize_dimension(name, length) else: - self.ds.createDimension(name, size=length) + self.ds.dimensions[name] = length def set_attribute(self, key, value): with self.ensure_open(autoclose=False): - self.ds.setncattr(key, value) + self.ds.attrs[key] = value def encode_variable(self, variable): return _encode_nc4_variable(variable) @@ -149,8 +158,8 @@ def prepare_variable(self, name, variable, check_encoding=False, attrs = variable.attrs.copy() dtype = _get_datatype(variable) - fill_value = attrs.pop('_FillValue', None) - if dtype is str and fill_value is not None: + fillvalue = attrs.pop('_FillValue', None) + if dtype is str and fillvalue is not None: raise NotImplementedError( 'h5netcdf does not yet support setting a fill value for ' 'variable-length strings ' @@ -166,18 +175,38 @@ def prepare_variable(self, name, variable, check_encoding=False, raise_on_invalid=check_encoding) kwargs = {} - for key in ['zlib', 'complevel', 'shuffle', - 'chunksizes', 'fletcher32']: + # Convert from NetCDF4-Python style compression settings to h5py style + # If both styles are used together, h5py takes precedence + # If set_encoding=True, raise ValueError in case of mismatch + if encoding.pop('zlib', False): + if (check_encoding and encoding.get('compression') + not in (None, 'gzip')): + raise ValueError("'zlib' and 'compression' encodings mismatch") + encoding.setdefault('compression', 'gzip') + + if (check_encoding and encoding.get('complevel') not in + (None, encoding.get('compression_opts'))): + raise ValueError("'complevel' and 'compression_opts' encodings " + "mismatch") + complevel = encoding.pop('complevel', 0) + if complevel != 0: + encoding.setdefault('compression_opts', complevel) + + encoding['chunks'] = encoding.pop('chunksizes', None) + + for key in ['compression', 'compression_opts', 'shuffle', + 'chunks', 'fletcher32']: if key in encoding: kwargs[key] = encoding[key] - if name not in self.ds.variables: - nc4_var = self.ds.createVariable(name, dtype, variable.dims, - fill_value=fill_value, **kwargs) + if name not in self.ds: + nc4_var = self.ds.create_variable( + name, dtype=dtype, dimensions=variable.dims, + fillvalue=fillvalue, **kwargs) else: - nc4_var = self.ds.variables[name] + nc4_var = self.ds[name] for k, v in iteritems(attrs): - nc4_var.setncattr(k, v) + nc4_var.attrs[k] = v target = H5NetCDFArrayWrapper(name, self) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 1195301825b..a0f6cbcdd33 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -159,8 +159,8 @@ def _force_native_endianness(var): def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, - lsd_okay=True, backend='netCDF4', - unlimited_dims=None): + lsd_okay=True, h5py_okay=False, + backend='netCDF4', unlimited_dims=None): if unlimited_dims is None: unlimited_dims = () @@ -171,6 +171,9 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, 'chunksizes', 'shuffle', '_FillValue']) if lsd_okay: valid_encodings.add('least_significant_digit') + if h5py_okay: + valid_encodings.add('compression') + valid_encodings.add('compression_opts') if not raise_on_invalid and encoding.get('chunksizes') is not None: # It's possible to get encoded chunksizes larger than a dimension size diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 9ff631e7cfc..1ceaced5961 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1443,8 +1443,7 @@ def to_masked_array(self, copy=True): return np.ma.MaskedArray(data=self.values, mask=isnull, copy=copy) def to_netcdf(self, *args, **kwargs): - """ - Write DataArray contents to a netCDF file. + """Write DataArray contents to a netCDF file. Parameters ---------- diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f28e7980b34..32913127636 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1123,6 +1123,13 @@ def to_netcdf(self, path=None, mode='w', format=None, group=None, variable specific encodings as values, e.g., ``{'my_variable': {'dtype': 'int16', 'scale_factor': 0.1, 'zlib': True}, ...}`` + + The `h5netcdf` engine supports both the NetCDF4-style compression + encoding parameters ``{'zlib': True, 'complevel': 9}`` and the h5py + ones ``{'compression': 'gzip', 'compression_opts': 9}``. + This allows using any compression plugin installed in the HDF5 + library, e.g. LZF. + unlimited_dims : sequence of str, optional Dimension(s) that should be serialized as unlimited dimensions. By default, no dimensions are treated as unlimited dimensions. diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 7f8a440ba5d..632145007e2 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1668,6 +1668,84 @@ def test_encoding_unlimited_dims(self): self.assertEqual(actual.encoding['unlimited_dims'], set('y')) assert_equal(ds, actual) + def test_compression_encoding_h5py(self): + ENCODINGS = ( + # h5py style compression with gzip codec will be converted to + # NetCDF4-Python style on round-trip + ({'compression': 'gzip', 'compression_opts': 9}, + {'zlib': True, 'complevel': 9}), + # What can't be expressed in NetCDF4-Python style is + # round-tripped unaltered + ({'compression': 'lzf', 'compression_opts': None}, + {'compression': 'lzf', 'compression_opts': None}), + # If both styles are used together, h5py format takes precedence + ({'compression': 'lzf', 'compression_opts': None, + 'zlib': True, 'complevel': 9}, + {'compression': 'lzf', 'compression_opts': None})) + + for compr_in, compr_out in ENCODINGS: + data = create_test_data() + compr_common = { + 'chunksizes': (5, 5), + 'fletcher32': True, + 'shuffle': True, + 'original_shape': data.var2.shape + } + data['var2'].encoding.update(compr_in) + data['var2'].encoding.update(compr_common) + compr_out.update(compr_common) + with self.roundtrip(data) as actual: + for k, v in compr_out.items(): + self.assertEqual(v, actual['var2'].encoding[k]) + + def test_compression_check_encoding_h5py(self): + """When mismatched h5py and NetCDF4-Python encodings are expressed + in to_netcdf(encoding=...), must raise ValueError + """ + data = Dataset({'x': ('y', np.arange(10.0))}) + # Compatible encodings are graciously supported + with create_tmp_file() as tmp_file: + data.to_netcdf( + tmp_file, engine='h5netcdf', + encoding={'x': {'compression': 'gzip', 'zlib': True, + 'compression_opts': 6, 'complevel': 6}}) + with open_dataset(tmp_file, engine='h5netcdf') as actual: + assert actual.x.encoding['zlib'] is True + assert actual.x.encoding['complevel'] == 6 + + # Incompatible encodings cause a crash + with create_tmp_file() as tmp_file: + with raises_regex(ValueError, + "'zlib' and 'compression' encodings mismatch"): + data.to_netcdf( + tmp_file, engine='h5netcdf', + encoding={'x': {'compression': 'lzf', 'zlib': True}}) + + with create_tmp_file() as tmp_file: + with raises_regex( + ValueError, + "'complevel' and 'compression_opts' encodings mismatch"): + data.to_netcdf( + tmp_file, engine='h5netcdf', + encoding={'x': {'compression': 'gzip', + 'compression_opts': 5, 'complevel': 6}}) + + def test_dump_encodings_h5py(self): + # regression test for #709 + ds = Dataset({'x': ('y', np.arange(10.0))}) + + kwargs = {'encoding': {'x': { + 'compression': 'gzip', 'compression_opts': 9}}} + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + self.assertEqual(actual.x.encoding['zlib'], True) + self.assertEqual(actual.x.encoding['complevel'], 9) + + kwargs = {'encoding': {'x': { + 'compression': 'lzf', 'compression_opts': None}}} + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + self.assertEqual(actual.x.encoding['compression'], 'lzf') + self.assertEqual(actual.x.encoding['compression_opts'], None) + # tests pending h5netcdf fix @unittest.skip From aeae80b21ed659e14d4378a513f2351452eed460 Mon Sep 17 00:00:00 2001 From: Ray Bell Date: Tue, 8 May 2018 00:23:02 -0400 Subject: [PATCH 02/61] DOC: Add resample e.g. Edit rolling e.g. Add groupby e.g. (#2101) * DOC: Add resample e.g. Edit rolling e.g. Add groupby e.g. * DOC: Add 2d resample example * DOC: Add upsample example in resample * DOC: drop sentence is resample docstring * extend resample DeprecationWarning. Drop n-d resample example. * change resample DeprecationWarning * don't display how twice in the warning * DOC: add assign_coords example * DOC: remove parameters to resample --- xarray/core/common.py | 95 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 80 insertions(+), 15 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 5beb5234d4c..0c6e0fccd8e 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -308,6 +308,25 @@ def assign_coords(self, **kwargs): assigned : same type as caller A new object with the new coordinates in addition to the existing data. + + Examples + -------- + + Convert longitude coordinates from 0-359 to -180-179: + + >>> da = xr.DataArray(np.random.rand(4), + ... coords=[np.array([358, 359, 0, 1])], + ... dims='lon') + >>> da + + array([0.28298 , 0.667347, 0.657938, 0.177683]) + Coordinates: + * lon (lon) int64 358 359 0 1 + >>> da.assign_coords(lon=(((da.lon + 180) % 360) - 180)) + + array([0.28298 , 0.667347, 0.657938, 0.177683]) + Coordinates: + * lon (lon) int64 -2 -1 0 1 Notes ----- @@ -426,7 +445,27 @@ def groupby(self, group, squeeze=True): grouped : GroupBy A `GroupBy` object patterned after `pandas.GroupBy` that can be iterated over in the form of `(unique_value, grouped_array)` pairs. - + + Examples + -------- + Calculate daily anomalies for daily data: + + >>> da = xr.DataArray(np.linspace(0, 1826, num=1827), + ... coords=[pd.date_range('1/1/2000', '31/12/2004', + ... freq='D')], + ... dims='time') + >>> da + + array([0.000e+00, 1.000e+00, 2.000e+00, ..., 1.824e+03, 1.825e+03, 1.826e+03]) + Coordinates: + * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 ... + >>> da.groupby('time.dayofyear') - da.groupby('time.dayofyear').mean('time') + + array([-730.8, -730.8, -730.8, ..., 730.2, 730.2, 730.5]) + Coordinates: + * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 ... + dayofyear (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 ... + See Also -------- core.groupby.DataArrayGroupBy @@ -514,7 +553,7 @@ def rolling(self, min_periods=None, center=False, **windows): -------- Create rolling seasonal average of monthly data e.g. DJF, JFM, ..., SON: - >>> da = xr.DataArray(np.linspace(0,11,num=12), + >>> da = xr.DataArray(np.linspace(0, 11, num=12), ... coords=[pd.date_range('15/12/1999', ... periods=12, freq=pd.DateOffset(months=1))], ... dims='time') @@ -523,19 +562,19 @@ def rolling(self, min_periods=None, center=False, **windows): array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) Coordinates: * time (time) datetime64[ns] 1999-12-15 2000-01-15 2000-02-15 ... - >>> da.rolling(time=3).mean() + >>> da.rolling(time=3, center=True).mean() - array([ nan, nan, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) + array([nan, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., nan]) Coordinates: * time (time) datetime64[ns] 1999-12-15 2000-01-15 2000-02-15 ... Remove the NaNs using ``dropna()``: - >>> da.rolling(time=3).mean().dropna('time') + >>> da.rolling(time=3, center=True).mean().dropna('time') - array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) + array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) Coordinates: - * time (time) datetime64[ns] 2000-02-15 2000-03-15 2000-04-15 ... + * time (time) datetime64[ns] 2000-01-15 2000-02-15 2000-03-15 ... See Also -------- @@ -550,9 +589,8 @@ def resample(self, freq=None, dim=None, how=None, skipna=None, closed=None, label=None, base=0, keep_attrs=False, **indexer): """Returns a Resample object for performing resampling operations. - Handles both downsampling and upsampling. Upsampling with filling is - not supported; if any intervals contain no values from the original - object, they will be given the value ``NaN``. + Handles both downsampling and upsampling. If any intervals contain no + values from the original object, they will be given the value ``NaN``. Parameters ---------- @@ -578,7 +616,34 @@ def resample(self, freq=None, dim=None, how=None, skipna=None, ------- resampled : same type as caller This object resampled. - + + Examples + -------- + Downsample monthly time-series data to seasonal data: + + >>> da = xr.DataArray(np.linspace(0, 11, num=12), + ... coords=[pd.date_range('15/12/1999', + ... periods=12, freq=pd.DateOffset(months=1))], + ... dims='time') + >>> da + + array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) + Coordinates: + * time (time) datetime64[ns] 1999-12-15 2000-01-15 2000-02-15 ... + >>> da.resample(time="Q-DEC").mean() + + array([ 1., 4., 7., 10.]) + Coordinates: + * time (time) datetime64[ns] 2000-02-29 2000-05-31 2000-08-31 2000-11-30 + + Upsample monthly time-series data to daily data: + + >>> da.resample(time='1D').interpolate('linear') + + array([ 0. , 0.032258, 0.064516, ..., 10.935484, 10.967742, 11. ]) + Coordinates: + * time (time) datetime64[ns] 1999-12-15 1999-12-16 1999-12-17 ... + References ---------- @@ -628,10 +693,10 @@ def _resample_immediately(self, freq, dim, how, skipna, warnings.warn("\n.resample() has been modified to defer " "calculations. Instead of passing 'dim' and " - "'how=\"{how}\", instead consider using " - ".resample({dim}=\"{freq}\").{how}() ".format( - dim=dim, freq=freq, how=how - ), DeprecationWarning, stacklevel=3) + "how=\"{how}\", instead consider using " + ".resample({dim}=\"{freq}\").{how}('{dim}') ".format( + dim=dim, freq=freq, how=how), + DeprecationWarning, stacklevel=3) if isinstance(dim, basestring): dim = self[dim] From c046528522a6d4cf18c81a19aeae82f5f7d63d34 Mon Sep 17 00:00:00 2001 From: Henk Griffioen Date: Wed, 9 May 2018 17:28:32 +0200 Subject: [PATCH 03/61] DOC: Update link to documentation of Rasterio (#2110) --- doc/io.rst | 2 +- doc/whats-new.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/io.rst b/doc/io.rst index c14e1516b38..668416e714d 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -534,7 +534,7 @@ longitudes and latitudes. considered as being experimental. Please report any bug you may find on xarray's github repository. -.. _rasterio: https://mapbox.github.io/rasterio/ +.. _rasterio: https://rasterio.readthedocs.io/en/latest/ .. _test files: https://github.com/mapbox/rasterio/blob/master/tests/data/RGB.byte.tif .. _pyproj: https://github.com/jswhit/pyproj diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d614a23d0fc..fdbe6831a24 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -785,7 +785,7 @@ Enhancements By `Stephan Hoyer `_. - New function :py:func:`~xarray.open_rasterio` for opening raster files with - the `rasterio `_ library. + the `rasterio `_ library. See :ref:`the docs ` for details. By `Joe Hamman `_, `Nic Wayand `_ and From 70e2eb539d2fe33ee1b5efbd5d2476649dea898b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 9 May 2018 17:45:40 -0700 Subject: [PATCH 04/61] Plotting upgrades (#2092) * Support xincrease, yincrease * Better tick label rotation in case of dateticks. Avoid autofmt_xdate because it deletes all x-axis ticklabels except for the last subplot. * Tests * docs. * review. * Prevent unclosed file ResourceWarning. --- doc/plotting.rst | 10 ++++++++++ doc/whats-new.rst | 4 ++++ xarray/plot/plot.py | 21 ++++++++++++++++++--- xarray/plot/utils.py | 1 + xarray/tests/test_plot.py | 7 +++++++ 5 files changed, 40 insertions(+), 3 deletions(-) diff --git a/doc/plotting.rst b/doc/plotting.rst index c85a54d783b..28fbe7062a6 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -207,6 +207,16 @@ It is also possible to make line plots such that the data are on the x-axis and @savefig plotting_example_xy_kwarg.png air.isel(time=10, lon=[10, 11]).plot.line(y='lat', hue='lon') +Changing Axes Direction +----------------------- + +The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes direction. + +.. ipython:: python + + @savefig plotting_example_xincrease_yincrease_kwarg.png + air.isel(time=10, lon=[10, 11]).plot.line(y='lat', hue='lon', xincrease=False, yincrease=False) + Two Dimensions -------------- diff --git a/doc/whats-new.rst b/doc/whats-new.rst index fdbe6831a24..3c2e143bec3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -55,6 +55,10 @@ Bug fixes By `Keisuke Fujii `_. - Better error handling in ``open_mfdataset`` (:issue:`2077`). By `Stephan Hoyer `_. +- ``plot.line()`` does not call ``autofmt_xdate()`` anymore. Instead it changes the rotation and horizontal alignment of labels without removing the x-axes of any other subplots in the figure (if any). + By `Deepak Cherian `_. +- ``plot.line()`` learned new kwargs: ``xincrease``, ``yincrease`` that change the direction of the respective axes. + By `Deepak Cherian `_. .. _whats-new.0.10.3: diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 94ddc8c0535..6a3bed08f72 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -182,6 +182,12 @@ def line(darray, *args, **kwargs): Coordinates for x, y axis. Only one of these may be specified. The other coordinate plots values from the DataArray on which this plot method is called. + xincrease : None, True, or False, optional + Should the values on the x axes be increasing from left to right? + if None, use the default for the matplotlib function. + yincrease : None, True, or False, optional + Should the values on the y axes be increasing from top to bottom? + if None, use the default for the matplotlib function. add_legend : boolean, optional Add legend with y axis coordinates (2D inputs only). *args, **kwargs : optional @@ -203,6 +209,8 @@ def line(darray, *args, **kwargs): hue = kwargs.pop('hue', None) x = kwargs.pop('x', None) y = kwargs.pop('y', None) + xincrease = kwargs.pop('xincrease', True) + yincrease = kwargs.pop('yincrease', True) add_legend = kwargs.pop('add_legend', True) ax = get_axis(figsize, size, aspect, ax) @@ -269,8 +277,15 @@ def line(darray, *args, **kwargs): title=huelabel) # Rotate dates on xlabels + # Do this without calling autofmt_xdate so that x-axes ticks + # on other subplots (if any) are not deleted. + # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots if np.issubdtype(xplt.dtype, np.datetime64): - ax.get_figure().autofmt_xdate() + for xlabels in ax.get_xticklabels(): + xlabels.set_rotation(30) + xlabels.set_ha('right') + + _update_axes_limits(ax, xincrease, yincrease) return primitive @@ -429,10 +444,10 @@ def _plot2d(plotfunc): Use together with ``col`` to wrap faceted plots xincrease : None, True, or False, optional Should the values on the x axes be increasing from left to right? - if None, use the default for the matplotlib function + if None, use the default for the matplotlib function. yincrease : None, True, or False, optional Should the values on the y axes be increasing from top to bottom? - if None, use the default for the matplotlib function + if None, use the default for the matplotlib function. add_colorbar : Boolean, optional Adds colorbar to axis add_labels : Boolean, optional diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 59d67ed79f1..3db8bcab3a7 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -21,6 +21,7 @@ def _load_default_cmap(fname='default_colormap.csv'): # Not sure what the first arg here should be f = pkg_resources.resource_stream(__name__, fname) cm_data = pd.read_csv(f, header=None).values + f.close() return LinearSegmentedColormap.from_list('viridis', cm_data) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 2a5eeb86bdd..a3446fe240b 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -352,6 +352,13 @@ def test_x_ticks_are_rotated_for_time(self): rotation = plt.gca().get_xticklabels()[0].get_rotation() assert rotation != 0 + def test_xyincrease_false_changes_axes(self): + self.darray.plot.line(xincrease=False, yincrease=False) + xlim = plt.gca().get_xlim() + ylim = plt.gca().get_ylim() + diffs = xlim[1] - xlim[0], ylim[1] - ylim[0] + assert all(x < 0 for x in diffs) + def test_slice_in_title(self): self.darray.coords['d'] = 10 self.darray.plot.line() From 6d8ac11ca0a785a6fe176eeca9b735c321a35527 Mon Sep 17 00:00:00 2001 From: Ryan May Date: Thu, 10 May 2018 11:49:59 -0600 Subject: [PATCH 05/61] Fix docstring formatting for load(). (#2115) Need '::' to introduce a code literal block. This was causing MetPy's doc build to warn (since we inherit AbstractDataStore). --- xarray/backends/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index c46f9d5b552..7d8aa8446a2 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -198,7 +198,7 @@ def load(self): A centralized loading function makes it easier to create data stores that do automatic encoding/decoding. - For example: + For example:: class SuffixAppendingDataStore(AbstractDataStore): From d63001cdbc3bd84f4d6d90bd570a2215ea9e5c2e Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Sat, 12 May 2018 07:54:43 +0900 Subject: [PATCH 06/61] Support keep_attrs for apply_ufunc for xr.Variable (#2119) * Support keep_attrs for apply_ufunc for xr.Dataset, xr.Variable * whats new * whats new again * improve doc --- doc/whats-new.rst | 3 +++ xarray/core/computation.py | 21 +++++++++++---------- xarray/tests/test_computation.py | 7 +++++++ 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3c2e143bec3..b177fc702c0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -49,6 +49,9 @@ Enhancements Bug fixes ~~~~~~~~~ +- Fixed a bug where `keep_attrs=True` flag was neglected if + :py:func:`apply_func` was used with :py:class:`Variable`. (:issue:`2114`) + By `Keisuke Fujii `_. - When assigning a :py:class:`DataArray` to :py:class:`Dataset`, any conflicted non-dimensional coordinates of the DataArray are now dropped. (:issue:`2068`) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index f06e90b583b..77a52ac055d 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -195,7 +195,6 @@ def apply_dataarray_ufunc(func, *args, **kwargs): signature = kwargs.pop('signature') join = kwargs.pop('join', 'inner') exclude_dims = kwargs.pop('exclude_dims', _DEFAULT_FROZEN_SET) - keep_attrs = kwargs.pop('keep_attrs', False) if kwargs: raise TypeError('apply_dataarray_ufunc() got unexpected keyword ' 'arguments: %s' % list(kwargs)) @@ -217,11 +216,6 @@ def apply_dataarray_ufunc(func, *args, **kwargs): coords, = result_coords out = DataArray(result_var, coords, name=name, fastpath=True) - if keep_attrs and isinstance(args[0], DataArray): - if isinstance(out, tuple): - out = tuple(ds._copy_attrs_from(args[0]) for ds in out) - else: - out._copy_attrs_from(args[0]) return out @@ -526,6 +520,7 @@ def apply_variable_ufunc(func, *args, **kwargs): dask = kwargs.pop('dask', 'forbidden') output_dtypes = kwargs.pop('output_dtypes', None) output_sizes = kwargs.pop('output_sizes', None) + keep_attrs = kwargs.pop('keep_attrs', False) if kwargs: raise TypeError('apply_variable_ufunc() got unexpected keyword ' 'arguments: %s' % list(kwargs)) @@ -567,11 +562,17 @@ def func(*arrays): if signature.num_outputs > 1: output = [] for dims, data in zip(output_dims, result_data): - output.append(Variable(dims, data)) + var = Variable(dims, data) + if keep_attrs and isinstance(args[0], Variable): + var.attrs.update(args[0].attrs) + output.append(var) return tuple(output) else: dims, = output_dims - return Variable(dims, result_data) + var = Variable(dims, result_data) + if keep_attrs and isinstance(args[0], Variable): + var.attrs.update(args[0].attrs) + return var def _apply_with_dask_atop(func, args, input_dims, output_dims, signature, @@ -902,6 +903,7 @@ def earth_mover_distance(first_samples, variables_ufunc = functools.partial(apply_variable_ufunc, func, signature=signature, exclude_dims=exclude_dims, + keep_attrs=keep_attrs, dask=dask, output_dtypes=output_dtypes, output_sizes=output_sizes) @@ -930,8 +932,7 @@ def earth_mover_distance(first_samples, return apply_dataarray_ufunc(variables_ufunc, *args, signature=signature, join=join, - exclude_dims=exclude_dims, - keep_attrs=keep_attrs) + exclude_dims=exclude_dims) elif any(isinstance(a, Variable) for a in args): return variables_ufunc(*args) else: diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index db10ee3e820..c84ed17bfd3 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -480,12 +480,19 @@ def add(a, b, keep_attrs): a = xr.DataArray([0, 1], [('x', [0, 1])]) a.attrs['attr'] = 'da' + a['x'].attrs['attr'] = 'da_coord' b = xr.DataArray([1, 2], [('x', [0, 1])]) actual = add(a, b, keep_attrs=False) assert not actual.attrs actual = add(a, b, keep_attrs=True) assert_identical(actual.attrs, a.attrs) + assert_identical(actual['x'].attrs, a['x'].attrs) + + actual = add(a.variable, b.variable, keep_attrs=False) + assert not actual.attrs + actual = add(a.variable, b.variable, keep_attrs=True) + assert_identical(actual.attrs, a.attrs) a = xr.Dataset({'x': [0, 1]}) a.attrs['attr'] = 'ds' From a52540505f606bd7619536d82d43f19f2cbe58b5 Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Sat, 12 May 2018 15:15:54 +0900 Subject: [PATCH 07/61] Fixes centerized rolling with bottleneck. Also, fixed rolling with an integer dask array. (#2122) --- doc/whats-new.rst | 3 +++ xarray/core/dask_array_ops.py | 5 ++++- xarray/core/rolling.py | 16 ++++++++++++---- xarray/tests/test_dataarray.py | 30 +++++++++++++++++++++++++----- 4 files changed, 44 insertions(+), 10 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b177fc702c0..cc16991ccf1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -49,6 +49,9 @@ Enhancements Bug fixes ~~~~~~~~~ +- Fixed a bug in `rolling` with bottleneck. Also, fixed a bug in rolling an + integer dask array. (:issue:`21133`) + By `Keisuke Fujii `_. - Fixed a bug where `keep_attrs=True` flag was neglected if :py:func:`apply_func` was used with :py:class:`Variable`. (:issue:`2114`) By `Keisuke Fujii `_. diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 4bd3766ced9..ee87c3564cc 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -3,6 +3,7 @@ import numpy as np from . import nputils +from . import dtypes try: import dask.array as da @@ -12,12 +13,14 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): '''wrapper to apply bottleneck moving window funcs on dask arrays''' + dtype, fill_value = dtypes.maybe_promote(a.dtype) + a = a.astype(dtype) # inputs for ghost if axis < 0: axis = a.ndim + axis depth = {d: 0 for d in range(a.ndim)} depth[axis] = window - 1 - boundary = {d: np.nan for d in range(a.ndim)} + boundary = {d: fill_value for d in range(a.ndim)} # create ghosted arrays ag = da.ghost.ghost(a, depth=depth, boundary=boundary) # apply rolling func diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 079c60f35a7..f54a4c36631 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -285,18 +285,26 @@ def wrapped_func(self, **kwargs): padded = self.obj.variable if self.center: - shift = (-self.window // 2) + 1 - if (LooseVersion(np.__version__) < LooseVersion('1.13') and self.obj.dtype.kind == 'b'): # with numpy < 1.13 bottleneck cannot handle np.nan-Boolean # mixed array correctly. We cast boolean array to float. padded = padded.astype(float) + + if isinstance(padded.data, dask_array_type): + # Workaround to make the padded chunk size is larger than + # self.window-1 + shift = - (self.window - 1) + offset = -shift - self.window // 2 + valid = (slice(None), ) * axis + ( + slice(offset, offset + self.obj.shape[axis]), ) + else: + shift = (-self.window // 2) + 1 + valid = (slice(None), ) * axis + (slice(-shift, None), ) padded = padded.pad_with_fill_value(**{self.dim: (0, -shift)}) - valid = (slice(None), ) * axis + (slice(-shift, None), ) if isinstance(padded.data, dask_array_type): - values = dask_rolling_wrapper(func, self.obj.data, + values = dask_rolling_wrapper(func, padded, window=self.window, min_count=min_count, axis=axis) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 32ab3a634cb..e9a2babfa2e 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3439,23 +3439,43 @@ def test_rolling_wrapped_bottleneck(da, name, center, min_periods): assert_equal(actual, da['time']) -@pytest.mark.parametrize('name', ('sum', 'mean', 'std', 'min', 'max', - 'median')) +@pytest.mark.parametrize('name', ('mean', 'count')) @pytest.mark.parametrize('center', (True, False, None)) @pytest.mark.parametrize('min_periods', (1, None)) -def test_rolling_wrapped_bottleneck_dask(da_dask, name, center, min_periods): +@pytest.mark.parametrize('window', (7, 8)) +def test_rolling_wrapped_dask(da_dask, name, center, min_periods, window): pytest.importorskip('dask.array') # dask version - rolling_obj = da_dask.rolling(time=7, min_periods=min_periods) + rolling_obj = da_dask.rolling(time=window, min_periods=min_periods, + center=center) actual = getattr(rolling_obj, name)().load() # numpy version - rolling_obj = da_dask.load().rolling(time=7, min_periods=min_periods) + rolling_obj = da_dask.load().rolling(time=window, min_periods=min_periods, + center=center) expected = getattr(rolling_obj, name)() # using all-close because rolling over ghost cells introduces some # precision errors assert_allclose(actual, expected) + # with zero chunked array GH:2113 + rolling_obj = da_dask.chunk().rolling(time=window, min_periods=min_periods, + center=center) + actual = getattr(rolling_obj, name)().load() + assert_allclose(actual, expected) + + +@pytest.mark.parametrize('center', (True, None)) +def test_rolling_wrapped_dask_nochunk(center): + # GH:2113 + pytest.importorskip('dask.array') + + da_day_clim = xr.DataArray(np.arange(1, 367), + coords=[np.arange(1, 367)], dims='dayofyear') + expected = da_day_clim.rolling(dayofyear=31, center=center).mean() + actual = da_day_clim.chunk().rolling(dayofyear=31, center=center).mean() + assert_allclose(actual, expected) + @pytest.mark.parametrize('center', (True, False)) @pytest.mark.parametrize('min_periods', (None, 1, 2, 3)) From 2c6bd2d1b09a84488ab1f1ebffa9cd359d0437ce Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 11 May 2018 23:36:36 -0700 Subject: [PATCH 08/61] Prevent Inf from screwing colorbar scale. (#2120) pd.isnull([np.inf]) is True while np.isfinite([np.inf]) is False. Let's use the latter. --- doc/whats-new.rst | 2 ++ xarray/plot/utils.py | 2 +- xarray/tests/test_plot.py | 9 +++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index cc16991ccf1..cc5506b553c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -65,6 +65,8 @@ Bug fixes By `Deepak Cherian `_. - ``plot.line()`` learned new kwargs: ``xincrease``, ``yincrease`` that change the direction of the respective axes. By `Deepak Cherian `_. +- Colorbar limits are now determined by excluding ±Infs too. + By `Deepak Cherian `_. .. _whats-new.0.10.3: diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 3db8bcab3a7..7ba48819518 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -160,7 +160,7 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, """ import matplotlib as mpl - calc_data = np.ravel(plot_data[~pd.isnull(plot_data)]) + calc_data = np.ravel(plot_data[np.isfinite(plot_data)]) # Handle all-NaN input data gracefully if calc_data.size == 0: diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index a3446fe240b..aadc452b8a7 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -426,6 +426,15 @@ def test_center(self): assert cmap_params['levels'] is None assert cmap_params['norm'] is None + def test_nan_inf_are_ignored(self): + cmap_params1 = _determine_cmap_params(self.data) + data = self.data + data[50:55] = np.nan + data[56:60] = np.inf + cmap_params2 = _determine_cmap_params(data) + assert cmap_params1['vmin'] == cmap_params2['vmin'] + assert cmap_params1['vmax'] == cmap_params2['vmax'] + @pytest.mark.slow def test_integer_levels(self): data = self.data + 1 From ebe0dd03187a5c3138ea12ca4beb13643679fe21 Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Sun, 13 May 2018 01:19:09 -0400 Subject: [PATCH 09/61] CFTimeIndex (#1252) * Start on implementing and testing NetCDFTimeIndex * TST Move to using pytest fixtures to structure tests * Address initial review comments * Address second round of review comments * Fix failing python3 tests * Match test method name to method name * First attempts at integrating NetCDFTimeIndex into xarray This is a first pass at the following: - Resetting the logic for decoding datetimes such that `np.datetime64` objects are never used for non-standard calendars - Adding logic to use a `NetCDFTimeIndex` whenever `netcdftime.datetime` objects are used in an array being cast as an index (so if one reads in a Dataset from a netCDF file or creates one in Python, which is indexed by a time coordinate that uses `netcdftime.datetime` objects a NetCDFTimeIndex will be used rather than a generic object-based index) - Adding logic to encode `netcdftime.datetime` objects when saving out to netCDF files * Cleanup * Fix DataFrame and Series test failures for NetCDFTimeIndex These were related to a recent minor upstream change in pandas: https://github.com/pandas-dev/pandas/blame/master/pandas/core/indexing.py#L1433 * First pass at making NetCDFTimeIndex compatible with #1356 * Address initial review comments * Restore test_conventions.py * Fix failing test in test_utils.py * flake8 * Update for standalone netcdftime * Address stickler-ci comments * Skip test_format_netcdftime_datetime if netcdftime not installed * A start on documentation * Fix failing zarr tests related to netcdftime encoding * Simplify test_decode_standard_calendar_single_element_non_ns_range * Address a couple review comments * Use else clause in _maybe_cast_to_netcdftimeindex * Start on adding enable_netcdftimeindex option * Continue parametrizing tests in test_coding_times.py * Update time-series.rst for enable_netcdftimeindex option * Use :py:func: in rst for xarray.set_options * Add a what's new entry and test that resample raises a TypeError * Move what's new entry to the version 0.10.3 section * Add version-dependent pathway for importing netcdftime.datetime * Make NetCDFTimeIndex and date decoding/encoding compatible with datetime.datetime * Remove logic to make NetCDFTimeIndex compatible with datetime.datetime * Documentation edits * Ensure proper enable_netcdftimeindex option is used under lazy decoding Prior to this, opening a dataset with enable_netcdftimeindex set to True and then accessing one of its variables outside the context manager would lead to it being decoded with the default enable_netcdftimeindex (which is False). This makes sure that lazy decoding takes into account the context under which it was called. * Add fix and test for concatenating variables with a NetCDFTimeIndex Previously when concatenating variables indexed by a NetCDFTimeIndex the index would be wrongly converted to a generic pd.Index * Further namespace changes due to netcdftime/cftime renaming * NetCDFTimeIndex -> CFTimeIndex * Documentation updates * Only allow use of CFTimeIndex when using the standalone cftime Also only allow for serialization of cftime.datetime objects when using the standalone cftime package. * Fix errant what's new changes * flake8 * Fix skip logic in test_cftimeindex.py * Use only_use_cftime_datetimes option in num2date * Require standalone cftime library for all new functionality Add tests/fixes for dt accessor with cftime datetimes * Improve skipping logic in test_cftimeindex.py * Fix skipping logic in test_cftimeindex.py for when cftime or netcdftime are not available. Use existing requires_cftime decorator where possible (i.e. only on tests that are not parametrized via pytest.mark.parametrize) * Fix skip logic in Python 3.4 build for test_cftimeindex.py * Improve error messages when for when the standalone cftime is not installed * Tweak skip logic in test_accessors.py * flake8 * Address review comments * Temporarily remove cftime from py27 build environment on windows * flake8 * Install cftime via pip for Python 2.7 on Windows * flake8 * Remove unnecessary new lines; simplify _maybe_cast_to_cftimeindex * Restore test case for #2002 in test_coding_times.py I must have inadvertently removed it during a merge. * Tweak dates out of range warning logic slightly to preserve current default * Address review comments --- doc/time-series.rst | 96 ++- doc/whats-new.rst | 10 + xarray/coding/cftimeindex.py | 252 +++++++ xarray/coding/times.py | 127 ++-- xarray/core/accessors.py | 33 +- xarray/core/common.py | 31 + xarray/core/dataset.py | 5 +- xarray/core/options.py | 4 + xarray/core/utils.py | 18 +- xarray/plot/plot.py | 8 +- xarray/tests/test_accessors.py | 118 +++- xarray/tests/test_backends.py | 76 ++- xarray/tests/test_cftimeindex.py | 555 ++++++++++++++++ xarray/tests/test_coding_times.py | 1019 ++++++++++++++++++++--------- xarray/tests/test_dataarray.py | 19 +- xarray/tests/test_plot.py | 24 +- xarray/tests/test_utils.py | 67 +- 17 files changed, 2095 insertions(+), 367 deletions(-) create mode 100644 xarray/coding/cftimeindex.py create mode 100644 xarray/tests/test_cftimeindex.py diff --git a/doc/time-series.rst b/doc/time-series.rst index afd9f087bfe..5b857789629 100644 --- a/doc/time-series.rst +++ b/doc/time-series.rst @@ -70,7 +70,11 @@ You can manual decode arrays in this form by passing a dataset to One unfortunate limitation of using ``datetime64[ns]`` is that it limits the native representation of dates to those that fall between the years 1678 and 2262. When a netCDF file contains dates outside of these bounds, dates will be -returned as arrays of ``netcdftime.datetime`` objects. +returned as arrays of ``cftime.datetime`` objects and a ``CFTimeIndex`` +can be used for indexing. The ``CFTimeIndex`` enables only a subset of +the indexing functionality of a ``pandas.DatetimeIndex`` and is only enabled +when using standalone version of ``cftime`` (not the version packaged with +earlier versions ``netCDF4``). See :ref:`CFTimeIndex` for more information. Datetime indexing ----------------- @@ -207,3 +211,93 @@ Dataset and DataArray objects with an arbitrary number of dimensions. For more examples of using grouped operations on a time dimension, see :ref:`toy weather data`. + + +.. _CFTimeIndex: + +Non-standard calendars and dates outside the Timestamp-valid range +------------------------------------------------------------------ + +Through the standalone ``cftime`` library and a custom subclass of +``pandas.Index``, xarray supports a subset of the indexing functionality enabled +through the standard ``pandas.DatetimeIndex`` for dates from non-standard +calendars or dates using a standard calendar, but outside the +`Timestamp-valid range`_ (approximately between years 1678 and 2262). This +behavior has not yet been turned on by default; to take advantage of this +functionality, you must have the ``enable_cftimeindex`` option set to +``True`` within your context (see :py:func:`~xarray.set_options` for more +information). It is expected that this will become the default behavior in +xarray version 0.11. + +For instance, you can create a DataArray indexed by a time +coordinate with a no-leap calendar within a context manager setting the +``enable_cftimeindex`` option, and the time index will be cast to a +``CFTimeIndex``: + +.. ipython:: python + + from itertools import product + from cftime import DatetimeNoLeap + + dates = [DatetimeNoLeap(year, month, 1) for year, month in + product(range(1, 3), range(1, 13))] + with xr.set_options(enable_cftimeindex=True): + da = xr.DataArray(np.arange(24), coords=[dates], dims=['time'], + name='foo') + +.. note:: + + With the ``enable_cftimeindex`` option activated, a ``CFTimeIndex`` + will be used for time indexing if any of the following are true: + + - The dates are from a non-standard calendar + - Any dates are outside the Timestamp-valid range + + Otherwise a ``pandas.DatetimeIndex`` will be used. In addition, if any + variable (not just an index variable) is encoded using a non-standard + calendar, its times will be decoded into ``cftime.datetime`` objects, + regardless of whether or not they can be represented using + ``np.datetime64[ns]`` objects. + +For data indexed by a ``CFTimeIndex`` xarray currently supports: + +- `Partial datetime string indexing`_ using strictly `ISO 8601-format`_ partial + datetime strings: + +.. ipython:: python + + da.sel(time='0001') + da.sel(time=slice('0001-05', '0002-02')) + +- Access of basic datetime components via the ``dt`` accessor (in this case + just "year", "month", "day", "hour", "minute", "second", "microsecond", and + "season"): + +.. ipython:: python + + da.time.dt.year + da.time.dt.month + da.time.dt.season + +- Group-by operations based on datetime accessor attributes (e.g. by month of + the year): + +.. ipython:: python + + da.groupby('time.month').sum() + +- And serialization: + +.. ipython:: python + + da.to_netcdf('example.nc') + xr.open_dataset('example.nc') + +.. note:: + + Currently resampling along the time dimension for data indexed by a + ``CFTimeIndex`` is not supported. + +.. _Timestamp-valid range: https://pandas.pydata.org/pandas-docs/stable/timeseries.html#timestamp-limitations +.. _ISO 8601-format: https://en.wikipedia.org/wiki/ISO_8601 +.. _partial datetime string indexing: https://pandas.pydata.org/pandas-docs/stable/timeseries.html#partial-string-indexing diff --git a/doc/whats-new.rst b/doc/whats-new.rst index cc5506b553c..fc5f8bf3266 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,6 +34,16 @@ v0.10.4 (unreleased) Enhancements ~~~~~~~~~~~~ +- Add an option for using a ``CFTimeIndex`` for indexing times with + non-standard calendars and/or outside the Timestamp-valid range; this index + enables a subset of the functionality of a standard + ``pandas.DatetimeIndex`` (:issue:`789`, :issue:`1084`, :issue:`1252`). + By `Spencer Clark `_ with help from + `Stephan Hoyer `_. +- Allow for serialization of ``cftime.datetime`` objects (:issue:`789`, + :issue:`1084`, :issue:`2008`, :issue:`1252`) using the standalone ``cftime`` + library. By `Spencer Clark + `_. - Support writing lists of strings as netCDF attributes (:issue:`2044`). By `Dan Nowacki `_. - :py:meth:`~xarray.Dataset.to_netcdf(engine='h5netcdf')` now accepts h5py diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py new file mode 100644 index 00000000000..fb51ace5d69 --- /dev/null +++ b/xarray/coding/cftimeindex.py @@ -0,0 +1,252 @@ +from __future__ import absolute_import +import re +from datetime import timedelta + +import numpy as np +import pandas as pd + +from xarray.core import pycompat +from xarray.core.utils import is_scalar + + +def named(name, pattern): + return '(?P<' + name + '>' + pattern + ')' + + +def optional(x): + return '(?:' + x + ')?' + + +def trailing_optional(xs): + if not xs: + return '' + return xs[0] + optional(trailing_optional(xs[1:])) + + +def build_pattern(date_sep='\-', datetime_sep='T', time_sep='\:'): + pieces = [(None, 'year', '\d{4}'), + (date_sep, 'month', '\d{2}'), + (date_sep, 'day', '\d{2}'), + (datetime_sep, 'hour', '\d{2}'), + (time_sep, 'minute', '\d{2}'), + (time_sep, 'second', '\d{2}')] + pattern_list = [] + for sep, name, sub_pattern in pieces: + pattern_list.append((sep if sep else '') + named(name, sub_pattern)) + # TODO: allow timezone offsets? + return '^' + trailing_optional(pattern_list) + '$' + + +_BASIC_PATTERN = build_pattern(date_sep='', time_sep='') +_EXTENDED_PATTERN = build_pattern() +_PATTERNS = [_BASIC_PATTERN, _EXTENDED_PATTERN] + + +def parse_iso8601(datetime_string): + for pattern in _PATTERNS: + match = re.match(pattern, datetime_string) + if match: + return match.groupdict() + raise ValueError('no ISO-8601 match for string: %s' % datetime_string) + + +def _parse_iso8601_with_reso(date_type, timestr): + default = date_type(1, 1, 1) + result = parse_iso8601(timestr) + replace = {} + + for attr in ['year', 'month', 'day', 'hour', 'minute', 'second']: + value = result.get(attr, None) + if value is not None: + # Note ISO8601 conventions allow for fractional seconds. + # TODO: Consider adding support for sub-second resolution? + replace[attr] = int(value) + resolution = attr + + return default.replace(**replace), resolution + + +def _parsed_string_to_bounds(date_type, resolution, parsed): + """Generalization of + pandas.tseries.index.DatetimeIndex._parsed_string_to_bounds + for use with non-standard calendars and cftime.datetime + objects. + """ + if resolution == 'year': + return (date_type(parsed.year, 1, 1), + date_type(parsed.year + 1, 1, 1) - timedelta(microseconds=1)) + elif resolution == 'month': + if parsed.month == 12: + end = date_type(parsed.year + 1, 1, 1) - timedelta(microseconds=1) + else: + end = (date_type(parsed.year, parsed.month + 1, 1) - + timedelta(microseconds=1)) + return date_type(parsed.year, parsed.month, 1), end + elif resolution == 'day': + start = date_type(parsed.year, parsed.month, parsed.day) + return start, start + timedelta(days=1, microseconds=-1) + elif resolution == 'hour': + start = date_type(parsed.year, parsed.month, parsed.day, parsed.hour) + return start, start + timedelta(hours=1, microseconds=-1) + elif resolution == 'minute': + start = date_type(parsed.year, parsed.month, parsed.day, parsed.hour, + parsed.minute) + return start, start + timedelta(minutes=1, microseconds=-1) + elif resolution == 'second': + start = date_type(parsed.year, parsed.month, parsed.day, parsed.hour, + parsed.minute, parsed.second) + return start, start + timedelta(seconds=1, microseconds=-1) + else: + raise KeyError + + +def get_date_field(datetimes, field): + """Adapted from pandas.tslib.get_date_field""" + return np.array([getattr(date, field) for date in datetimes]) + + +def _field_accessor(name, docstring=None): + """Adapted from pandas.tseries.index._field_accessor""" + def f(self): + return get_date_field(self._data, name) + + f.__name__ = name + f.__doc__ = docstring + return property(f) + + +def get_date_type(self): + return type(self._data[0]) + + +def assert_all_valid_date_type(data): + import cftime + + sample = data[0] + date_type = type(sample) + if not isinstance(sample, cftime.datetime): + raise TypeError( + 'CFTimeIndex requires cftime.datetime ' + 'objects. Got object of {}.'.format(date_type)) + if not all(isinstance(value, date_type) for value in data): + raise TypeError( + 'CFTimeIndex requires using datetime ' + 'objects of all the same type. Got\n{}.'.format(data)) + + +class CFTimeIndex(pd.Index): + year = _field_accessor('year', 'The year of the datetime') + month = _field_accessor('month', 'The month of the datetime') + day = _field_accessor('day', 'The days of the datetime') + hour = _field_accessor('hour', 'The hours of the datetime') + minute = _field_accessor('minute', 'The minutes of the datetime') + second = _field_accessor('second', 'The seconds of the datetime') + microsecond = _field_accessor('microsecond', + 'The microseconds of the datetime') + date_type = property(get_date_type) + + def __new__(cls, data): + result = object.__new__(cls) + assert_all_valid_date_type(data) + result._data = np.array(data) + return result + + def _partial_date_slice(self, resolution, parsed): + """Adapted from + pandas.tseries.index.DatetimeIndex._partial_date_slice + + Note that when using a CFTimeIndex, if a partial-date selection + returns a single element, it will never be converted to a scalar + coordinate; this is in slight contrast to the behavior when using + a DatetimeIndex, which sometimes will return a DataArray with a scalar + coordinate depending on the resolution of the datetimes used in + defining the index. For example: + + >>> from cftime import DatetimeNoLeap + >>> import pandas as pd + >>> import xarray as xr + >>> da = xr.DataArray([1, 2], + coords=[[DatetimeNoLeap(2001, 1, 1), + DatetimeNoLeap(2001, 2, 1)]], + dims=['time']) + >>> da.sel(time='2001-01-01') + + array([1]) + Coordinates: + * time (time) object 2001-01-01 00:00:00 + >>> da = xr.DataArray([1, 2], + coords=[[pd.Timestamp(2001, 1, 1), + pd.Timestamp(2001, 2, 1)]], + dims=['time']) + >>> da.sel(time='2001-01-01') + + array(1) + Coordinates: + time datetime64[ns] 2001-01-01 + >>> da = xr.DataArray([1, 2], + coords=[[pd.Timestamp(2001, 1, 1, 1), + pd.Timestamp(2001, 2, 1)]], + dims=['time']) + >>> da.sel(time='2001-01-01') + + array([1]) + Coordinates: + * time (time) datetime64[ns] 2001-01-01T01:00:00 + """ + start, end = _parsed_string_to_bounds(self.date_type, resolution, + parsed) + lhs_mask = (self._data >= start) + rhs_mask = (self._data <= end) + return (lhs_mask & rhs_mask).nonzero()[0] + + def _get_string_slice(self, key): + """Adapted from pandas.tseries.index.DatetimeIndex._get_string_slice""" + parsed, resolution = _parse_iso8601_with_reso(self.date_type, key) + loc = self._partial_date_slice(resolution, parsed) + return loc + + def get_loc(self, key, method=None, tolerance=None): + """Adapted from pandas.tseries.index.DatetimeIndex.get_loc""" + if isinstance(key, pycompat.basestring): + return self._get_string_slice(key) + else: + return pd.Index.get_loc(self, key, method=method, + tolerance=tolerance) + + def _maybe_cast_slice_bound(self, label, side, kind): + """Adapted from + pandas.tseries.index.DatetimeIndex._maybe_cast_slice_bound""" + if isinstance(label, pycompat.basestring): + parsed, resolution = _parse_iso8601_with_reso(self.date_type, + label) + start, end = _parsed_string_to_bounds(self.date_type, resolution, + parsed) + if self.is_monotonic_decreasing and len(self): + return end if side == 'left' else start + return start if side == 'left' else end + else: + return label + + # TODO: Add ability to use integer range outside of iloc? + # e.g. series[1:5]. + def get_value(self, series, key): + """Adapted from pandas.tseries.index.DatetimeIndex.get_value""" + if not isinstance(key, slice): + return series.iloc[self.get_loc(key)] + else: + return series.iloc[self.slice_indexer( + key.start, key.stop, key.step)] + + def __contains__(self, key): + """Adapted from + pandas.tseries.base.DatetimeIndexOpsMixin.__contains__""" + try: + result = self.get_loc(key) + return (is_scalar(result) or type(result) == slice or + (isinstance(result, np.ndarray) and result.size)) + except (KeyError, TypeError, ValueError): + return False + + def contains(self, key): + """Needed for .loc based partial-string indexing""" + return self.__contains__(key) diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 0a48b62986e..61314d9cbe6 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -9,8 +9,10 @@ import numpy as np import pandas as pd +from ..core.common import contains_cftime_datetimes from ..core import indexing from ..core.formatting import first_n_items, format_timestamp, last_item +from ..core.options import OPTIONS from ..core.pycompat import PY3 from ..core.variable import Variable from .variables import ( @@ -24,7 +26,7 @@ from pandas.tslib import OutOfBoundsDatetime -# standard calendars recognized by netcdftime +# standard calendars recognized by cftime _STANDARD_CALENDARS = set(['standard', 'gregorian', 'proleptic_gregorian']) _NS_PER_TIME_DELTA = {'us': int(1e3), @@ -54,6 +56,15 @@ def _import_cftime(): return cftime +def _require_standalone_cftime(): + """Raises an ImportError if the standalone cftime is not found""" + try: + import cftime # noqa: F401 + except ImportError: + raise ImportError('Using a CFTimeIndex requires the standalone ' + 'version of the cftime library.') + + def _netcdf_to_numpy_timeunit(units): units = units.lower() if not units.endswith('s'): @@ -73,28 +84,41 @@ def _unpack_netcdf_time_units(units): return delta_units, ref_date -def _decode_datetime_with_netcdftime(num_dates, units, calendar): +def _decode_datetime_with_cftime(num_dates, units, calendar, + enable_cftimeindex): cftime = _import_cftime() + if enable_cftimeindex: + _require_standalone_cftime() + dates = np.asarray(cftime.num2date(num_dates, units, calendar, + only_use_cftime_datetimes=True)) + else: + dates = np.asarray(cftime.num2date(num_dates, units, calendar)) - dates = np.asarray(cftime.num2date(num_dates, units, calendar)) if (dates[np.nanargmin(num_dates)].year < 1678 or dates[np.nanargmax(num_dates)].year >= 2262): - warnings.warn('Unable to decode time axis into full ' - 'numpy.datetime64 objects, continuing using dummy ' - 'netcdftime.datetime objects instead, reason: dates out' - ' of range', SerializationWarning, stacklevel=3) + if not enable_cftimeindex or calendar in _STANDARD_CALENDARS: + warnings.warn( + 'Unable to decode time axis into full ' + 'numpy.datetime64 objects, continuing using dummy ' + 'cftime.datetime objects instead, reason: dates out ' + 'of range', SerializationWarning, stacklevel=3) else: - try: - dates = cftime_to_nptime(dates) - except ValueError as e: - warnings.warn('Unable to decode time axis into full ' - 'numpy.datetime64 objects, continuing using ' - 'dummy netcdftime.datetime objects instead, reason:' - '{0}'.format(e), SerializationWarning, stacklevel=3) + if enable_cftimeindex: + if calendar in _STANDARD_CALENDARS: + dates = cftime_to_nptime(dates) + else: + try: + dates = cftime_to_nptime(dates) + except ValueError as e: + warnings.warn( + 'Unable to decode time axis into full ' + 'numpy.datetime64 objects, continuing using ' + 'dummy cftime.datetime objects instead, reason:' + '{0}'.format(e), SerializationWarning, stacklevel=3) return dates -def _decode_cf_datetime_dtype(data, units, calendar): +def _decode_cf_datetime_dtype(data, units, calendar, enable_cftimeindex): # Verify that at least the first and last date can be decoded # successfully. Otherwise, tracebacks end up swallowed by # Dataset.__repr__ when users try to view their lazily decoded array. @@ -104,7 +128,8 @@ def _decode_cf_datetime_dtype(data, units, calendar): last_item(values) or [0]]) try: - result = decode_cf_datetime(example_value, units, calendar) + result = decode_cf_datetime(example_value, units, calendar, + enable_cftimeindex) except Exception: calendar_msg = ('the default calendar' if calendar is None else 'calendar %r' % calendar) @@ -120,12 +145,13 @@ def _decode_cf_datetime_dtype(data, units, calendar): return dtype -def decode_cf_datetime(num_dates, units, calendar=None): +def decode_cf_datetime(num_dates, units, calendar=None, + enable_cftimeindex=False): """Given an array of numeric dates in netCDF format, convert it into a numpy array of date time objects. For standard (Gregorian) calendars, this function uses vectorized - operations, which makes it much faster than netcdftime.num2date. In such a + operations, which makes it much faster than cftime.num2date. In such a case, the returned array will be of type np.datetime64. Note that time unit in `units` must not be smaller than microseconds and @@ -133,7 +159,7 @@ def decode_cf_datetime(num_dates, units, calendar=None): See also -------- - netcdftime.num2date + cftime.num2date """ num_dates = np.asarray(num_dates) flat_num_dates = num_dates.ravel() @@ -151,7 +177,7 @@ def decode_cf_datetime(num_dates, units, calendar=None): ref_date = pd.Timestamp(ref_date) except ValueError: # ValueError is raised by pd.Timestamp for non-ISO timestamp - # strings, in which case we fall back to using netcdftime + # strings, in which case we fall back to using cftime raise OutOfBoundsDatetime # fixes: https://github.com/pydata/pandas/issues/14068 @@ -170,8 +196,9 @@ def decode_cf_datetime(num_dates, units, calendar=None): ref_date).values except (OutOfBoundsDatetime, OverflowError): - dates = _decode_datetime_with_netcdftime( - flat_num_dates.astype(np.float), units, calendar) + dates = _decode_datetime_with_cftime( + flat_num_dates.astype(np.float), units, calendar, + enable_cftimeindex) return dates.reshape(num_dates.shape) @@ -203,18 +230,41 @@ def _infer_time_units_from_diff(unique_timedeltas): return 'seconds' +def infer_calendar_name(dates): + """Given an array of datetimes, infer the CF calendar name""" + if np.asarray(dates).dtype == 'datetime64[ns]': + return 'proleptic_gregorian' + else: + return np.asarray(dates).ravel()[0].calendar + + def infer_datetime_units(dates): """Given an array of datetimes, returns a CF compatible time-unit string of the form "{time_unit} since {date[0]}", where `time_unit` is 'days', 'hours', 'minutes' or 'seconds' (the first one that can evenly divide all unique time deltas in `dates`) """ - dates = pd.to_datetime(np.asarray(dates).ravel(), box=False) - dates = dates[pd.notnull(dates)] - unique_timedeltas = np.unique(np.diff(dates)) + dates = np.asarray(dates).ravel() + if np.asarray(dates).dtype == 'datetime64[ns]': + dates = pd.to_datetime(dates, box=False) + dates = dates[pd.notnull(dates)] + reference_date = dates[0] if len(dates) > 0 else '1970-01-01' + reference_date = pd.Timestamp(reference_date) + else: + reference_date = dates[0] if len(dates) > 0 else '1970-01-01' + reference_date = format_cftime_datetime(reference_date) + unique_timedeltas = np.unique(np.diff(dates)).astype('timedelta64[ns]') units = _infer_time_units_from_diff(unique_timedeltas) - reference_date = dates[0] if len(dates) > 0 else '1970-01-01' - return '%s since %s' % (units, pd.Timestamp(reference_date)) + return '%s since %s' % (units, reference_date) + + +def format_cftime_datetime(date): + """Converts a cftime.datetime object to a string with the format: + YYYY-MM-DD HH:MM:SS.UUUUUU + """ + return '{:04d}-{:02d}-{:02d} {:02d}:{:02d}:{:02d}.{:06d}'.format( + date.year, date.month, date.day, date.hour, date.minute, date.second, + date.microsecond) def infer_timedelta_units(deltas): @@ -249,8 +299,8 @@ def _cleanup_netcdf_time_units(units): return units -def _encode_datetime_with_netcdftime(dates, units, calendar): - """Fallback method for encoding dates using netcdftime. +def _encode_datetime_with_cftime(dates, units, calendar): + """Fallback method for encoding dates using cftime. This method is more flexible than xarray's parsing using datetime64[ns] arrays but also slower because it loops over each element. @@ -282,7 +332,7 @@ def encode_cf_datetime(dates, units=None, calendar=None): See also -------- - netcdftime.date2num + cftime.date2num """ dates = np.asarray(dates) @@ -292,12 +342,12 @@ def encode_cf_datetime(dates, units=None, calendar=None): units = _cleanup_netcdf_time_units(units) if calendar is None: - calendar = 'proleptic_gregorian' + calendar = infer_calendar_name(dates) delta, ref_date = _unpack_netcdf_time_units(units) try: if calendar not in _STANDARD_CALENDARS or dates.dtype.kind == 'O': - # parse with netcdftime instead + # parse with cftime instead raise OutOfBoundsDatetime assert dates.dtype == 'datetime64[ns]' @@ -307,7 +357,7 @@ def encode_cf_datetime(dates, units=None, calendar=None): num = (dates - ref_date) / time_delta except (OutOfBoundsDatetime, OverflowError): - num = _encode_datetime_with_netcdftime(dates, units, calendar) + num = _encode_datetime_with_cftime(dates, units, calendar) num = cast_to_int_if_safe(num) return (num, units, calendar) @@ -328,8 +378,8 @@ class CFDatetimeCoder(VariableCoder): def encode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_encoding(variable) - - if np.issubdtype(data.dtype, np.datetime64): + if (np.issubdtype(data.dtype, np.datetime64) or + contains_cftime_datetimes(variable)): (data, units, calendar) = encode_cf_datetime( data, encoding.pop('units', None), @@ -342,12 +392,15 @@ def encode(self, variable, name=None): def decode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_decoding(variable) + enable_cftimeindex = OPTIONS['enable_cftimeindex'] if 'units' in attrs and 'since' in attrs['units']: units = pop_to(attrs, encoding, 'units') calendar = pop_to(attrs, encoding, 'calendar') - dtype = _decode_cf_datetime_dtype(data, units, calendar) + dtype = _decode_cf_datetime_dtype( + data, units, calendar, enable_cftimeindex) transform = partial( - decode_cf_datetime, units=units, calendar=calendar) + decode_cf_datetime, units=units, calendar=calendar, + enable_cftimeindex=enable_cftimeindex) data = lazy_elemwise_func(data, transform, dtype) return Variable(dims, data, attrs, encoding) diff --git a/xarray/core/accessors.py b/xarray/core/accessors.py index 52d9e6db408..81af0532d93 100644 --- a/xarray/core/accessors.py +++ b/xarray/core/accessors.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd -from .dtypes import is_datetime_like +from .common import is_np_datetime_like, _contains_datetime_like_objects from .pycompat import dask_array_type @@ -16,6 +16,20 @@ def _season_from_months(months): return seasons[(months // 3) % 4] +def _access_through_cftimeindex(values, name): + """Coerce an array of datetime-like values to a CFTimeIndex + and access requested datetime component + """ + from ..coding.cftimeindex import CFTimeIndex + values_as_cftimeindex = CFTimeIndex(values.ravel()) + if name == 'season': + months = values_as_cftimeindex.month + field_values = _season_from_months(months) + else: + field_values = getattr(values_as_cftimeindex, name) + return field_values.reshape(values.shape) + + def _access_through_series(values, name): """Coerce an array of datetime-like values to a pandas Series and access requested datetime component @@ -48,12 +62,17 @@ def _get_date_field(values, name, dtype): Array-like of datetime fields accessed for each element in values """ + if is_np_datetime_like(values.dtype): + access_method = _access_through_series + else: + access_method = _access_through_cftimeindex + if isinstance(values, dask_array_type): from dask.array import map_blocks - return map_blocks(_access_through_series, + return map_blocks(access_method, values, name, dtype=dtype) else: - return _access_through_series(values, name) + return access_method(values, name) def _round_series(values, name, freq): @@ -111,15 +130,17 @@ class DatetimeAccessor(object): All of the pandas fields are accessible here. Note that these fields are not calendar-aware; if your datetimes are encoded with a non-Gregorian - calendar (e.g. a 360-day calendar) using netcdftime, then some fields like + calendar (e.g. a 360-day calendar) using cftime, then some fields like `dayofyear` may not be accurate. """ def __init__(self, xarray_obj): - if not is_datetime_like(xarray_obj.dtype): + if not _contains_datetime_like_objects(xarray_obj): raise TypeError("'dt' accessor only available for " - "DataArray with datetime64 or timedelta64 dtype") + "DataArray with datetime64 timedelta64 dtype or " + "for arrays containing cftime datetime " + "objects.") self._obj = xarray_obj def _tslib_field_accessor(name, docstring=None, dtype=None): diff --git a/xarray/core/common.py b/xarray/core/common.py index 0c6e0fccd8e..f623091ebdb 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -933,3 +933,34 @@ def ones_like(other, dtype=None): """Shorthand for full_like(other, 1, dtype) """ return full_like(other, 1, dtype) + + +def is_np_datetime_like(dtype): + """Check if a dtype is a subclass of the numpy datetime types + """ + return (np.issubdtype(dtype, np.datetime64) or + np.issubdtype(dtype, np.timedelta64)) + + +def contains_cftime_datetimes(var): + """Check if a variable contains cftime datetime objects""" + try: + from cftime import datetime as cftime_datetime + except ImportError: + return False + else: + if var.dtype == np.dtype('O') and var.data.size > 0: + sample = var.data.ravel()[0] + if isinstance(sample, dask_array_type): + sample = sample.compute() + if isinstance(sample, np.ndarray): + sample = sample.item() + return isinstance(sample, cftime_datetime) + else: + return False + + +def _contains_datetime_like_objects(var): + """Check if a variable contains datetime like objects (either + np.datetime64, np.timedelta64, or cftime.datetime)""" + return is_np_datetime_like(var.dtype) or contains_cftime_datetimes(var) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 32913127636..bdb2bf86990 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -17,7 +17,8 @@ rolling, utils) from .. import conventions from .alignment import align -from .common import DataWithCoords, ImplementsDatasetReduce +from .common import (DataWithCoords, ImplementsDatasetReduce, + _contains_datetime_like_objects) from .coordinates import ( DatasetCoordinates, Indexes, LevelCoordinatesSource, assert_coordinate_consistent, remap_label_indexers) @@ -75,7 +76,7 @@ def _get_virtual_variable(variables, key, level_vars=None, dim_sizes=None): virtual_var = ref_var var_name = key else: - if is_datetime_like(ref_var.dtype): + if _contains_datetime_like_objects(ref_var): ref_var = xr.DataArray(ref_var) data = getattr(ref_var.dt, var_name).data else: diff --git a/xarray/core/options.py b/xarray/core/options.py index b2968a2a02f..48d4567fc99 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -3,6 +3,7 @@ OPTIONS = { 'display_width': 80, 'arithmetic_join': 'inner', + 'enable_cftimeindex': False } @@ -15,6 +16,9 @@ class set_options(object): Default: ``80``. - ``arithmetic_join``: DataArray/Dataset alignment in binary operations. Default: ``'inner'``. + - ``enable_cftimeindex``: flag to enable using a ``CFTimeIndex`` + for time indexes with non-standard calendars or dates outside the + Timestamp-valid range. Default: ``False``. You can use ``set_options`` either as a context manager: diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 25a60b87266..06bb3ede393 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -12,6 +12,7 @@ import numpy as np import pandas as pd +from .options import OPTIONS from .pycompat import ( OrderedDict, basestring, bytes_type, dask_array_type, iteritems) @@ -36,6 +37,21 @@ def wrapper(*args, **kwargs): return wrapper +def _maybe_cast_to_cftimeindex(index): + from ..coding.cftimeindex import CFTimeIndex + + if not OPTIONS['enable_cftimeindex']: + return index + else: + if index.dtype == 'O': + try: + return CFTimeIndex(index) + except (ImportError, TypeError): + return index + else: + return index + + def safe_cast_to_index(array): """Given an array, safely cast it to a pandas.Index. @@ -54,7 +70,7 @@ def safe_cast_to_index(array): if hasattr(array, 'dtype') and array.dtype.kind == 'O': kwargs['dtype'] = object index = pd.Index(np.asarray(array), **kwargs) - return index + return _maybe_cast_to_cftimeindex(index) def multiindex_from_product_levels(levels, names=None): diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 6a3bed08f72..ee1df611d3b 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -14,6 +14,7 @@ import numpy as np import pandas as pd +from xarray.core.common import contains_cftime_datetimes from xarray.core.pycompat import basestring from .facetgrid import FacetGrid @@ -53,7 +54,8 @@ def _ensure_plottable(*args): if not (_valid_numpy_subdtype(np.array(x), numpy_types) or _valid_other_type(np.array(x), other_types)): raise TypeError('Plotting requires coordinates to be numeric ' - 'or dates.') + 'or dates of type np.datetime64 or ' + 'datetime.datetime.') def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None, @@ -120,6 +122,10 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, rtol=0.01, """ darray = darray.squeeze() + if contains_cftime_datetimes(darray): + raise NotImplementedError('Plotting arrays of cftime.datetime objects ' + 'is currently not possible.') + plot_dims = set(darray.dims) plot_dims.discard(row) plot_dims.discard(col) diff --git a/xarray/tests/test_accessors.py b/xarray/tests/test_accessors.py index ad521546d2e..e1b3a95b942 100644 --- a/xarray/tests/test_accessors.py +++ b/xarray/tests/test_accessors.py @@ -2,11 +2,13 @@ import numpy as np import pandas as pd +import pytest import xarray as xr from . import ( - TestCase, assert_array_equal, assert_equal, raises_regex, requires_dask) + TestCase, assert_array_equal, assert_equal, raises_regex, requires_dask, + has_cftime, has_dask, has_cftime_or_netCDF4) class TestDatetimeAccessor(TestCase): @@ -114,3 +116,117 @@ def test_rounders(self): xdates.time.dt.ceil('D').values) assert_array_equal(dates.round('D').values, xdates.time.dt.round('D').values) + + +_CFTIME_CALENDARS = ['365_day', '360_day', 'julian', 'all_leap', + '366_day', 'gregorian', 'proleptic_gregorian'] +_NT = 100 + + +@pytest.fixture(params=_CFTIME_CALENDARS) +def calendar(request): + return request.param + + +@pytest.fixture() +def times(calendar): + import cftime + + return cftime.num2date( + np.arange(_NT), units='hours since 2000-01-01', calendar=calendar, + only_use_cftime_datetimes=True) + + +@pytest.fixture() +def data(times): + data = np.random.rand(10, 10, _NT) + lons = np.linspace(0, 11, 10) + lats = np.linspace(0, 20, 10) + return xr.DataArray(data, coords=[lons, lats, times], + dims=['lon', 'lat', 'time'], name='data') + + +@pytest.fixture() +def times_3d(times): + lons = np.linspace(0, 11, 10) + lats = np.linspace(0, 20, 10) + times_arr = np.random.choice(times, size=(10, 10, _NT)) + return xr.DataArray(times_arr, coords=[lons, lats, times], + dims=['lon', 'lat', 'time'], + name='data') + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('field', ['year', 'month', 'day', 'hour']) +def test_field_access(data, field): + result = getattr(data.time.dt, field) + expected = xr.DataArray( + getattr(xr.coding.cftimeindex.CFTimeIndex(data.time.values), field), + name=field, coords=data.time.coords, dims=data.time.dims) + + assert_equal(result, expected) + + +@pytest.mark.skipif(not has_dask, reason='dask not installed') +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('field', ['year', 'month', 'day', 'hour']) +def test_dask_field_access_1d(data, field): + import dask.array as da + + expected = xr.DataArray( + getattr(xr.coding.cftimeindex.CFTimeIndex(data.time.values), field), + name=field, dims=['time']) + times = xr.DataArray(data.time.values, dims=['time']).chunk({'time': 50}) + result = getattr(times.dt, field) + assert isinstance(result.data, da.Array) + assert result.chunks == times.chunks + assert_equal(result.compute(), expected) + + +@pytest.mark.skipif(not has_dask, reason='dask not installed') +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('field', ['year', 'month', 'day', 'hour']) +def test_dask_field_access(times_3d, data, field): + import dask.array as da + + expected = xr.DataArray( + getattr(xr.coding.cftimeindex.CFTimeIndex(times_3d.values.ravel()), + field).reshape(times_3d.shape), + name=field, coords=times_3d.coords, dims=times_3d.dims) + times_3d = times_3d.chunk({'lon': 5, 'lat': 5, 'time': 50}) + result = getattr(times_3d.dt, field) + assert isinstance(result.data, da.Array) + assert result.chunks == times_3d.chunks + assert_equal(result.compute(), expected) + + +@pytest.fixture() +def cftime_date_type(calendar): + from .test_coding_times import _all_cftime_date_types + + return _all_cftime_date_types()[calendar] + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_seasons(cftime_date_type): + dates = np.array([cftime_date_type(2000, month, 15) + for month in range(1, 13)]) + dates = xr.DataArray(dates) + seasons = ['DJF', 'DJF', 'MAM', 'MAM', 'MAM', 'JJA', + 'JJA', 'JJA', 'SON', 'SON', 'SON', 'DJF'] + seasons = xr.DataArray(seasons) + + assert_array_equal(seasons.values, dates.dt.season.values) + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, + reason='cftime or netCDF4 not installed') +def test_dt_accessor_error_netCDF4(cftime_date_type): + da = xr.DataArray( + [cftime_date_type(1, 1, 1), cftime_date_type(2, 1, 1)], + dims=['time']) + if not has_cftime: + with pytest.raises(TypeError): + da.dt.month + else: + da.dt.month diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 632145007e2..2d4e5c0f261 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -32,7 +32,8 @@ assert_identical, has_dask, has_netCDF4, has_scipy, network, raises_regex, requires_dask, requires_h5netcdf, requires_netCDF4, requires_pathlib, requires_pydap, requires_pynio, requires_rasterio, requires_scipy, - requires_scipy_or_netCDF4, requires_zarr) + requires_scipy_or_netCDF4, requires_zarr, + requires_cftime) from .test_dataset import create_test_data try: @@ -341,7 +342,7 @@ def test_roundtrip_string_encoded_characters(self): assert_identical(expected, actual) self.assertEqual(actual['x'].encoding['_Encoding'], 'ascii') - def test_roundtrip_datetime_data(self): + def test_roundtrip_numpy_datetime_data(self): times = pd.to_datetime(['2000-01-01', '2000-01-02', 'NaT']) expected = Dataset({'t': ('t', times), 't0': times[0]}) kwds = {'encoding': {'t0': {'units': 'days since 1950-01-01'}}} @@ -349,6 +350,35 @@ def test_roundtrip_datetime_data(self): assert_identical(expected, actual) assert actual.t0.encoding['units'] == 'days since 1950-01-01' + @requires_cftime + def test_roundtrip_cftime_datetime_data_enable_cftimeindex(self): + from .test_coding_times import _all_cftime_date_types + + date_types = _all_cftime_date_types() + for date_type in date_types.values(): + times = [date_type(1, 1, 1), date_type(1, 1, 2)] + expected = Dataset({'t': ('t', times), 't0': times[0]}) + kwds = {'encoding': {'t0': {'units': 'days since 0001-01-01'}}} + expected_decoded_t = np.array(times) + expected_decoded_t0 = np.array([date_type(1, 1, 1)]) + expected_calendar = times[0].calendar + + with xr.set_options(enable_cftimeindex=True): + with self.roundtrip(expected, save_kwargs=kwds) as actual: + abs_diff = abs(actual.t.values - expected_decoded_t) + assert (abs_diff <= np.timedelta64(1, 's')).all() + assert (actual.t.encoding['units'] == + 'days since 0001-01-01 00:00:00.000000') + assert (actual.t.encoding['calendar'] == + expected_calendar) + + abs_diff = abs(actual.t0.values - expected_decoded_t0) + assert (abs_diff <= np.timedelta64(1, 's')).all() + assert (actual.t0.encoding['units'] == + 'days since 0001-01-01') + assert (actual.t.encoding['calendar'] == + expected_calendar) + def test_roundtrip_timedelta_data(self): time_deltas = pd.to_timedelta(['1h', '2h', 'NaT']) expected = Dataset({'td': ('td', time_deltas), 'td0': time_deltas[0]}) @@ -1949,7 +1979,7 @@ def test_roundtrip_string_encoded_characters(self): def test_roundtrip_coordinates_with_space(self): pass - def test_roundtrip_datetime_data(self): + def test_roundtrip_numpy_datetime_data(self): # Override method in DatasetIOTestCases - remove not applicable # save_kwds times = pd.to_datetime(['2000-01-01', '2000-01-02', 'NaT']) @@ -1957,6 +1987,46 @@ def test_roundtrip_datetime_data(self): with self.roundtrip(expected) as actual: assert_identical(expected, actual) + def test_roundtrip_cftime_datetime_data_enable_cftimeindex(self): + # Override method in DatasetIOTestCases - remove not applicable + # save_kwds + from .test_coding_times import _all_cftime_date_types + + date_types = _all_cftime_date_types() + for date_type in date_types.values(): + times = [date_type(1, 1, 1), date_type(1, 1, 2)] + expected = Dataset({'t': ('t', times), 't0': times[0]}) + expected_decoded_t = np.array(times) + expected_decoded_t0 = np.array([date_type(1, 1, 1)]) + + with xr.set_options(enable_cftimeindex=True): + with self.roundtrip(expected) as actual: + abs_diff = abs(actual.t.values - expected_decoded_t) + self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all()) + + abs_diff = abs(actual.t0.values - expected_decoded_t0) + self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all()) + + def test_roundtrip_cftime_datetime_data_disable_cftimeindex(self): + # Override method in DatasetIOTestCases - remove not applicable + # save_kwds + from .test_coding_times import _all_cftime_date_types + + date_types = _all_cftime_date_types() + for date_type in date_types.values(): + times = [date_type(1, 1, 1), date_type(1, 1, 2)] + expected = Dataset({'t': ('t', times), 't0': times[0]}) + expected_decoded_t = np.array(times) + expected_decoded_t0 = np.array([date_type(1, 1, 1)]) + + with xr.set_options(enable_cftimeindex=False): + with self.roundtrip(expected) as actual: + abs_diff = abs(actual.t.values - expected_decoded_t) + self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all()) + + abs_diff = abs(actual.t0.values - expected_decoded_t0) + self.assertTrue((abs_diff <= np.timedelta64(1, 's')).all()) + def test_write_store(self): # Override method in DatasetIOTestCases - not applicable to dask pass diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py new file mode 100644 index 00000000000..c78ac038bd5 --- /dev/null +++ b/xarray/tests/test_cftimeindex.py @@ -0,0 +1,555 @@ +from __future__ import absolute_import + +import pytest + +import pandas as pd +import xarray as xr + +from datetime import timedelta +from xarray.coding.cftimeindex import ( + parse_iso8601, CFTimeIndex, assert_all_valid_date_type, + _parsed_string_to_bounds, _parse_iso8601_with_reso) +from xarray.tests import assert_array_equal, assert_identical + +from . import has_cftime, has_cftime_or_netCDF4 +from .test_coding_times import _all_cftime_date_types + + +def date_dict(year=None, month=None, day=None, + hour=None, minute=None, second=None): + return dict(year=year, month=month, day=day, hour=hour, + minute=minute, second=second) + + +ISO8601_STRING_TESTS = { + 'year': ('1999', date_dict(year='1999')), + 'month': ('199901', date_dict(year='1999', month='01')), + 'month-dash': ('1999-01', date_dict(year='1999', month='01')), + 'day': ('19990101', date_dict(year='1999', month='01', day='01')), + 'day-dash': ('1999-01-01', date_dict(year='1999', month='01', day='01')), + 'hour': ('19990101T12', date_dict( + year='1999', month='01', day='01', hour='12')), + 'hour-dash': ('1999-01-01T12', date_dict( + year='1999', month='01', day='01', hour='12')), + 'minute': ('19990101T1234', date_dict( + year='1999', month='01', day='01', hour='12', minute='34')), + 'minute-dash': ('1999-01-01T12:34', date_dict( + year='1999', month='01', day='01', hour='12', minute='34')), + 'second': ('19990101T123456', date_dict( + year='1999', month='01', day='01', hour='12', minute='34', + second='56')), + 'second-dash': ('1999-01-01T12:34:56', date_dict( + year='1999', month='01', day='01', hour='12', minute='34', + second='56')) +} + + +@pytest.mark.parametrize(('string', 'expected'), + list(ISO8601_STRING_TESTS.values()), + ids=list(ISO8601_STRING_TESTS.keys())) +def test_parse_iso8601(string, expected): + result = parse_iso8601(string) + assert result == expected + + with pytest.raises(ValueError): + parse_iso8601(string + '3') + parse_iso8601(string + '.3') + + +_CFTIME_CALENDARS = ['365_day', '360_day', 'julian', 'all_leap', + '366_day', 'gregorian', 'proleptic_gregorian'] + + +@pytest.fixture(params=_CFTIME_CALENDARS) +def date_type(request): + return _all_cftime_date_types()[request.param] + + +@pytest.fixture +def index(date_type): + dates = [date_type(1, 1, 1), date_type(1, 2, 1), + date_type(2, 1, 1), date_type(2, 2, 1)] + return CFTimeIndex(dates) + + +@pytest.fixture +def monotonic_decreasing_index(date_type): + dates = [date_type(2, 2, 1), date_type(2, 1, 1), + date_type(1, 2, 1), date_type(1, 1, 1)] + return CFTimeIndex(dates) + + +@pytest.fixture +def da(index): + return xr.DataArray([1, 2, 3, 4], coords=[index], + dims=['time']) + + +@pytest.fixture +def series(index): + return pd.Series([1, 2, 3, 4], index=index) + + +@pytest.fixture +def df(index): + return pd.DataFrame([1, 2, 3, 4], index=index) + + +@pytest.fixture +def feb_days(date_type): + import cftime + if date_type is cftime.DatetimeAllLeap: + return 29 + elif date_type is cftime.Datetime360Day: + return 30 + else: + return 28 + + +@pytest.fixture +def dec_days(date_type): + import cftime + if date_type is cftime.Datetime360Day: + return 30 + else: + return 31 + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_assert_all_valid_date_type(date_type, index): + import cftime + if date_type is cftime.DatetimeNoLeap: + mixed_date_types = [date_type(1, 1, 1), + cftime.DatetimeAllLeap(1, 2, 1)] + else: + mixed_date_types = [date_type(1, 1, 1), + cftime.DatetimeNoLeap(1, 2, 1)] + with pytest.raises(TypeError): + assert_all_valid_date_type(mixed_date_types) + + with pytest.raises(TypeError): + assert_all_valid_date_type([1, date_type(1, 1, 1)]) + + assert_all_valid_date_type([date_type(1, 1, 1), date_type(1, 2, 1)]) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize(('field', 'expected'), [ + ('year', [1, 1, 2, 2]), + ('month', [1, 2, 1, 2]), + ('day', [1, 1, 1, 1]), + ('hour', [0, 0, 0, 0]), + ('minute', [0, 0, 0, 0]), + ('second', [0, 0, 0, 0]), + ('microsecond', [0, 0, 0, 0])]) +def test_cftimeindex_field_accessors(index, field, expected): + result = getattr(index, field) + assert_array_equal(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize(('string', 'date_args', 'reso'), [ + ('1999', (1999, 1, 1), 'year'), + ('199902', (1999, 2, 1), 'month'), + ('19990202', (1999, 2, 2), 'day'), + ('19990202T01', (1999, 2, 2, 1), 'hour'), + ('19990202T0101', (1999, 2, 2, 1, 1), 'minute'), + ('19990202T010156', (1999, 2, 2, 1, 1, 56), 'second')]) +def test_parse_iso8601_with_reso(date_type, string, date_args, reso): + expected_date = date_type(*date_args) + expected_reso = reso + result_date, result_reso = _parse_iso8601_with_reso(date_type, string) + assert result_date == expected_date + assert result_reso == expected_reso + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_parse_string_to_bounds_year(date_type, dec_days): + parsed = date_type(2, 2, 10, 6, 2, 8, 1) + expected_start = date_type(2, 1, 1) + expected_end = date_type(2, 12, dec_days, 23, 59, 59, 999999) + result_start, result_end = _parsed_string_to_bounds( + date_type, 'year', parsed) + assert result_start == expected_start + assert result_end == expected_end + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_parse_string_to_bounds_month_feb(date_type, feb_days): + parsed = date_type(2, 2, 10, 6, 2, 8, 1) + expected_start = date_type(2, 2, 1) + expected_end = date_type(2, 2, feb_days, 23, 59, 59, 999999) + result_start, result_end = _parsed_string_to_bounds( + date_type, 'month', parsed) + assert result_start == expected_start + assert result_end == expected_end + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_parse_string_to_bounds_month_dec(date_type, dec_days): + parsed = date_type(2, 12, 1) + expected_start = date_type(2, 12, 1) + expected_end = date_type(2, 12, dec_days, 23, 59, 59, 999999) + result_start, result_end = _parsed_string_to_bounds( + date_type, 'month', parsed) + assert result_start == expected_start + assert result_end == expected_end + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize(('reso', 'ex_start_args', 'ex_end_args'), [ + ('day', (2, 2, 10), (2, 2, 10, 23, 59, 59, 999999)), + ('hour', (2, 2, 10, 6), (2, 2, 10, 6, 59, 59, 999999)), + ('minute', (2, 2, 10, 6, 2), (2, 2, 10, 6, 2, 59, 999999)), + ('second', (2, 2, 10, 6, 2, 8), (2, 2, 10, 6, 2, 8, 999999))]) +def test_parsed_string_to_bounds_sub_monthly(date_type, reso, + ex_start_args, ex_end_args): + parsed = date_type(2, 2, 10, 6, 2, 8, 123456) + expected_start = date_type(*ex_start_args) + expected_end = date_type(*ex_end_args) + + result_start, result_end = _parsed_string_to_bounds( + date_type, reso, parsed) + assert result_start == expected_start + assert result_end == expected_end + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_parsed_string_to_bounds_raises(date_type): + with pytest.raises(KeyError): + _parsed_string_to_bounds(date_type, 'a', date_type(1, 1, 1)) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_get_loc(date_type, index): + result = index.get_loc('0001') + expected = [0, 1] + assert_array_equal(result, expected) + + result = index.get_loc(date_type(1, 2, 1)) + expected = 1 + assert result == expected + + result = index.get_loc('0001-02-01') + expected = 1 + assert result == expected + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('kind', ['loc', 'getitem']) +def test_get_slice_bound(date_type, index, kind): + result = index.get_slice_bound('0001', 'left', kind) + expected = 0 + assert result == expected + + result = index.get_slice_bound('0001', 'right', kind) + expected = 2 + assert result == expected + + result = index.get_slice_bound( + date_type(1, 3, 1), 'left', kind) + expected = 2 + assert result == expected + + result = index.get_slice_bound( + date_type(1, 3, 1), 'right', kind) + expected = 2 + assert result == expected + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('kind', ['loc', 'getitem']) +def test_get_slice_bound_decreasing_index( + date_type, monotonic_decreasing_index, kind): + result = monotonic_decreasing_index.get_slice_bound('0001', 'left', kind) + expected = 2 + assert result == expected + + result = monotonic_decreasing_index.get_slice_bound('0001', 'right', kind) + expected = 4 + assert result == expected + + result = monotonic_decreasing_index.get_slice_bound( + date_type(1, 3, 1), 'left', kind) + expected = 2 + assert result == expected + + result = monotonic_decreasing_index.get_slice_bound( + date_type(1, 3, 1), 'right', kind) + expected = 2 + assert result == expected + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_date_type_property(date_type, index): + assert index.date_type is date_type + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_contains(date_type, index): + assert '0001-01-01' in index + assert '0001' in index + assert '0003' not in index + assert date_type(1, 1, 1) in index + assert date_type(3, 1, 1) not in index + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_groupby(da): + result = da.groupby('time.month').sum('time') + expected = xr.DataArray([4, 6], coords=[[1, 2]], dims=['month']) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_resample_error(da): + with pytest.raises(TypeError): + da.resample(time='Y') + + +SEL_STRING_OR_LIST_TESTS = { + 'string': '0001', + 'string-slice': slice('0001-01-01', '0001-12-30'), + 'bool-list': [True, True, False, False] +} + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_arg', list(SEL_STRING_OR_LIST_TESTS.values()), + ids=list(SEL_STRING_OR_LIST_TESTS.keys())) +def test_sel_string_or_list(da, index, sel_arg): + expected = xr.DataArray([1, 2], coords=[index[:2]], dims=['time']) + result = da.sel(time=sel_arg) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_sel_date_slice_or_list(da, index, date_type): + expected = xr.DataArray([1, 2], coords=[index[:2]], dims=['time']) + result = da.sel(time=slice(date_type(1, 1, 1), date_type(1, 12, 30))) + assert_identical(result, expected) + + result = da.sel(time=[date_type(1, 1, 1), date_type(1, 2, 1)]) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_sel_date_scalar(da, date_type, index): + expected = xr.DataArray(1).assign_coords(time=index[0]) + result = da.sel(time=date_type(1, 1, 1)) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_kwargs', [ + {'method': 'nearest'}, + {'method': 'nearest', 'tolerance': timedelta(days=70)} +]) +def test_sel_date_scalar_nearest(da, date_type, index, sel_kwargs): + expected = xr.DataArray(2).assign_coords(time=index[1]) + result = da.sel(time=date_type(1, 4, 1), **sel_kwargs) + assert_identical(result, expected) + + expected = xr.DataArray(3).assign_coords(time=index[2]) + result = da.sel(time=date_type(1, 11, 1), **sel_kwargs) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_kwargs', [ + {'method': 'pad'}, + {'method': 'pad', 'tolerance': timedelta(days=365)} +]) +def test_sel_date_scalar_pad(da, date_type, index, sel_kwargs): + expected = xr.DataArray(2).assign_coords(time=index[1]) + result = da.sel(time=date_type(1, 4, 1), **sel_kwargs) + assert_identical(result, expected) + + expected = xr.DataArray(2).assign_coords(time=index[1]) + result = da.sel(time=date_type(1, 11, 1), **sel_kwargs) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_kwargs', [ + {'method': 'backfill'}, + {'method': 'backfill', 'tolerance': timedelta(days=365)} +]) +def test_sel_date_scalar_backfill(da, date_type, index, sel_kwargs): + expected = xr.DataArray(3).assign_coords(time=index[2]) + result = da.sel(time=date_type(1, 4, 1), **sel_kwargs) + assert_identical(result, expected) + + expected = xr.DataArray(3).assign_coords(time=index[2]) + result = da.sel(time=date_type(1, 11, 1), **sel_kwargs) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_kwargs', [ + {'method': 'pad', 'tolerance': timedelta(days=20)}, + {'method': 'backfill', 'tolerance': timedelta(days=20)}, + {'method': 'nearest', 'tolerance': timedelta(days=20)}, +]) +def test_sel_date_scalar_tolerance_raises(da, date_type, sel_kwargs): + with pytest.raises(KeyError): + da.sel(time=date_type(1, 5, 1), **sel_kwargs) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_kwargs', [ + {'method': 'nearest'}, + {'method': 'nearest', 'tolerance': timedelta(days=70)} +]) +def test_sel_date_list_nearest(da, date_type, index, sel_kwargs): + expected = xr.DataArray( + [2, 2], coords=[[index[1], index[1]]], dims=['time']) + result = da.sel( + time=[date_type(1, 3, 1), date_type(1, 4, 1)], **sel_kwargs) + assert_identical(result, expected) + + expected = xr.DataArray( + [2, 3], coords=[[index[1], index[2]]], dims=['time']) + result = da.sel( + time=[date_type(1, 3, 1), date_type(1, 12, 1)], **sel_kwargs) + assert_identical(result, expected) + + expected = xr.DataArray( + [3, 3], coords=[[index[2], index[2]]], dims=['time']) + result = da.sel( + time=[date_type(1, 11, 1), date_type(1, 12, 1)], **sel_kwargs) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_kwargs', [ + {'method': 'pad'}, + {'method': 'pad', 'tolerance': timedelta(days=365)} +]) +def test_sel_date_list_pad(da, date_type, index, sel_kwargs): + expected = xr.DataArray( + [2, 2], coords=[[index[1], index[1]]], dims=['time']) + result = da.sel( + time=[date_type(1, 3, 1), date_type(1, 4, 1)], **sel_kwargs) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_kwargs', [ + {'method': 'backfill'}, + {'method': 'backfill', 'tolerance': timedelta(days=365)} +]) +def test_sel_date_list_backfill(da, date_type, index, sel_kwargs): + expected = xr.DataArray( + [3, 3], coords=[[index[2], index[2]]], dims=['time']) + result = da.sel( + time=[date_type(1, 3, 1), date_type(1, 4, 1)], **sel_kwargs) + assert_identical(result, expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('sel_kwargs', [ + {'method': 'pad', 'tolerance': timedelta(days=20)}, + {'method': 'backfill', 'tolerance': timedelta(days=20)}, + {'method': 'nearest', 'tolerance': timedelta(days=20)}, +]) +def test_sel_date_list_tolerance_raises(da, date_type, sel_kwargs): + with pytest.raises(KeyError): + da.sel(time=[date_type(1, 2, 1), date_type(1, 5, 1)], **sel_kwargs) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_isel(da, index): + expected = xr.DataArray(1).assign_coords(time=index[0]) + result = da.isel(time=0) + assert_identical(result, expected) + + expected = xr.DataArray([1, 2], coords=[index[:2]], dims=['time']) + result = da.isel(time=[0, 1]) + assert_identical(result, expected) + + +@pytest.fixture +def scalar_args(date_type): + return [date_type(1, 1, 1)] + + +@pytest.fixture +def range_args(date_type): + return ['0001', slice('0001-01-01', '0001-12-30'), + slice(None, '0001-12-30'), + slice(date_type(1, 1, 1), date_type(1, 12, 30)), + slice(None, date_type(1, 12, 30))] + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_indexing_in_series_getitem(series, index, scalar_args, range_args): + for arg in scalar_args: + assert series[arg] == 1 + + expected = pd.Series([1, 2], index=index[:2]) + for arg in range_args: + assert series[arg].equals(expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_indexing_in_series_loc(series, index, scalar_args, range_args): + for arg in scalar_args: + assert series.loc[arg] == 1 + + expected = pd.Series([1, 2], index=index[:2]) + for arg in range_args: + assert series.loc[arg].equals(expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_indexing_in_series_iloc(series, index): + expected = 1 + assert series.iloc[0] == expected + + expected = pd.Series([1, 2], index=index[:2]) + assert series.iloc[:2].equals(expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_indexing_in_dataframe_loc(df, index, scalar_args, range_args): + expected = pd.Series([1], name=index[0]) + for arg in scalar_args: + result = df.loc[arg] + assert result.equals(expected) + + expected = pd.DataFrame([1, 2], index=index[:2]) + for arg in range_args: + result = df.loc[arg] + assert result.equals(expected) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_indexing_in_dataframe_iloc(df, index): + expected = pd.Series([1], name=index[0]) + result = df.iloc[0] + assert result.equals(expected) + assert result.equals(expected) + + expected = pd.DataFrame([1, 2], index=index[:2]) + result = df.iloc[:2] + assert result.equals(expected) + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize('enable_cftimeindex', [False, True]) +def test_concat_cftimeindex(date_type, enable_cftimeindex): + with xr.set_options(enable_cftimeindex=enable_cftimeindex): + da1 = xr.DataArray( + [1., 2.], coords=[[date_type(1, 1, 1), date_type(1, 2, 1)]], + dims=['time']) + da2 = xr.DataArray( + [3., 4.], coords=[[date_type(1, 3, 1), date_type(1, 4, 1)]], + dims=['time']) + da = xr.concat([da1, da2], dim='time') + + if enable_cftimeindex and has_cftime: + assert isinstance(da.indexes['time'], CFTimeIndex) + else: + assert isinstance(da.indexes['time'], pd.Index) + assert not isinstance(da.indexes['time'], CFTimeIndex) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 7e69d4b3ff2..7c1e869f772 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -1,15 +1,55 @@ from __future__ import absolute_import, division, print_function +from itertools import product import warnings import numpy as np import pandas as pd import pytest -from xarray import Variable, coding +from xarray import Variable, coding, set_options, DataArray, decode_cf from xarray.coding.times import _import_cftime - -from . import TestCase, assert_array_equal, requires_cftime_or_netCDF4 +from xarray.coding.variables import SerializationWarning +from xarray.core.common import contains_cftime_datetimes + +from . import (assert_array_equal, has_cftime_or_netCDF4, + requires_cftime_or_netCDF4, has_cftime, has_dask) + + +_NON_STANDARD_CALENDARS = {'noleap', '365_day', '360_day', + 'julian', 'all_leap', '366_day'} +_ALL_CALENDARS = _NON_STANDARD_CALENDARS.union( + coding.times._STANDARD_CALENDARS) +_CF_DATETIME_NUM_DATES_UNITS = [ + (np.arange(10), 'days since 2000-01-01'), + (np.arange(10).astype('float64'), 'days since 2000-01-01'), + (np.arange(10).astype('float32'), 'days since 2000-01-01'), + (np.arange(10).reshape(2, 5), 'days since 2000-01-01'), + (12300 + np.arange(5), 'hours since 1680-01-01 00:00:00'), + # here we add a couple minor formatting errors to test + # the robustness of the parsing algorithm. + (12300 + np.arange(5), 'hour since 1680-01-01 00:00:00'), + (12300 + np.arange(5), u'Hour since 1680-01-01 00:00:00'), + (12300 + np.arange(5), ' Hour since 1680-01-01 00:00:00 '), + (10, 'days since 2000-01-01'), + ([10], 'daYs since 2000-01-01'), + ([[10]], 'days since 2000-01-01'), + ([10, 10], 'days since 2000-01-01'), + (np.array(10), 'days since 2000-01-01'), + (0, 'days since 1000-01-01'), + ([0], 'days since 1000-01-01'), + ([[0]], 'days since 1000-01-01'), + (np.arange(2), 'days since 1000-01-01'), + (np.arange(0, 100000, 20000), 'days since 1900-01-01'), + (17093352.0, 'hours since 1-1-1 00:00:0.0'), + ([0.5, 1.5], 'hours since 1900-01-01T00:00:00'), + (0, 'milliseconds since 2000-01-01T00:00:00'), + (0, 'microseconds since 2000-01-01T00:00:00'), + (np.int32(788961600), 'seconds since 1981-01-01') # GH2002 +] +_CF_DATETIME_TESTS = [num_dates_units + (calendar,) for num_dates_units, + calendar in product(_CF_DATETIME_NUM_DATES_UNITS, + coding.times._STANDARD_CALENDARS)] @np.vectorize @@ -20,309 +60,698 @@ def _ensure_naive_tz(dt): return dt -class TestDatetime(TestCase): - @requires_cftime_or_netCDF4 - def test_cf_datetime(self): - cftime = _import_cftime() - for num_dates, units in [ - (np.arange(10), 'days since 2000-01-01'), - (np.arange(10).astype('float64'), 'days since 2000-01-01'), - (np.arange(10).astype('float32'), 'days since 2000-01-01'), - (np.arange(10).reshape(2, 5), 'days since 2000-01-01'), - (12300 + np.arange(5), 'hours since 1680-01-01 00:00:00'), - # here we add a couple minor formatting errors to test - # the robustness of the parsing algorithm. - (12300 + np.arange(5), 'hour since 1680-01-01 00:00:00'), - (12300 + np.arange(5), u'Hour since 1680-01-01 00:00:00'), - (12300 + np.arange(5), ' Hour since 1680-01-01 00:00:00 '), - (10, 'days since 2000-01-01'), - ([10], 'daYs since 2000-01-01'), - ([[10]], 'days since 2000-01-01'), - ([10, 10], 'days since 2000-01-01'), - (np.array(10), 'days since 2000-01-01'), - (0, 'days since 1000-01-01'), - ([0], 'days since 1000-01-01'), - ([[0]], 'days since 1000-01-01'), - (np.arange(2), 'days since 1000-01-01'), - (np.arange(0, 100000, 20000), 'days since 1900-01-01'), - (17093352.0, 'hours since 1-1-1 00:00:0.0'), - ([0.5, 1.5], 'hours since 1900-01-01T00:00:00'), - (0, 'milliseconds since 2000-01-01T00:00:00'), - (0, 'microseconds since 2000-01-01T00:00:00'), - (np.int32(788961600), 'seconds since 1981-01-01'), # GH2002 - ]: - for calendar in ['standard', 'gregorian', 'proleptic_gregorian']: - expected = _ensure_naive_tz( - cftime.num2date(num_dates, units, calendar)) - print(num_dates, units, calendar) - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', - 'Unable to decode time axis') - actual = coding.times.decode_cf_datetime(num_dates, units, - calendar) - if (isinstance(actual, np.ndarray) and - np.issubdtype(actual.dtype, np.datetime64)): - # self.assertEqual(actual.dtype.kind, 'M') - # For some reason, numpy 1.8 does not compare ns precision - # datetime64 arrays as equal to arrays of datetime objects, - # but it works for us precision. Thus, convert to us - # precision for the actual array equal comparison... - actual_cmp = actual.astype('M8[us]') - else: - actual_cmp = actual - assert_array_equal(expected, actual_cmp) - encoded, _, _ = coding.times.encode_cf_datetime(actual, units, - calendar) - if '1-1-1' not in units: - # pandas parses this date very strangely, so the original - # units/encoding cannot be preserved in this case: - # (Pdb) pd.to_datetime('1-1-1 00:00:0.0') - # Timestamp('2001-01-01 00:00:00') - assert_array_equal(num_dates, np.around(encoded, 1)) - if (hasattr(num_dates, 'ndim') and num_dates.ndim == 1 and - '1000' not in units): - # verify that wrapping with a pandas.Index works - # note that it *does not* currently work to even put - # non-datetime64 compatible dates into a pandas.Index - encoded, _, _ = coding.times.encode_cf_datetime( - pd.Index(actual), units, calendar) - assert_array_equal(num_dates, np.around(encoded, 1)) - - @requires_cftime_or_netCDF4 - def test_decode_cf_datetime_overflow(self): - # checks for - # https://github.com/pydata/pandas/issues/14068 - # https://github.com/pydata/xarray/issues/975 - - from datetime import datetime - units = 'days since 2000-01-01 00:00:00' - - # date after 2262 and before 1678 - days = (-117608, 95795) - expected = (datetime(1677, 12, 31), datetime(2262, 4, 12)) - - for i, day in enumerate(days): - result = coding.times.decode_cf_datetime(day, units) - assert result == expected[i] - - def test_decode_cf_datetime_non_standard_units(self): - expected = pd.date_range(periods=100, start='1970-01-01', freq='h') - # netCDFs from madis.noaa.gov use this format for their time units - # they cannot be parsed by netcdftime, but pd.Timestamp works - units = 'hours since 1-1-1970' - actual = coding.times.decode_cf_datetime(np.arange(100), units) +def _all_cftime_date_types(): + try: + import cftime + except ImportError: + import netcdftime as cftime + return {'noleap': cftime.DatetimeNoLeap, + '365_day': cftime.DatetimeNoLeap, + '360_day': cftime.Datetime360Day, + 'julian': cftime.DatetimeJulian, + 'all_leap': cftime.DatetimeAllLeap, + '366_day': cftime.DatetimeAllLeap, + 'gregorian': cftime.DatetimeGregorian, + 'proleptic_gregorian': cftime.DatetimeProlepticGregorian} + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize(['num_dates', 'units', 'calendar'], + _CF_DATETIME_TESTS) +def test_cf_datetime(num_dates, units, calendar): + cftime = _import_cftime() + expected = _ensure_naive_tz( + cftime.num2date(num_dates, units, calendar)) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', + 'Unable to decode time axis') + actual = coding.times.decode_cf_datetime(num_dates, units, + calendar) + if (isinstance(actual, np.ndarray) and + np.issubdtype(actual.dtype, np.datetime64)): + # self.assertEqual(actual.dtype.kind, 'M') + # For some reason, numpy 1.8 does not compare ns precision + # datetime64 arrays as equal to arrays of datetime objects, + # but it works for us precision. Thus, convert to us + # precision for the actual array equal comparison... + actual_cmp = actual.astype('M8[us]') + else: + actual_cmp = actual + assert_array_equal(expected, actual_cmp) + encoded, _, _ = coding.times.encode_cf_datetime(actual, units, + calendar) + if '1-1-1' not in units: + # pandas parses this date very strangely, so the original + # units/encoding cannot be preserved in this case: + # (Pdb) pd.to_datetime('1-1-1 00:00:0.0') + # Timestamp('2001-01-01 00:00:00') + assert_array_equal(num_dates, np.around(encoded, 1)) + if (hasattr(num_dates, 'ndim') and num_dates.ndim == 1 and + '1000' not in units): + # verify that wrapping with a pandas.Index works + # note that it *does not* currently work to even put + # non-datetime64 compatible dates into a pandas.Index + encoded, _, _ = coding.times.encode_cf_datetime( + pd.Index(actual), units, calendar) + assert_array_equal(num_dates, np.around(encoded, 1)) + + +@requires_cftime_or_netCDF4 +def test_decode_cf_datetime_overflow(): + # checks for + # https://github.com/pydata/pandas/issues/14068 + # https://github.com/pydata/xarray/issues/975 + + from datetime import datetime + units = 'days since 2000-01-01 00:00:00' + + # date after 2262 and before 1678 + days = (-117608, 95795) + expected = (datetime(1677, 12, 31), datetime(2262, 4, 12)) + + for i, day in enumerate(days): + result = coding.times.decode_cf_datetime(day, units) + assert result == expected[i] + + +def test_decode_cf_datetime_non_standard_units(): + expected = pd.date_range(periods=100, start='1970-01-01', freq='h') + # netCDFs from madis.noaa.gov use this format for their time units + # they cannot be parsed by cftime, but pd.Timestamp works + units = 'hours since 1-1-1970' + actual = coding.times.decode_cf_datetime(np.arange(100), units) + assert_array_equal(actual, expected) + + +@requires_cftime_or_netCDF4 +def test_decode_cf_datetime_non_iso_strings(): + # datetime strings that are _almost_ ISO compliant but not quite, + # but which netCDF4.num2date can still parse correctly + expected = pd.date_range(periods=100, start='2000-01-01', freq='h') + cases = [(np.arange(100), 'hours since 2000-01-01 0'), + (np.arange(100), 'hours since 2000-1-1 0'), + (np.arange(100), 'hours since 2000-01-01 0:00')] + for num_dates, units in cases: + actual = coding.times.decode_cf_datetime(num_dates, units) assert_array_equal(actual, expected) - @requires_cftime_or_netCDF4 - def test_decode_cf_datetime_non_iso_strings(self): - # datetime strings that are _almost_ ISO compliant but not quite, - # but which netCDF4.num2date can still parse correctly - expected = pd.date_range(periods=100, start='2000-01-01', freq='h') - cases = [(np.arange(100), 'hours since 2000-01-01 0'), - (np.arange(100), 'hours since 2000-1-1 0'), - (np.arange(100), 'hours since 2000-01-01 0:00')] - for num_dates, units in cases: - actual = coding.times.decode_cf_datetime(num_dates, units) - assert_array_equal(actual, expected) - - @requires_cftime_or_netCDF4 - def test_decode_non_standard_calendar(self): - cftime = _import_cftime() - - for calendar in ['noleap', '365_day', '360_day', 'julian', 'all_leap', - '366_day']: - units = 'days since 0001-01-01' - times = pd.date_range('2001-04-01-00', end='2001-04-30-23', - freq='H') - noleap_time = cftime.date2num(times.to_pydatetime(), units, - calendar=calendar) - expected = times.values - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'Unable to decode time axis') - actual = coding.times.decode_cf_datetime(noleap_time, units, - calendar=calendar) - assert actual.dtype == np.dtype('M8[ns]') - abs_diff = abs(actual - expected) - # once we no longer support versions of netCDF4 older than 1.1.5, - # we could do this check with near microsecond accuracy: - # https://github.com/Unidata/netcdf4-python/issues/355 - assert (abs_diff <= np.timedelta64(1, 's')).all() - - @requires_cftime_or_netCDF4 - def test_decode_non_standard_calendar_single_element(self): - units = 'days since 0001-01-01' - for calendar in ['noleap', '365_day', '360_day', 'julian', 'all_leap', - '366_day']: - for num_time in [735368, [735368], [[735368]]]: - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', - 'Unable to decode time axis') - actual = coding.times.decode_cf_datetime(num_time, units, - calendar=calendar) - assert actual.dtype == np.dtype('M8[ns]') - - @requires_cftime_or_netCDF4 - def test_decode_non_standard_calendar_single_element_fallback(self): - cftime = _import_cftime() - - units = 'days since 0001-01-01' - try: - dt = cftime.netcdftime.datetime(2001, 2, 29) - except AttributeError: - # Must be using standalone netcdftime library - dt = cftime.datetime(2001, 2, 29) - for calendar in ['360_day', 'all_leap', '366_day']: - num_time = cftime.date2num(dt, units, calendar) - with pytest.warns(Warning, match='Unable to decode time axis'): - actual = coding.times.decode_cf_datetime(num_time, units, - calendar=calendar) - expected = np.asarray(cftime.num2date(num_time, units, calendar)) - assert actual.dtype == np.dtype('O') - assert expected == actual - - @requires_cftime_or_netCDF4 - def test_decode_non_standard_calendar_multidim_time(self): - cftime = _import_cftime() - - calendar = 'noleap' - units = 'days since 0001-01-01' - times1 = pd.date_range('2001-04-01', end='2001-04-05', freq='D') - times2 = pd.date_range('2001-05-01', end='2001-05-05', freq='D') - noleap_time1 = cftime.date2num(times1.to_pydatetime(), units, - calendar=calendar) - noleap_time2 = cftime.date2num(times2.to_pydatetime(), units, - calendar=calendar) - mdim_time = np.empty((len(noleap_time1), 2), ) - mdim_time[:, 0] = noleap_time1 - mdim_time[:, 1] = noleap_time2 - expected1 = times1.values - expected2 = times2.values +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize( + ['calendar', 'enable_cftimeindex'], + product(coding.times._STANDARD_CALENDARS, [False, True])) +def test_decode_standard_calendar_inside_timestamp_range( + calendar, enable_cftimeindex): + if enable_cftimeindex: + pytest.importorskip('cftime') + + cftime = _import_cftime() + units = 'days since 0001-01-01' + times = pd.date_range('2001-04-01-00', end='2001-04-30-23', + freq='H') + noleap_time = cftime.date2num(times.to_pydatetime(), units, + calendar=calendar) + expected = times.values + expected_dtype = np.dtype('M8[ns]') + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Unable to decode time axis') + actual = coding.times.decode_cf_datetime( + noleap_time, units, calendar=calendar, + enable_cftimeindex=enable_cftimeindex) + assert actual.dtype == expected_dtype + abs_diff = abs(actual - expected) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff <= np.timedelta64(1, 's')).all() + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize( + ['calendar', 'enable_cftimeindex'], + product(_NON_STANDARD_CALENDARS, [False, True])) +def test_decode_non_standard_calendar_inside_timestamp_range( + calendar, enable_cftimeindex): + if enable_cftimeindex: + pytest.importorskip('cftime') + + cftime = _import_cftime() + units = 'days since 0001-01-01' + times = pd.date_range('2001-04-01-00', end='2001-04-30-23', + freq='H') + noleap_time = cftime.date2num(times.to_pydatetime(), units, + calendar=calendar) + if enable_cftimeindex: + expected = cftime.num2date(noleap_time, units, calendar=calendar) + expected_dtype = np.dtype('O') + else: + expected = times.values + expected_dtype = np.dtype('M8[ns]') + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Unable to decode time axis') + actual = coding.times.decode_cf_datetime( + noleap_time, units, calendar=calendar, + enable_cftimeindex=enable_cftimeindex) + assert actual.dtype == expected_dtype + abs_diff = abs(actual - expected) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff <= np.timedelta64(1, 's')).all() + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize( + ['calendar', 'enable_cftimeindex'], + product(_ALL_CALENDARS, [False, True])) +def test_decode_dates_outside_timestamp_range( + calendar, enable_cftimeindex): + from datetime import datetime + + if enable_cftimeindex: + pytest.importorskip('cftime') + + cftime = _import_cftime() + + units = 'days since 0001-01-01' + times = [datetime(1, 4, 1, h) for h in range(1, 5)] + noleap_time = cftime.date2num(times, units, calendar=calendar) + if enable_cftimeindex: + expected = cftime.num2date(noleap_time, units, calendar=calendar, + only_use_cftime_datetimes=True) + else: + expected = cftime.num2date(noleap_time, units, calendar=calendar) + expected_date_type = type(expected[0]) + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Unable to decode time axis') + actual = coding.times.decode_cf_datetime( + noleap_time, units, calendar=calendar, + enable_cftimeindex=enable_cftimeindex) + assert all(isinstance(value, expected_date_type) for value in actual) + abs_diff = abs(actual - expected) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff <= np.timedelta64(1, 's')).all() + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize( + ['calendar', 'enable_cftimeindex'], + product(coding.times._STANDARD_CALENDARS, [False, True])) +def test_decode_standard_calendar_single_element_inside_timestamp_range( + calendar, enable_cftimeindex): + if enable_cftimeindex: + pytest.importorskip('cftime') + + units = 'days since 0001-01-01' + for num_time in [735368, [735368], [[735368]]]: with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'Unable to decode time axis') - actual = coding.times.decode_cf_datetime(mdim_time, units, - calendar=calendar) + warnings.filterwarnings('ignore', + 'Unable to decode time axis') + actual = coding.times.decode_cf_datetime( + num_time, units, calendar=calendar, + enable_cftimeindex=enable_cftimeindex) assert actual.dtype == np.dtype('M8[ns]') - assert_array_equal(actual[:, 0], expected1) - assert_array_equal(actual[:, 1], expected2) - - @requires_cftime_or_netCDF4 - def test_decode_non_standard_calendar_fallback(self): - cftime = _import_cftime() - # ensure leap year doesn't matter - for year in [2010, 2011, 2012, 2013, 2014]: - for calendar in ['360_day', '366_day', 'all_leap']: - calendar = '360_day' - units = 'days since {0}-01-01'.format(year) - num_times = np.arange(100) - expected = cftime.num2date(num_times, units, calendar) - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - actual = coding.times.decode_cf_datetime(num_times, units, - calendar=calendar) - assert len(w) == 1 - assert 'Unable to decode time axis' in \ - str(w[0].message) - - assert actual.dtype == np.dtype('O') - assert_array_equal(actual, expected) - - @requires_cftime_or_netCDF4 - def test_cf_datetime_nan(self): - for num_dates, units, expected_list in [ - ([np.nan], 'days since 2000-01-01', ['NaT']), - ([np.nan, 0], 'days since 2000-01-01', - ['NaT', '2000-01-01T00:00:00Z']), - ([np.nan, 0, 1], 'days since 2000-01-01', - ['NaT', '2000-01-01T00:00:00Z', '2000-01-02T00:00:00Z']), - ]: + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize( + ['calendar', 'enable_cftimeindex'], + product(_NON_STANDARD_CALENDARS, [False, True])) +def test_decode_non_standard_calendar_single_element_inside_timestamp_range( + calendar, enable_cftimeindex): + if enable_cftimeindex: + pytest.importorskip('cftime') + + units = 'days since 0001-01-01' + for num_time in [735368, [735368], [[735368]]]: + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', + 'Unable to decode time axis') + actual = coding.times.decode_cf_datetime( + num_time, units, calendar=calendar, + enable_cftimeindex=enable_cftimeindex) + if enable_cftimeindex: + assert actual.dtype == np.dtype('O') + else: + assert actual.dtype == np.dtype('M8[ns]') + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize( + ['calendar', 'enable_cftimeindex'], + product(_NON_STANDARD_CALENDARS, [False, True])) +def test_decode_single_element_outside_timestamp_range( + calendar, enable_cftimeindex): + if enable_cftimeindex: + pytest.importorskip('cftime') + + cftime = _import_cftime() + units = 'days since 0001-01-01' + for days in [1, 1470376]: + for num_time in [days, [days], [[days]]]: with warnings.catch_warnings(): - warnings.filterwarnings('ignore', 'All-NaN') - actual = coding.times.decode_cf_datetime(num_dates, units) - expected = np.array(expected_list, dtype='datetime64[ns]') - assert_array_equal(expected, actual) - - @requires_cftime_or_netCDF4 - def test_decoded_cf_datetime_array_2d(self): - # regression test for GH1229 - variable = Variable(('x', 'y'), np.array([[0, 1], [2, 3]]), - {'units': 'days since 2000-01-01'}) - result = coding.times.CFDatetimeCoder().decode(variable) - assert result.dtype == 'datetime64[ns]' - expected = pd.date_range('2000-01-01', periods=4).values.reshape(2, 2) - assert_array_equal(np.asarray(result), expected) - - def test_infer_datetime_units(self): - for dates, expected in [(pd.date_range('1900-01-01', periods=5), - 'days since 1900-01-01 00:00:00'), - (pd.date_range('1900-01-01 12:00:00', freq='H', - periods=2), - 'hours since 1900-01-01 12:00:00'), - (['1900-01-01', '1900-01-02', - '1900-01-02 00:00:01'], - 'seconds since 1900-01-01 00:00:00'), - (pd.to_datetime( - ['1900-01-01', '1900-01-02', 'NaT']), - 'days since 1900-01-01 00:00:00'), - (pd.to_datetime(['1900-01-01', - '1900-01-02T00:00:00.005']), - 'seconds since 1900-01-01 00:00:00'), - (pd.to_datetime(['NaT', '1900-01-01']), - 'days since 1900-01-01 00:00:00'), - (pd.to_datetime(['NaT']), - 'days since 1970-01-01 00:00:00'), - ]: - assert expected == coding.times.infer_datetime_units(dates) + warnings.filterwarnings('ignore', + 'Unable to decode time axis') + actual = coding.times.decode_cf_datetime( + num_time, units, calendar=calendar, + enable_cftimeindex=enable_cftimeindex) + expected = cftime.num2date(days, units, calendar) + assert isinstance(actual.item(), type(expected)) + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize( + ['calendar', 'enable_cftimeindex'], + product(coding.times._STANDARD_CALENDARS, [False, True])) +def test_decode_standard_calendar_multidim_time_inside_timestamp_range( + calendar, enable_cftimeindex): + if enable_cftimeindex: + pytest.importorskip('cftime') + + cftime = _import_cftime() + + units = 'days since 0001-01-01' + times1 = pd.date_range('2001-04-01', end='2001-04-05', freq='D') + times2 = pd.date_range('2001-05-01', end='2001-05-05', freq='D') + noleap_time1 = cftime.date2num(times1.to_pydatetime(), + units, calendar=calendar) + noleap_time2 = cftime.date2num(times2.to_pydatetime(), + units, calendar=calendar) + mdim_time = np.empty((len(noleap_time1), 2), ) + mdim_time[:, 0] = noleap_time1 + mdim_time[:, 1] = noleap_time2 + + expected1 = times1.values + expected2 = times2.values + + actual = coding.times.decode_cf_datetime( + mdim_time, units, calendar=calendar, + enable_cftimeindex=enable_cftimeindex) + assert actual.dtype == np.dtype('M8[ns]') + + abs_diff1 = abs(actual[:, 0] - expected1) + abs_diff2 = abs(actual[:, 1] - expected2) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff1 <= np.timedelta64(1, 's')).all() + assert (abs_diff2 <= np.timedelta64(1, 's')).all() + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize( + ['calendar', 'enable_cftimeindex'], + product(_NON_STANDARD_CALENDARS, [False, True])) +def test_decode_nonstandard_calendar_multidim_time_inside_timestamp_range( + calendar, enable_cftimeindex): + if enable_cftimeindex: + pytest.importorskip('cftime') + + cftime = _import_cftime() + + units = 'days since 0001-01-01' + times1 = pd.date_range('2001-04-01', end='2001-04-05', freq='D') + times2 = pd.date_range('2001-05-01', end='2001-05-05', freq='D') + noleap_time1 = cftime.date2num(times1.to_pydatetime(), + units, calendar=calendar) + noleap_time2 = cftime.date2num(times2.to_pydatetime(), + units, calendar=calendar) + mdim_time = np.empty((len(noleap_time1), 2), ) + mdim_time[:, 0] = noleap_time1 + mdim_time[:, 1] = noleap_time2 + + if enable_cftimeindex: + expected1 = cftime.num2date(noleap_time1, units, calendar) + expected2 = cftime.num2date(noleap_time2, units, calendar) + expected_dtype = np.dtype('O') + else: + expected1 = times1.values + expected2 = times2.values + expected_dtype = np.dtype('M8[ns]') + + actual = coding.times.decode_cf_datetime( + mdim_time, units, calendar=calendar, + enable_cftimeindex=enable_cftimeindex) + + assert actual.dtype == expected_dtype + abs_diff1 = abs(actual[:, 0] - expected1) + abs_diff2 = abs(actual[:, 1] - expected2) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff1 <= np.timedelta64(1, 's')).all() + assert (abs_diff2 <= np.timedelta64(1, 's')).all() + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize( + ['calendar', 'enable_cftimeindex'], + product(_ALL_CALENDARS, [False, True])) +def test_decode_multidim_time_outside_timestamp_range( + calendar, enable_cftimeindex): + from datetime import datetime + + if enable_cftimeindex: + pytest.importorskip('cftime') + + cftime = _import_cftime() + + units = 'days since 0001-01-01' + times1 = [datetime(1, 4, day) for day in range(1, 6)] + times2 = [datetime(1, 5, day) for day in range(1, 6)] + noleap_time1 = cftime.date2num(times1, units, calendar=calendar) + noleap_time2 = cftime.date2num(times2, units, calendar=calendar) + mdim_time = np.empty((len(noleap_time1), 2), ) + mdim_time[:, 0] = noleap_time1 + mdim_time[:, 1] = noleap_time2 + + if enable_cftimeindex: + expected1 = cftime.num2date(noleap_time1, units, calendar, + only_use_cftime_datetimes=True) + expected2 = cftime.num2date(noleap_time2, units, calendar, + only_use_cftime_datetimes=True) + else: + expected1 = cftime.num2date(noleap_time1, units, calendar) + expected2 = cftime.num2date(noleap_time2, units, calendar) + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Unable to decode time axis') + actual = coding.times.decode_cf_datetime( + mdim_time, units, calendar=calendar, + enable_cftimeindex=enable_cftimeindex) + + assert actual.dtype == np.dtype('O') + + abs_diff1 = abs(actual[:, 0] - expected1) + abs_diff2 = abs(actual[:, 1] - expected2) + # once we no longer support versions of netCDF4 older than 1.1.5, + # we could do this check with near microsecond accuracy: + # https://github.com/Unidata/netcdf4-python/issues/355 + assert (abs_diff1 <= np.timedelta64(1, 's')).all() + assert (abs_diff2 <= np.timedelta64(1, 's')).all() + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize( + ['calendar', 'enable_cftimeindex'], + product(['360_day', 'all_leap', '366_day'], [False, True])) +def test_decode_non_standard_calendar_single_element_fallback( + calendar, enable_cftimeindex): + if enable_cftimeindex: + pytest.importorskip('cftime') + + cftime = _import_cftime() + + units = 'days since 0001-01-01' + try: + dt = cftime.netcdftime.datetime(2001, 2, 29) + except AttributeError: + # Must be using standalone netcdftime library + dt = cftime.datetime(2001, 2, 29) + + num_time = cftime.date2num(dt, units, calendar) + if enable_cftimeindex: + actual = coding.times.decode_cf_datetime( + num_time, units, calendar=calendar, + enable_cftimeindex=enable_cftimeindex) + else: + with pytest.warns(SerializationWarning, + match='Unable to decode time axis'): + actual = coding.times.decode_cf_datetime( + num_time, units, calendar=calendar, + enable_cftimeindex=enable_cftimeindex) + + expected = np.asarray(cftime.num2date(num_time, units, calendar)) + assert actual.dtype == np.dtype('O') + assert expected == actual + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize( + ['calendar', 'enable_cftimeindex'], + product(['360_day'], [False, True])) +def test_decode_non_standard_calendar_fallback( + calendar, enable_cftimeindex): + if enable_cftimeindex: + pytest.importorskip('cftime') + + cftime = _import_cftime() + # ensure leap year doesn't matter + for year in [2010, 2011, 2012, 2013, 2014]: + units = 'days since {0}-01-01'.format(year) + num_times = np.arange(100) + expected = cftime.num2date(num_times, units, calendar) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + actual = coding.times.decode_cf_datetime( + num_times, units, calendar=calendar, + enable_cftimeindex=enable_cftimeindex) + if enable_cftimeindex: + assert len(w) == 0 + else: + assert len(w) == 1 + assert 'Unable to decode time axis' in str(w[0].message) + + assert actual.dtype == np.dtype('O') + assert_array_equal(actual, expected) - def test_cf_timedelta(self): - examples = [ - ('1D', 'days', np.int64(1)), - (['1D', '2D', '3D'], 'days', np.array([1, 2, 3], 'int64')), - ('1h', 'hours', np.int64(1)), - ('1ms', 'milliseconds', np.int64(1)), - ('1us', 'microseconds', np.int64(1)), - (['NaT', '0s', '1s'], None, [np.nan, 0, 1]), - (['30m', '60m'], 'hours', [0.5, 1.0]), - (np.timedelta64('NaT', 'ns'), 'days', np.nan), - (['NaT', 'NaT'], 'days', [np.nan, np.nan]), - ] - - for timedeltas, units, numbers in examples: - timedeltas = pd.to_timedelta(timedeltas, box=False) - numbers = np.array(numbers) - - expected = numbers - actual, _ = coding.times.encode_cf_timedelta(timedeltas, units) - assert_array_equal(expected, actual) - assert expected.dtype == actual.dtype - - if units is not None: - expected = timedeltas - actual = coding.times.decode_cf_timedelta(numbers, units) - assert_array_equal(expected, actual) - assert expected.dtype == actual.dtype - - expected = np.timedelta64('NaT', 'ns') - actual = coding.times.decode_cf_timedelta(np.array(np.nan), 'days') - assert_array_equal(expected, actual) - def test_cf_timedelta_2d(self): - timedeltas = ['1D', '2D', '3D'] - units = 'days' - numbers = np.atleast_2d([1, 2, 3]) +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize( + ['num_dates', 'units', 'expected_list'], + [([np.nan], 'days since 2000-01-01', ['NaT']), + ([np.nan, 0], 'days since 2000-01-01', + ['NaT', '2000-01-01T00:00:00Z']), + ([np.nan, 0, 1], 'days since 2000-01-01', + ['NaT', '2000-01-01T00:00:00Z', '2000-01-02T00:00:00Z'])]) +def test_cf_datetime_nan(num_dates, units, expected_list): + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'All-NaN') + actual = coding.times.decode_cf_datetime(num_dates, units) + expected = np.array(expected_list, dtype='datetime64[ns]') + assert_array_equal(expected, actual) + + +@requires_cftime_or_netCDF4 +def test_decoded_cf_datetime_array_2d(): + # regression test for GH1229 + variable = Variable(('x', 'y'), np.array([[0, 1], [2, 3]]), + {'units': 'days since 2000-01-01'}) + result = coding.times.CFDatetimeCoder().decode(variable) + assert result.dtype == 'datetime64[ns]' + expected = pd.date_range('2000-01-01', periods=4).values.reshape(2, 2) + assert_array_equal(np.asarray(result), expected) + + +@pytest.mark.parametrize( + ['dates', 'expected'], + [(pd.date_range('1900-01-01', periods=5), + 'days since 1900-01-01 00:00:00'), + (pd.date_range('1900-01-01 12:00:00', freq='H', + periods=2), + 'hours since 1900-01-01 12:00:00'), + (pd.to_datetime( + ['1900-01-01', '1900-01-02', 'NaT']), + 'days since 1900-01-01 00:00:00'), + (pd.to_datetime(['1900-01-01', + '1900-01-02T00:00:00.005']), + 'seconds since 1900-01-01 00:00:00'), + (pd.to_datetime(['NaT', '1900-01-01']), + 'days since 1900-01-01 00:00:00'), + (pd.to_datetime(['NaT']), + 'days since 1970-01-01 00:00:00')]) +def test_infer_datetime_units(dates, expected): + assert expected == coding.times.infer_datetime_units(dates) + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +def test_infer_cftime_datetime_units(): + date_types = _all_cftime_date_types() + for date_type in date_types.values(): + for dates, expected in [ + ([date_type(1900, 1, 1), + date_type(1900, 1, 2)], + 'days since 1900-01-01 00:00:00.000000'), + ([date_type(1900, 1, 1, 12), + date_type(1900, 1, 1, 13)], + 'seconds since 1900-01-01 12:00:00.000000'), + ([date_type(1900, 1, 1), + date_type(1900, 1, 2), + date_type(1900, 1, 2, 0, 0, 1)], + 'seconds since 1900-01-01 00:00:00.000000'), + ([date_type(1900, 1, 1), + date_type(1900, 1, 2, 0, 0, 0, 5)], + 'days since 1900-01-01 00:00:00.000000')]: + assert expected == coding.times.infer_datetime_units(dates) - timedeltas = np.atleast_2d(pd.to_timedelta(timedeltas, box=False)) - expected = timedeltas +@pytest.mark.parametrize( + ['timedeltas', 'units', 'numbers'], + [('1D', 'days', np.int64(1)), + (['1D', '2D', '3D'], 'days', np.array([1, 2, 3], 'int64')), + ('1h', 'hours', np.int64(1)), + ('1ms', 'milliseconds', np.int64(1)), + ('1us', 'microseconds', np.int64(1)), + (['NaT', '0s', '1s'], None, [np.nan, 0, 1]), + (['30m', '60m'], 'hours', [0.5, 1.0]), + (np.timedelta64('NaT', 'ns'), 'days', np.nan), + (['NaT', 'NaT'], 'days', [np.nan, np.nan])]) +def test_cf_timedelta(timedeltas, units, numbers): + timedeltas = pd.to_timedelta(timedeltas, box=False) + numbers = np.array(numbers) + + expected = numbers + actual, _ = coding.times.encode_cf_timedelta(timedeltas, units) + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + + if units is not None: + expected = timedeltas actual = coding.times.decode_cf_timedelta(numbers, units) assert_array_equal(expected, actual) assert expected.dtype == actual.dtype - def test_infer_timedelta_units(self): - for deltas, expected in [ - (pd.to_timedelta(['1 day', '2 days']), 'days'), - (pd.to_timedelta(['1h', '1 day 1 hour']), 'hours'), - (pd.to_timedelta(['1m', '2m', np.nan]), 'minutes'), - (pd.to_timedelta(['1m3s', '1m4s']), 'seconds')]: - assert expected == coding.times.infer_timedelta_units(deltas) + expected = np.timedelta64('NaT', 'ns') + actual = coding.times.decode_cf_timedelta(np.array(np.nan), 'days') + assert_array_equal(expected, actual) + + +def test_cf_timedelta_2d(): + timedeltas = ['1D', '2D', '3D'] + units = 'days' + numbers = np.atleast_2d([1, 2, 3]) + + timedeltas = np.atleast_2d(pd.to_timedelta(timedeltas, box=False)) + expected = timedeltas + + actual = coding.times.decode_cf_timedelta(numbers, units) + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + + +@pytest.mark.parametrize( + ['deltas', 'expected'], + [(pd.to_timedelta(['1 day', '2 days']), 'days'), + (pd.to_timedelta(['1h', '1 day 1 hour']), 'hours'), + (pd.to_timedelta(['1m', '2m', np.nan]), 'minutes'), + (pd.to_timedelta(['1m3s', '1m4s']), 'seconds')]) +def test_infer_timedelta_units(deltas, expected): + assert expected == coding.times.infer_timedelta_units(deltas) + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize(['date_args', 'expected'], + [((1, 2, 3, 4, 5, 6), + '0001-02-03 04:05:06.000000'), + ((10, 2, 3, 4, 5, 6), + '0010-02-03 04:05:06.000000'), + ((100, 2, 3, 4, 5, 6), + '0100-02-03 04:05:06.000000'), + ((1000, 2, 3, 4, 5, 6), + '1000-02-03 04:05:06.000000')]) +def test_format_cftime_datetime(date_args, expected): + date_types = _all_cftime_date_types() + for date_type in date_types.values(): + result = coding.times.format_cftime_datetime(date_type(*date_args)) + assert result == expected + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize( + ['calendar', 'enable_cftimeindex'], + product(_ALL_CALENDARS, [False, True])) +def test_decode_cf_enable_cftimeindex(calendar, enable_cftimeindex): + days = [1., 2., 3.] + da = DataArray(days, coords=[days], dims=['time'], name='test') + ds = da.to_dataset() + + for v in ['test', 'time']: + ds[v].attrs['units'] = 'days since 2001-01-01' + ds[v].attrs['calendar'] = calendar + + if (not has_cftime and enable_cftimeindex and + calendar not in coding.times._STANDARD_CALENDARS): + with pytest.raises(ValueError): + with set_options(enable_cftimeindex=enable_cftimeindex): + ds = decode_cf(ds) + else: + with set_options(enable_cftimeindex=enable_cftimeindex): + ds = decode_cf(ds) + + if (enable_cftimeindex and + calendar not in coding.times._STANDARD_CALENDARS): + assert ds.test.dtype == np.dtype('O') + else: + assert ds.test.dtype == np.dtype('M8[ns]') + + +@pytest.fixture(params=_ALL_CALENDARS) +def calendar(request): + return request.param + + +@pytest.fixture() +def times(calendar): + cftime = _import_cftime() + + return cftime.num2date( + np.arange(4), units='hours since 2000-01-01', calendar=calendar, + only_use_cftime_datetimes=True) + + +@pytest.fixture() +def data(times): + data = np.random.rand(2, 2, 4) + lons = np.linspace(0, 11, 2) + lats = np.linspace(0, 20, 2) + return DataArray(data, coords=[lons, lats, times], + dims=['lon', 'lat', 'time'], name='data') + + +@pytest.fixture() +def times_3d(times): + lons = np.linspace(0, 11, 2) + lats = np.linspace(0, 20, 2) + times_arr = np.random.choice(times, size=(2, 2, 4)) + return DataArray(times_arr, coords=[lons, lats, times], + dims=['lon', 'lat', 'time'], + name='data') + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_contains_cftime_datetimes_1d(data): + assert contains_cftime_datetimes(data.time) + + +@pytest.mark.skipif(not has_dask, reason='dask not installed') +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_contains_cftime_datetimes_dask_1d(data): + assert contains_cftime_datetimes(data.time.chunk()) + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_contains_cftime_datetimes_3d(times_3d): + assert contains_cftime_datetimes(times_3d) + + +@pytest.mark.skipif(not has_dask, reason='dask not installed') +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_contains_cftime_datetimes_dask_3d(times_3d): + assert contains_cftime_datetimes(times_3d.chunk()) + + +@pytest.mark.parametrize('non_cftime_data', [DataArray([]), DataArray([1, 2])]) +def test_contains_cftime_datetimes_non_cftimes(non_cftime_data): + assert not contains_cftime_datetimes(non_cftime_data) + + +@pytest.mark.skipif(not has_dask, reason='dask not installed') +@pytest.mark.parametrize('non_cftime_data', [DataArray([]), DataArray([1, 2])]) +def test_contains_cftime_datetimes_non_cftimes_dask(non_cftime_data): + assert not contains_cftime_datetimes(non_cftime_data.chunk()) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index e9a2babfa2e..b16cb8ddcea 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -11,14 +11,14 @@ import xarray as xr from xarray import ( - DataArray, Dataset, IndexVariable, Variable, align, broadcast) -from xarray.coding.times import CFDatetimeCoder + DataArray, Dataset, IndexVariable, Variable, align, broadcast, set_options) +from xarray.coding.times import CFDatetimeCoder, _import_cftime from xarray.core.common import full_like from xarray.core.pycompat import OrderedDict, iteritems from xarray.tests import ( ReturnItem, TestCase, assert_allclose, assert_array_equal, assert_equal, assert_identical, raises_regex, requires_bottleneck, requires_dask, - requires_scipy, source_ndarray, unittest) + requires_scipy, source_ndarray, unittest, requires_cftime) class TestDataArray(TestCase): @@ -2208,6 +2208,19 @@ def test_resample(self): with raises_regex(ValueError, 'index must be monotonic'): array[[2, 0, 1]].resample(time='1D') + @requires_cftime + def test_resample_cftimeindex(self): + cftime = _import_cftime() + times = cftime.num2date(np.arange(12), units='hours since 0001-01-01', + calendar='noleap') + with set_options(enable_cftimeindex=True): + array = DataArray(np.arange(12), [('time', times)]) + + with raises_regex(TypeError, + 'Only valid with DatetimeIndex, ' + 'TimedeltaIndex or PeriodIndex'): + array.resample(time='6H').mean() + def test_resample_first(self): times = pd.date_range('2000-01-01', freq='6H', periods=10) array = DataArray(np.arange(10), [('time', times)]) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index aadc452b8a7..70ed1156643 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -9,6 +9,7 @@ import xarray.plot as xplt from xarray import DataArray +from xarray.coding.times import _import_cftime from xarray.plot.plot import _infer_interval_breaks from xarray.plot.utils import ( _build_discrete_cmap, _color_palette, _determine_cmap_params, @@ -16,7 +17,7 @@ from . import ( TestCase, assert_array_equal, assert_equal, raises_regex, - requires_matplotlib, requires_seaborn) + requires_matplotlib, requires_seaborn, requires_cftime) # import mpl and change the backend before other mpl imports try: @@ -1504,3 +1505,24 @@ def test_plot_seaborn_no_import_warning(): with pytest.warns(None) as record: _color_palette('Blues', 4) assert len(record) == 0 + + +@requires_cftime +def test_plot_cftime_coordinate_error(): + cftime = _import_cftime() + time = cftime.num2date(np.arange(5), units='days since 0001-01-01', + calendar='noleap') + data = DataArray(np.arange(5), coords=[time], dims=['time']) + with raises_regex(TypeError, + 'requires coordinates to be numeric or dates'): + data.plot() + + +@requires_cftime +def test_plot_cftime_data_error(): + cftime = _import_cftime() + data = cftime.num2date(np.arange(5), units='days since 0001-01-01', + calendar='noleap') + data = DataArray(data, coords=[np.arange(5)], dims=['x']) + with raises_regex(NotImplementedError, 'cftime.datetime'): + data.plot() diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 3a76b6e8c92..0b3b0ee7dd6 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -4,10 +4,14 @@ import pandas as pd import pytest +from datetime import datetime +from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import duck_array_ops, utils +from xarray.core.options import set_options from xarray.core.pycompat import OrderedDict - -from . import TestCase, assert_array_equal, requires_dask +from .test_coding_times import _all_cftime_date_types +from . import (TestCase, requires_dask, assert_array_equal, + has_cftime_or_netCDF4, has_cftime) class TestAlias(TestCase): @@ -20,20 +24,51 @@ def new_method(): old_method() -class TestSafeCastToIndex(TestCase): - def test(self): - dates = pd.date_range('2000-01-01', periods=10) - x = np.arange(5) - td = x * np.timedelta64(1, 'D') - for expected, array in [ - (dates, dates.values), - (pd.Index(x, dtype=object), x.astype(object)), - (pd.Index(td), td), - (pd.Index(td, dtype=object), td.astype(object)), - ]: - actual = utils.safe_cast_to_index(array) - assert_array_equal(expected, actual) - assert expected.dtype == actual.dtype +def test_safe_cast_to_index(): + dates = pd.date_range('2000-01-01', periods=10) + x = np.arange(5) + td = x * np.timedelta64(1, 'D') + for expected, array in [ + (dates, dates.values), + (pd.Index(x, dtype=object), x.astype(object)), + (pd.Index(td), td), + (pd.Index(td, dtype=object), td.astype(object)), + ]: + actual = utils.safe_cast_to_index(array) + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + + +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize('enable_cftimeindex', [False, True]) +def test_safe_cast_to_index_cftimeindex(enable_cftimeindex): + date_types = _all_cftime_date_types() + for date_type in date_types.values(): + dates = [date_type(1, 1, day) for day in range(1, 20)] + + if enable_cftimeindex and has_cftime: + expected = CFTimeIndex(dates) + else: + expected = pd.Index(dates) + + with set_options(enable_cftimeindex=enable_cftimeindex): + actual = utils.safe_cast_to_index(np.array(dates)) + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + assert isinstance(actual, type(expected)) + + +# Test that datetime.datetime objects are never used in a CFTimeIndex +@pytest.mark.skipif(not has_cftime_or_netCDF4, reason='cftime not installed') +@pytest.mark.parametrize('enable_cftimeindex', [False, True]) +def test_safe_cast_to_index_datetime_datetime(enable_cftimeindex): + dates = [datetime(1, 1, day) for day in range(1, 20)] + + expected = pd.Index(dates) + with set_options(enable_cftimeindex=enable_cftimeindex): + actual = utils.safe_cast_to_index(np.array(dates)) + assert_array_equal(expected, actual) + assert isinstance(actual, pd.Index) def test_multiindex_from_product_levels(): From 91ac573e00538e0372cf9e5f2fdc1528a4ee8cb8 Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Sun, 13 May 2018 07:56:54 -0400 Subject: [PATCH 10/61] Add cftime to doc/environment.yml (#2126) --- doc/environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/environment.yml b/doc/environment.yml index 880151ab2d9..a7683ff1824 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -21,3 +21,4 @@ dependencies: - zarr - iris - flake8 + - cftime From f861186cbd11bdbfb2aab8289118a59283a2d7af Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Mon, 14 May 2018 07:37:47 +0900 Subject: [PATCH 11/61] Reduce pad size in rolling (#2125) --- asv_bench/benchmarks/rolling.py | 9 +++++---- doc/whats-new.rst | 5 ++++- xarray/core/dask_array_ops.py | 2 +- xarray/core/rolling.py | 4 ++-- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/asv_bench/benchmarks/rolling.py b/asv_bench/benchmarks/rolling.py index 3f2a38104de..5ba7406f6e0 100644 --- a/asv_bench/benchmarks/rolling.py +++ b/asv_bench/benchmarks/rolling.py @@ -35,7 +35,7 @@ def setup(self, *args, **kwargs): @parameterized(['func', 'center'], (['mean', 'count'], [True, False])) def time_rolling(self, func, center): - getattr(self.ds.rolling(x=window, center=center), func)() + getattr(self.ds.rolling(x=window, center=center), func)().load() @parameterized(['func', 'pandas'], (['mean', 'count'], [True, False])) @@ -44,19 +44,20 @@ def time_rolling_long(self, func, pandas): se = self.da_long.to_series() getattr(se.rolling(window=window), func)() else: - getattr(self.da_long.rolling(x=window), func)() + getattr(self.da_long.rolling(x=window), func)().load() @parameterized(['window_', 'min_periods'], ([20, 40], [5, None])) def time_rolling_np(self, window_, min_periods): self.ds.rolling(x=window_, center=False, - min_periods=min_periods).reduce(getattr(np, 'nanmean')) + min_periods=min_periods).reduce( + getattr(np, 'nanmean')).load() @parameterized(['center', 'stride'], ([True, False], [1, 200])) def time_rolling_construct(self, center, stride): self.ds.rolling(x=window, center=center).construct( - 'window_dim', stride=stride).mean(dim='window_dim') + 'window_dim', stride=stride).mean(dim='window_dim').load() class RollingDask(Rolling): diff --git a/doc/whats-new.rst b/doc/whats-new.rst index fc5f8bf3266..49d39bacbf8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,6 +34,9 @@ v0.10.4 (unreleased) Enhancements ~~~~~~~~~~~~ +- Slight modification in `rolling` with dask.array and bottleneck. Also, fixed a bug in rolling an + integer dask array. + By `Keisuke Fujii `_. - Add an option for using a ``CFTimeIndex`` for indexing times with non-standard calendars and/or outside the Timestamp-valid range; this index enables a subset of the functionality of a standard @@ -43,7 +46,7 @@ Enhancements - Allow for serialization of ``cftime.datetime`` objects (:issue:`789`, :issue:`1084`, :issue:`2008`, :issue:`1252`) using the standalone ``cftime`` library. By `Spencer Clark - `_. + `_. - Support writing lists of strings as netCDF attributes (:issue:`2044`). By `Dan Nowacki `_. - :py:meth:`~xarray.Dataset.to_netcdf(engine='h5netcdf')` now accepts h5py diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index ee87c3564cc..55ba1c1cbc6 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -19,7 +19,7 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): if axis < 0: axis = a.ndim + axis depth = {d: 0 for d in range(a.ndim)} - depth[axis] = window - 1 + depth[axis] = (window + 1) // 2 boundary = {d: fill_value for d in range(a.ndim)} # create ghosted arrays ag = da.ghost.ghost(a, depth=depth, boundary=boundary) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index f54a4c36631..24ed280b19e 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -294,8 +294,8 @@ def wrapped_func(self, **kwargs): if isinstance(padded.data, dask_array_type): # Workaround to make the padded chunk size is larger than # self.window-1 - shift = - (self.window - 1) - offset = -shift - self.window // 2 + shift = - (self.window + 1) // 2 + offset = (self.window - 1) // 2 valid = (slice(None), ) * axis + ( slice(offset, offset + self.obj.shape[axis]), ) else: From d1b669ec7a1e9a0b9296855f71de72c975ec78e5 Mon Sep 17 00:00:00 2001 From: Edward Betts Date: Mon, 14 May 2018 17:09:03 +0100 Subject: [PATCH 12/61] Correct two github URLs. (#2130) --- doc/examples/multidimensional-coords.rst | 2 +- examples/xarray_multidimensional_coords.ipynb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/examples/multidimensional-coords.rst b/doc/examples/multidimensional-coords.rst index a54e6058921..eed818ba064 100644 --- a/doc/examples/multidimensional-coords.rst +++ b/doc/examples/multidimensional-coords.rst @@ -3,7 +3,7 @@ Working with Multidimensional Coordinates ========================================= -Author: `Ryan Abernathey `__ +Author: `Ryan Abernathey `__ Many datasets have *physical coordinates* which differ from their *logical coordinates*. Xarray provides several ways to plot and analyze diff --git a/examples/xarray_multidimensional_coords.ipynb b/examples/xarray_multidimensional_coords.ipynb index bed7e8b962f..6bd942c5ba2 100644 --- a/examples/xarray_multidimensional_coords.ipynb +++ b/examples/xarray_multidimensional_coords.ipynb @@ -6,7 +6,7 @@ "source": [ "# Working with Multidimensional Coordinates\n", "\n", - "Author: [Ryan Abernathey](http://github.org/rabernat)\n", + "Author: [Ryan Abernathey](https://github.com/rabernat)\n", "\n", "Many datasets have _physical coordinates_ which differ from their _logical coordinates_. Xarray provides several ways to plot and analyze such datasets." ] From 188141fe97a5effacf32f2508fd05b644c720e5d Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Mon, 14 May 2018 15:17:36 -0400 Subject: [PATCH 13/61] Fix datetime.timedelta casting bug in coding.times.infer_datetime_units (#2128) * Fix #2127 * Fix typo in time-series.rst * Use pd.to_timedelta to convert to np.timedelta64 objects * Install cftime through netcdf4 through pip * box=False --- ci/requirements-py27-windows.yml | 3 +-- doc/time-series.rst | 2 +- xarray/coding/times.py | 6 +++++- xarray/tests/test_coding_times.py | 3 +++ 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/ci/requirements-py27-windows.yml b/ci/requirements-py27-windows.yml index 7562874785b..967b7c584b9 100644 --- a/ci/requirements-py27-windows.yml +++ b/ci/requirements-py27-windows.yml @@ -8,7 +8,6 @@ dependencies: - h5py - h5netcdf - matplotlib - - netcdf4 - pathlib2 - pytest - flake8 @@ -21,4 +20,4 @@ dependencies: - rasterio - zarr - pip: - - cftime + - netcdf4 diff --git a/doc/time-series.rst b/doc/time-series.rst index 5b857789629..a7ce9226d4d 100644 --- a/doc/time-series.rst +++ b/doc/time-series.rst @@ -73,7 +73,7 @@ native representation of dates to those that fall between the years 1678 and returned as arrays of ``cftime.datetime`` objects and a ``CFTimeIndex`` can be used for indexing. The ``CFTimeIndex`` enables only a subset of the indexing functionality of a ``pandas.DatetimeIndex`` and is only enabled -when using standalone version of ``cftime`` (not the version packaged with +when using the standalone version of ``cftime`` (not the version packaged with earlier versions ``netCDF4``). See :ref:`CFTimeIndex` for more information. Datetime indexing diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 61314d9cbe6..d946e2ed378 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -253,7 +253,11 @@ def infer_datetime_units(dates): else: reference_date = dates[0] if len(dates) > 0 else '1970-01-01' reference_date = format_cftime_datetime(reference_date) - unique_timedeltas = np.unique(np.diff(dates)).astype('timedelta64[ns]') + unique_timedeltas = np.unique(np.diff(dates)) + if unique_timedeltas.dtype == np.dtype('O'): + # Convert to np.timedelta64 objects using pandas to work around a + # NumPy casting bug: https://github.com/numpy/numpy/issues/11096 + unique_timedeltas = pd.to_timedelta(unique_timedeltas, box=False) units = _infer_time_units_from_diff(unique_timedeltas) return '%s since %s' % (units, reference_date) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 7c1e869f772..6329e91ac78 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -587,6 +587,9 @@ def test_infer_cftime_datetime_units(): 'seconds since 1900-01-01 00:00:00.000000'), ([date_type(1900, 1, 1), date_type(1900, 1, 2, 0, 0, 0, 5)], + 'days since 1900-01-01 00:00:00.000000'), + ([date_type(1900, 1, 1), date_type(1900, 1, 8), + date_type(1900, 1, 16)], 'days since 1900-01-01 00:00:00.000000')]: assert expected == coding.times.infer_datetime_units(dates) From 29d608af6694b37feac48cf369fa547d9fe2d00a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 14 May 2018 15:04:30 -0600 Subject: [PATCH 14/61] Add "awesome xarray" list to faq. (#2118) * Add "awesome xarray" list to faq. * Add whats new entry + bugfix earlier entry. * Fixes + add xrft, xmitgcm. * Add link to xarray github topic. * Add more links plus add some organization. * Remove "points of reference" sentence * Remove subheadings under geosciences. * whats-new bugfix. --- doc/faq.rst | 63 +++++++++++++++++++++++++++++++++++++++++++++++ doc/internals.rst | 17 ------------- doc/whats-new.rst | 10 ++++++-- 3 files changed, 71 insertions(+), 19 deletions(-) diff --git a/doc/faq.rst b/doc/faq.rst index 68670d0f5a4..46f1e20f4e8 100644 --- a/doc/faq.rst +++ b/doc/faq.rst @@ -157,6 +157,69 @@ and CDAT have some great domain specific functionality, and we would love to have support for converting their native objects to and from xarray (see :issue:`37` and :issue:`133`) + +What other projects leverage xarray? +------------------------------------ + +Here are several existing libraries that build functionality upon xarray. + +Geosciences +~~~~~~~~~~~ + +- `aospy `_: Automated analysis and management of gridded climate data. +- `infinite-diff `_: xarray-based finite-differencing, focused on gridded climate/meterology data +- `marc_analysis `_: Analysis package for CESM/MARC experiments and output. +- `MPAS-Analysis `_: Analysis for simulations produced with Model for Prediction Across Scales (MPAS) components and the Accelerated Climate Model for Energy (ACME). +- `OGGM `_: Open Global Glacier Model +- `Oocgcm `_: Analysis of large gridded geophysical datasets +- `Open Data Cube `_: Analysis toolkit of continental scale Earth Observation data from satellites. +- `Pangaea: `_: xarray extension for gridded land surface & weather model output). +- `Pangeo `_: A community effort for big data geoscience in the cloud. +- `PyGDX `_: Python 3 package for + accessing data stored in GAMS Data eXchange (GDX) files. Also uses a custom + subclass. +- `Regionmask `_: plotting and creation of masks of spatial regions +- `salem `_: Adds geolocalised subsetting, masking, and plotting operations to xarray's data structures via accessors. +- `Spyfit `_: FTIR spectroscopy of the atmosphere +- `windspharm `_: Spherical + harmonic wind analysis in Python. +- `wrf-python `_: A collection of diagnostic and interpolation routines for use with output of the Weather Research and Forecasting (WRF-ARW) Model. +- `xarray-simlab `_: xarray extension for computer model simulations. +- `xarray-topo `_: xarray extension for topographic analysis and modelling. +- `xbpch `_: xarray interface for bpch files. +- `xESMF `_: Universal Regridder for Geospatial Data. +- `xgcm `_: Extends the xarray data model to understand finite volume grid cells (common in General Circulation Models) and provides interpolation and difference operations for such grids. +- `xmitgcm `_: a python package for reading `MITgcm `_ binary MDS files into xarray data structures. +- `xshape `_: Tools for working with shapefiles, topographies, and polygons in xarray. + +Machine Learning +~~~~~~~~~~~~~~~~ +- `cesium `_: machine learning for time series analysis +- `Elm `_: Parallel machine learning on xarray data structures +- `sklearn-xarray (1) `_: Combines scikit-learn and xarray (1). +- `sklearn-xarray (2) `_: Combines scikit-learn and xarray (2). + +Extend xarray capabilities +~~~~~~~~~~~~~~~~~~~~~~~~~~ +- `Collocate `_: Collocate xarray trajectories in arbitrary physical dimensions +- `eofs `_: EOF analysis in Python. +- `xarray_extras `_: Advanced algorithms for xarray objects (e.g. intergrations/interpolations). +- `xrft `_: Fourier transforms for xarray data. +- `xr-scipy `_: A lightweight scipy wrapper for xarray. +- `X-regression `_: Multiple linear regression from Statsmodels library coupled with Xarray library. + +Visualization +~~~~~~~~~~~~~ +- `Datashader `_, `geoviews `_, `holoviews `_, : visualization packages for large data +- `psyplot `_: Interactive data visualization with python. + +Other +~~~~~ +- `ptsa `_: EEG Time Series Analysis +- `pycalphad `_: Computational Thermodynamics in Python + +More projects can be found at the `"xarray" Github topic `_. + How should I cite xarray? ------------------------- diff --git a/doc/internals.rst b/doc/internals.rst index e5e14896472..170e2d0b0cc 100644 --- a/doc/internals.rst +++ b/doc/internals.rst @@ -130,20 +130,3 @@ To help users keep things straight, please `let us know `_ if you plan to write a new accessor for an open source library. In the future, we will maintain a list of accessors and the libraries that implement them on this page. - -Here are several existing libraries that build functionality upon xarray. -They may be useful points of reference for your work: - -- `xgcm `_: General Circulation Model - Postprocessing. Uses subclassing and custom xarray backends. -- `PyGDX `_: Python 3 package for - accessing data stored in GAMS Data eXchange (GDX) files. Also uses a custom - subclass. -- `windspharm `_: Spherical - harmonic wind analysis in Python. -- `eofs `_: EOF analysis in Python. -- `salem `_: Adds geolocalised subsetting, - masking, and plotting operations to xarray's data structures via accessors. - -.. TODO: consider adding references to these projects somewhere more prominent -.. in the documentation? maybe the FAQ page? diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 49d39bacbf8..218df9b9707 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -31,6 +31,12 @@ What's New v0.10.4 (unreleased) -------------------- +Documentation +~~~~~~~~~~~~~ +- `FAQ `_ now lists projects that leverage xarray. + By `Deepak Cherian `_. + + Enhancements ~~~~~~~~~~~~ @@ -58,6 +64,8 @@ Enhancements This greatly boosts speed and allows chunking on the core dims. The function now requires dask >= 0.17.3 to work on dask-backed data (:issue:`2074`). By `Guido Imperiale `_. +- ``plot.line()`` learned new kwargs: ``xincrease``, ``yincrease`` that change the direction of the respective axes. + By `Deepak Cherian `_. Bug fixes ~~~~~~~~~ @@ -76,8 +84,6 @@ Bug fixes By `Stephan Hoyer `_. - ``plot.line()`` does not call ``autofmt_xdate()`` anymore. Instead it changes the rotation and horizontal alignment of labels without removing the x-axes of any other subplots in the figure (if any). By `Deepak Cherian `_. -- ``plot.line()`` learned new kwargs: ``xincrease``, ``yincrease`` that change the direction of the respective axes. - By `Deepak Cherian `_. - Colorbar limits are now determined by excluding ±Infs too. By `Deepak Cherian `_. From 218ad549a25fb30b836aabdfdda412450fdc9585 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 14 May 2018 22:06:37 +0100 Subject: [PATCH 15/61] xarray.dot to pass **kwargs to einsum (#2106) * Support for optimize, split_every, etc. * Avoid einsum params that aren't ubiquitously supported * Fix tests for einsum params * Stickler fix * Reinstate test for invalid parameters --- xarray/core/computation.py | 8 ++++---- xarray/tests/test_computation.py | 25 +++++++++++++++++++++++-- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 77a52ac055d..f6e22dfe6c1 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -952,6 +952,9 @@ def dot(*arrays, **kwargs): dims: str or tuple of strings, optional Which dimensions to sum over. If not speciified, then all the common dimensions are summed over. + **kwargs: dict + Additional keyword arguments passed to numpy.einsum or + dask.array.einsum Returns ------- @@ -976,9 +979,6 @@ def dot(*arrays, **kwargs): from .variable import Variable dims = kwargs.pop('dims', None) - if len(kwargs) > 0: - raise TypeError('Invalid keyward arguments {} are given'.format( - list(kwargs.keys()))) if any(not isinstance(arr, (Variable, DataArray)) for arr in arrays): raise TypeError('Only xr.DataArray and xr.Variable are supported.' @@ -1024,7 +1024,7 @@ def dot(*arrays, **kwargs): # subscripts should be passed to np.einsum as arg, not as kwargs. We need # to construct a partial function for apply_ufunc to work. - func = functools.partial(duck_array_ops.einsum, subscripts) + func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs) result = apply_ufunc(func, *arrays, input_core_dims=input_core_dims, output_core_dims=output_core_dims, diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index c84ed17bfd3..c829453cc9d 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1,5 +1,6 @@ import functools import operator +import pickle from collections import OrderedDict from distutils.version import LooseVersion @@ -842,13 +843,33 @@ def test_dot(use_dask): assert actual.dims == ('b', ) assert (actual.data == np.zeros(actual.shape)).all() - with pytest.raises(TypeError): - xr.dot(da_a, dims='a', invalid=None) + # Invalid cases + if not use_dask or LooseVersion(dask.__version__) > LooseVersion('0.17.4'): + with pytest.raises(TypeError): + xr.dot(da_a, dims='a', invalid=None) with pytest.raises(TypeError): xr.dot(da_a.to_dataset(name='da'), dims='a') with pytest.raises(TypeError): xr.dot(dims='a') + # einsum parameters + actual = xr.dot(da_a, da_b, dims=['b'], order='C') + assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + assert actual.values.flags['C_CONTIGUOUS'] + assert not actual.values.flags['F_CONTIGUOUS'] + actual = xr.dot(da_a, da_b, dims=['b'], order='F') + assert (actual.data == np.einsum('ij,ijk->ik', a, b)).all() + # dask converts Fortran arrays to C order when merging the final array + if not use_dask: + assert not actual.values.flags['C_CONTIGUOUS'] + assert actual.values.flags['F_CONTIGUOUS'] + + # einsum has a constant string as of the first parameter, which makes + # it hard to pass to xarray.apply_ufunc. + # make sure dot() uses functools.partial(einsum, subscripts), which + # can be pickled, and not a lambda, which can't. + pickle.loads(pickle.dumps(xr.dot(da_a))) + def test_where(): cond = xr.DataArray([True, False], dims='x') From f4ef34f00902a55f65b82c998d29a4ab8f5b6bf0 Mon Sep 17 00:00:00 2001 From: Alex Hilson Date: Mon, 14 May 2018 23:53:21 +0100 Subject: [PATCH 16/61] Fix to_iris conversion issues (#2111) * TST: assert lazy array maintained by to_iris (#2046) * Add masked_invalid array op, resolves to_iris rechunking issue (#2046) * Fix dask_module in duck_array_ops.masked_invalid * Really fix it * Resolving to_iris dask array issues --- doc/whats-new.rst | 2 ++ xarray/convert.py | 10 ++-------- xarray/core/duck_array_ops.py | 4 ++++ xarray/tests/test_dataarray.py | 9 ++++----- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 218df9b9707..520e38bd80f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -86,6 +86,8 @@ Bug fixes By `Deepak Cherian `_. - Colorbar limits are now determined by excluding ±Infs too. By `Deepak Cherian `_. +- Fixed ``to_iris`` to maintain lazy dask array after conversion (:issue:`2046`). + By `Alex Hilson `_ and `Stephan Hoyer `_. .. _whats-new.0.10.3: diff --git a/xarray/convert.py b/xarray/convert.py index a6defd083bf..a3c99119306 100644 --- a/xarray/convert.py +++ b/xarray/convert.py @@ -6,6 +6,7 @@ from .coding.times import CFDatetimeCoder, CFTimedeltaCoder from .conventions import decode_cf +from .core import duck_array_ops from .core.dataarray import DataArray from .core.dtypes import get_fill_value from .core.pycompat import OrderedDict, range @@ -94,7 +95,6 @@ def to_iris(dataarray): # Iris not a hard dependency import iris from iris.fileformats.netcdf import parse_cell_methods - from xarray.core.pycompat import dask_array_type dim_coords = [] aux_coords = [] @@ -121,13 +121,7 @@ def to_iris(dataarray): args['cell_methods'] = \ parse_cell_methods(dataarray.attrs['cell_methods']) - # Create the right type of masked array (should be easier after #1769) - if isinstance(dataarray.data, dask_array_type): - from dask.array import ma as dask_ma - masked_data = dask_ma.masked_invalid(dataarray) - else: - masked_data = np.ma.masked_invalid(dataarray) - + masked_data = duck_array_ops.masked_invalid(dataarray.data) cube = iris.cube.Cube(masked_data, **args) return cube diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 3a2c123f87e..ef52b4890ef 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -101,6 +101,10 @@ def isnull(data): einsum = _dask_or_eager_func('einsum', array_args=slice(1, None), requires_dask='0.17.3') +masked_invalid = _dask_or_eager_func( + 'masked_invalid', eager_module=np.ma, + dask_module=getattr(dask_array, 'ma', None)) + def asarray(data): return data if isinstance(data, dask_array_type) else np.asarray(data) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index b16cb8ddcea..22bfecebe3c 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3038,12 +3038,11 @@ def test_to_and_from_iris_dask(self): roundtripped = DataArray.from_iris(actual) assert_identical(original, roundtripped) - # If the Iris version supports it then we should get a dask array back + # If the Iris version supports it then we should have a dask array + # at each stage of the conversion if hasattr(actual, 'core_data'): - pass - # TODO This currently fails due to the decoding loading - # the data (#1372) - # self.assertEqual(type(original.data), type(roundtripped.data)) + self.assertEqual(type(original.data), type(actual.core_data())) + self.assertEqual(type(original.data), type(roundtripped.data)) actual.remove_coord('time') auto_time_dimension = DataArray.from_iris(actual) From 9a48157b525d9e346e73f358a99ceb52717fd3ea Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Wed, 16 May 2018 01:39:22 +0900 Subject: [PATCH 17/61] Raise an Error if a coordinate with wrong size is assigned to a dataarray (#2124) * fix * Fix DataArrayCoordinates._update_coords * Update misleading comments --- doc/whats-new.rst | 5 ++++- xarray/core/coordinates.py | 11 +++++++++-- xarray/tests/test_dataarray.py | 7 +++++++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 520e38bd80f..98116aa2a95 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -70,8 +70,11 @@ Enhancements Bug fixes ~~~~~~~~~ +- Now raises an Error if a coordinate with wrong size is assigned to a + :py:class:`~xarray.DataArray`. (:issue:`2112`) + By `Keisuke Fujii `_. - Fixed a bug in `rolling` with bottleneck. Also, fixed a bug in rolling an - integer dask array. (:issue:`21133`) + integer dask array. (:issue:`2113`) By `Keisuke Fujii `_. - Fixed a bug where `keep_attrs=True` flag was neglected if :py:func:`apply_func` was used with :py:class:`Variable`. (:issue:`2114`) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 92d717b9f62..cb22c0b687b 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -9,10 +9,15 @@ from .merge import ( expand_and_merge_variables, merge_coords, merge_coords_for_inplace_math) from .pycompat import OrderedDict -from .utils import Frozen +from .utils import Frozen, ReprObject from .variable import Variable +# Used as the key corresponding to a DataArray's variable when converting +# arbitrary DataArray objects to datasets +_THIS_ARRAY = ReprObject('') + + class AbstractCoordinates(Mapping, formatting.ReprMixin): def __getitem__(self, key): raise NotImplementedError @@ -225,7 +230,9 @@ def __getitem__(self, key): def _update_coords(self, coords): from .dataset import calculate_dimensions - dims = calculate_dimensions(coords) + coords_plus_data = coords.copy() + coords_plus_data[_THIS_ARRAY] = self._data.variable + dims = calculate_dimensions(coords_plus_data) if not set(dims) <= set(self.dims): raise ValueError('cannot add coordinates with new dimensions to ' 'a DataArray') diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 22bfecebe3c..f2e076db78a 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1180,6 +1180,13 @@ def test_assign_coords(self): with raises_regex(ValueError, 'conflicting MultiIndex'): self.mda.assign_coords(level_1=range(4)) + # GH: 2112 + da = xr.DataArray([0, 1, 2], dims='x') + with pytest.raises(ValueError): + da['x'] = [0, 1, 2, 3] # size conflict + with pytest.raises(ValueError): + da.coords['x'] = [0, 1, 2, 3] # size conflict + def test_coords_alignment(self): lhs = DataArray([1, 2, 3], [('x', [0, 1, 2])]) rhs = DataArray([2, 3, 4], [('x', [1, 2, 3])]) From 3df3023c10ede416054bc8282ded858ba736424e Mon Sep 17 00:00:00 2001 From: chiaral Date: Tue, 15 May 2018 22:13:25 -0400 Subject: [PATCH 18/61] DOC: Added text to Assign values with Indexing (#2133) * DOC: Added text to Assign values with Indexing * DOC: Added Warning to xarray.DataArray.sel * DOC: Added Warning to xarray.DataArray.sel fixed length * DOC: Added info on whats-new --- doc/indexing.rst | 32 +++++++++++++++++++++++++++++++- doc/whats-new.rst | 2 ++ xarray/core/dataarray.py | 13 +++++++++++++ 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/doc/indexing.rst b/doc/indexing.rst index 1f6ae006cf7..cec438dd2e4 100644 --- a/doc/indexing.rst +++ b/doc/indexing.rst @@ -398,7 +398,37 @@ These methods may and also be applied to ``Dataset`` objects Assigning values with indexing ------------------------------ -Vectorized indexing can be used to assign values to xarray object. +To select and assign values to a portion of a :py:meth:`~xarray.DataArray` you +can use indexing with ``.loc`` : + +.. ipython:: python + + ds = xr.tutorial.load_dataset('air_temperature') + + #add an empty 2D dataarray + ds['empty']= xr.full_like(ds.air.mean('time'),fill_value=0) + + #modify one grid point using loc() + ds['empty'].loc[dict(lon=260, lat=30)] = 100 + + #modify a 2D region using loc() + lc = ds.coords['lon'] + la = ds.coords['lat'] + ds['empty'].loc[dict(lon=lc[(lc>220)&(lc<260)], lat=la[(la>20)&(la<60)])] = 100 + +or :py:meth:`~xarray.where`: + +.. ipython:: python + + #modify one grid point using xr.where() + ds['empty'] = xr.where((ds.coords['lat']==20)&(ds.coords['lon']==260), 100, ds['empty']) + + #or modify a 2D region using xr.where() + mask = (ds.coords['lat']>20)&(ds.coords['lat']<60)&(ds.coords['lon']>220)&(ds.coords['lon']<260) + ds['empty'] = xr.where(mask, 100, ds['empty']) + + +Vectorized indexing can also be used to assign values to xarray object. .. ipython:: python diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 98116aa2a95..0d9e75ba940 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -35,6 +35,8 @@ Documentation ~~~~~~~~~~~~~ - `FAQ `_ now lists projects that leverage xarray. By `Deepak Cherian `_. +- `Assigning values with indexing `_ now includes examples on how to select and assign values to a :py:class:`~xarray.DataArray`. + By `Chiara Lepore `_. Enhancements diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 1ceaced5961..fc7091dad85 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -758,10 +758,23 @@ def sel(self, method=None, tolerance=None, drop=False, **indexers): """Return a new DataArray whose dataset is given by selecting index labels along the specified dimension(s). + .. warning:: + + Do not try to assign values when using any of the indexing methods + ``isel`` or ``sel``:: + + da = xr.DataArray([0, 1, 2, 3], dims=['x']) + # DO NOT do this + da.isel(x=[0, 1, 2])[1] = -1 + + Assigning values with the chained indexing using ``.sel`` or + ``.isel`` fails silently. + See Also -------- Dataset.sel DataArray.isel + """ ds = self._to_temp_dataset().sel(drop=drop, method=method, tolerance=tolerance, **indexers) From 9f58d509a432f18d9ceb69bdc0808f2cb9b77f6c Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 15 May 2018 21:02:46 -0700 Subject: [PATCH 19/61] Fix test suite with pandas 0.23 (#2136) * Fix test suite with pandas 0.23 * Disable -OO check --- .travis.yml | 4 +++- xarray/tests/test_dataset.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index e375f6fb063..bd53edb0029 100644 --- a/.travis.yml +++ b/.travis.yml @@ -101,7 +101,9 @@ install: - python xarray/util/print_versions.py script: - - python -OO -c "import xarray" + # TODO: restore this check once the upstream pandas issue is fixed: + # https://github.com/pandas-dev/pandas/issues/21071 + # - python -OO -c "import xarray" - if [[ "$CONDA_ENV" == "docs" ]]; then conda install -c conda-forge sphinx sphinx_rtd_theme sphinx-gallery numpydoc; sphinx-build -n -j auto -b html -d _build/doctrees doc _build/html; diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index b99f7ea1eec..3335a55e4ab 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1848,7 +1848,9 @@ def test_drop_index_labels(self): expected = data.isel(x=slice(0, 0)) assert_identical(expected, actual) - with pytest.raises(ValueError): + # This exception raised by pandas changed from ValueError -> KeyError + # in pandas 0.23. + with pytest.raises((ValueError, KeyError)): # not contained in axis data.drop(['c'], dim='x') From 8ef194f2e6f2e68f1f818606d6362ddfe801df1e Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Wed, 16 May 2018 11:05:02 -0400 Subject: [PATCH 20/61] WIP: Compute==False for to_zarr and to_netcdf (#1811) * move backend append logic to the prepare_variable methods * deprecate variables/dimensions/attrs properties on AbstractWritableDataStore * warnings instead of errors for backend properties * use attrs.update when setting zarr attributes * more performance improvements to attributes in zarr backend * fix typo * new set_dimensions method for writable data stores * more fixes for zarr * more tests for zarr and remove append logic for zarr * more tests for zarr and remove append logic for zarr * a few more tweaks to zarr attrs * Add encode methods to writable data stores, fixes for Zarr tests * fix for InMemoryDataStore * fix for unlimited dimensions Scipy Datastores * another patch for scipy * whatsnew * initial commit returning dask futures from to_netcdf and to_zarr methods * ordereddict * address some of rabernats comments, in particular, this commit removes the _DIMENSION_KEY from the zarr_group.attrs * stop skipping zero-dim zarr tests * update minimum zarr version for tests * cleanup a bit before adding tests * tempoary checkin * cleanup implementation of compute=False for to_foo functions, still needs additional tests * docs and more tests, failing tests on h5netcdf backend only * skip h5netcdf/netcdf4 tests in certain places * remove spurious returns * finalize stores when compute=False * more docs, skip h5netcdf netcdf tests, raise informative error for h5netcdf and scipy * cleanup whats-new * reorg dask task graph when using compute=False and save_mfdataset * move compute_false tests to DaskTests class * small doc/style fixes * save api.py --- doc/dask.rst | 15 ++++++++++ doc/whats-new.rst | 19 ++++++++---- xarray/backends/api.py | 42 ++++++++++++++++++++------ xarray/backends/common.py | 12 +++++--- xarray/backends/h5netcdf_.py | 7 +++-- xarray/backends/netCDF4_.py | 4 +-- xarray/backends/scipy_.py | 7 +++-- xarray/backends/zarr.py | 4 +-- xarray/core/dataset.py | 20 +++++++++---- xarray/tests/test_backends.py | 56 +++++++++++++++++++++++++++++++---- 10 files changed, 147 insertions(+), 39 deletions(-) diff --git a/doc/dask.rst b/doc/dask.rst index 8fc0f655023..2d4beea4f70 100644 --- a/doc/dask.rst +++ b/doc/dask.rst @@ -100,6 +100,21 @@ Once you've manipulated a dask array, you can still write a dataset too big to fit into memory back to disk by using :py:meth:`~xarray.Dataset.to_netcdf` in the usual way. +.. ipython:: python + + ds.to_netcdf('manipulated-example-data.nc') + +By setting the ``compute`` argument to ``False``, :py:meth:`~xarray.Dataset.to_netcdf` +will return a dask delayed object that can be computed later. + +.. ipython:: python + + from dask.diagnostics import ProgressBar + # or distributed.progress when using the distributed scheduler + delayed_obj = ds.to_netcdf('manipulated-example-data.nc', compute=False) + with ProgressBar(): + results = delayed_obj.compute() + .. note:: When using dask's distributed scheduler to write NETCDF4 files, diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0d9e75ba940..1b696c4486d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -69,6 +69,19 @@ Enhancements - ``plot.line()`` learned new kwargs: ``xincrease``, ``yincrease`` that change the direction of the respective axes. By `Deepak Cherian `_. +- Added the ``parallel`` option to :py:func:`open_mfdataset`. This option uses + ``dask.delayed`` to parallelize the open and preprocessing steps within + ``open_mfdataset``. This is expected to provide performance improvements when + opening many files, particularly when used in conjunction with dask's + multiprocessing or distributed schedulers (:issue:`1981`). + By `Joe Hamman `_. + +- New ``compute`` option in :py:meth:`~xarray.Dataset.to_netcdf`, + :py:meth:`~xarray.Dataset.to_zarr`, and :py:func:`~xarray.save_mfdataset` to + allow for the lazy computation of netCDF and zarr stores. This feature is + currently only supported by the netCDF4 and zarr backends. (:issue:`1784`). + By `Joe Hamman `_. + Bug fixes ~~~~~~~~~ @@ -104,12 +117,6 @@ The minor release includes a number of bug-fixes and backwards compatible enhanc Enhancements ~~~~~~~~~~~~ -- Added the ``parallel`` option to :py:func:`open_mfdataset`. This option uses - ``dask.delayed`` to parallelize the open and preprocessing steps within - ``open_mfdataset``. This is expected to provide performance improvements when - opening many files, particularly when used in conjunction with dask's - multiprocessing or distributed schedulers (:issue:`1981`). - By `Joe Hamman `_. - :py:meth:`~xarray.DataArray.isin` and :py:meth:`~xarray.Dataset.isin` methods, which test each value in the array for whether it is contained in the supplied list, returning a bool array. See :ref:`selecting values with isin` diff --git a/xarray/backends/api.py b/xarray/backends/api.py index b8cfa3c926a..dec63a85d6e 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -144,6 +144,13 @@ def _get_lock(engine, scheduler, format, path_or_file): return lock +def _finalize_store(write, store): + """ Finalize this store by explicitly syncing and closing""" + del write # ensure writing is done first + store.sync() + store.close() + + def open_dataset(filename_or_obj, group=None, decode_cf=True, mask_and_scale=True, decode_times=True, autoclose=False, concat_characters=True, decode_coords=True, engine=None, @@ -620,7 +627,8 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, - engine=None, writer=None, encoding=None, unlimited_dims=None): + engine=None, writer=None, encoding=None, unlimited_dims=None, + compute=True): """This function creates an appropriate datastore for writing a dataset to disk as a netCDF file @@ -680,19 +688,22 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, unlimited_dims = dataset.encoding.get('unlimited_dims', None) try: dataset.dump_to_store(store, sync=sync, encoding=encoding, - unlimited_dims=unlimited_dims) + unlimited_dims=unlimited_dims, compute=compute) if path_or_file is None: return target.getvalue() finally: if sync and isinstance(path_or_file, basestring): store.close() + if not compute: + import dask + return dask.delayed(_finalize_store)(store.delayed_store, store) + if not sync: return store - def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, - engine=None): + engine=None, compute=True): """Write multiple datasets to disk as netCDF files simultaneously. This function is intended for use with datasets consisting of dask.array @@ -742,6 +753,9 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, default engine is chosen based on available dependencies, with a preference for 'netcdf4' if writing to a file on disk. See `Dataset.to_netcdf` for additional information. + compute: boolean + If true compute immediately, otherwise return a + ``dask.delayed.Delayed`` object that can be computed later. Examples -------- @@ -769,11 +783,17 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, 'datasets, paths and groups arguments to ' 'save_mfdataset') - writer = ArrayWriter() - stores = [to_netcdf(ds, path, mode, format, group, engine, writer) + writer = ArrayWriter() if compute else None + stores = [to_netcdf(ds, path, mode, format, group, engine, writer, + compute=compute) for ds, path, group in zip(datasets, paths, groups)] + + if not compute: + import dask + return dask.delayed(stores) + try: - writer.sync() + delayed = writer.sync(compute=compute) for store in stores: store.sync() finally: @@ -782,7 +802,7 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, def to_zarr(dataset, store=None, mode='w-', synchronizer=None, group=None, - encoding=None): + encoding=None, compute=True): """This function creates an appropriate datastore for writing a dataset to a zarr ztore @@ -803,5 +823,9 @@ def to_zarr(dataset, store=None, mode='w-', synchronizer=None, group=None, # I think zarr stores should always be sync'd immediately # TODO: figure out how to properly handle unlimited_dims - dataset.dump_to_store(store, sync=True, encoding=encoding) + dataset.dump_to_store(store, sync=True, encoding=encoding, compute=compute) + + if not compute: + import dask + return dask.delayed(_finalize_store)(store.delayed_store, store) return store diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 7d8aa8446a2..2961838e85f 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -264,12 +264,15 @@ def add(self, source, target): else: target[...] = source - def sync(self): + def sync(self, compute=True): if self.sources: import dask.array as da - da.store(self.sources, self.targets, lock=self.lock) + delayed_store = da.store(self.sources, self.targets, + lock=self.lock, compute=compute, + flush=True) self.sources = [] self.targets = [] + return delayed_store class AbstractWritableDataStore(AbstractDataStore): @@ -277,6 +280,7 @@ def __init__(self, writer=None, lock=HDF5_LOCK): if writer is None: writer = ArrayWriter(lock=lock) self.writer = writer + self.delayed_store = None def encode(self, variables, attributes): """ @@ -318,11 +322,11 @@ def set_attribute(self, k, v): # pragma: no cover def set_variable(self, k, v): # pragma: no cover raise NotImplementedError - def sync(self): + def sync(self, compute=True): if self._isopen and self._autoclose: # datastore will be reopened during write self.close() - self.writer.sync() + self.delayed_store = self.writer.sync(compute=compute) def store_dataset(self, dataset): """ diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index d34fa2d9267..f9e2b3dece1 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -212,9 +212,12 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, variable.data - def sync(self): + def sync(self, compute=True): + if not compute: + raise NotImplementedError( + 'compute=False is not supported for the h5netcdf backend yet') with self.ensure_open(autoclose=True): - super(H5NetCDFStore, self).sync() + super(H5NetCDFStore, self).sync(compute=compute) self.ds.sync() def close(self): diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index a0f6cbcdd33..14061a0fb08 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -439,9 +439,9 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, variable.data - def sync(self): + def sync(self, compute=True): with self.ensure_open(autoclose=True): - super(NetCDF4DataStore, self).sync() + super(NetCDF4DataStore, self).sync(compute=compute) self.ds.sync() def close(self): diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index ee2c0fbf106..cd84431f6b7 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -219,9 +219,12 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, data - def sync(self): + def sync(self, compute=True): + if not compute: + raise NotImplementedError( + 'compute=False is not supported for the scipy backend yet') with self.ensure_open(autoclose=True): - super(ScipyDataStore, self).sync() + super(ScipyDataStore, self).sync(compute=compute) self.ds.flush() def close(self): diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 83dcbd9a172..343690eaabd 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -342,8 +342,8 @@ def store(self, variables, attributes, *args, **kwargs): AbstractWritableDataStore.store(self, variables, attributes, *args, **kwargs) - def sync(self): - self.writer.sync() + def sync(self, compute=True): + self.delayed_store = self.writer.sync(compute=compute) def open_zarr(store, group=None, synchronizer=None, auto_chunk=True, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bdb2bf86990..a9ec8c16866 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1055,7 +1055,7 @@ def reset_coords(self, names=None, drop=False, inplace=False): return obj def dump_to_store(self, store, encoder=None, sync=True, encoding=None, - unlimited_dims=None): + unlimited_dims=None, compute=True): """Store dataset contents to a backends.*DataStore object.""" if encoding is None: encoding = {} @@ -1074,10 +1074,11 @@ def dump_to_store(self, store, encoder=None, sync=True, encoding=None, store.store(variables, attrs, check_encoding, unlimited_dims=unlimited_dims) if sync: - store.sync() + store.sync(compute=compute) def to_netcdf(self, path=None, mode='w', format=None, group=None, - engine=None, encoding=None, unlimited_dims=None): + engine=None, encoding=None, unlimited_dims=None, + compute=True): """Write dataset contents to a netCDF file. Parameters @@ -1136,16 +1137,20 @@ def to_netcdf(self, path=None, mode='w', format=None, group=None, By default, no dimensions are treated as unlimited dimensions. Note that unlimited_dims may also be set via ``dataset.encoding['unlimited_dims']``. + compute: boolean + If true compute immediately, otherwise return a + ``dask.delayed.Delayed`` object that can be computed later. """ if encoding is None: encoding = {} from ..backends.api import to_netcdf return to_netcdf(self, path, mode, format=format, group=group, engine=engine, encoding=encoding, - unlimited_dims=unlimited_dims) + unlimited_dims=unlimited_dims, + compute=compute) def to_zarr(self, store=None, mode='w-', synchronizer=None, group=None, - encoding=None): + encoding=None, compute=True): """Write dataset contents to a zarr group. .. note:: Experimental @@ -1167,6 +1172,9 @@ def to_zarr(self, store=None, mode='w-', synchronizer=None, group=None, Nested dictionary with variable names as keys and dictionaries of variable specific encodings as values, e.g., ``{'my_variable': {'dtype': 'int16', 'scale_factor': 0.1,}, ...}`` + compute: boolean + If true compute immediately, otherwise return a + ``dask.delayed.Delayed`` object that can be computed later. """ if encoding is None: encoding = {} @@ -1176,7 +1184,7 @@ def to_zarr(self, store=None, mode='w-', synchronizer=None, group=None, "and 'w-'.") from ..backends.api import to_zarr return to_zarr(self, store=store, mode=mode, synchronizer=synchronizer, - group=group, encoding=encoding) + group=group, encoding=encoding, compute=compute) def __unicode__(self): return formatting.dataset_repr(self) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 2d4e5c0f261..95d92cd8b8a 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -159,8 +159,8 @@ def roundtrip_append(self, data, save_kwargs={}, open_kwargs={}, # The save/open methods may be overwritten below def save(self, dataset, path, **kwargs): - dataset.to_netcdf(path, engine=self.engine, format=self.file_format, - **kwargs) + return dataset.to_netcdf(path, engine=self.engine, + format=self.file_format, **kwargs) @contextlib.contextmanager def open(self, path, **kwargs): @@ -709,7 +709,7 @@ def test_roundtrip_endian(self): # should still pass though. assert_identical(ds, actual) - if isinstance(self, NetCDF4DataTest): + if self.engine == 'netcdf4': ds['z'].encoding['endian'] = 'big' with pytest.raises(NotImplementedError): with self.roundtrip(ds) as actual: @@ -902,7 +902,8 @@ def test_open_group(self): open_dataset(tmp_file, group=(1, 2, 3)) def test_open_subgroup(self): - # Create a netCDF file with a dataset within a group within a group + # Create a netCDF file with a dataset stored within a group within a + # group with create_tmp_file() as tmp_file: rootgrp = nc4.Dataset(tmp_file, 'w') foogrp = rootgrp.createGroup('foo') @@ -1232,7 +1233,7 @@ def create_store(self): yield backends.ZarrStore.open_group(store_target, mode='w') def save(self, dataset, store_target, **kwargs): - dataset.to_zarr(store=store_target, **kwargs) + return dataset.to_zarr(store=store_target, **kwargs) @contextlib.contextmanager def open(self, store_target, **kwargs): @@ -1419,6 +1420,19 @@ def test_append_overwrite_values(self): def test_append_with_invalid_dim_raises(self): super(CFEncodedDataTest, self).test_append_with_invalid_dim_raises() + def test_to_zarr_compute_false_roundtrip(self): + from dask.delayed import Delayed + + original = create_test_data().chunk() + + with self.create_zarr_target() as store: + delayed_obj = self.save(original, store, compute=False) + assert isinstance(delayed_obj, Delayed) + delayed_obj.compute() + + with self.open(store) as actual: + assert_identical(original, actual) + @requires_zarr class ZarrDictStoreTest(BaseZarrTest, TestCase): @@ -2227,6 +2241,36 @@ def test_dataarray_compute(self): self.assertTrue(computed._in_memory) assert_allclose(actual, computed, decode_bytes=False) + def test_to_netcdf_compute_false_roundtrip(self): + from dask.delayed import Delayed + + original = create_test_data().chunk() + + with create_tmp_file() as tmp_file: + # dataset, path, **kwargs): + delayed_obj = self.save(original, tmp_file, compute=False) + assert isinstance(delayed_obj, Delayed) + delayed_obj.compute() + + with self.open(tmp_file) as actual: + assert_identical(original, actual) + + def test_save_mfdataset_compute_false_roundtrip(self): + from dask.delayed import Delayed + + original = Dataset({'foo': ('x', np.random.randn(10))}).chunk() + datasets = [original.isel(x=slice(5)), + original.isel(x=slice(5, 10))] + with create_tmp_file() as tmp1: + with create_tmp_file() as tmp2: + delayed_obj = save_mfdataset(datasets, [tmp1, tmp2], + engine=self.engine, compute=False) + assert isinstance(delayed_obj, Delayed) + delayed_obj.compute() + with open_mfdataset([tmp1, tmp2], + autoclose=self.autoclose) as actual: + assert_identical(actual, original) + class DaskTestAutocloseTrue(DaskTest): autoclose = True @@ -2348,7 +2392,7 @@ def open(self, path, **kwargs): yield ds def save(self, dataset, path, **kwargs): - dataset.to_netcdf(path, engine='scipy', **kwargs) + return dataset.to_netcdf(path, engine='scipy', **kwargs) def test_weakrefs(self): example = Dataset({'foo': ('x', np.arange(5.0))}) From 4972dfd84d4e7ed31875b4257492ca84939eda4a Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Wed, 16 May 2018 15:47:59 -0400 Subject: [PATCH 21/61] expose CFTimeIndex to public API (#2141) * expose CFTimeIndex to public API * more docs --- doc/api.rst | 7 +++++++ xarray/__init__.py | 2 ++ xarray/coding/cftimeindex.py | 4 ++++ 3 files changed, 13 insertions(+) diff --git a/doc/api.rst b/doc/api.rst index bce4e0d1c8e..ff708dc4c1b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -522,6 +522,13 @@ GroupByObjects core.groupby.DatasetGroupBy.apply core.groupby.DatasetGroupBy.reduce +Custom Indexes +============== +.. autosummary:: + :toctree: generated/ + + CFTimeIndex + Plotting ======== diff --git a/xarray/__init__.py b/xarray/__init__.py index 1a2bf3fe283..94e8029edbb 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -22,6 +22,8 @@ from .conventions import decode_cf, SerializationWarning +from .coding.cftimeindex import CFTimeIndex + try: from .version import version as __version__ except ImportError: # pragma: no cover diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index fb51ace5d69..5fca14ddbb1 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -135,6 +135,10 @@ def assert_all_valid_date_type(data): class CFTimeIndex(pd.Index): + """Custom Index for working with CF calendars and dates + + All elements of a CFTimeIndex must be cftime.datetime objects. + """ year = _field_accessor('year', 'The year of the datetime') month = _field_accessor('month', 'The month of the datetime') day = _field_accessor('day', 'The days of the datetime') From 954b8d0ce72eadba812821a2e64ae0ef4ceb2767 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 16 May 2018 18:11:00 -0700 Subject: [PATCH 22/61] Doc updates for 0.10.4 release (#2138) * Doc updates for 0.10.4 release * Fix to_netcdf() with engine=h5netcdf entry in whatsnew --- doc/api.rst | 38 ++++++++++++++++++++++++++++++-------- doc/faq.rst | 14 +++++++++----- doc/whats-new.rst | 41 ++++++++++++++++++++++++----------------- 3 files changed, 63 insertions(+), 30 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index ff708dc4c1b..a528496bb6a 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -496,6 +496,19 @@ DataArray methods DataArray.load DataArray.chunk +GroupBy objects +=============== + +.. autosummary:: + :toctree: generated/ + + core.groupby.DataArrayGroupBy + core.groupby.DataArrayGroupBy.apply + core.groupby.DataArrayGroupBy.reduce + core.groupby.DatasetGroupBy + core.groupby.DatasetGroupBy.apply + core.groupby.DatasetGroupBy.reduce + Rolling objects =============== @@ -509,18 +522,27 @@ Rolling objects core.rolling.DatasetRolling.construct core.rolling.DatasetRolling.reduce -GroupByObjects -============== +Resample objects +================ + +Resample objects also implement the GroupBy interface +(methods like ``apply()``, ``reduce()``, ``mean()``, ``sum()``, etc.). .. autosummary:: :toctree: generated/ - core.groupby.DataArrayGroupBy - core.groupby.DataArrayGroupBy.apply - core.groupby.DataArrayGroupBy.reduce - core.groupby.DatasetGroupBy - core.groupby.DatasetGroupBy.apply - core.groupby.DatasetGroupBy.reduce + core.resample.DataArrayResample + core.resample.DataArrayResample.asfreq + core.resample.DataArrayResample.backfill + core.resample.DataArrayResample.interpolate + core.resample.DataArrayResample.nearest + core.resample.DataArrayResample.pad + core.resample.DatasetResample + core.resample.DatasetResample.asfreq + core.resample.DatasetResample.backfill + core.resample.DatasetResample.interpolate + core.resample.DatasetResample.nearest + core.resample.DatasetResample.pad Custom Indexes ============== diff --git a/doc/faq.rst b/doc/faq.rst index 46f1e20f4e8..360cdb50791 100644 --- a/doc/faq.rst +++ b/doc/faq.rst @@ -1,3 +1,5 @@ +.. _faq: + Frequently Asked Questions ========================== @@ -129,8 +131,8 @@ What other netCDF related Python libraries should I know about? `netCDF4-python`__ provides a lower level interface for working with netCDF and OpenDAP datasets in Python. We use netCDF4-python internally in xarray, and have contributed a number of improvements and fixes upstream. xarray -does not yet support all of netCDF4-python's features, such as writing to -netCDF groups or modifying files on-disk. +does not yet support all of netCDF4-python's features, such as modifying files +on-disk. __ https://github.com/Unidata/netcdf4-python @@ -153,10 +155,12 @@ __ http://drclimate.wordpress.com/2014/01/02/a-beginners-guide-to-scripting-with We think the design decisions we have made for xarray (namely, basing it on pandas) make it a faster and more flexible data analysis tool. That said, Iris -and CDAT have some great domain specific functionality, and we would love to -have support for converting their native objects to and from xarray (see -:issue:`37` and :issue:`133`) +and CDAT have some great domain specific functionality, and xarray includes +methods for converting back and forth between xarray and these libraries. See +:py:meth:`~xarray.DataArray.to_iris` and :py:meth:`~xarray.DataArray.to_cdms2` +for more details. +.. _faq.other_projects: What other projects leverage xarray? ------------------------------------ diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1b696c4486d..fbe7fc5edca 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -31,33 +31,37 @@ What's New v0.10.4 (unreleased) -------------------- +The minor release includes a number of bug-fixes and backwards compatible +enhancements. A highlight is ``CFTimeIndex``, which offers support for +non-standard calendars used in climate modeling. + Documentation ~~~~~~~~~~~~~ -- `FAQ `_ now lists projects that leverage xarray. + +- New FAQ entry, :ref:`faq.other_projects`. By `Deepak Cherian `_. -- `Assigning values with indexing `_ now includes examples on how to select and assign values to a :py:class:`~xarray.DataArray`. +- :ref:`assigning_values` now includes examples on how to select and assign + values to a :py:class:`~xarray.DataArray` with ``.loc``. By `Chiara Lepore `_. - Enhancements ~~~~~~~~~~~~ -- Slight modification in `rolling` with dask.array and bottleneck. Also, fixed a bug in rolling an - integer dask array. - By `Keisuke Fujii `_. - Add an option for using a ``CFTimeIndex`` for indexing times with non-standard calendars and/or outside the Timestamp-valid range; this index enables a subset of the functionality of a standard - ``pandas.DatetimeIndex`` (:issue:`789`, :issue:`1084`, :issue:`1252`). + ``pandas.DatetimeIndex``. + See :ref:`CFTimeIndex` for full details. + (:issue:`789`, :issue:`1084`, :issue:`1252`) By `Spencer Clark `_ with help from `Stephan Hoyer `_. - Allow for serialization of ``cftime.datetime`` objects (:issue:`789`, :issue:`1084`, :issue:`2008`, :issue:`1252`) using the standalone ``cftime`` - library. By `Spencer Clark - `_. + library. + By `Spencer Clark `_. - Support writing lists of strings as netCDF attributes (:issue:`2044`). By `Dan Nowacki `_. -- :py:meth:`~xarray.Dataset.to_netcdf(engine='h5netcdf')` now accepts h5py +- :py:meth:`~xarray.Dataset.to_netcdf` with ``engine='h5netcdf'`` now accepts h5py encoding settings ``compression`` and ``compression_opts``, along with the NetCDF4-Python style settings ``gzip=True`` and ``complevel``. This allows using any compression plugin installed in hdf5, e.g. LZF @@ -66,7 +70,8 @@ Enhancements This greatly boosts speed and allows chunking on the core dims. The function now requires dask >= 0.17.3 to work on dask-backed data (:issue:`2074`). By `Guido Imperiale `_. -- ``plot.line()`` learned new kwargs: ``xincrease``, ``yincrease`` that change the direction of the respective axes. +- ``plot.line()`` learned new kwargs: ``xincrease``, ``yincrease`` that change + the direction of the respective axes. By `Deepak Cherian `_. - Added the ``parallel`` option to :py:func:`open_mfdataset`. This option uses @@ -85,14 +90,14 @@ Enhancements Bug fixes ~~~~~~~~~ -- Now raises an Error if a coordinate with wrong size is assigned to a - :py:class:`~xarray.DataArray`. (:issue:`2112`) +- ``ValueError`` is raised when coordinates with the wrong size are assigned to + a :py:class:`DataArray`. (:issue:`2112`) By `Keisuke Fujii `_. -- Fixed a bug in `rolling` with bottleneck. Also, fixed a bug in rolling an - integer dask array. (:issue:`2113`) +- Fixed a bug in :py:meth:`~xarary.DatasArray.rolling` with bottleneck. Also, + fixed a bug in rolling an integer dask array. (:issue:`2113`) By `Keisuke Fujii `_. - Fixed a bug where `keep_attrs=True` flag was neglected if - :py:func:`apply_func` was used with :py:class:`Variable`. (:issue:`2114`) + :py:func:`apply_ufunc` was used with :py:class:`Variable`. (:issue:`2114`) By `Keisuke Fujii `_. - When assigning a :py:class:`DataArray` to :py:class:`Dataset`, any conflicted non-dimensional coordinates of the DataArray are now dropped. @@ -100,7 +105,9 @@ Bug fixes By `Keisuke Fujii `_. - Better error handling in ``open_mfdataset`` (:issue:`2077`). By `Stephan Hoyer `_. -- ``plot.line()`` does not call ``autofmt_xdate()`` anymore. Instead it changes the rotation and horizontal alignment of labels without removing the x-axes of any other subplots in the figure (if any). +- ``plot.line()`` does not call ``autofmt_xdate()`` anymore. Instead it changes + the rotation and horizontal alignment of labels without removing the x-axes of + any other subplots in the figure (if any). By `Deepak Cherian `_. - Colorbar limits are now determined by excluding ±Infs too. By `Deepak Cherian `_. From 5d7304ea49dc04d7ce0d11947437fb0ad1fbd001 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 16 May 2018 18:12:15 -0700 Subject: [PATCH 23/61] Release v0.10.4 --- doc/whats-new.rst | 4 ++-- setup.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index fbe7fc5edca..9a12f07a914 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -28,8 +28,8 @@ What's New .. _whats-new.0.10.4: -v0.10.4 (unreleased) --------------------- +v0.10.4 (May 16, 2018) +---------------------- The minor release includes a number of bug-fixes and backwards compatible enhancements. A highlight is ``CFTimeIndex``, which offers support for diff --git a/setup.py b/setup.py index c7c02c90e2f..c5e2ba831b7 100644 --- a/setup.py +++ b/setup.py @@ -8,8 +8,8 @@ MAJOR = 0 MINOR = 10 -MICRO = 3 -ISRELEASED = False +MICRO = 4 +ISRELEASED = True VERSION = '%d.%d.%d' % (MAJOR, MINOR, MICRO) QUALIFIER = '' From 008c2c8e7544b0d8ea4e2fecde5625afabe6ea63 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 16 May 2018 18:17:27 -0700 Subject: [PATCH 24/61] Revert to dev version for 0.10.5 --- doc/whats-new.rst | 15 +++++++++++++++ setup.py | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9a12f07a914..48abb892350 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -26,6 +26,21 @@ What's New - `Tips on porting to Python 3 `__ +.. _whats-new.0.10.5: + +v0.10.5 (unreleased) +-------------------- + +Documentation +~~~~~~~~~~~~~ + +Enhancements +~~~~~~~~~~~~ + +Bug fixes +~~~~~~~~~ + + .. _whats-new.0.10.4: v0.10.4 (May 16, 2018) diff --git a/setup.py b/setup.py index c5e2ba831b7..b5130958c00 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ MAJOR = 0 MINOR = 10 MICRO = 4 -ISRELEASED = True +ISRELEASED = False VERSION = '%d.%d.%d' % (MAJOR, MINOR, MICRO) QUALIFIER = '' From 0a766b38de1d11f4c3110b267db72cb73e238d07 Mon Sep 17 00:00:00 2001 From: Katrin Leinweber <9948149+katrinleinweber@users.noreply.github.com> Date: Thu, 17 May 2018 15:16:47 +0200 Subject: [PATCH 25/61] Hyperlink DOI against preferred resolver (#2147) --- doc/faq.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/faq.rst b/doc/faq.rst index 360cdb50791..9d763f1c15f 100644 --- a/doc/faq.rst +++ b/doc/faq.rst @@ -266,5 +266,5 @@ would certainly appreciate it. We recommend two citations. month = aug, year = 2016, doi = {10.5281/zenodo.59499}, - url = {http://dx.doi.org/10.5281/zenodo.59499} + url = {https://doi.org/10.5281/zenodo.59499} } From 7bab27cc637a60bff2b510d4f4a419c9754eeaa3 Mon Sep 17 00:00:00 2001 From: Fabien Maussion Date: Thu, 17 May 2018 18:55:30 +0200 Subject: [PATCH 26/61] Add favicon to docs? (#2146) --- doc/_static/favicon.ico | Bin 0 -> 4286 bytes doc/conf.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 doc/_static/favicon.ico diff --git a/doc/_static/favicon.ico b/doc/_static/favicon.ico new file mode 100644 index 0000000000000000000000000000000000000000..a1536e3ef76bf14b1cf5357b3cc3b9e63de94b80 GIT binary patch literal 4286 zcmds4dsLHG62CDV)UCRp*0!rqtV%$vJQVVPibl}tYFoE<-DoMvtx8;nO2NVV2nB7Ko^^fgZ%%m*CeaiMrjk>2PD2UA2VjaZ)X=du%>?& z9JpH$8no9dHY#KOk-Zu6gWsgdIYy<+%=9n!u4#Gjg>o(LGYP&BwJhaBMW%|Y)yrl z*?LYRE_?g$=FD=deJP;Y5BTq!g_ zou(VAFowF4Zq(hy?b^cUP*Zpx0(=zUFR6g_;!=zu?aJ8SO%X9QSg_BlUa-(7z_vaG z_W4CMn_z1K5nUJq z#^HDQ!dsIgR?Tv~*6t}afK2ifB;pR3F(WU@#x`Pn%2AClEwS7{^=&BY;T+e?x_K;= z2g*6&wPNfCLR@!A-2p+qO7N4EgTJ%_W2lC89>riiL9&3YyVT7>9s%3djRRxY+jjR4cJZa@xQT`_Er;5dvkE5P)?EkX7|rzwYrz%=vQ`_6%5fr!j`$d%CeM`}VrK z*0so_ov8l_cwrq}T`ub!r(Y1AM1G$CTx0&LDF0Y{;y300N`ao|fqH`KP@dZc8+`LI z?^@)q;$ywItj~G^^83|3XInaeZSTbx48Vfhb>m!P_PD#&w|L47AQe9X>H}Bj>SmGs zX;JF2GGThs#jd)NUTCcN5o+<=HI~1C+G6ZSZ7*szp2s>(4>V|BU@u~gO7Q&6&x8PJ zCC|SxhRVLN6C|j=vYYFt9@vI4Y-!V3yT$CC?NW>RC;d-BhIMdtxzg<9lp_k?a8S5? zQ(Wh%0~K)cP#K)~s|HRVDu?g(Yv9x&EhK$g0;djY(JqD4vFMK}gX1wOSo>Z&eBq&i z4W5^H3>ZtF70-W_frUOc=AP;h(xJ0h~Bu=N(ihkQSPxF#|D4n(X}z!x$Vte2GX7_b-nULWn?!^E5A z%)kC$EaExVpa-S~&6^cm*Alk+=HKQCIcbe>?zl6lKN{$Tm> zuLF}|N5C0~-k6S<0=qVx?Z-6O8JNoLZ-P?cE5DQQZIBisR%K%h1$Yih=?wIZeQk=I z&NQ37VG#>#`h}l?;CuI&B-Dr8_S24s&YpGS??(NXnjIX{L`r3nP^CKYqNAe}9zVGV z_wE6CE&x)Kf_mP16C81@9^K-i~QutuVU@uTcb{6;e0 zL!YoUg1$!P$&|Fd3i~O_V!}9f4>}BTh)kKxQiO^-77=R ze2#5x8bmGTqZq#bz8DT1@PWicF?_Yn6T;RU`{_^KJ8p~|HcrABrkLr3?2VaX=J25> z`7JH^$39kZ&YwRh%+5Y+Xl~8{olb>&XK_x-gHD$ZO-&aN^Dr;;JRkFt z1NHR^*uP&6@$nKkaY6=>JM7yREn|+f@t5;c`2P3~HX`1Ij14mxmX`eWXEIr4W}Gl5 zC*D9hTU(2G-U*4brA3YLWb?WZH|2)bH8kYGp+i0p7bk{<1SuRp?g^2R&iWA}EGC&Z z?w<(QFe@DI{Mv7_v$NxcO64&_Yil92x0fQXiX-(vTU!Z_hmg3b2DGlJsQ~rQ=lYKx zl>&{VBu}ul9ksx;?@Krq=k_nzgQh0^EiEcu^V_#`IhyoR9gMnFh(%m~V`Kgx{geZ& zMW2n$sKNbbWFmh4Jzx-Vzmb)71obDN{sPpk;dv*$)ce$fRBuAkO}$5Xxp6}UF){PG zZqiQtM~@EhHOaE@<7&n?+pzkt1R} z12WD}`X^3&Z9aC};09xCNdDm9Rl;rCKCe(JcSBazu0Dk#3bM1KXbuWRG~&*FBzrex zX70r2yW#7v=Wt%qPBCB}DF4VWutV`vahR?l`Nxg36ukTHo34?OOKUSSg5c~~zrM7z z07y&o#}NSO=>vm&WD};Od;vRm%FOzSf6^pge>`qXgu_(xA^A-~u3P8r92>jps#+Ze z`T4%Q~yc(u>Vc{Hyrn+o~Jy#);~x1?UkGy zXm+k$J9p-R1E1Cw6@?RLUtyt<$A}?7Nb%4x`Ocr;z+*^EH0w9*M==n;6vt$f^HmrX zwcPApy?TywOw3=dQ65O|tB|_+*RH!@Bz1_6FIAI86*|LWNw zGcxicv%PAS*HnLhdFYlc3nMmfUKnl~X!H4|O}^nFA&bH{ZCVf>9BjlDD?Gvk0%rbA zX5Nem=;PO7!2a*E$jA@*Lu>2N0xvHIiw{1SWwB(5o5j+lgsx^>yx7I!qmSGz7AsGE90KnuHp2TM8|D2Y8{_>m JuEW8Z{{ueS$#MVy literal 0 HcmV?d00001 diff --git a/doc/conf.py b/doc/conf.py index 0fd5eaf05d7..36c0d42b808 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -175,7 +175,7 @@ # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +html_favicon = '_static/favicon.ico' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, From ecb10e347bbe0f0e4bab8a358f406923e5468dcf Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Fri, 18 May 2018 07:48:10 -0700 Subject: [PATCH 27/61] fix unlimited dims bug (#2154) --- doc/whats-new.rst | 4 +++- xarray/backends/api.py | 3 +++ xarray/tests/test_backends.py | 12 ++++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 48abb892350..fe75507e59e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -39,7 +39,9 @@ Enhancements Bug fixes ~~~~~~~~~ - +- Fixed a bug where `to_netcdf(..., unlimited_dims='bar'` yielded NetCDF files + with spurious 0-length dimensions (i.e. `b`, `a`, and `r`) (:issue:`2134`). + By `Joe Hamman `_. .. _whats-new.0.10.4: diff --git a/xarray/backends/api.py b/xarray/backends/api.py index dec63a85d6e..c3b2aa59fcd 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -686,6 +686,9 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, if unlimited_dims is None: unlimited_dims = dataset.encoding.get('unlimited_dims', None) + if isinstance(unlimited_dims, basestring): + unlimited_dims = [unlimited_dims] + try: dataset.dump_to_store(store, sync=sync, encoding=encoding, unlimited_dims=unlimited_dims, compute=compute) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 95d92cd8b8a..513f5f0834e 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1653,11 +1653,23 @@ def test_encoding_unlimited_dims(self): self.assertEqual(actual.encoding['unlimited_dims'], set('y')) assert_equal(ds, actual) + # Regression test for https://github.com/pydata/xarray/issues/2134 + with self.roundtrip(ds, + save_kwargs=dict(unlimited_dims='y')) as actual: + self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert_equal(ds, actual) + ds.encoding = {'unlimited_dims': ['y']} with self.roundtrip(ds) as actual: self.assertEqual(actual.encoding['unlimited_dims'], set('y')) assert_equal(ds, actual) + # Regression test for https://github.com/pydata/xarray/issues/2134 + ds.encoding = {'unlimited_dims': 'y'} + with self.roundtrip(ds) as actual: + self.assertEqual(actual.encoding['unlimited_dims'], set('y')) + assert_equal(ds, actual) + class GenericNetCDFDataTestAutocloseTrue(GenericNetCDFDataTest): autoclose = True From c346d3b7bcdbd6073cf96fdeb0710467a284a611 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 18 May 2018 09:30:42 -1000 Subject: [PATCH 28/61] Bug fixes for Dataset.reduce() and n-dimensional cumsum/cumprod (#2156) * Bug fixes for Dataset.reduce() and n-dimensional cumsum/cumprod Fixes GH1470, "Dataset.mean drops coordinates" Fixes a bug where non-scalar data-variables that did not include the aggregated dimension were not properly reduced: Previously:: >>> ds = Dataset({'x': ('a', [2, 2]), 'y': 2, 'z': ('b', [2])}) >>> ds.var('a') Dimensions: (b: 1) Dimensions without coordinates: b Data variables: x float64 0.0 y float64 0.0 z (b) int64 2 Now:: >>> ds.var('a') Dimensions: (b: 1) Dimensions without coordinates: b Data variables: x int64 0 y int64 0 z (b) int64 0 Finally, adds support for n-dimensional cumsum() and cumprod(), reducing over all dimensions of an array. (This was necessary as part of the above fix.) * Lint fixup * remove confusing comments --- doc/whats-new.rst | 12 ++++++++ xarray/core/dataset.py | 38 +++++++++++------------ xarray/core/duck_array_ops.py | 36 +++++++++++++++++----- xarray/core/variable.py | 5 --- xarray/tests/test_dataarray.py | 5 +++ xarray/tests/test_dataset.py | 34 +++++++++++++++----- xarray/tests/test_duck_array_ops.py | 48 +++++++++++++++++++++++++++++ 7 files changed, 140 insertions(+), 38 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index fe75507e59e..7df47488e21 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,12 +37,24 @@ Documentation Enhancements ~~~~~~~~~~~~ +- :py:meth:`~DataArray.cumsum` and :py:meth:`~DataArray.cumprod` now support + aggregation over multiple dimensions at the same time. This is the default + behavior when dimensions are not specified (previously this raised an error). + By `Stephan Hoyer `_ + Bug fixes ~~~~~~~~~ + - Fixed a bug where `to_netcdf(..., unlimited_dims='bar'` yielded NetCDF files with spurious 0-length dimensions (i.e. `b`, `a`, and `r`) (:issue:`2134`). By `Joe Hamman `_. +- Aggregations with :py:meth:`Dataset.reduce` (including ``mean``, ``sum``, + etc) no longer drop unrelated coordinates (:issue:`1470`). Also fixed a + bug where non-scalar data-variables that did not include the aggregation + dimension were improperly skipped. + By `Stephan Hoyer `_ + .. _whats-new.0.10.4: v0.10.4 (May 16, 2018) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a9ec8c16866..fff11dedb01 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2594,26 +2594,26 @@ def reduce(self, func, dim=None, keep_attrs=False, numeric_only=False, variables = OrderedDict() for name, var in iteritems(self._variables): reduce_dims = [dim for dim in var.dims if dim in dims] - if reduce_dims or not var.dims: - if name not in self.coords: - if (not numeric_only or - np.issubdtype(var.dtype, np.number) or - (var.dtype == np.bool_)): - if len(reduce_dims) == 1: - # unpack dimensions for the benefit of functions - # like np.argmin which can't handle tuple arguments - reduce_dims, = reduce_dims - elif len(reduce_dims) == var.ndim: - # prefer to aggregate over axis=None rather than - # axis=(0, 1) if they will be equivalent, because - # the former is often more efficient - reduce_dims = None - variables[name] = var.reduce(func, dim=reduce_dims, - keep_attrs=keep_attrs, - allow_lazy=allow_lazy, - **kwargs) + if name in self.coords: + if not reduce_dims: + variables[name] = var else: - variables[name] = var + if (not numeric_only or + np.issubdtype(var.dtype, np.number) or + (var.dtype == np.bool_)): + if len(reduce_dims) == 1: + # unpack dimensions for the benefit of functions + # like np.argmin which can't handle tuple arguments + reduce_dims, = reduce_dims + elif len(reduce_dims) == var.ndim: + # prefer to aggregate over axis=None rather than + # axis=(0, 1) if they will be equivalent, because + # the former is often more efficient + reduce_dims = None + variables[name] = var.reduce(func, dim=reduce_dims, + keep_attrs=keep_attrs, + allow_lazy=allow_lazy, + **kwargs) coord_names = set(k for k in self.coords if k in variables) attrs = self.attrs if keep_attrs else None diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index ef52b4890ef..69b0d0825be 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -281,8 +281,7 @@ def _nanvar_object(value, axis=None, **kwargs): def _create_nan_agg_method(name, numeric_only=False, np_compat=False, - no_bottleneck=False, coerce_strings=False, - keep_dims=False): + no_bottleneck=False, coerce_strings=False): def f(values, axis=None, skipna=None, **kwargs): if kwargs.pop('out', None) is not None: raise TypeError('`out` is not valid for {}'.format(name)) @@ -343,7 +342,6 @@ def f(values, axis=None, skipna=None, **kwargs): 'or newer to use skipna=True or skipna=None' % name) raise NotImplementedError(msg) f.numeric_only = numeric_only - f.keep_dims = keep_dims f.__name__ = name return f @@ -358,10 +356,34 @@ def f(values, axis=None, skipna=None, **kwargs): var = _create_nan_agg_method('var', numeric_only=True) median = _create_nan_agg_method('median', numeric_only=True) prod = _create_nan_agg_method('prod', numeric_only=True, no_bottleneck=True) -cumprod = _create_nan_agg_method('cumprod', numeric_only=True, np_compat=True, - no_bottleneck=True, keep_dims=True) -cumsum = _create_nan_agg_method('cumsum', numeric_only=True, np_compat=True, - no_bottleneck=True, keep_dims=True) +cumprod_1d = _create_nan_agg_method( + 'cumprod', numeric_only=True, np_compat=True, no_bottleneck=True) +cumsum_1d = _create_nan_agg_method( + 'cumsum', numeric_only=True, np_compat=True, no_bottleneck=True) + + +def _nd_cum_func(cum_func, array, axis, **kwargs): + array = asarray(array) + if axis is None: + axis = tuple(range(array.ndim)) + if isinstance(axis, int): + axis = (axis,) + + out = array + for ax in axis: + out = cum_func(out, axis=ax, **kwargs) + return out + + +def cumprod(array, axis=None, **kwargs): + """N-dimensional version of cumprod.""" + return _nd_cum_func(cumprod_1d, array, axis, **kwargs) + + +def cumsum(array, axis=None, **kwargs): + """N-dimensional version of cumsum.""" + return _nd_cum_func(cumsum_1d, array, axis, **kwargs) + _fail_on_dask_array_input_skipna = partial( fail_on_dask_array_input, diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 622ac60d7f6..9dcb99459d4 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1256,11 +1256,6 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False, if dim is not None and axis is not None: raise ValueError("cannot supply both 'axis' and 'dim' arguments") - if getattr(func, 'keep_dims', False): - if dim is None and axis is None: - raise ValueError("must supply either single 'dim' or 'axis' " - "argument to %s" % (func.__name__)) - if dim is not None: axis = self.get_axis_num(dim) data = func(self.data if allow_lazy else self.values, diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index f2e076db78a..35e270f0db7 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1767,6 +1767,11 @@ def test_cumops(self): orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=['x', 'y']) + actual = orig.cumsum() + expected = DataArray([[-1, -1, 0], [-4, -4, 0]], coords, + dims=['x', 'y']) + assert_identical(expected, actual) + actual = orig.cumsum('x') expected = DataArray([[-1, 0, 1], [-4, 0, 4]], coords, dims=['x', 'y']) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 3335a55e4ab..76e41c43c6d 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -3331,7 +3331,18 @@ def test_reduce(self): assert_equal(data.mean(dim=[]), data) - # uint support + def test_reduce_coords(self): + # regression test for GH1470 + data = xr.Dataset({'a': ('x', [1, 2, 3])}, coords={'b': 4}) + expected = xr.Dataset({'a': 2}, coords={'b': 4}) + actual = data.mean('x') + assert_identical(actual, expected) + + # should be consistent + actual = data['a'].mean('x').to_dataset() + assert_identical(actual, expected) + + def test_mean_uint_dtype(self): data = xr.Dataset({'a': (('x', 'y'), np.arange(6).reshape(3, 2).astype('uint')), 'b': (('x', ), np.array([0.1, 0.2, np.nan]))}) @@ -3345,15 +3356,20 @@ def test_reduce_bad_dim(self): with raises_regex(ValueError, 'Dataset does not contain'): data.mean(dim='bad_dim') + def test_reduce_cumsum(self): + data = xr.Dataset({'a': 1, + 'b': ('x', [1, 2]), + 'c': (('x', 'y'), [[np.nan, 3], [0, 4]])}) + assert_identical(data.fillna(0), data.cumsum('y')) + + expected = xr.Dataset({'a': 1, + 'b': ('x', [1, 3]), + 'c': (('x', 'y'), [[0, 3], [0, 7]])}) + assert_identical(expected, data.cumsum()) + def test_reduce_cumsum_test_dims(self): data = create_test_data() for cumfunc in ['cumsum', 'cumprod']: - with raises_regex(ValueError, - "must supply either single 'dim' or 'axis'"): - getattr(data, cumfunc)() - with raises_regex(ValueError, - "must supply either single 'dim' or 'axis'"): - getattr(data, cumfunc)(dim=['dim1', 'dim2']) with raises_regex(ValueError, 'Dataset does not contain'): getattr(data, cumfunc)(dim='bad_dim') @@ -3460,6 +3476,10 @@ def test_reduce_scalars(self): actual = ds.var() assert_identical(expected, actual) + expected = Dataset({'x': 0, 'y': 0, 'z': ('b', [0])}) + actual = ds.var('a') + assert_identical(expected, actual) + def test_reduce_only_one_axis(self): def mean_only_one_axis(x, axis): diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 2983e1991f1..3f4adee6713 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -8,6 +8,7 @@ import warnings from xarray import DataArray, concat +from xarray.core import duck_array_ops from xarray.core.duck_array_ops import ( array_notnull_equiv, concatenate, count, first, last, mean, rolling_window, stack, where) @@ -103,6 +104,53 @@ def test_all_nan_arrays(self): assert np.isnan(mean([np.nan, np.nan])) +def test_cumsum_1d(): + inputs = np.array([0, 1, 2, 3]) + expected = np.array([0, 1, 3, 6]) + actual = duck_array_ops.cumsum(inputs) + assert_array_equal(expected, actual) + + actual = duck_array_ops.cumsum(inputs, axis=0) + assert_array_equal(expected, actual) + + actual = duck_array_ops.cumsum(inputs, axis=-1) + assert_array_equal(expected, actual) + + actual = duck_array_ops.cumsum(inputs, axis=(0,)) + assert_array_equal(expected, actual) + + actual = duck_array_ops.cumsum(inputs, axis=()) + assert_array_equal(inputs, actual) + + +def test_cumsum_2d(): + inputs = np.array([[1, 2], [3, 4]]) + + expected = np.array([[1, 3], [4, 10]]) + actual = duck_array_ops.cumsum(inputs) + assert_array_equal(expected, actual) + + actual = duck_array_ops.cumsum(inputs, axis=(0, 1)) + assert_array_equal(expected, actual) + + actual = duck_array_ops.cumsum(inputs, axis=()) + assert_array_equal(inputs, actual) + + +def test_cumprod_2d(): + inputs = np.array([[1, 2], [3, 4]]) + + expected = np.array([[1, 2], [3, 2 * 3 * 4]]) + actual = duck_array_ops.cumprod(inputs) + assert_array_equal(expected, actual) + + actual = duck_array_ops.cumprod(inputs, axis=(0, 1)) + assert_array_equal(expected, actual) + + actual = duck_array_ops.cumprod(inputs, axis=()) + assert_array_equal(inputs, actual) + + class TestArrayNotNullEquiv(): @pytest.mark.parametrize("arr1, arr2", [ (np.array([1, 2, 3]), np.array([1, 2, 3])), From 585b9a7913d98e26c28b4f1da599c1c6db551362 Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Sun, 20 May 2018 16:14:02 -0700 Subject: [PATCH 29/61] Versioneer (#2163) * add versioneer to simplify/standardize package versioning * reorg __init__.py for version import * fix for docs * what is new --- .gitattributes | 1 + HOW_TO_RELEASE | 28 +- MANIFEST.in | 2 + doc/conf.py | 2 +- doc/whats-new.rst | 4 + setup.cfg | 11 + setup.py | 87 +-- versioneer.py | 1822 ++++++++++++++++++++++++++++++++++++++++++++ xarray/__init__.py | 11 +- xarray/_version.py | 520 +++++++++++++ 10 files changed, 2380 insertions(+), 108 deletions(-) create mode 100644 versioneer.py create mode 100644 xarray/_version.py diff --git a/.gitattributes b/.gitattributes index a52f4ca283a..daa5b82874e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,3 @@ # reduce the number of merge conflicts doc/whats-new.rst merge=union +xarray/_version.py export-subst diff --git a/HOW_TO_RELEASE b/HOW_TO_RELEASE index f1fee59e177..cdfcace809a 100644 --- a/HOW_TO_RELEASE +++ b/HOW_TO_RELEASE @@ -7,21 +7,20 @@ Time required: about an hour. 2. Look over whats-new.rst and the docs. Make sure "What's New" is complete (check the date!) and add a brief summary note describing the release at the top. - 3. Update the version in setup.py and switch to `ISRELEASED = True`. - 4. If you have any doubts, run the full test suite one final time! + 3. If you have any doubts, run the full test suite one final time! py.test - 5. On the master branch, commit the release in git: + 4. On the master branch, commit the release in git: git commit -a -m 'Release v0.X.Y' - 6. Tag the release: + 5. Tag the release: git tag -a v0.X.Y -m 'v0.X.Y' - 7. Build source and binary wheels for pypi: + 6. Build source and binary wheels for pypi: python setup.py bdist_wheel sdist - 8. Use twine to register and upload the release on pypi. Be careful, you can't + 7. Use twine to register and upload the release on pypi. Be careful, you can't take this back! twine upload dist/xarray-0.X.Y* You will need to be listed as a package owner at https://pypi.python.org/pypi/xarray for this to work. - 9. Push your changes to master: + 8. Push your changes to master: git push upstream master git push upstream --tags 9. Update the stable branch (used by ReadTheDocs) and switch back to master: @@ -32,25 +31,22 @@ Time required: about an hour. It's OK to force push to 'stable' if necessary. We also update the stable branch with `git cherrypick` for documentation only fixes that apply the current released version. -10. Revert ISRELEASED in setup.py back to False. Don't change the version - number: in normal development, we keep the version number in setup.py as the - last released version. -11. Add a section for the next release (v.X.(Y+1)) to doc/whats-new.rst. -12. Commit your changes and push to master again: +10. Add a section for the next release (v.X.(Y+1)) to doc/whats-new.rst. +11. Commit your changes and push to master again: git commit -a -m 'Revert to dev version' git push upstream master You're done pushing to master! -13. Issue the release on GitHub. Click on "Draft a new release" at +12. Issue the release on GitHub. Click on "Draft a new release" at https://github.com/pydata/xarray/releases and paste in the latest from whats-new.rst. -14. Update the docs. Login to https://readthedocs.org/projects/xray/versions/ +13. Update the docs. Login to https://readthedocs.org/projects/xray/versions/ and switch your new release tag (at the bottom) from "Inactive" to "Active". It should now build automatically. -15. Update conda-forge. Clone https://github.com/conda-forge/xarray-feedstock +14. Update conda-forge. Clone https://github.com/conda-forge/xarray-feedstock and update the version number and sha256 in meta.yaml. (On OS X, you can calculate sha256 with `shasum -a 256 xarray-0.X.Y.tar.gz`). Submit a pull request (and merge it, once CI passes). -16. Issue the release announcement! For bug fix releases, I usually only email +15. Issue the release announcement! For bug fix releases, I usually only email xarray@googlegroups.com. For major/feature releases, I will email a broader list (no more than once every 3-6 months): pydata@googlegroups.com, xarray@googlegroups.com, diff --git a/MANIFEST.in b/MANIFEST.in index a49c49cd396..a006660e5fb 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,3 +4,5 @@ recursive-include doc * prune doc/_build prune doc/generated global-exclude .DS_Store +include versioneer.py +include xarray/_version.py diff --git a/doc/conf.py b/doc/conf.py index 36c0d42b808..5fd3bece3bd 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -102,7 +102,7 @@ # built documents. # # The short X.Y version. -version = xarray.version.short_version +version = xarray.__version__.split('+')[0] # The full version, including alpha/beta/rc tags. release = xarray.__version__ diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 7df47488e21..4c9a1415e26 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -42,6 +42,10 @@ Enhancements behavior when dimensions are not specified (previously this raised an error). By `Stephan Hoyer `_ +- Xarray now uses `Versioneer `__ + to manage its version strings. (:issue:`1300`). + By `Joe Hamman `_. + Bug fixes ~~~~~~~~~ diff --git a/setup.cfg b/setup.cfg index ec30a10b242..850551b3579 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,3 +15,14 @@ exclude= default_section=THIRDPARTY known_first_party=xarray multi_line_output=4 + +[versioneer] +VCS = git +style = pep440 +versionfile_source = xarray/_version.py +versionfile_build = xarray/_version.py +tag_prefix = +parentdir_prefix = xarray- + +[aliases] +test = pytest diff --git a/setup.py b/setup.py index b5130958c00..77c6083f52c 100644 --- a/setup.py +++ b/setup.py @@ -1,17 +1,9 @@ #!/usr/bin/env python -import os -import re import sys -import warnings from setuptools import find_packages, setup -MAJOR = 0 -MINOR = 10 -MICRO = 4 -ISRELEASED = False -VERSION = '%d.%d.%d' % (MAJOR, MINOR, MICRO) -QUALIFIER = '' +import versioneer DISTNAME = 'xarray' @@ -65,83 +57,10 @@ - SciPy2015 talk: https://www.youtube.com/watch?v=X0pAhJgySxk """ # noqa -# Code to extract and write the version copied from pandas. -# Used under the terms of pandas's license, see licenses/PANDAS_LICENSE. -FULLVERSION = VERSION -write_version = True - -if not ISRELEASED: - import subprocess - FULLVERSION += '.dev' - - pipe = None - for cmd in ['git', 'git.cmd']: - try: - pipe = subprocess.Popen( - [cmd, "describe", "--always", "--match", "v[0-9]*"], - stdout=subprocess.PIPE) - (so, serr) = pipe.communicate() - if pipe.returncode == 0: - break - except BaseException: - pass - - if pipe is None or pipe.returncode != 0: - # no git, or not in git dir - if os.path.exists('xarray/version.py'): - warnings.warn( - "WARNING: Couldn't get git revision," - " using existing xarray/version.py") - write_version = False - else: - warnings.warn( - "WARNING: Couldn't get git revision," - " using generic version string") - else: - # have git, in git dir, but may have used a shallow clone (travis does - # this) - rev = so.strip() - # makes distutils blow up on Python 2.7 - if sys.version_info[0] >= 3: - rev = rev.decode('ascii') - - if not rev.startswith('v') and re.match("[a-zA-Z0-9]{7,9}", rev): - # partial clone, manually construct version string - # this is the format before we started using git-describe - # to get an ordering on dev version strings. - rev = "v%s+dev.%s" % (VERSION, rev) - - # Strip leading v from tags format "vx.y.z" to get th version string - FULLVERSION = rev.lstrip('v') - - # make sure we respect PEP 440 - FULLVERSION = FULLVERSION.replace("-", "+dev", 1).replace("-", ".") - -else: - FULLVERSION += QUALIFIER - - -def write_version_py(filename=None): - cnt = """\ -version = '%s' -short_version = '%s' -""" - if not filename: - filename = os.path.join( - os.path.dirname(__file__), 'xarray', 'version.py') - - a = open(filename, 'w') - try: - a.write(cnt % (FULLVERSION, VERSION)) - finally: - a.close() - - -if write_version: - write_version_py() setup(name=DISTNAME, - version=FULLVERSION, + version=versioneer.get_version(), + cmdclass=versioneer.get_cmdclass(), license=LICENSE, author=AUTHOR, author_email=AUTHOR_EMAIL, diff --git a/versioneer.py b/versioneer.py new file mode 100644 index 00000000000..64fea1c8927 --- /dev/null +++ b/versioneer.py @@ -0,0 +1,1822 @@ + +# Version: 0.18 + +"""The Versioneer - like a rocketeer, but for versions. + +The Versioneer +============== + +* like a rocketeer, but for versions! +* https://github.com/warner/python-versioneer +* Brian Warner +* License: Public Domain +* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy +* [![Latest Version] +(https://pypip.in/version/versioneer/badge.svg?style=flat) +](https://pypi.python.org/pypi/versioneer/) +* [![Build Status] +(https://travis-ci.org/warner/python-versioneer.png?branch=master) +](https://travis-ci.org/warner/python-versioneer) + +This is a tool for managing a recorded version number in distutils-based +python projects. The goal is to remove the tedious and error-prone "update +the embedded version string" step from your release process. Making a new +release should be as easy as recording a new tag in your version-control +system, and maybe making new tarballs. + + +## Quick Install + +* `pip install versioneer` to somewhere to your $PATH +* add a `[versioneer]` section to your setup.cfg (see below) +* run `versioneer install` in your source tree, commit the results + +## Version Identifiers + +Source trees come from a variety of places: + +* a version-control system checkout (mostly used by developers) +* a nightly tarball, produced by build automation +* a snapshot tarball, produced by a web-based VCS browser, like github's + "tarball from tag" feature +* a release tarball, produced by "setup.py sdist", distributed through PyPI + +Within each source tree, the version identifier (either a string or a number, +this tool is format-agnostic) can come from a variety of places: + +* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows + about recent "tags" and an absolute revision-id +* the name of the directory into which the tarball was unpacked +* an expanded VCS keyword ($Id$, etc) +* a `_version.py` created by some earlier build step + +For released software, the version identifier is closely related to a VCS +tag. Some projects use tag names that include more than just the version +string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool +needs to strip the tag prefix to extract the version identifier. For +unreleased software (between tags), the version identifier should provide +enough information to help developers recreate the same tree, while also +giving them an idea of roughly how old the tree is (after version 1.2, before +version 1.3). Many VCS systems can report a description that captures this, +for example `git describe --tags --dirty --always` reports things like +"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the +0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has +uncommitted changes. + +The version identifier is used for multiple purposes: + +* to allow the module to self-identify its version: `myproject.__version__` +* to choose a name and prefix for a 'setup.py sdist' tarball + +## Theory of Operation + +Versioneer works by adding a special `_version.py` file into your source +tree, where your `__init__.py` can import it. This `_version.py` knows how to +dynamically ask the VCS tool for version information at import time. + +`_version.py` also contains `$Revision$` markers, and the installation +process marks `_version.py` to have this marker rewritten with a tag name +during the `git archive` command. As a result, generated tarballs will +contain enough information to get the proper version. + +To allow `setup.py` to compute a version too, a `versioneer.py` is added to +the top level of your source tree, next to `setup.py` and the `setup.cfg` +that configures it. This overrides several distutils/setuptools commands to +compute the version when invoked, and changes `setup.py build` and `setup.py +sdist` to replace `_version.py` with a small static file that contains just +the generated version data. + +## Installation + +See [INSTALL.md](./INSTALL.md) for detailed installation instructions. + +## Version-String Flavors + +Code which uses Versioneer can learn about its version string at runtime by +importing `_version` from your main `__init__.py` file and running the +`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can +import the top-level `versioneer.py` and run `get_versions()`. + +Both functions return a dictionary with different flavors of version +information: + +* `['version']`: A condensed version string, rendered using the selected + style. This is the most commonly used value for the project's version + string. The default "pep440" style yields strings like `0.11`, + `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section + below for alternative styles. + +* `['full-revisionid']`: detailed revision identifier. For Git, this is the + full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". + +* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the + commit date in ISO 8601 format. This will be None if the date is not + available. + +* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that + this is only accurate if run in a VCS checkout, otherwise it is likely to + be False or None + +* `['error']`: if the version string could not be computed, this will be set + to a string describing the problem, otherwise it will be None. It may be + useful to throw an exception in setup.py if this is set, to avoid e.g. + creating tarballs with a version string of "unknown". + +Some variants are more useful than others. Including `full-revisionid` in a +bug report should allow developers to reconstruct the exact code being tested +(or indicate the presence of local changes that should be shared with the +developers). `version` is suitable for display in an "about" box or a CLI +`--version` output: it can be easily compared against release notes and lists +of bugs fixed in various releases. + +The installer adds the following text to your `__init__.py` to place a basic +version in `YOURPROJECT.__version__`: + + from ._version import get_versions + __version__ = get_versions()['version'] + del get_versions + +## Styles + +The setup.cfg `style=` configuration controls how the VCS information is +rendered into a version string. + +The default style, "pep440", produces a PEP440-compliant string, equal to the +un-prefixed tag name for actual releases, and containing an additional "local +version" section with more detail for in-between builds. For Git, this is +TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags +--dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the +tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and +that this commit is two revisions ("+2") beyond the "0.11" tag. For released +software (exactly equal to a known tag), the identifier will only contain the +stripped tag, e.g. "0.11". + +Other styles are available. See [details.md](details.md) in the Versioneer +source tree for descriptions. + +## Debugging + +Versioneer tries to avoid fatal errors: if something goes wrong, it will tend +to return a version of "0+unknown". To investigate the problem, run `setup.py +version`, which will run the version-lookup code in a verbose mode, and will +display the full contents of `get_versions()` (including the `error` string, +which may help identify what went wrong). + +## Known Limitations + +Some situations are known to cause problems for Versioneer. This details the +most significant ones. More can be found on Github +[issues page](https://github.com/warner/python-versioneer/issues). + +### Subprojects + +Versioneer has limited support for source trees in which `setup.py` is not in +the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are +two common reasons why `setup.py` might not be in the root: + +* Source trees which contain multiple subprojects, such as + [Buildbot](https://github.com/buildbot/buildbot), which contains both + "master" and "slave" subprojects, each with their own `setup.py`, + `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI + distributions (and upload multiple independently-installable tarballs). +* Source trees whose main purpose is to contain a C library, but which also + provide bindings to Python (and perhaps other langauges) in subdirectories. + +Versioneer will look for `.git` in parent directories, and most operations +should get the right version string. However `pip` and `setuptools` have bugs +and implementation details which frequently cause `pip install .` from a +subproject directory to fail to find a correct version string (so it usually +defaults to `0+unknown`). + +`pip install --editable .` should work correctly. `setup.py install` might +work too. + +Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in +some later version. + +[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking +this issue. The discussion in +[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the +issue from the Versioneer side in more detail. +[pip PR#3176](https://github.com/pypa/pip/pull/3176) and +[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve +pip to let Versioneer work correctly. + +Versioneer-0.16 and earlier only looked for a `.git` directory next to the +`setup.cfg`, so subprojects were completely unsupported with those releases. + +### Editable installs with setuptools <= 18.5 + +`setup.py develop` and `pip install --editable .` allow you to install a +project into a virtualenv once, then continue editing the source code (and +test) without re-installing after every change. + +"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a +convenient way to specify executable scripts that should be installed along +with the python package. + +These both work as expected when using modern setuptools. When using +setuptools-18.5 or earlier, however, certain operations will cause +`pkg_resources.DistributionNotFound` errors when running the entrypoint +script, which must be resolved by re-installing the package. This happens +when the install happens with one version, then the egg_info data is +regenerated while a different version is checked out. Many setup.py commands +cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into +a different virtualenv), so this can be surprising. + +[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes +this one, but upgrading to a newer version of setuptools should probably +resolve it. + +### Unicode version strings + +While Versioneer works (and is continually tested) with both Python 2 and +Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. +Newer releases probably generate unicode version strings on py2. It's not +clear that this is wrong, but it may be surprising for applications when then +write these strings to a network connection or include them in bytes-oriented +APIs like cryptographic checksums. + +[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates +this question. + + +## Updating Versioneer + +To upgrade your project to a new release of Versioneer, do the following: + +* install the new Versioneer (`pip install -U versioneer` or equivalent) +* edit `setup.cfg`, if necessary, to include any new configuration settings + indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. +* re-run `versioneer install` in your source tree, to replace + `SRC/_version.py` +* commit any changed files + +## Future Directions + +This tool is designed to make it easily extended to other version-control +systems: all VCS-specific components are in separate directories like +src/git/ . The top-level `versioneer.py` script is assembled from these +components by running make-versioneer.py . In the future, make-versioneer.py +will take a VCS name as an argument, and will construct a version of +`versioneer.py` that is specific to the given VCS. It might also take the +configuration arguments that are currently provided manually during +installation by editing setup.py . Alternatively, it might go the other +direction and include code from all supported VCS systems, reducing the +number of intermediate scripts. + + +## License + +To make Versioneer easier to embed, all its code is dedicated to the public +domain. The `_version.py` that it creates is also in the public domain. +Specifically, both are released under the Creative Commons "Public Domain +Dedication" license (CC0-1.0), as described in +https://creativecommons.org/publicdomain/zero/1.0/ . + +""" + +from __future__ import print_function +try: + import configparser +except ImportError: + import ConfigParser as configparser +import errno +import json +import os +import re +import subprocess +import sys + + +class VersioneerConfig: + """Container for Versioneer configuration parameters.""" + + +def get_root(): + """Get the project root directory. + + We require that all commands are run from the project root, i.e. the + directory that contains setup.py, setup.cfg, and versioneer.py . + """ + root = os.path.realpath(os.path.abspath(os.getcwd())) + setup_py = os.path.join(root, "setup.py") + versioneer_py = os.path.join(root, "versioneer.py") + if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + # allow 'python path/to/setup.py COMMAND' + root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) + setup_py = os.path.join(root, "setup.py") + versioneer_py = os.path.join(root, "versioneer.py") + if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + err = ("Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND').") + raise VersioneerBadRootError(err) + try: + # Certain runtime workflows (setup.py install/develop in a setuptools + # tree) execute all dependencies in a single python process, so + # "versioneer" may be imported multiple times, and python's shared + # module-import table will cache the first one. So we can't use + # os.path.dirname(__file__), as that will find whichever + # versioneer.py was first imported, even in later projects. + me = os.path.realpath(os.path.abspath(__file__)) + me_dir = os.path.normcase(os.path.splitext(me)[0]) + vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) + if me_dir != vsr_dir: + print("Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(me), versioneer_py)) + except NameError: + pass + return root + + +def get_config_from_root(root): + """Read the project setup.cfg file to determine Versioneer config.""" + # This might raise EnvironmentError (if setup.cfg is missing), or + # configparser.NoSectionError (if it lacks a [versioneer] section), or + # configparser.NoOptionError (if it lacks "VCS="). See the docstring at + # the top of versioneer.py for instructions on writing your setup.cfg . + setup_cfg = os.path.join(root, "setup.cfg") + parser = configparser.SafeConfigParser() + with open(setup_cfg, "r") as f: + parser.readfp(f) + VCS = parser.get("versioneer", "VCS") # mandatory + + def get(parser, name): + if parser.has_option("versioneer", name): + return parser.get("versioneer", name) + return None + cfg = VersioneerConfig() + cfg.VCS = VCS + cfg.style = get(parser, "style") or "" + cfg.versionfile_source = get(parser, "versionfile_source") + cfg.versionfile_build = get(parser, "versionfile_build") + cfg.tag_prefix = get(parser, "tag_prefix") + if cfg.tag_prefix in ("''", '""'): + cfg.tag_prefix = "" + cfg.parentdir_prefix = get(parser, "parentdir_prefix") + cfg.verbose = get(parser, "verbose") + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +# these dictionaries contain VCS-specific tools +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method): # decorator + """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + return decorate + + +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, + env=None): + """Call the given command(s).""" + assert isinstance(commands, list) + p = None + for c in commands: + try: + dispcmd = str([c] + args) + # remember shell=False, so use git.cmd on windows, not just git + p = subprocess.Popen([c] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None)) + break + except EnvironmentError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %s" % dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %s" % (commands,)) + return None, None + stdout = p.communicate()[0].strip() + if sys.version_info[0] >= 3: + stdout = stdout.decode() + if p.returncode != 0: + if verbose: + print("unable to run %s (error)" % dispcmd) + print("stdout was %s" % stdout) + return None, p.returncode + return stdout, p.returncode + + +LONG_VERSION_PY['git'] = ''' +# This file helps to compute a version number in source trees obtained from +# git-archive tarball (such as those provided by githubs download-from-tag +# feature). Distribution tarballs (built by setup.py sdist) and build +# directories (produced by setup.py build) will contain a much shorter file +# that just contains the computed version number. + +# This file is released into the public domain. Generated by +# versioneer-0.18 (https://github.com/warner/python-versioneer) + +"""Git implementation of _version.py.""" + +import errno +import os +import re +import subprocess +import sys + + +def get_keywords(): + """Get the keywords needed to look up the version information.""" + # these strings will be replaced by git during git-archive. + # setup.py/versioneer.py will grep for the variable names, so they must + # each be defined on a line of their own. _version.py will just call + # get_keywords(). + git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" + git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" + git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" + keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} + return keywords + + +class VersioneerConfig: + """Container for Versioneer configuration parameters.""" + + +def get_config(): + """Create, populate and return the VersioneerConfig() object.""" + # these strings are filled in when 'setup.py versioneer' creates + # _version.py + cfg = VersioneerConfig() + cfg.VCS = "git" + cfg.style = "%(STYLE)s" + cfg.tag_prefix = "%(TAG_PREFIX)s" + cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" + cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" + cfg.verbose = False + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method): # decorator + """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + return decorate + + +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, + env=None): + """Call the given command(s).""" + assert isinstance(commands, list) + p = None + for c in commands: + try: + dispcmd = str([c] + args) + # remember shell=False, so use git.cmd on windows, not just git + p = subprocess.Popen([c] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None)) + break + except EnvironmentError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %%s" %% dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %%s" %% (commands,)) + return None, None + stdout = p.communicate()[0].strip() + if sys.version_info[0] >= 3: + stdout = stdout.decode() + if p.returncode != 0: + if verbose: + print("unable to run %%s (error)" %% dispcmd) + print("stdout was %%s" %% stdout) + return None, p.returncode + return stdout, p.returncode + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for i in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + else: + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print("Tried directories %%s but none started with prefix %%s" %% + (str(rootdirs), parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + f = open(versionfile_abs, "r") + for line in f.readlines(): + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + f.close() + except EnvironmentError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if not keywords: + raise NotThisMethod("no keywords at all, weird") + date = keywords.get("date") + if date is not None: + # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = set([r.strip() for r in refnames.strip("()").split(",")]) + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %%d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = set([r for r in refs if re.search(r'\d', r)]) + if verbose: + print("discarding '%%s', no digits" %% ",".join(refs - tags)) + if verbose: + print("likely tags: %%s" %% ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + if verbose: + print("picking %%s" %% r) + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=True) + if rc != 0: + if verbose: + print("Directory %%s not under git control" %% root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", + "--always", "--long", + "--match", "%%s*" %% tag_prefix], + cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparseable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%%s'" + %% describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%%s' doesn't start with prefix '%%s'" + print(fmt %% (full_tag, tag_prefix)) + pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" + %% (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], + cwd=root) + pieces["distance"] = int(count_out) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], + cwd=root)[0].strip() + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_pre(pieces): + """TAG[.post.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post.devDISTANCE + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += ".post.dev%%d" %% pieces["distance"] + else: + # exception #1 + rendered = "0.post.dev%%d" %% pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%%s" %% pieces["short"] + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%%s" %% pieces["short"] + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Eexceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%%s'" %% style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} + + +def get_versions(): + """Get version information or return default if unable to do so.""" + # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have + # __file__, we can work backwards from there to the root. Some + # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which + # case we can only use expanded keywords. + + cfg = get_config() + verbose = cfg.verbose + + try: + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, + verbose) + except NotThisMethod: + pass + + try: + root = os.path.realpath(__file__) + # versionfile_source is the relative path from the top of the source + # tree (where the .git directory might live) to this file. Invert + # this to find the root from __file__. + for i in cfg.versionfile_source.split('/'): + root = os.path.dirname(root) + except NameError: + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None} + + try: + pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) + return render(pieces, cfg.style) + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + except NotThisMethod: + pass + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", "date": None} +''' + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + f = open(versionfile_abs, "r") + for line in f.readlines(): + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + f.close() + except EnvironmentError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if not keywords: + raise NotThisMethod("no keywords at all, weird") + date = keywords.get("date") + if date is not None: + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = set([r.strip() for r in refnames.strip("()").split(",")]) + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = set([r for r in refs if re.search(r'\d', r)]) + if verbose: + print("discarding '%s', no digits" % ",".join(refs - tags)) + if verbose: + print("likely tags: %s" % ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + if verbose: + print("picking %s" % r) + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=True) + if rc != 0: + if verbose: + print("Directory %s not under git control" % root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", + "--always", "--long", + "--match", "%s*" % tag_prefix], + cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparseable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%s' doesn't start with prefix '%s'" + print(fmt % (full_tag, tag_prefix)) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], + cwd=root) + pieces["distance"] = int(count_out) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], + cwd=root)[0].strip() + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def do_vcs_install(manifest_in, versionfile_source, ipy): + """Git-specific installation logic for Versioneer. + + For Git, this means creating/changing .gitattributes to mark _version.py + for export-subst keyword substitution. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + files = [manifest_in, versionfile_source] + if ipy: + files.append(ipy) + try: + me = __file__ + if me.endswith(".pyc") or me.endswith(".pyo"): + me = os.path.splitext(me)[0] + ".py" + versioneer_file = os.path.relpath(me) + except NameError: + versioneer_file = "versioneer.py" + files.append(versioneer_file) + present = False + try: + f = open(".gitattributes", "r") + for line in f.readlines(): + if line.strip().startswith(versionfile_source): + if "export-subst" in line.strip().split()[1:]: + present = True + f.close() + except EnvironmentError: + pass + if not present: + f = open(".gitattributes", "a+") + f.write("%s export-subst\n" % versionfile_source) + f.close() + files.append(".gitattributes") + run_command(GITS, ["add", "--"] + files) + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for i in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + else: + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +SHORT_VERSION_PY = """ +# This file was generated by 'versioneer.py' (0.18) from +# revision-control system data, or from the parent directory name of an +# unpacked source archive. Distribution tarballs contain a pre-generated copy +# of this file. + +import json + +version_json = ''' +%s +''' # END VERSION_JSON + + +def get_versions(): + return json.loads(version_json) +""" + + +def versions_from_file(filename): + """Try to determine the version from _version.py if present.""" + try: + with open(filename) as f: + contents = f.read() + except EnvironmentError: + raise NotThisMethod("unable to read _version.py") + mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", + contents, re.M | re.S) + if not mo: + mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", + contents, re.M | re.S) + if not mo: + raise NotThisMethod("no version_json in _version.py") + return json.loads(mo.group(1)) + + +def write_to_version_file(filename, versions): + """Write the given version number to the given _version.py file.""" + os.unlink(filename) + contents = json.dumps(versions, sort_keys=True, + indent=1, separators=(",", ": ")) + with open(filename, "w") as f: + f.write(SHORT_VERSION_PY % contents) + + print("set %s to '%s'" % (filename, versions["version"])) + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_pre(pieces): + """TAG[.post.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post.devDISTANCE + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += ".post.dev%d" % pieces["distance"] + else: + # exception #1 + rendered = "0.post.dev%d" % pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Eexceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%s'" % style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} + + +class VersioneerBadRootError(Exception): + """The project root directory is unknown or missing key files.""" + + +def get_versions(verbose=False): + """Get the project version from whatever source is available. + + Returns dict with two keys: 'version' and 'full'. + """ + if "versioneer" in sys.modules: + # see the discussion in cmdclass.py:get_cmdclass() + del sys.modules["versioneer"] + + root = get_root() + cfg = get_config_from_root(root) + + assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" + handlers = HANDLERS.get(cfg.VCS) + assert handlers, "unrecognized VCS '%s'" % cfg.VCS + verbose = verbose or cfg.verbose + assert cfg.versionfile_source is not None, \ + "please set versioneer.versionfile_source" + assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" + + versionfile_abs = os.path.join(root, cfg.versionfile_source) + + # extract version from first of: _version.py, VCS command (e.g. 'git + # describe'), parentdir. This is meant to work for developers using a + # source checkout, for users of a tarball created by 'setup.py sdist', + # and for users of a tarball/zipball created by 'git archive' or github's + # download-from-tag feature or the equivalent in other VCSes. + + get_keywords_f = handlers.get("get_keywords") + from_keywords_f = handlers.get("keywords") + if get_keywords_f and from_keywords_f: + try: + keywords = get_keywords_f(versionfile_abs) + ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) + if verbose: + print("got version from expanded keyword %s" % ver) + return ver + except NotThisMethod: + pass + + try: + ver = versions_from_file(versionfile_abs) + if verbose: + print("got version from file %s %s" % (versionfile_abs, ver)) + return ver + except NotThisMethod: + pass + + from_vcs_f = handlers.get("pieces_from_vcs") + if from_vcs_f: + try: + pieces = from_vcs_f(cfg.tag_prefix, root, verbose) + ver = render(pieces, cfg.style) + if verbose: + print("got version from VCS %s" % ver) + return ver + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + if verbose: + print("got version from parentdir %s" % ver) + return ver + except NotThisMethod: + pass + + if verbose: + print("unable to compute version") + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, "error": "unable to compute version", + "date": None} + + +def get_version(): + """Get the short version string for this project.""" + return get_versions()["version"] + + +def get_cmdclass(): + """Get the custom setuptools/distutils subclasses used by Versioneer.""" + if "versioneer" in sys.modules: + del sys.modules["versioneer"] + # this fixes the "python setup.py develop" case (also 'install' and + # 'easy_install .'), in which subdependencies of the main project are + # built (using setup.py bdist_egg) in the same python process. Assume + # a main project A and a dependency B, which use different versions + # of Versioneer. A's setup.py imports A's Versioneer, leaving it in + # sys.modules by the time B's setup.py is executed, causing B to run + # with the wrong versioneer. Setuptools wraps the sub-dep builds in a + # sandbox that restores sys.modules to it's pre-build state, so the + # parent is protected against the child's "import versioneer". By + # removing ourselves from sys.modules here, before the child build + # happens, we protect the child from the parent's versioneer too. + # Also see https://github.com/warner/python-versioneer/issues/52 + + cmds = {} + + # we add "version" to both distutils and setuptools + from distutils.core import Command + + class cmd_version(Command): + description = "report generated version string" + user_options = [] + boolean_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + vers = get_versions(verbose=True) + print("Version: %s" % vers["version"]) + print(" full-revisionid: %s" % vers.get("full-revisionid")) + print(" dirty: %s" % vers.get("dirty")) + print(" date: %s" % vers.get("date")) + if vers["error"]: + print(" error: %s" % vers["error"]) + cmds["version"] = cmd_version + + # we override "build_py" in both distutils and setuptools + # + # most invocation pathways end up running build_py: + # distutils/build -> build_py + # distutils/install -> distutils/build ->.. + # setuptools/bdist_wheel -> distutils/install ->.. + # setuptools/bdist_egg -> distutils/install_lib -> build_py + # setuptools/install -> bdist_egg ->.. + # setuptools/develop -> ? + # pip install: + # copies source tree to a tempdir before running egg_info/etc + # if .git isn't copied too, 'git describe' will fail + # then does setup.py bdist_wheel, or sometimes setup.py install + # setup.py egg_info -> ? + + # we override different "build_py" commands for both environments + if "setuptools" in sys.modules: + from setuptools.command.build_py import build_py as _build_py + else: + from distutils.command.build_py import build_py as _build_py + + class cmd_build_py(_build_py): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + _build_py.run(self) + # now locate _version.py in the new build/ directory and replace + # it with an updated value + if cfg.versionfile_build: + target_versionfile = os.path.join(self.build_lib, + cfg.versionfile_build) + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + cmds["build_py"] = cmd_build_py + + if "cx_Freeze" in sys.modules: # cx_freeze enabled? + from cx_Freeze.dist import build_exe as _build_exe + # nczeczulin reports that py2exe won't like the pep440-style string + # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. + # setup(console=[{ + # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION + # "product_version": versioneer.get_version(), + # ... + + class cmd_build_exe(_build_exe): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + target_versionfile = cfg.versionfile_source + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + + _build_exe.run(self) + os.unlink(target_versionfile) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write(LONG % + {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) + cmds["build_exe"] = cmd_build_exe + del cmds["build_py"] + + if 'py2exe' in sys.modules: # py2exe enabled? + try: + from py2exe.distutils_buildexe import py2exe as _py2exe # py3 + except ImportError: + from py2exe.build_exe import py2exe as _py2exe # py2 + + class cmd_py2exe(_py2exe): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + target_versionfile = cfg.versionfile_source + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + + _py2exe.run(self) + os.unlink(target_versionfile) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write(LONG % + {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) + cmds["py2exe"] = cmd_py2exe + + # we override different "sdist" commands for both environments + if "setuptools" in sys.modules: + from setuptools.command.sdist import sdist as _sdist + else: + from distutils.command.sdist import sdist as _sdist + + class cmd_sdist(_sdist): + def run(self): + versions = get_versions() + self._versioneer_generated_versions = versions + # unless we update this, the command will keep using the old + # version + self.distribution.metadata.version = versions["version"] + return _sdist.run(self) + + def make_release_tree(self, base_dir, files): + root = get_root() + cfg = get_config_from_root(root) + _sdist.make_release_tree(self, base_dir, files) + # now locate _version.py in the new base_dir directory + # (remembering that it may be a hardlink) and replace it with an + # updated value + target_versionfile = os.path.join(base_dir, cfg.versionfile_source) + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, + self._versioneer_generated_versions) + cmds["sdist"] = cmd_sdist + + return cmds + + +CONFIG_ERROR = """ +setup.cfg is missing the necessary Versioneer configuration. You need +a section like: + + [versioneer] + VCS = git + style = pep440 + versionfile_source = src/myproject/_version.py + versionfile_build = myproject/_version.py + tag_prefix = + parentdir_prefix = myproject- + +You will also need to edit your setup.py to use the results: + + import versioneer + setup(version=versioneer.get_version(), + cmdclass=versioneer.get_cmdclass(), ...) + +Please read the docstring in ./versioneer.py for configuration instructions, +edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. +""" + +SAMPLE_CONFIG = """ +# See the docstring in versioneer.py for instructions. Note that you must +# re-run 'versioneer.py setup' after changing this section, and commit the +# resulting files. + +[versioneer] +#VCS = git +#style = pep440 +#versionfile_source = +#versionfile_build = +#tag_prefix = +#parentdir_prefix = + +""" + +INIT_PY_SNIPPET = """ +from ._version import get_versions +__version__ = get_versions()['version'] +del get_versions +""" + + +def do_setup(): + """Main VCS-independent setup function for installing Versioneer.""" + root = get_root() + try: + cfg = get_config_from_root(root) + except (EnvironmentError, configparser.NoSectionError, + configparser.NoOptionError) as e: + if isinstance(e, (EnvironmentError, configparser.NoSectionError)): + print("Adding sample versioneer config to setup.cfg", + file=sys.stderr) + with open(os.path.join(root, "setup.cfg"), "a") as f: + f.write(SAMPLE_CONFIG) + print(CONFIG_ERROR, file=sys.stderr) + return 1 + + print(" creating %s" % cfg.versionfile_source) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write(LONG % {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), + "__init__.py") + if os.path.exists(ipy): + try: + with open(ipy, "r") as f: + old = f.read() + except EnvironmentError: + old = "" + if INIT_PY_SNIPPET not in old: + print(" appending to %s" % ipy) + with open(ipy, "a") as f: + f.write(INIT_PY_SNIPPET) + else: + print(" %s unmodified" % ipy) + else: + print(" %s doesn't exist, ok" % ipy) + ipy = None + + # Make sure both the top-level "versioneer.py" and versionfile_source + # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so + # they'll be copied into source distributions. Pip won't be able to + # install the package without this. + manifest_in = os.path.join(root, "MANIFEST.in") + simple_includes = set() + try: + with open(manifest_in, "r") as f: + for line in f: + if line.startswith("include "): + for include in line.split()[1:]: + simple_includes.add(include) + except EnvironmentError: + pass + # That doesn't cover everything MANIFEST.in can do + # (http://docs.python.org/2/distutils/sourcedist.html#commands), so + # it might give some false negatives. Appending redundant 'include' + # lines is safe, though. + if "versioneer.py" not in simple_includes: + print(" appending 'versioneer.py' to MANIFEST.in") + with open(manifest_in, "a") as f: + f.write("include versioneer.py\n") + else: + print(" 'versioneer.py' already in MANIFEST.in") + if cfg.versionfile_source not in simple_includes: + print(" appending versionfile_source ('%s') to MANIFEST.in" % + cfg.versionfile_source) + with open(manifest_in, "a") as f: + f.write("include %s\n" % cfg.versionfile_source) + else: + print(" versionfile_source already in MANIFEST.in") + + # Make VCS-specific changes. For git, this means creating/changing + # .gitattributes to mark _version.py for export-subst keyword + # substitution. + do_vcs_install(manifest_in, cfg.versionfile_source, ipy) + return 0 + + +def scan_setup_py(): + """Validate the contents of setup.py against Versioneer's expectations.""" + found = set() + setters = False + errors = 0 + with open("setup.py", "r") as f: + for line in f.readlines(): + if "import versioneer" in line: + found.add("import") + if "versioneer.get_cmdclass()" in line: + found.add("cmdclass") + if "versioneer.get_version()" in line: + found.add("get_version") + if "versioneer.VCS" in line: + setters = True + if "versioneer.versionfile_source" in line: + setters = True + if len(found) != 3: + print("") + print("Your setup.py appears to be missing some important items") + print("(but I might be wrong). Please make sure it has something") + print("roughly like the following:") + print("") + print(" import versioneer") + print(" setup( version=versioneer.get_version(),") + print(" cmdclass=versioneer.get_cmdclass(), ...)") + print("") + errors += 1 + if setters: + print("You should remove lines like 'versioneer.VCS = ' and") + print("'versioneer.versionfile_source = ' . This configuration") + print("now lives in setup.cfg, and should be removed from setup.py") + print("") + errors += 1 + return errors + + +if __name__ == "__main__": + cmd = sys.argv[1] + if cmd == "setup": + errors = do_setup() + errors += scan_setup_py() + if errors: + sys.exit(1) diff --git a/xarray/__init__.py b/xarray/__init__.py index 94e8029edbb..7cc7811b783 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -3,6 +3,10 @@ from __future__ import division from __future__ import print_function +from ._version import get_versions +__version__ = get_versions()['version'] +del get_versions + from .core.alignment import align, broadcast, broadcast_arrays from .core.common import full_like, zeros_like, ones_like from .core.combine import concat, auto_combine @@ -24,13 +28,6 @@ from .coding.cftimeindex import CFTimeIndex -try: - from .version import version as __version__ -except ImportError: # pragma: no cover - raise ImportError('xarray not properly installed. If you are running from ' - 'the source directory, please instead create a new ' - 'virtual environment (using conda or virtualenv) and ' - 'then install it in-place by running: pip install -e .') from .util.print_versions import show_versions from . import tutorial diff --git a/xarray/_version.py b/xarray/_version.py new file mode 100644 index 00000000000..2fa32b69798 --- /dev/null +++ b/xarray/_version.py @@ -0,0 +1,520 @@ + +# This file helps to compute a version number in source trees obtained from +# git-archive tarball (such as those provided by githubs download-from-tag +# feature). Distribution tarballs (built by setup.py sdist) and build +# directories (produced by setup.py build) will contain a much shorter file +# that just contains the computed version number. + +# This file is released into the public domain. Generated by +# versioneer-0.18 (https://github.com/warner/python-versioneer) + +"""Git implementation of _version.py.""" + +import errno +import os +import re +import subprocess +import sys + + +def get_keywords(): + """Get the keywords needed to look up the version information.""" + # these strings will be replaced by git during git-archive. + # setup.py/versioneer.py will grep for the variable names, so they must + # each be defined on a line of their own. _version.py will just call + # get_keywords(). + git_refnames = "$Format:%d$" + git_full = "$Format:%H$" + git_date = "$Format:%ci$" + keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} + return keywords + + +class VersioneerConfig: + """Container for Versioneer configuration parameters.""" + + +def get_config(): + """Create, populate and return the VersioneerConfig() object.""" + # these strings are filled in when 'setup.py versioneer' creates + # _version.py + cfg = VersioneerConfig() + cfg.VCS = "git" + cfg.style = "pep440" + cfg.tag_prefix = "" + cfg.parentdir_prefix = "xarray-" + cfg.versionfile_source = "xarray/_version.py" + cfg.verbose = False + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method): # decorator + """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + return decorate + + +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, + env=None): + """Call the given command(s).""" + assert isinstance(commands, list) + p = None + for c in commands: + try: + dispcmd = str([c] + args) + # remember shell=False, so use git.cmd on windows, not just git + p = subprocess.Popen([c] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None)) + break + except EnvironmentError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %s" % dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %s" % (commands,)) + return None, None + stdout = p.communicate()[0].strip() + if sys.version_info[0] >= 3: + stdout = stdout.decode() + if p.returncode != 0: + if verbose: + print("unable to run %s (error)" % dispcmd) + print("stdout was %s" % stdout) + return None, p.returncode + return stdout, p.returncode + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for i in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + else: + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + f = open(versionfile_abs, "r") + for line in f.readlines(): + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + f.close() + except EnvironmentError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if not keywords: + raise NotThisMethod("no keywords at all, weird") + date = keywords.get("date") + if date is not None: + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = set([r.strip() for r in refnames.strip("()").split(",")]) + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = set([r for r in refs if re.search(r'\d', r)]) + if verbose: + print("discarding '%s', no digits" % ",".join(refs - tags)) + if verbose: + print("likely tags: %s" % ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + if verbose: + print("picking %s" % r) + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=True) + if rc != 0: + if verbose: + print("Directory %s not under git control" % root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", + "--always", "--long", + "--match", "%s*" % tag_prefix], + cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparseable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%s' doesn't start with prefix '%s'" + print(fmt % (full_tag, tag_prefix)) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], + cwd=root) + pieces["distance"] = int(count_out) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], + cwd=root)[0].strip() + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_pre(pieces): + """TAG[.post.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post.devDISTANCE + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += ".post.dev%d" % pieces["distance"] + else: + # exception #1 + rendered = "0.post.dev%d" % pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Eexceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%s'" % style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} + + +def get_versions(): + """Get version information or return default if unable to do so.""" + # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have + # __file__, we can work backwards from there to the root. Some + # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which + # case we can only use expanded keywords. + + cfg = get_config() + verbose = cfg.verbose + + try: + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, + verbose) + except NotThisMethod: + pass + + try: + root = os.path.realpath(__file__) + # versionfile_source is the relative path from the top of the source + # tree (where the .git directory might live) to this file. Invert + # this to find the root from __file__. + for i in cfg.versionfile_source.split('/'): + root = os.path.dirname(root) + except NameError: + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None} + + try: + pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) + return render(pieces, cfg.style) + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + except NotThisMethod: + pass + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", "date": None} From 48d55eea052fec204b843babdc81c258f3ed5ce1 Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Mon, 21 May 2018 04:02:34 -0400 Subject: [PATCH 30/61] Fix string slice indexing for a length-1 CFTimeIndex (#2166) * Fix string slice indexing for length-1 CFTimeIndex * Skip test if cftime is not installed * Add a what's new entry --- doc/whats-new.rst | 6 ++++++ xarray/coding/cftimeindex.py | 2 +- xarray/tests/test_cftimeindex.py | 36 ++++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4c9a1415e26..d9f43fa1868 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -59,6 +59,12 @@ Bug fixes dimension were improperly skipped. By `Stephan Hoyer `_ +- Selecting data indexed by a length-1 ``CFTimeIndex`` with a slice of strings + now behaves as it does when using a length-1 ``DatetimeIndex`` (i.e. it no + longer falsely returns an empty array when the slice includes the value in + the index) (:issue:`2165`). + By `Spencer Clark `_. + .. _whats-new.0.10.4: v0.10.4 (May 16, 2018) diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 5fca14ddbb1..eb8cae2f398 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -225,7 +225,7 @@ def _maybe_cast_slice_bound(self, label, side, kind): label) start, end = _parsed_string_to_bounds(self.date_type, resolution, parsed) - if self.is_monotonic_decreasing and len(self): + if self.is_monotonic_decreasing and len(self) > 1: return end if side == 'left' else start return start if side == 'left' else end else: diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index c78ac038bd5..6f102b60b9d 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -79,6 +79,12 @@ def monotonic_decreasing_index(date_type): return CFTimeIndex(dates) +@pytest.fixture +def length_one_index(date_type): + dates = [date_type(1, 1, 1)] + return CFTimeIndex(dates) + + @pytest.fixture def da(index): return xr.DataArray([1, 2, 3, 4], coords=[index], @@ -280,6 +286,36 @@ def test_get_slice_bound_decreasing_index( assert result == expected +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +@pytest.mark.parametrize('kind', ['loc', 'getitem']) +def test_get_slice_bound_length_one_index( + date_type, length_one_index, kind): + result = length_one_index.get_slice_bound('0001', 'left', kind) + expected = 0 + assert result == expected + + result = length_one_index.get_slice_bound('0001', 'right', kind) + expected = 1 + assert result == expected + + result = length_one_index.get_slice_bound( + date_type(1, 3, 1), 'left', kind) + expected = 1 + assert result == expected + + result = length_one_index.get_slice_bound( + date_type(1, 3, 1), 'right', kind) + expected = 1 + assert result == expected + + +@pytest.mark.skipif(not has_cftime, reason='cftime not installed') +def test_string_slice_length_one_index(length_one_index): + da = xr.DataArray([1], coords=[length_one_index], dims=['time']) + result = da.sel(time=slice('0001', '0001')) + assert_identical(result, da) + + @pytest.mark.skipif(not has_cftime, reason='cftime not installed') def test_date_type_property(date_type, index): assert index.date_type is date_type From b48e0969670f17857a314b5a755b1a1bf7ee38df Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 24 May 2018 17:52:06 -0700 Subject: [PATCH 31/61] BUG: fix writing to groups with h5netcdf (#2181) * BUG: fix writing to groups with h5netcdf Fixes GH2177 Our test suite was inadvertently not checking this. * what's new note --- doc/whats-new.rst | 6 +++++- xarray/backends/h5netcdf_.py | 9 +++++++-- xarray/backends/netCDF4_.py | 10 +++++++--- xarray/tests/test_backends.py | 12 ++++++------ 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d9f43fa1868..4a01065bd70 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -64,7 +64,11 @@ Bug fixes longer falsely returns an empty array when the slice includes the value in the index) (:issue:`2165`). By `Spencer Clark `_. - + +- Fix Dataset.to_netcdf() cannot create group with engine="h5netcdf" + (:issue:`2177`). + By `Stephan Hoyer `_ + .. _whats-new.0.10.4: v0.10.4 (May 16, 2018) diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index f9e2b3dece1..6b3cd9ebb15 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -12,7 +12,7 @@ HDF5_LOCK, DataStorePickleMixin, WritableCFDataStore, find_root) from .netCDF4_ import ( BaseNetCDF4Array, _encode_nc4_variable, _extract_nc4_variable_encoding, - _get_datatype, _nc4_group) + _get_datatype, _nc4_require_group) class H5NetCDFArrayWrapper(BaseNetCDF4Array): @@ -57,11 +57,16 @@ def _read_attributes(h5netcdf_var): lsd_okay=False, h5py_okay=True, backend='h5netcdf') +def _h5netcdf_create_group(dataset, name): + return dataset.create_group(name) + + def _open_h5netcdf_group(filename, mode, group): import h5netcdf ds = h5netcdf.File(filename, mode=mode) with close_on_error(ds): - return _nc4_group(ds, group, mode) + return _nc4_require_group( + ds, group, mode, create_group=_h5netcdf_create_group) class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin): diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 14061a0fb08..5391a890fb3 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -108,7 +108,11 @@ def _nc4_dtype(var): return dtype -def _nc4_group(ds, group, mode): +def _netcdf4_create_group(dataset, name): + return dataset.createGroup(name) + + +def _nc4_require_group(ds, group, mode, create_group=_netcdf4_create_group): if group in set([None, '', '/']): # use the root group return ds @@ -123,7 +127,7 @@ def _nc4_group(ds, group, mode): ds = ds.groups[key] except KeyError as e: if mode != 'r': - ds = ds.createGroup(key) + ds = create_group(ds, key) else: # wrap error to provide slightly more helpful message raise IOError('group not found: %s' % key, e) @@ -210,7 +214,7 @@ def _open_netcdf4_group(filename, mode, group=None, **kwargs): ds = nc4.Dataset(filename, mode=mode, **kwargs) with close_on_error(ds): - ds = _nc4_group(ds, group, mode) + ds = _nc4_require_group(ds, group, mode) _disable_auto_decode_group(ds) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 513f5f0834e..0768a942a77 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -892,7 +892,7 @@ def test_open_group(self): # check equivalent ways to specify group for group in 'foo', '/foo', 'foo/', '/foo/': - with open_dataset(tmp_file, group=group) as actual: + with self.open(tmp_file, group=group) as actual: assert_equal(actual['x'], expected['x']) # check that missing group raises appropriate exception @@ -920,18 +920,18 @@ def test_open_subgroup(self): # check equivalent ways to specify group for group in 'foo/bar', '/foo/bar', 'foo/bar/', '/foo/bar/': - with open_dataset(tmp_file, group=group) as actual: + with self.open(tmp_file, group=group) as actual: assert_equal(actual['x'], expected['x']) def test_write_groups(self): data1 = create_test_data() data2 = data1 * 2 with create_tmp_file() as tmp_file: - data1.to_netcdf(tmp_file, group='data/1') - data2.to_netcdf(tmp_file, group='data/2', mode='a') - with open_dataset(tmp_file, group='data/1') as actual1: + self.save(data1, tmp_file, group='data/1') + self.save(data2, tmp_file, group='data/2', mode='a') + with self.open(tmp_file, group='data/1') as actual1: assert_identical(data1, actual1) - with open_dataset(tmp_file, group='data/2') as actual2: + with self.open(tmp_file, group='data/2') as actual2: assert_identical(data2, actual2) def test_roundtrip_string_with_fill_value_vlen(self): From 04df50efefecaea729133c14082eb5e24491633e Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Fri, 25 May 2018 19:38:48 +0900 Subject: [PATCH 32/61] weighted rolling mean -> weighted rolling sum (#2185) An example of weighted rolling mean in doc is actually weighted rolling *sum*. It is a little bit misleading (SO)[https://stackoverflow.com/questions/50520835/xarray-simple-weighted-rolling-mean-example-using-construct/50524093#50524093], so I propose to change `weighted rolling mean` -> `weighted rolling sum` --- doc/computation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/computation.rst b/doc/computation.rst index 0f22a2ed967..6793e667e06 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -185,7 +185,7 @@ windowed rolling, convolution, short-time FFT etc. Because the ``DataArray`` given by ``r.construct('window_dim')`` is a view of the original array, it is memory efficient. -You can also use ``construct`` to compute a weighted rolling mean: +You can also use ``construct`` to compute a weighted rolling sum: .. ipython:: python From a28aab005b42eabe0b1651d2330ed2f3268bb9f8 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 25 May 2018 20:29:45 -0700 Subject: [PATCH 33/61] Fix DataArray.stack() with non-unique coordinates on pandas 0.23 (#2168) --- doc/whats-new.rst | 4 ++++ xarray/core/utils.py | 14 ++++++++------ xarray/tests/test_dataarray.py | 7 +++++++ xarray/tests/test_utils.py | 12 +++++++++++- 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4a01065bd70..055369f0352 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -59,6 +59,10 @@ Bug fixes dimension were improperly skipped. By `Stephan Hoyer `_ +- Fix :meth:`~DataArray.stack` with non-unique coordinates on pandas 0.23 + (:issue:`2160`). + By `Stephan Hoyer `_ + - Selecting data indexed by a length-1 ``CFTimeIndex`` with a slice of strings now behaves as it does when using a length-1 ``DatetimeIndex`` (i.e. it no longer falsely returns an empty array when the slice includes the value in diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 06bb3ede393..f6c5830cc9e 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -76,13 +76,12 @@ def safe_cast_to_index(array): def multiindex_from_product_levels(levels, names=None): """Creating a MultiIndex from a product without refactorizing levels. - Keeping levels the same is faster, and also gives back the original labels - when we unstack. + Keeping levels the same gives back the original labels when we unstack. Parameters ---------- - levels : sequence of arrays - Unique labels for each level. + levels : sequence of pd.Index + Values for each MultiIndex level. names : optional sequence of objects Names for each level. @@ -90,8 +89,11 @@ def multiindex_from_product_levels(levels, names=None): ------- pandas.MultiIndex """ - labels_mesh = np.meshgrid(*[np.arange(len(lev)) for lev in levels], - indexing='ij') + if any(not isinstance(lev, pd.Index) for lev in levels): + raise TypeError('levels must be a list of pd.Index objects') + + split_labels, levels = zip(*[lev.factorize() for lev in levels]) + labels_mesh = np.meshgrid(*split_labels, indexing='ij') labels = [x.ravel() for x in labels_mesh] return pd.MultiIndex(levels, labels, sortorder=0, names=names) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 35e270f0db7..a03d265c3e3 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1673,6 +1673,13 @@ def test_unstack_pandas_consistency(self): actual = DataArray(s, dims='z').unstack('z') assert_identical(expected, actual) + def test_stack_nonunique_consistency(self): + orig = DataArray([[0, 1], [2, 3]], dims=['x', 'y'], + coords={'x': [0, 1], 'y': [0, 0]}) + actual = orig.stack(z=['x', 'y']) + expected = DataArray(orig.to_pandas().stack(), dims='z') + assert_identical(expected, actual) + def test_transpose(self): assert_equal(self.dv.variable.transpose(), self.dv.transpose().variable) diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 0b3b0ee7dd6..1f73743d01d 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -72,7 +72,8 @@ def test_safe_cast_to_index_datetime_datetime(enable_cftimeindex): def test_multiindex_from_product_levels(): - result = utils.multiindex_from_product_levels([['b', 'a'], [1, 3, 2]]) + result = utils.multiindex_from_product_levels( + [pd.Index(['b', 'a']), pd.Index([1, 3, 2])]) np.testing.assert_array_equal( result.labels, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]]) np.testing.assert_array_equal(result.levels[0], ['b', 'a']) @@ -82,6 +83,15 @@ def test_multiindex_from_product_levels(): np.testing.assert_array_equal(result.values, other.values) +def test_multiindex_from_product_levels_non_unique(): + result = utils.multiindex_from_product_levels( + [pd.Index(['b', 'a']), pd.Index([1, 1, 2])]) + np.testing.assert_array_equal( + result.labels, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]]) + np.testing.assert_array_equal(result.levels[0], ['b', 'a']) + np.testing.assert_array_equal(result.levels[1], [1, 2]) + + class TestArrayEquiv(TestCase): def test_0d(self): # verify our work around for pd.isnull not working for 0-dimensional From a8c1ed2ae3cc15863d37d869a7e1658eb33e01f6 Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Sun, 27 May 2018 21:45:37 +0100 Subject: [PATCH 34/61] add xyzpy to projects (#2189) --- doc/faq.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/faq.rst b/doc/faq.rst index 9d763f1c15f..170a1e17bdc 100644 --- a/doc/faq.rst +++ b/doc/faq.rst @@ -211,6 +211,7 @@ Extend xarray capabilities - `xrft `_: Fourier transforms for xarray data. - `xr-scipy `_: A lightweight scipy wrapper for xarray. - `X-regression `_: Multiple linear regression from Statsmodels library coupled with Xarray library. +- `xyzpy `_: Easily generate high dimensional data, including parallelization. Visualization ~~~~~~~~~~~~~ From 847050026d45e2817960a37564bd8e909ecbdb05 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+maxim-lian@users.noreply.github.com> Date: Sun, 27 May 2018 16:48:30 -0400 Subject: [PATCH 35/61] Datasets more robust to non-string keys (#2174) * ds more robust to non-str keys * formatting * time.dayofyear needs cover in dataarray getitem * trial of indexer_dict * feedback from stephan * a few more methods * reindex added * rename to either_dict_or_kwargs * remove assert check * docstring * more docstring * `optional` goes last * last docstring * what's new * artefact * test either_dict_or_kwargs --- doc/whats-new.rst | 8 +++++ setup.cfg | 1 + xarray/core/coordinates.py | 8 +++-- xarray/core/dataarray.py | 42 +++++++++++++++-------- xarray/core/dataset.py | 63 +++++++++++++++++++++------------- xarray/core/utils.py | 2 +- xarray/core/variable.py | 8 +++-- xarray/tests/test_dataarray.py | 4 +-- xarray/tests/test_dataset.py | 8 ++++- xarray/tests/test_utils.py | 24 +++++++++++-- 10 files changed, 116 insertions(+), 52 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 055369f0352..68bf5318bf5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -46,6 +46,14 @@ Enhancements to manage its version strings. (:issue:`1300`). By `Joe Hamman `_. +- :py:meth:`~DataArray.sel`, :py:meth:`~DataArray.isel` & :py:meth:`~DataArray.reindex`, + (and their :py:class:`Dataset` counterparts) now support supplying a ``dict`` + as a first argument, as an alternative to the existing approach + of supplying a set of `kwargs`. This allows for more robust behavior + of dimension names which conflict with other keyword names, or are + not strings. + By `Maximilian Roos `_. + Bug fixes ~~~~~~~~~ diff --git a/setup.cfg b/setup.cfg index 850551b3579..4dd1bffe043 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,7 @@ testpaths=xarray/tests [flake8] max-line-length=79 ignore= + W503 exclude= doc/ diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index cb22c0b687b..efe8affb2a3 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -9,10 +9,9 @@ from .merge import ( expand_and_merge_variables, merge_coords, merge_coords_for_inplace_math) from .pycompat import OrderedDict -from .utils import Frozen, ReprObject +from .utils import Frozen, ReprObject, either_dict_or_kwargs from .variable import Variable - # Used as the key corresponding to a DataArray's variable when converting # arbitrary DataArray objects to datasets _THIS_ARRAY = ReprObject('') @@ -332,7 +331,8 @@ def assert_coordinate_consistent(obj, coords): .format(k, obj[k], coords[k])) -def remap_label_indexers(obj, method=None, tolerance=None, **indexers): +def remap_label_indexers(obj, indexers=None, method=None, tolerance=None, + **indexers_kwargs): """ Remap **indexers from obj.coords. If indexer is an instance of DataArray and it has coordinate, then this @@ -345,6 +345,8 @@ def remap_label_indexers(obj, method=None, tolerance=None, **indexers): new_indexes: mapping of new dimensional-coordinate. """ from .dataarray import DataArray + indexers = either_dict_or_kwargs( + indexers, indexers_kwargs, 'remap_label_indexers') v_indexers = {k: v.variable.data if isinstance(v, DataArray) else v for k, v in indexers.items()} diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index fc7091dad85..da9acb48a7a 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -18,7 +18,9 @@ from .formatting import format_item from .options import OPTIONS from .pycompat import OrderedDict, basestring, iteritems, range, zip -from .utils import decode_numpy_dict_values, ensure_us_time_resolution +from .utils import ( + either_dict_or_kwargs, decode_numpy_dict_values, + ensure_us_time_resolution) from .variable import ( IndexVariable, Variable, as_compatible_data, as_variable, assert_unique_multiindex_level_names) @@ -470,7 +472,7 @@ def __getitem__(self, key): return self._getitem_coord(key) else: # xarray-style array indexing - return self.isel(**self._item_key_to_dict(key)) + return self.isel(indexers=self._item_key_to_dict(key)) def __setitem__(self, key, value): if isinstance(key, basestring): @@ -498,7 +500,7 @@ def _attr_sources(self): @property def _item_sources(self): """List of places to look-up items for key-completion""" - return [self.coords, {d: self[d] for d in self.dims}, + return [self.coords, {d: self.coords[d] for d in self.dims}, LevelCoordinatesSource(self)] def __contains__(self, key): @@ -742,7 +744,7 @@ def chunk(self, chunks=None, name_prefix='xarray-', token=None, token=token, lock=lock) return self._from_temp_dataset(ds) - def isel(self, drop=False, **indexers): + def isel(self, indexers=None, drop=False, **indexers_kwargs): """Return a new DataArray whose dataset is given by integer indexing along the specified dimension(s). @@ -751,10 +753,12 @@ def isel(self, drop=False, **indexers): Dataset.isel DataArray.sel """ - ds = self._to_temp_dataset().isel(drop=drop, **indexers) + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'isel') + ds = self._to_temp_dataset().isel(drop=drop, indexers=indexers) return self._from_temp_dataset(ds) - def sel(self, method=None, tolerance=None, drop=False, **indexers): + def sel(self, indexers=None, method=None, tolerance=None, drop=False, + **indexers_kwargs): """Return a new DataArray whose dataset is given by selecting index labels along the specified dimension(s). @@ -776,8 +780,9 @@ def sel(self, method=None, tolerance=None, drop=False, **indexers): DataArray.isel """ - ds = self._to_temp_dataset().sel(drop=drop, method=method, - tolerance=tolerance, **indexers) + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'sel') + ds = self._to_temp_dataset().sel( + indexers=indexers, drop=drop, method=method, tolerance=tolerance) return self._from_temp_dataset(ds) def isel_points(self, dim='points', **indexers): @@ -851,12 +856,19 @@ def reindex_like(self, other, method=None, tolerance=None, copy=True): return self.reindex(method=method, tolerance=tolerance, copy=copy, **indexers) - def reindex(self, method=None, tolerance=None, copy=True, **indexers): + def reindex(self, indexers=None, method=None, tolerance=None, copy=True, + **indexers_kwargs): """Conform this object onto a new set of indexes, filling in missing values with NaN. Parameters ---------- + indexers : dict, optional + Dictionary with keys given by dimension names and values given by + arrays of coordinates tick labels. Any mis-matched coordinate + values will be filled in with NaN, and any mis-matched dimension + names will simply be ignored. + One of indexers or indexers_kwargs must be provided. copy : bool, optional If ``copy=True``, data in the return value is always copied. If ``copy=False`` and reindexing is unnecessary, or can be performed @@ -874,11 +886,9 @@ def reindex(self, method=None, tolerance=None, copy=True, **indexers): Maximum distance between original and new labels for inexact matches. The values of the index at the matching locations most satisfy the equation ``abs(index[indexer] - target) <= tolerance``. - **indexers : dict - Dictionary with keys given by dimension names and values given by - arrays of coordinates tick labels. Any mis-matched coordinate - values will be filled in with NaN, and any mis-matched dimension - names will simply be ignored. + **indexers_kwarg : {dim: indexer, ...}, optional + The keyword arguments form of ``indexers``. + One of indexers or indexers_kwargs must be provided. Returns ------- @@ -891,8 +901,10 @@ def reindex(self, method=None, tolerance=None, copy=True, **indexers): DataArray.reindex_like align """ + indexers = either_dict_or_kwargs( + indexers, indexers_kwargs, 'reindex') ds = self._to_temp_dataset().reindex( - method=method, tolerance=tolerance, copy=copy, **indexers) + indexers=indexers, method=method, tolerance=tolerance, copy=copy) return self._from_temp_dataset(ds) def rename(self, new_name_or_name_dict): diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index fff11dedb01..d6a5ac1c172 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -17,8 +17,8 @@ rolling, utils) from .. import conventions from .alignment import align -from .common import (DataWithCoords, ImplementsDatasetReduce, - _contains_datetime_like_objects) +from .common import ( + DataWithCoords, ImplementsDatasetReduce, _contains_datetime_like_objects) from .coordinates import ( DatasetCoordinates, Indexes, LevelCoordinatesSource, assert_coordinate_consistent, remap_label_indexers) @@ -30,7 +30,7 @@ from .pycompat import ( OrderedDict, basestring, dask_array_type, integer_types, iteritems, range) from .utils import ( - Frozen, SortedKeysDict, decode_numpy_dict_values, + Frozen, SortedKeysDict, either_dict_or_kwargs, decode_numpy_dict_values, ensure_us_time_resolution, hashable, maybe_wrap_array) from .variable import IndexVariable, Variable, as_variable, broadcast_variables @@ -1368,7 +1368,7 @@ def _get_indexers_coordinates(self, indexers): attached_coords[k] = v return attached_coords - def isel(self, drop=False, **indexers): + def isel(self, indexers=None, drop=False, **indexers_kwargs): """Returns a new dataset with each array indexed along the specified dimension(s). @@ -1378,15 +1378,19 @@ def isel(self, drop=False, **indexers): Parameters ---------- - drop : bool, optional - If ``drop=True``, drop coordinates variables indexed by integers - instead of making them scalar. - **indexers : {dim: indexer, ...} - Keyword arguments with names matching dimensions and values given + indexers : dict, optional + A dict with keys matching dimensions and values given by integers, slice objects or arrays. indexer can be a integer, slice, array-like or DataArray. If DataArrays are passed as indexers, xarray-style indexing will be carried out. See :ref:`indexing` for the details. + One of indexers or indexers_kwargs must be provided. + drop : bool, optional + If ``drop=True``, drop coordinates variables indexed by integers + instead of making them scalar. + **indexers_kwarg : {dim: indexer, ...}, optional + The keyword arguments form of ``indexers``. + One of indexers or indexers_kwargs must be provided. Returns ------- @@ -1404,12 +1408,15 @@ def isel(self, drop=False, **indexers): Dataset.sel DataArray.isel """ + + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'isel') + indexers_list = self._validate_indexers(indexers) variables = OrderedDict() for name, var in iteritems(self._variables): var_indexers = {k: v for k, v in indexers_list if k in var.dims} - new_var = var.isel(**var_indexers) + new_var = var.isel(indexers=var_indexers) if not (drop and name in var_indexers): variables[name] = new_var @@ -1425,7 +1432,8 @@ def isel(self, drop=False, **indexers): .union(coord_vars)) return self._replace_vars_and_dims(variables, coord_names=coord_names) - def sel(self, method=None, tolerance=None, drop=False, **indexers): + def sel(self, indexers=None, method=None, tolerance=None, drop=False, + **indexers_kwargs): """Returns a new dataset with each array indexed by tick labels along the specified dimension(s). @@ -1444,6 +1452,14 @@ def sel(self, method=None, tolerance=None, drop=False, **indexers): Parameters ---------- + indexers : dict, optional + A dict with keys matching dimensions and values given + by scalars, slices or arrays of tick labels. For dimensions with + multi-index, the indexer may also be a dict-like object with keys + matching index level names. + If DataArrays are passed as indexers, xarray-style indexing will be + carried out. See :ref:`indexing` for the details. + One of indexers or indexers_kwargs must be provided. method : {None, 'nearest', 'pad'/'ffill', 'backfill'/'bfill'}, optional Method to use for inexact matches (requires pandas>=0.16): @@ -1459,13 +1475,9 @@ def sel(self, method=None, tolerance=None, drop=False, **indexers): drop : bool, optional If ``drop=True``, drop coordinates variables in `indexers` instead of making them scalar. - **indexers : {dim: indexer, ...} - Keyword arguments with names matching dimensions and values given - by scalars, slices or arrays of tick labels. For dimensions with - multi-index, the indexer may also be a dict-like object with keys - matching index level names. - If DataArrays are passed as indexers, xarray-style indexing will be - carried out. See :ref:`indexing` for the details. + **indexers_kwarg : {dim: indexer, ...}, optional + The keyword arguments form of ``indexers``. + One of indexers or indexers_kwargs must be provided. Returns ------- @@ -1484,9 +1496,10 @@ def sel(self, method=None, tolerance=None, drop=False, **indexers): Dataset.isel DataArray.sel """ - pos_indexers, new_indexes = remap_label_indexers(self, method, - tolerance, **indexers) - result = self.isel(drop=drop, **pos_indexers) + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'sel') + pos_indexers, new_indexes = remap_label_indexers( + self, indexers=indexers, method=method, tolerance=tolerance) + result = self.isel(indexers=pos_indexers, drop=drop) return result._replace_indexes(new_indexes) def isel_points(self, dim='points', **indexers): @@ -1734,7 +1747,7 @@ def reindex_like(self, other, method=None, tolerance=None, copy=True): **indexers) def reindex(self, indexers=None, method=None, tolerance=None, copy=True, - **kw_indexers): + **indexers_kwargs): """Conform this object onto a new set of indexes, filling in missing values with NaN. @@ -1745,6 +1758,7 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True, arrays of coordinates tick labels. Any mis-matched coordinate values will be filled in with NaN, and any mis-matched dimension names will simply be ignored. + One of indexers or indexers_kwargs must be provided. method : {None, 'nearest', 'pad'/'ffill', 'backfill'/'bfill'}, optional Method to use for filling index values in ``indexers`` not found in this dataset: @@ -1763,8 +1777,9 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True, ``copy=False`` and reindexing is unnecessary, or can be performed with only slice operations, then the output may share memory with the input. In either case, a new xarray object is always returned. - **kw_indexers : optional + **indexers_kwarg : {dim: indexer, ...}, optional Keyword arguments in the same form as ``indexers``. + One of indexers or indexers_kwargs must be provided. Returns ------- @@ -1777,7 +1792,7 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True, align pandas.Index.get_indexer """ - indexers = utils.combine_pos_and_kw_args(indexers, kw_indexers, + indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, 'reindex') bad_dims = [d for d in indexers if d not in self.dims] diff --git a/xarray/core/utils.py b/xarray/core/utils.py index f6c5830cc9e..c3bb747fac5 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -185,7 +185,7 @@ def is_full_slice(value): return isinstance(value, slice) and value == slice(None) -def combine_pos_and_kw_args(pos_kwargs, kw_kwargs, func_name): +def either_dict_or_kwargs(pos_kwargs, kw_kwargs, func_name): if pos_kwargs is not None: if not is_dict_like(pos_kwargs): raise ValueError('the first argument to .%s must be a dictionary' diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 9dcb99459d4..52d470accfe 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -11,13 +11,13 @@ import xarray as xr # only for Dataset and DataArray from . import ( - arithmetic, common, dtypes, duck_array_ops, indexing, nputils, ops, utils,) + arithmetic, common, dtypes, duck_array_ops, indexing, nputils, ops, utils) from .indexing import ( BasicIndexer, OuterIndexer, PandasIndexAdapter, VectorizedIndexer, as_indexable) from .pycompat import ( OrderedDict, basestring, dask_array_type, integer_types, zip) -from .utils import OrderedSet +from .utils import OrderedSet, either_dict_or_kwargs try: import dask.array as da @@ -824,7 +824,7 @@ def chunk(self, chunks=None, name=None, lock=False): return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True) - def isel(self, **indexers): + def isel(self, indexers=None, drop=False, **indexers_kwargs): """Return a new array indexed along the specified dimension(s). Parameters @@ -841,6 +841,8 @@ def isel(self, **indexers): unless numpy fancy indexing was triggered by using an array indexer, in which case the data will be a copy. """ + indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'isel') + invalid = [k for k in indexers if k not in self.dims] if invalid: raise ValueError("dimensions %r do not exist" % invalid) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index a03d265c3e3..17e02fce7ed 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -17,8 +17,8 @@ from xarray.core.pycompat import OrderedDict, iteritems from xarray.tests import ( ReturnItem, TestCase, assert_allclose, assert_array_equal, assert_equal, - assert_identical, raises_regex, requires_bottleneck, requires_dask, - requires_scipy, source_ndarray, unittest, requires_cftime) + assert_identical, raises_regex, requires_bottleneck, requires_cftime, + requires_dask, requires_scipy, source_ndarray, unittest) class TestDataArray(TestCase): diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 76e41c43c6d..38e2dce1633 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1435,7 +1435,7 @@ def test_sel_method(self): with raises_regex(TypeError, '``method``'): # this should not pass silently - data.sel(data) + data.sel(method=data) # cannot pass method if there is no associated coordinate with raises_regex(ValueError, 'cannot supply'): @@ -4181,6 +4181,12 @@ def test_dir_non_string(data_set): result = dir(data_set) assert not (5 in result) + # GH2172 + sample_data = np.random.uniform(size=[2, 2000, 10000]) + x = xr.Dataset({"sample_data": (sample_data.shape, sample_data)}) + x2 = x["sample_data"] + dir(x2) + def test_dir_unicode(data_set): data_set[u'unicode'] = 'uni' diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 1f73743d01d..ed8045b78e4 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -1,17 +1,21 @@ from __future__ import absolute_import, division, print_function +from datetime import datetime + import numpy as np import pandas as pd import pytest -from datetime import datetime from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import duck_array_ops, utils from xarray.core.options import set_options from xarray.core.pycompat import OrderedDict +from xarray.core.utils import either_dict_or_kwargs + +from . import ( + TestCase, assert_array_equal, has_cftime, has_cftime_or_netCDF4, + requires_dask) from .test_coding_times import _all_cftime_date_types -from . import (TestCase, requires_dask, assert_array_equal, - has_cftime_or_netCDF4, has_cftime) class TestAlias(TestCase): @@ -245,3 +249,17 @@ def test_hidden_key_dict(): hkd[hidden_key] with pytest.raises(KeyError): del hkd[hidden_key] + + +def test_either_dict_or_kwargs(): + + result = either_dict_or_kwargs(dict(a=1), None, 'foo') + expected = dict(a=1) + assert result == expected + + result = either_dict_or_kwargs(None, dict(a=1), 'foo') + expected = dict(a=1) + assert result == expected + + with pytest.raises(ValueError, match=r'foo'): + result = either_dict_or_kwargs(dict(a=1), dict(a=1), 'foo') From fb7a43ea102c7706ad5d3bc8399264155cb273dd Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+maxim-lian@users.noreply.github.com> Date: Mon, 28 May 2018 23:05:08 -0400 Subject: [PATCH 36/61] Rename takes kwargs (#2194) * rename takes kwargs * tests * better check for type of rename --- doc/whats-new.rst | 7 ++++++- xarray/core/dataarray.py | 17 +++++++++++------ xarray/core/dataset.py | 8 ++++++-- xarray/tests/test_dataarray.py | 3 +++ xarray/tests/test_dataset.py | 3 +++ 5 files changed, 29 insertions(+), 9 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 68bf5318bf5..f3af2b399d8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -49,11 +49,16 @@ Enhancements - :py:meth:`~DataArray.sel`, :py:meth:`~DataArray.isel` & :py:meth:`~DataArray.reindex`, (and their :py:class:`Dataset` counterparts) now support supplying a ``dict`` as a first argument, as an alternative to the existing approach - of supplying a set of `kwargs`. This allows for more robust behavior + of supplying `kwargs`. This allows for more robust behavior of dimension names which conflict with other keyword names, or are not strings. By `Maximilian Roos `_. +- :py:meth:`~DataArray.rename` now supports supplying `kwargs`, as an + alternative to the existing approach of supplying a ``dict`` as the + first argument. + By `Maximilian Roos `_. + Bug fixes ~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index da9acb48a7a..01f7c91f3a5 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -19,8 +19,7 @@ from .options import OPTIONS from .pycompat import OrderedDict, basestring, iteritems, range, zip from .utils import ( - either_dict_or_kwargs, decode_numpy_dict_values, - ensure_us_time_resolution) + decode_numpy_dict_values, either_dict_or_kwargs, ensure_us_time_resolution) from .variable import ( IndexVariable, Variable, as_compatible_data, as_variable, assert_unique_multiindex_level_names) @@ -907,16 +906,20 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True, indexers=indexers, method=method, tolerance=tolerance, copy=copy) return self._from_temp_dataset(ds) - def rename(self, new_name_or_name_dict): + def rename(self, new_name_or_name_dict=None, **names): """Returns a new DataArray with renamed coordinates or a new name. Parameters ---------- - new_name_or_name_dict : str or dict-like + new_name_or_name_dict : str or dict-like, optional If the argument is dict-like, it it used as a mapping from old names to new names for coordinates. Otherwise, use the argument as the new name for this array. + **names, optional + The keyword arguments form of a mapping from old names to + new names for coordinates. + One of new_name_or_name_dict or names must be provided. Returns @@ -929,8 +932,10 @@ def rename(self, new_name_or_name_dict): Dataset.rename DataArray.swap_dims """ - if utils.is_dict_like(new_name_or_name_dict): - dataset = self._to_temp_dataset().rename(new_name_or_name_dict) + if names or utils.is_dict_like(new_name_or_name_dict): + name_dict = either_dict_or_kwargs( + new_name_or_name_dict, names, 'rename') + dataset = self._to_temp_dataset().rename(name_dict) return self._from_temp_dataset(dataset) else: return self._replace(name=new_name_or_name_dict) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d6a5ac1c172..08f5f70d72b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1806,17 +1806,20 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True, coord_names.update(indexers) return self._replace_vars_and_dims(variables, coord_names) - def rename(self, name_dict, inplace=False): + def rename(self, name_dict=None, inplace=False, **names): """Returns a new object with renamed variables and dimensions. Parameters ---------- - name_dict : dict-like + name_dict : dict-like, optional Dictionary whose keys are current variable or dimension names and whose values are the desired names. inplace : bool, optional If True, rename variables and dimensions in-place. Otherwise, return a new dataset object. + **names, optional + Keyword form of ``name_dict``. + One of name_dict or names must be provided. Returns ------- @@ -1828,6 +1831,7 @@ def rename(self, name_dict, inplace=False): Dataset.swap_dims DataArray.rename """ + name_dict = either_dict_or_kwargs(name_dict, names, 'rename') for k, v in name_dict.items(): if k not in self and k not in self.dims: raise ValueError("cannot rename %r because it is not a " diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 17e02fce7ed..ef9620e4dc4 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1270,6 +1270,9 @@ def test_rename(self): assert renamed.name == 'z' assert renamed.dims == ('z',) + renamed_kwargs = self.dv.x.rename(x='z').rename('z') + assert_identical(renamed, renamed_kwargs) + def test_swap_dims(self): array = DataArray(np.random.randn(3), {'y': ('x', list('abc'))}, 'x') expected = DataArray(array.values, {'y': list('abc')}, dims='y') diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 38e2dce1633..e64f9859c9e 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1916,6 +1916,9 @@ def test_rename(self): with pytest.raises(UnexpectedDataAccess): renamed['renamed_var1'].values + renamed_kwargs = data.rename(**newnames) + assert_identical(renamed, renamed_kwargs) + def test_rename_old_name(self): # regtest for GH1477 data = create_test_data() From 5ddfee6b194f8bdce8bb42cddfabe3e1c142ef16 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 28 May 2018 20:15:07 -0700 Subject: [PATCH 37/61] Fix DataArray.groupby().reduce() mutating input array (#2169) * Fix DataArray.groupby().reduce() mutating input array Fixes GH2153 * Fix test failure --- doc/whats-new.rst | 5 +++++ xarray/core/dataarray.py | 2 +- xarray/tests/test_computation.py | 2 +- xarray/tests/test_groupby.py | 11 +++++++++++ 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f3af2b399d8..1465b9e68a5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -82,6 +82,11 @@ Bug fixes the index) (:issue:`2165`). By `Spencer Clark `_. +- Fix ``DataArray.groupby().reduce()`` mutating coordinates on the input array + when grouping over dimension coordinates with duplicated entries + (:issue:`2153`). + By `Stephan Hoyer `_ + - Fix Dataset.to_netcdf() cannot create group with engine="h5netcdf" (:issue:`2177`). By `Stephan Hoyer `_ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 01f7c91f3a5..fd2b49cc08a 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -252,7 +252,7 @@ def _replace(self, variable=None, coords=None, name=__default): def _replace_maybe_drop_dims(self, variable, name=__default): if variable.dims == self.dims: - coords = None + coords = self._coords.copy() else: allowed_dims = set(variable.dims) coords = OrderedDict((k, v) for k, v in self._coords.items() diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index c829453cc9d..bbcc02baf5a 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -553,7 +553,7 @@ def test_apply_dask(): array = da.ones((2,), chunks=2) variable = xr.Variable('x', array) coords = xr.DataArray(variable).coords.variables - data_array = xr.DataArray(variable, coords, fastpath=True) + data_array = xr.DataArray(variable, dims=['x'], coords=coords) dataset = xr.Dataset({'y': variable}) # encountered dask array, but did not set dask='allowed' diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index fd53e410583..6dd14f5d6ad 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -5,6 +5,7 @@ import pytest import xarray as xr +from . import assert_identical from xarray.core.groupby import _consolidate_slices @@ -73,4 +74,14 @@ def test_groupby_duplicate_coordinate_labels(): assert expected.equals(actual) +def test_groupby_input_mutation(): + # regression test for GH2153 + array = xr.DataArray([1, 2, 3], [('x', [2, 2, 1])]) + array_copy = array.copy() + expected = xr.DataArray([3, 3], [('x', [1, 2])]) + actual = array.groupby('x').sum() + assert_identical(expected, actual) + assert_identical(array, array_copy) # should not modify inputs + + # TODO: move other groupby tests from test_dataset and test_dataarray over here From 9c8005937556211a8bf28a946744da3768846c5a Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 28 May 2018 21:34:46 -0700 Subject: [PATCH 38/61] Test suite: explicitly ignore irrelevant warnings (#2162) * Test suite: explicitly ignore irrelevant warnings Also includes a fix for Dataset.update() * Cleaner implementation of Dataset.__setitem__ * more lint * Fix dask version check * Fix warning in Dataset.update() and clean-up logic * Fix whats new * More whats new * DeprecationWarning -> FutureWarning for old resample API --- doc/whats-new.rst | 9 +++++-- xarray/conventions.py | 2 +- xarray/core/common.py | 2 +- xarray/core/duck_array_ops.py | 20 +++++++++----- xarray/core/merge.py | 23 ++++++++++------- xarray/tests/__init__.py | 10 +++++-- xarray/tests/test_backends.py | 42 ++++++++++++++++++------------ xarray/tests/test_coding_times.py | 4 ++- xarray/tests/test_conventions.py | 4 ++- xarray/tests/test_dask.py | 3 +-- xarray/tests/test_dataarray.py | 30 +++++++++++++-------- xarray/tests/test_dataset.py | 43 ++++++++++++++++++++++++------- xarray/tests/test_variable.py | 7 +++-- 13 files changed, 132 insertions(+), 67 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1465b9e68a5..aef80a2b30a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -62,10 +62,15 @@ Enhancements Bug fixes ~~~~~~~~~ -- Fixed a bug where `to_netcdf(..., unlimited_dims='bar'` yielded NetCDF files - with spurious 0-length dimensions (i.e. `b`, `a`, and `r`) (:issue:`2134`). +- Fixed a bug where ``to_netcdf(..., unlimited_dims='bar')`` yielded NetCDF + files with spurious 0-length dimensions (i.e. ``b``, ``a``, and ``r``) + (:issue:`2134`). By `Joe Hamman `_. +- Removed spurious warnings with ``Dataset.update(Dataset)`` (:issue:`2161`) + and ``array.equals(array)`` when ``array`` contains ``NaT`` (:issue:`2162`). + By `Stephan Hoyer `_. + - Aggregations with :py:meth:`Dataset.reduce` (including ``mean``, ``sum``, etc) no longer drop unrelated coordinates (:issue:`1470`). Also fixed a bug where non-scalar data-variables that did not include the aggregation diff --git a/xarray/conventions.py b/xarray/conventions.py index ed90c34387b..6171c353a0d 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -89,7 +89,7 @@ def maybe_encode_nonstring_dtype(var, name=None): warnings.warn('saving variable %s with floating ' 'point data as an integer dtype without ' 'any _FillValue to use for NaNs' % name, - SerializationWarning, stacklevel=3) + SerializationWarning, stacklevel=10) data = duck_array_ops.around(data)[...] data = data.astype(dtype=dtype) var = Variable(dims, data, attrs, encoding) diff --git a/xarray/core/common.py b/xarray/core/common.py index f623091ebdb..d69c60eed56 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -696,7 +696,7 @@ def _resample_immediately(self, freq, dim, how, skipna, "how=\"{how}\", instead consider using " ".resample({dim}=\"{freq}\").{how}('{dim}') ".format( dim=dim, freq=freq, how=how), - DeprecationWarning, stacklevel=3) + FutureWarning, stacklevel=3) if isinstance(dim, basestring): dim = self[dim] diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 69b0d0825be..065ac165a0d 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -145,10 +145,13 @@ def array_equiv(arr1, arr2): if arr1.shape != arr2.shape: return False - flag_array = (arr1 == arr2) - flag_array |= (isnull(arr1) & isnull(arr2)) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', "In the future, 'NAT == x'") - return bool(flag_array.all()) + flag_array = (arr1 == arr2) + flag_array |= (isnull(arr1) & isnull(arr2)) + + return bool(flag_array.all()) def array_notnull_equiv(arr1, arr2): @@ -159,11 +162,14 @@ def array_notnull_equiv(arr1, arr2): if arr1.shape != arr2.shape: return False - flag_array = (arr1 == arr2) - flag_array |= isnull(arr1) - flag_array |= isnull(arr2) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', "In the future, 'NAT == x'") + + flag_array = (arr1 == arr2) + flag_array |= isnull(arr1) + flag_array |= isnull(arr2) - return bool(flag_array.all()) + return bool(flag_array.all()) def count(data, axis=None): diff --git a/xarray/core/merge.py b/xarray/core/merge.py index d3c9871abef..40a58d6e84e 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -547,21 +547,24 @@ def dataset_merge_method(dataset, other, overwrite_vars, compat, join): def dataset_update_method(dataset, other): - """Guts of the Dataset.update method + """Guts of the Dataset.update method. - This drops a duplicated coordinates from `other` (GH:2068) + This drops a duplicated coordinates from `other` if `other` is not an + `xarray.Dataset`, e.g., if it's a dict with DataArray values (GH2068, + GH2180). """ from .dataset import Dataset from .dataarray import DataArray - other = other.copy() - for k, obj in other.items(): - if isinstance(obj, (Dataset, DataArray)): - # drop duplicated coordinates - coord_names = [c for c in obj.coords - if c not in obj.dims and c in dataset.coords] - if coord_names: - other[k] = obj.drop(coord_names) + if not isinstance(other, Dataset): + other = OrderedDict(other) + for key, value in other.items(): + if isinstance(value, DataArray): + # drop conflicting coordinates + coord_names = [c for c in value.coords + if c not in value.dims and c in dataset.coords] + if coord_names: + other[key] = value.drop(coord_names) return merge_core([dataset, other], priority_arg=1, indexes=dataset.indexes) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 7584ed79a06..3acd26235ce 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -87,7 +87,10 @@ def _importorskip(modname, minversion=None): has_pathlib, requires_pathlib = _importorskip('pathlib2') if has_dask: import dask - dask.set_options(get=dask.get) + if LooseVersion(dask.__version__) < '0.18': + dask.set_options(get=dask.get) + else: + dask.config.set(scheduler='sync') try: import_seaborn() has_seaborn = True @@ -191,7 +194,10 @@ def source_ndarray(array): """Given an ndarray, return the base object which holds its memory, or the object itself. """ - base = getattr(array, 'base', np.asarray(array).base) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'DatetimeIndex.base') + warnings.filterwarnings('ignore', 'TimedeltaIndex.base') + base = getattr(array, 'base', np.asarray(array).base) if base is None: base = array return base diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 0768a942a77..b80cb18e2be 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -363,21 +363,26 @@ def test_roundtrip_cftime_datetime_data_enable_cftimeindex(self): expected_decoded_t0 = np.array([date_type(1, 1, 1)]) expected_calendar = times[0].calendar - with xr.set_options(enable_cftimeindex=True): - with self.roundtrip(expected, save_kwargs=kwds) as actual: - abs_diff = abs(actual.t.values - expected_decoded_t) - assert (abs_diff <= np.timedelta64(1, 's')).all() - assert (actual.t.encoding['units'] == - 'days since 0001-01-01 00:00:00.000000') - assert (actual.t.encoding['calendar'] == - expected_calendar) - - abs_diff = abs(actual.t0.values - expected_decoded_t0) - assert (abs_diff <= np.timedelta64(1, 's')).all() - assert (actual.t0.encoding['units'] == - 'days since 0001-01-01') - assert (actual.t.encoding['calendar'] == - expected_calendar) + with warnings.catch_warnings(): + if expected_calendar in {'proleptic_gregorian', 'gregorian'}: + warnings.filterwarnings( + 'ignore', 'Unable to decode time axis') + + with xr.set_options(enable_cftimeindex=True): + with self.roundtrip(expected, save_kwargs=kwds) as actual: + abs_diff = abs(actual.t.values - expected_decoded_t) + assert (abs_diff <= np.timedelta64(1, 's')).all() + assert (actual.t.encoding['units'] == + 'days since 0001-01-01 00:00:00.000000') + assert (actual.t.encoding['calendar'] == + expected_calendar) + + abs_diff = abs(actual.t0.values - expected_decoded_t0) + assert (abs_diff <= np.timedelta64(1, 's')).all() + assert (actual.t0.encoding['units'] == + 'days since 0001-01-01') + assert (actual.t.encoding['calendar'] == + expected_calendar) def test_roundtrip_timedelta_data(self): time_deltas = pd.to_timedelta(['1h', '2h', 'NaT']) @@ -767,8 +772,11 @@ def test_default_fill_value(self): # Test default encoding for int: ds = Dataset({'x': ('y', np.arange(10.0))}) kwargs = dict(encoding={'x': {'dtype': 'int16'}}) - with self.roundtrip(ds, save_kwargs=kwargs) as actual: - self.assertTrue('_FillValue' not in actual.x.encoding) + with warnings.catch_warnings(): + warnings.filterwarnings( + 'ignore', '.*floating point data as an integer') + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + self.assertTrue('_FillValue' not in actual.x.encoding) self.assertEqual(ds.x.encoding, {}) # Test default encoding for implicit int: diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 6329e91ac78..4d6ca731bb2 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -130,7 +130,9 @@ def test_decode_cf_datetime_overflow(): expected = (datetime(1677, 12, 31), datetime(2262, 4, 12)) for i, day in enumerate(days): - result = coding.times.decode_cf_datetime(day, units) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'Unable to decode time axis') + result = coding.times.decode_cf_datetime(day, units) assert result == expected[i] diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index 62ff8d7ee1a..acc1c978579 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -219,7 +219,9 @@ def test_decode_cf_datetime_transition_to_invalid(self): ds = Dataset(coords={'time': [0, 266 * 365]}) units = 'days since 2000-01-01 00:00:00' ds.time.attrs = dict(units=units) - ds_decoded = conventions.decode_cf(ds) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'unable to decode time') + ds_decoded = conventions.decode_cf(ds) expected = [datetime(2000, 1, 1, 0, 0), datetime(2265, 10, 28, 0, 0)] diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 1e4f313897b..ee5b3514348 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -459,8 +459,7 @@ def counting_get(*args, **kwargs): count[0] += 1 return dask.get(*args, **kwargs) - with dask.set_options(get=counting_get): - ds.load() + ds.load(get=counting_get) assert count[0] == 1 def test_stack(self): diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index ef9620e4dc4..d339e6402b6 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4,6 +4,7 @@ from copy import deepcopy from distutils.version import LooseVersion from textwrap import dedent +import warnings import numpy as np import pandas as pd @@ -321,11 +322,14 @@ def test_constructor_from_self_described(self): actual = DataArray(series) assert_equal(expected[0].reset_coords('x', drop=True), actual) - panel = pd.Panel({0: frame}) - actual = DataArray(panel) - expected = DataArray([data], expected.coords, ['dim_0', 'x', 'y']) - expected['dim_0'] = [0] - assert_identical(expected, actual) + if hasattr(pd, 'Panel'): + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', r'\W*Panel is deprecated') + panel = pd.Panel({0: frame}) + actual = DataArray(panel) + expected = DataArray([data], expected.coords, ['dim_0', 'x', 'y']) + expected['dim_0'] = [0] + assert_identical(expected, actual) expected = DataArray(data, coords={'x': ['a', 'b'], 'y': [-1, -2], @@ -2320,7 +2324,7 @@ def test_resample_old_vs_new_api(self): array = DataArray(np.ones(10), [('time', times)]) # Simple mean - with pytest.warns(DeprecationWarning): + with pytest.warns(FutureWarning): old_mean = array.resample('1D', 'time', how='mean') new_mean = array.resample(time='1D').mean() assert_identical(old_mean, new_mean) @@ -2329,7 +2333,7 @@ def test_resample_old_vs_new_api(self): attr_array = array.copy() attr_array.attrs['meta'] = 'data' - with pytest.warns(DeprecationWarning): + with pytest.warns(FutureWarning): old_mean = attr_array.resample('1D', dim='time', how='mean', keep_attrs=True) new_mean = attr_array.resample(time='1D').mean(keep_attrs=True) @@ -2340,7 +2344,7 @@ def test_resample_old_vs_new_api(self): nan_array = array.copy() nan_array[1] = np.nan - with pytest.warns(DeprecationWarning): + with pytest.warns(FutureWarning): old_mean = nan_array.resample('1D', 'time', how='mean', skipna=False) new_mean = nan_array.resample(time='1D').mean(skipna=False) @@ -2354,12 +2358,12 @@ def test_resample_old_vs_new_api(self): # Discard attributes on the call using the new api to match # convention from old api new_api = getattr(resampler, method)(keep_attrs=False) - with pytest.warns(DeprecationWarning): + with pytest.warns(FutureWarning): old_api = array.resample('1D', dim='time', how=method) assert_identical(new_api, old_api) for method in [np.mean, np.sum, np.max, np.min]: new_api = resampler.reduce(method) - with pytest.warns(DeprecationWarning): + with pytest.warns(FutureWarning): old_api = array.resample('1D', dim='time', how=method) assert_identical(new_api, old_api) @@ -2713,9 +2717,13 @@ def test_to_pandas(self): # roundtrips for shape in [(3,), (3, 4), (3, 4, 5)]: + if len(shape) > 2 and not hasattr(pd, 'Panel'): + continue dims = list('abc')[:len(shape)] da = DataArray(np.random.randn(*shape), dims=dims) - roundtripped = DataArray(da.to_pandas()).drop(dims) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', r'\W*Panel is deprecated') + roundtripped = DataArray(da.to_pandas()).drop(dims) assert_identical(da, roundtripped) with raises_regex(ValueError, 'cannot convert'): diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index e64f9859c9e..2dad40ae8f6 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5,6 +5,7 @@ from distutils.version import LooseVersion from io import StringIO from textwrap import dedent +import warnings import numpy as np import pandas as pd @@ -338,14 +339,20 @@ def test_constructor_pandas_single(self): das = [ DataArray(np.random.rand(4), dims=['a']), # series DataArray(np.random.rand(4, 3), dims=['a', 'b']), # df - DataArray(np.random.rand(4, 3, 2), dims=['a', 'b', 'c']), # panel ] - for a in das: - pandas_obj = a.to_pandas() - ds_based_on_pandas = Dataset(pandas_obj) - for dim in ds_based_on_pandas.data_vars: - assert_array_equal(ds_based_on_pandas[dim], pandas_obj[dim]) + if hasattr(pd, 'Panel'): + das.append( + DataArray(np.random.rand(4, 3, 2), dims=['a', 'b', 'c'])) + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', r'\W*Panel is deprecated') + for a in das: + pandas_obj = a.to_pandas() + ds_based_on_pandas = Dataset(pandas_obj) + for dim in ds_based_on_pandas.data_vars: + assert_array_equal( + ds_based_on_pandas[dim], pandas_obj[dim]) def test_constructor_compat(self): data = OrderedDict([('x', DataArray(0, coords={'y': 1})), @@ -2139,6 +2146,22 @@ def test_update(self): actual.update(other) assert_identical(expected, actual) + def test_update_overwrite_coords(self): + data = Dataset({'a': ('x', [1, 2])}, {'b': 3}) + data.update(Dataset(coords={'b': 4})) + expected = Dataset({'a': ('x', [1, 2])}, {'b': 4}) + assert_identical(data, expected) + + data = Dataset({'a': ('x', [1, 2])}, {'b': 3}) + data.update(Dataset({'c': 5}, coords={'b': 4})) + expected = Dataset({'a': ('x', [1, 2]), 'c': 5}, {'b': 4}) + assert_identical(data, expected) + + data = Dataset({'a': ('x', [1, 2])}, {'b': 3}) + data.update({'c': DataArray(5, coords={'b': 4})}) + expected = Dataset({'a': ('x', [1, 2]), 'c': 5}, {'b': 3}) + assert_identical(data, expected) + def test_update_auto_align(self): ds = Dataset({'x': ('t', [3, 4])}, {'t': [0, 1]}) @@ -2343,14 +2366,14 @@ def test_setitem_with_coords(self): actual = ds.copy() actual['var3'] = other assert_identical(expected, actual) - assert 'numbers' in other # should not change other + assert 'numbers' in other.coords # should not change other # with alignment other = ds['var3'].isel(dim3=slice(1, -1)) other['numbers'] = ('dim3', np.arange(8)) actual = ds.copy() actual['var3'] = other - assert 'numbers' in other # should not change other + assert 'numbers' in other.coords # should not change other expected = ds.copy() expected['var3'] = ds['var3'].isel(dim3=slice(1, -1)) assert_identical(expected, actual) @@ -2362,7 +2385,7 @@ def test_setitem_with_coords(self): actual = ds.copy() actual['var3'] = other assert 'position' in actual - assert 'position' in other + assert 'position' in other.coords # assigning a coordinate-only dataarray actual = ds.copy() @@ -2774,7 +2797,7 @@ def test_resample_old_vs_new_api(self): # Discard attributes on the call using the new api to match # convention from old api new_api = getattr(resampler, method)(keep_attrs=False) - with pytest.warns(DeprecationWarning): + with pytest.warns(FutureWarning): old_api = ds.resample('1D', dim='time', how=method) assert_identical(new_api, old_api) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 722d1af14f7..c486a394ae6 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -6,6 +6,7 @@ from datetime import datetime, timedelta from distutils.version import LooseVersion from textwrap import dedent +import warnings import numpy as np import pandas as pd @@ -138,8 +139,10 @@ def _assertIndexedLikeNDArray(self, variable, expected_value0, assert variable.equals(variable.copy()) assert variable.identical(variable.copy()) # check value is equal for both ndarray and Variable - assert variable.values[0] == expected_value0 - assert variable[0].values == expected_value0 + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', "In the future, 'NAT == x'") + assert variable.values[0] == expected_value0 + assert variable[0].values == expected_value0 # check type or dtype is consistent for both ndarray and Variable if expected_dtype is None: # check output type instead of array dtype From 7036eb5b629f2112da9aa13538aecb07f0f83f5a Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+maxim-lian@users.noreply.github.com> Date: Wed, 30 May 2018 20:22:58 -0400 Subject: [PATCH 39/61] Align DataArrays based on coords in Dataset constructor (#1826) * Align da based on explicit coords * add tags to gitignore * random formatting spot * initial tests * apply the test to the right degree of freedom - the coords of the variable added in * couple more for gitignore * @stickler-ci doesn't like `range` * more tests * more gitignores * whats new * raise * message= * Add all testmon files to gitignore * cast single dim tuples to indexes * test on different dataset coords types * updated whatsnew * version from Stephan's feedback; works but not clean * I think much cleaner version * formatting --- .gitignore | 1 + doc/whats-new.rst | 5 +++++ xarray/core/merge.py | 12 ++++++++++- xarray/tests/test_dataset.py | 41 +++++++++++++++++++++++++++++++++++- 4 files changed, 57 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 92e488ed616..2a016bb9228 100644 --- a/.gitignore +++ b/.gitignore @@ -38,6 +38,7 @@ nosetests.xml .ropeproject/ .tags* .testmon* +.tmontmp/ .pytest_cache dask-worker-space/ diff --git a/doc/whats-new.rst b/doc/whats-new.rst index aef80a2b30a..37b022f45c8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -46,6 +46,11 @@ Enhancements to manage its version strings. (:issue:`1300`). By `Joe Hamman `_. +- `:py:class:`Dataset`s align `:py:class:`DataArray`s to coords that are explicitly + passed into the constructor, where previously an error would be raised. + (:issue:`674`) + By `Maximilian Roos Date: Thu, 31 May 2018 08:40:03 -0700 Subject: [PATCH 40/61] Validate output dimension sizes with apply_ufunc (#2155) * Validate output dimension sizes with apply_ufunc Fixes GH1931 Uses of apply_ufunc that change dimension size now raise an explicit error, e.g., >>> xr.apply_ufunc(lambda x: x[:5], xr.Variable('x', np.arange(10))) ValueError: size of dimension 'x' on inputs was unexpectedly changed by applied function from 10 to 5. Only dimensions specified in ``exclude_dims`` with xarray.apply_ufunc are allowed to change size. * lint * More output validation for apply_ufunc --- doc/whats-new.rst | 4 ++ xarray/core/computation.py | 90 ++++++++++++++++++++++++-------- xarray/tests/test_computation.py | 88 +++++++++++++++++++++++++++++++ 3 files changed, 160 insertions(+), 22 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 37b022f45c8..f3db96f99bf 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -67,6 +67,10 @@ Enhancements Bug fixes ~~~~~~~~~ +- :py:func:`apply_ufunc` now directly validates output variables + (:issue:`1931`). + By `Stephan Hoyer `_. + - Fixed a bug where ``to_netcdf(..., unlimited_dims='bar')`` yielded NetCDF files with spurious 0-length dimensions (i.e. ``b``, ``a``, and ``r``) (:issue:`2134`). diff --git a/xarray/core/computation.py b/xarray/core/computation.py index f6e22dfe6c1..ebbce114ec3 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -513,7 +513,7 @@ def broadcast_compat_data(variable, broadcast_dims, core_dims): def apply_variable_ufunc(func, *args, **kwargs): """apply_variable_ufunc(func, *args, signature, exclude_dims=frozenset()) """ - from .variable import Variable + from .variable import Variable, as_compatible_data signature = kwargs.pop('signature') exclude_dims = kwargs.pop('exclude_dims', _DEFAULT_FROZEN_SET) @@ -559,20 +559,42 @@ def func(*arrays): 'apply_ufunc: {}'.format(dask)) result_data = func(*input_data) - if signature.num_outputs > 1: - output = [] - for dims, data in zip(output_dims, result_data): - var = Variable(dims, data) - if keep_attrs and isinstance(args[0], Variable): - var.attrs.update(args[0].attrs) - output.append(var) - return tuple(output) - else: - dims, = output_dims - var = Variable(dims, result_data) + if signature.num_outputs == 1: + result_data = (result_data,) + elif (not isinstance(result_data, tuple) or + len(result_data) != signature.num_outputs): + raise ValueError('applied function does not have the number of ' + 'outputs specified in the ufunc signature. ' + 'Result is not a tuple of {} elements: {!r}' + .format(signature.num_outputs, result_data)) + + output = [] + for dims, data in zip(output_dims, result_data): + data = as_compatible_data(data) + if data.ndim != len(dims): + raise ValueError( + 'applied function returned data with unexpected ' + 'number of dimensions: {} vs {}, for dimensions {}' + .format(data.ndim, len(dims), dims)) + + var = Variable(dims, data, fastpath=True) + for dim, new_size in var.sizes.items(): + if dim in dim_sizes and new_size != dim_sizes[dim]: + raise ValueError( + 'size of dimension {!r} on inputs was unexpectedly ' + 'changed by applied function from {} to {}. Only ' + 'dimensions specified in ``exclude_dims`` with ' + 'xarray.apply_ufunc are allowed to change size.' + .format(dim, dim_sizes[dim], new_size)) + if keep_attrs and isinstance(args[0], Variable): var.attrs.update(args[0].attrs) - return var + output.append(var) + + if signature.num_outputs == 1: + return output[0] + else: + return tuple(output) def _apply_with_dask_atop(func, args, input_dims, output_dims, signature, @@ -719,7 +741,8 @@ def apply_ufunc(func, *args, **kwargs): Core dimensions on the inputs to exclude from alignment and broadcasting entirely. Any input coordinates along these dimensions will be dropped. Each excluded dimension must also appear in - ``input_core_dims`` for at least one argument. + ``input_core_dims`` for at least one argument. Only dimensions listed + here are allowed to change size between input and output objects. vectorize : bool, optional If True, then assume ``func`` only takes arrays defined over core dimensions as input and vectorize it automatically with @@ -777,15 +800,38 @@ def apply_ufunc(func, *args, **kwargs): Examples -------- - For illustrative purposes only, here are examples of how you could use - ``apply_ufunc`` to write functions to (very nearly) replicate existing - xarray functionality: - Calculate the vector magnitude of two arguments:: + Calculate the vector magnitude of two arguments: + + >>> def magnitude(a, b): + ... func = lambda x, y: np.sqrt(x ** 2 + y ** 2) + ... return xr.apply_ufunc(func, a, b) + + You can now apply ``magnitude()`` to ``xr.DataArray`` and ``xr.Dataset`` + objects, with automatically preserved dimensions and coordinates, e.g., + + >>> array = xr.DataArray([1, 2, 3], coords=[('x', [0.1, 0.2, 0.3])]) + >>> magnitude(array, -array) + + array([1.414214, 2.828427, 4.242641]) + Coordinates: + * x (x) float64 0.1 0.2 0.3 + + Plain scalars, numpy arrays and a mix of these with xarray objects is also + supported: + + >>> magnitude(4, 5) + 5.0 + >>> magnitude(3, np.array([0, 4])) + array([3., 5.]) + >>> magnitude(array, 0) + + array([1., 2., 3.]) + Coordinates: + * x (x) float64 0.1 0.2 0.3 - def magnitude(a, b): - func = lambda x, y: np.sqrt(x ** 2 + y ** 2) - return xr.apply_func(func, a, b) + Other examples of how you could use ``apply_ufunc`` to write functions to + (very nearly) replicate existing xarray functionality: Compute the mean (``.mean``) over one dimension:: @@ -795,7 +841,7 @@ def mean(obj, dim): input_core_dims=[[dim]], kwargs={'axis': -1}) - Inner product over a specific dimension:: + Inner product over a specific dimension (like ``xr.dot``):: def _inner(x, y): result = np.matmul(x[..., np.newaxis, :], y[..., :, np.newaxis]) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index bbcc02baf5a..37f97a81f82 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -752,6 +752,94 @@ def test_vectorize_dask(): assert_identical(expected, actual) +def test_output_wrong_number(): + variable = xr.Variable('x', np.arange(10)) + + def identity(x): + return x + + def tuple3x(x): + return (x, x, x) + + with raises_regex(ValueError, 'number of outputs'): + apply_ufunc(identity, variable, output_core_dims=[(), ()]) + + with raises_regex(ValueError, 'number of outputs'): + apply_ufunc(tuple3x, variable, output_core_dims=[(), ()]) + + +def test_output_wrong_dims(): + variable = xr.Variable('x', np.arange(10)) + + def add_dim(x): + return x[..., np.newaxis] + + def remove_dim(x): + return x[..., 0] + + with raises_regex(ValueError, 'unexpected number of dimensions'): + apply_ufunc(add_dim, variable, output_core_dims=[('y', 'z')]) + + with raises_regex(ValueError, 'unexpected number of dimensions'): + apply_ufunc(add_dim, variable) + + with raises_regex(ValueError, 'unexpected number of dimensions'): + apply_ufunc(remove_dim, variable) + + +def test_output_wrong_dim_size(): + array = np.arange(10) + variable = xr.Variable('x', array) + data_array = xr.DataArray(variable, [('x', -array)]) + dataset = xr.Dataset({'y': variable}, {'x': -array}) + + def truncate(array): + return array[:5] + + def apply_truncate_broadcast_invalid(obj): + return apply_ufunc(truncate, obj) + + with raises_regex(ValueError, 'size of dimension'): + apply_truncate_broadcast_invalid(variable) + with raises_regex(ValueError, 'size of dimension'): + apply_truncate_broadcast_invalid(data_array) + with raises_regex(ValueError, 'size of dimension'): + apply_truncate_broadcast_invalid(dataset) + + def apply_truncate_x_x_invalid(obj): + return apply_ufunc(truncate, obj, input_core_dims=[['x']], + output_core_dims=[['x']]) + + with raises_regex(ValueError, 'size of dimension'): + apply_truncate_x_x_invalid(variable) + with raises_regex(ValueError, 'size of dimension'): + apply_truncate_x_x_invalid(data_array) + with raises_regex(ValueError, 'size of dimension'): + apply_truncate_x_x_invalid(dataset) + + def apply_truncate_x_z(obj): + return apply_ufunc(truncate, obj, input_core_dims=[['x']], + output_core_dims=[['z']]) + + assert_identical(xr.Variable('z', array[:5]), + apply_truncate_x_z(variable)) + assert_identical(xr.DataArray(array[:5], dims=['z']), + apply_truncate_x_z(data_array)) + assert_identical(xr.Dataset({'y': ('z', array[:5])}), + apply_truncate_x_z(dataset)) + + def apply_truncate_x_x_valid(obj): + return apply_ufunc(truncate, obj, input_core_dims=[['x']], + output_core_dims=[['x']], exclude_dims={'x'}) + + assert_identical(xr.Variable('x', array[:5]), + apply_truncate_x_x_valid(variable)) + assert_identical(xr.DataArray(array[:5], dims=['x']), + apply_truncate_x_x_valid(data_array)) + assert_identical(xr.Dataset({'y': ('x', array[:5])}), + apply_truncate_x_x_valid(dataset)) + + @pytest.mark.parametrize('use_dask', [True, False]) def test_dot(use_dask): if use_dask: From 9d60897a6544d3a2d4b9b3b64008b2bc316d8b98 Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Fri, 1 Jun 2018 10:01:33 +0900 Subject: [PATCH 41/61] Support dot with older dask (#2205) * Support dot with older dask * add an if block for non-dask environment --- doc/whats-new.rst | 12 ++++++++---- xarray/core/computation.py | 13 +++++++++++++ xarray/tests/test_computation.py | 8 +++++--- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f3db96f99bf..e3c4b050812 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,6 +37,10 @@ Documentation Enhancements ~~~~~~~~~~~~ +- `:py:meth:`~DataArray.dot` and :py:func:`~dot` are partly supported with older + dask<0.17.4. (related to :issue:`2203`) + By `Keisuke Fujii `_. - :py:meth:`~DataArray.rename` now supports supplying `kwargs`, as an - alternative to the existing approach of supplying a ``dict`` as the + alternative to the existing approach of supplying a ``dict`` as the first argument. By `Maximilian Roos `_. @@ -100,7 +104,7 @@ Bug fixes when grouping over dimension coordinates with duplicated entries (:issue:`2153`). By `Stephan Hoyer `_ - + - Fix Dataset.to_netcdf() cannot create group with engine="h5netcdf" (:issue:`2177`). By `Stephan Hoyer `_ diff --git a/xarray/core/computation.py b/xarray/core/computation.py index ebbce114ec3..6a49610cb7b 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1061,6 +1061,19 @@ def dot(*arrays, **kwargs): output_core_dims = [tuple(d for d in all_dims if d not in dims + broadcast_dims)] + # older dask than 0.17.4, we use tensordot if possible. + if isinstance(arr.data, dask_array_type): + import dask + if LooseVersion(dask.__version__) < LooseVersion('0.17.4'): + if len(broadcast_dims) == 0 and len(arrays) == 2: + axes = [[arr.get_axis_num(d) for d in arr.dims if d in dims] + for arr in arrays] + return apply_ufunc(duck_array_ops.tensordot, *arrays, + dask='allowed', + input_core_dims=input_core_dims, + output_core_dims=output_core_dims, + kwargs={'axes': axes}) + # construct einsum subscripts, such as '...abc,...ab->...c' # Note: input_core_dims are always moved to the last position subscripts_list = ['...' + ''.join([dim_map[d] for d in ds]) for ds diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 37f97a81f82..a802b91a3db 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -845,9 +845,6 @@ def test_dot(use_dask): if use_dask: if not has_dask: pytest.skip('test for dask.') - import dask - if LooseVersion(dask.__version__) < LooseVersion('0.17.3'): - pytest.skip("needs dask.array.einsum") a = np.arange(30 * 4).reshape(30, 4) b = np.arange(30 * 4 * 5).reshape(30, 4, 5) @@ -872,6 +869,11 @@ def test_dot(use_dask): assert (actual.data == np.einsum('ij,ijk->k', a, b)).all() assert isinstance(actual.variable.data, type(da_a.variable.data)) + if use_dask: + import dask + if LooseVersion(dask.__version__) < LooseVersion('0.17.3'): + pytest.skip("needs dask.array.einsum") + # for only a single array is passed without dims argument, just return # as is actual = xr.dot(da_a) From 4106b949091d06f96ac3c1d07e95917f235bfb5f Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 31 May 2018 18:09:38 -0700 Subject: [PATCH 42/61] Fix dtype=S1 encoding in to_netcdf() (#2158) * Fix dtype=S1 encoding in to_netcdf() Fixes GH2149 * Add test_encoding_kwarg_compression from crusaderky * Fix dtype=S1 in kwargs for bytes, too * Fix lint * Move compression encoding kwarg test * Remvoe no longer relevant chanegs * Fix encoding dtype=str * More lint * Fix failed tests * Review comments * oops, we still need to skip that test * check for presence in a tuple rather than making two comparisons --- doc/whats-new.rst | 5 +++ xarray/backends/h5netcdf_.py | 19 +++++++++-- xarray/backends/netCDF4_.py | 27 ++++++++++++--- xarray/coding/strings.py | 7 ++-- xarray/conventions.py | 10 ++---- xarray/tests/test_backends.py | 57 ++++++++++++++++++++++++++++++-- xarray/tests/test_conventions.py | 7 ++++ 7 files changed, 114 insertions(+), 18 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e3c4b050812..c4c8db243d4 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -71,6 +71,11 @@ Enhancements Bug fixes ~~~~~~~~~ +- Fixed a regression in 0.10.4, where explicitly specifying ``dtype='S1'`` or + ``dtype=str`` in ``encoding`` with ``to_netcdf()`` raised an error + (:issue:`2149`). + `Stephan Hoyer `_ + - :py:func:`apply_ufunc` now directly validates output variables (:issue:`1931`). By `Stephan Hoyer `_. diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 6b3cd9ebb15..ecc83e98691 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -94,6 +94,8 @@ def __init__(self, filename, mode='r', format=None, group=None, super(H5NetCDFStore, self).__init__(writer, lock=lock) def open_store_variable(self, name, var): + import h5py + with self.ensure_open(autoclose=False): dimensions = var.dimensions data = indexing.LazilyOuterIndexedArray( @@ -119,6 +121,15 @@ def open_store_variable(self, name, var): encoding['source'] = self._filename encoding['original_shape'] = var.shape + vlen_dtype = h5py.check_dtype(vlen=var.dtype) + if vlen_dtype is unicode_type: + encoding['dtype'] = str + elif vlen_dtype is not None: # pragma: no cover + # xarray doesn't support writing arbitrary vlen dtypes yet. + pass + else: + encoding['dtype'] = var.dtype + return Variable(dimensions, data, attrs, encoding) def get_variables(self): @@ -161,7 +172,8 @@ def prepare_variable(self, name, variable, check_encoding=False, import h5py attrs = variable.attrs.copy() - dtype = _get_datatype(variable) + dtype = _get_datatype( + variable, raise_on_invalid_encoding=check_encoding) fillvalue = attrs.pop('_FillValue', None) if dtype is str and fillvalue is not None: @@ -189,8 +201,9 @@ def prepare_variable(self, name, variable, check_encoding=False, raise ValueError("'zlib' and 'compression' encodings mismatch") encoding.setdefault('compression', 'gzip') - if (check_encoding and encoding.get('complevel') not in - (None, encoding.get('compression_opts'))): + if (check_encoding and + 'complevel' in encoding and 'compression_opts' in encoding and + encoding['complevel'] != encoding['compression_opts']): raise ValueError("'complevel' and 'compression_opts' encodings " "mismatch") complevel = encoding.pop('complevel', 0) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 5391a890fb3..d26b2b5321e 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -89,16 +89,33 @@ def _encode_nc4_variable(var): return var -def _get_datatype(var, nc_format='NETCDF4'): +def _check_encoding_dtype_is_vlen_string(dtype): + if dtype is not str: + raise AssertionError( # pragma: no cover + "unexpected dtype encoding %r. This shouldn't happen: please " + "file a bug report at github.com/pydata/xarray" % dtype) + + +def _get_datatype(var, nc_format='NETCDF4', raise_on_invalid_encoding=False): if nc_format == 'NETCDF4': datatype = _nc4_dtype(var) else: + if 'dtype' in var.encoding: + encoded_dtype = var.encoding['dtype'] + _check_encoding_dtype_is_vlen_string(encoded_dtype) + if raise_on_invalid_encoding: + raise ValueError( + 'encoding dtype=str for vlen strings is only supported ' + 'with format=\'NETCDF4\'.') datatype = var.dtype return datatype def _nc4_dtype(var): - if coding.strings.is_unicode_dtype(var.dtype): + if 'dtype' in var.encoding: + dtype = var.encoding.pop('dtype') + _check_encoding_dtype_is_vlen_string(dtype) + elif coding.strings.is_unicode_dtype(var.dtype): dtype = str elif var.dtype.kind in ['i', 'u', 'f', 'c', 'S']: dtype = var.dtype @@ -172,7 +189,7 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, safe_to_drop = set(['source', 'original_shape']) valid_encodings = set(['zlib', 'complevel', 'fletcher32', 'contiguous', - 'chunksizes', 'shuffle', '_FillValue']) + 'chunksizes', 'shuffle', '_FillValue', 'dtype']) if lsd_okay: valid_encodings.add('least_significant_digit') if h5py_okay: @@ -344,6 +361,7 @@ def open_store_variable(self, name, var): # save source so __repr__ can detect if it's local or not encoding['source'] = self._filename encoding['original_shape'] = var.shape + encoding['dtype'] = var.dtype return Variable(dimensions, data, attributes, encoding) @@ -398,7 +416,8 @@ def encode_variable(self, variable): def prepare_variable(self, name, variable, check_encoding=False, unlimited_dims=None): - datatype = _get_datatype(variable, self.format) + datatype = _get_datatype(variable, self.format, + raise_on_invalid_encoding=check_encoding) attrs = variable.attrs.copy() fill_value = attrs.pop('_FillValue', None) diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index 08edeed4153..87b17d9175e 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -43,7 +43,10 @@ def encode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_encoding(variable) contains_unicode = is_unicode_dtype(data.dtype) - encode_as_char = 'dtype' in encoding and encoding['dtype'] == 'S1' + encode_as_char = encoding.get('dtype') == 'S1' + + if encode_as_char: + del encoding['dtype'] # no longer relevant if contains_unicode and (encode_as_char or not self.allows_unicode): if '_FillValue' in attrs: @@ -100,7 +103,7 @@ def encode(self, variable, name=None): variable = ensure_fixed_length_bytes(variable) dims, data, attrs, encoding = unpack_for_encoding(variable) - if data.dtype.kind == 'S': + if data.dtype.kind == 'S' and encoding.get('dtype') is not str: data = bytes_to_char(data) dims = dims + ('string%s' % data.shape[-1],) return Variable(dims, data, attrs, encoding) diff --git a/xarray/conventions.py b/xarray/conventions.py index 6171c353a0d..67dcb8d6d4e 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -79,7 +79,8 @@ def _var_as_tuple(var): def maybe_encode_nonstring_dtype(var, name=None): - if 'dtype' in var.encoding and var.encoding['dtype'] != 'S1': + if ('dtype' in var.encoding and + var.encoding['dtype'] not in ('S1', str)): dims, data, attrs, encoding = _var_as_tuple(var) dtype = np.dtype(encoding.pop('dtype')) if dtype != var.dtype: @@ -307,12 +308,7 @@ def decode_cf_variable(name, var, concat_characters=True, mask_and_scale=True, data = NativeEndiannessArray(data) original_dtype = data.dtype - if 'dtype' in encoding: - if original_dtype != encoding['dtype']: - warnings.warn("CF decoding is overwriting dtype on variable {!r}" - .format(name)) - else: - encoding['dtype'] = original_dtype + encoding.setdefault('dtype', original_dtype) if 'dtype' in attributes and attributes['dtype'] == 'bool': del attributes['dtype'] diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index b80cb18e2be..1e7a09fa55a 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -753,6 +753,7 @@ def test_encoding_kwarg(self): with self.roundtrip(ds, save_kwargs=kwargs) as actual: pass + def test_encoding_kwarg_dates(self): ds = Dataset({'t': pd.date_range('2000-01-01', periods=3)}) units = 'days since 1900-01-01' kwargs = dict(encoding={'t': {'units': units}}) @@ -760,6 +761,18 @@ def test_encoding_kwarg(self): self.assertEqual(actual.t.encoding['units'], units) assert_identical(actual, ds) + def test_encoding_kwarg_fixed_width_string(self): + # regression test for GH2149 + for strings in [ + [b'foo', b'bar', b'baz'], + [u'foo', u'bar', u'baz'], + ]: + ds = Dataset({'x': strings}) + kwargs = dict(encoding={'x': {'dtype': 'S1'}}) + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + self.assertEqual(actual['x'].encoding['dtype'], 'S1') + assert_identical(actual, ds) + def test_default_fill_value(self): # Test default encoding for float: ds = Dataset({'x': ('y', np.arange(10.0))}) @@ -879,8 +892,8 @@ def create_tmp_files(nfiles, suffix='.nc', allow_cleanup_failure=False): yield files -@requires_netCDF4 class BaseNetCDF4Test(CFEncodedDataTest): + """Tests for both netCDF4-python and h5netcdf.""" engine = 'netcdf4' @@ -942,6 +955,18 @@ def test_write_groups(self): with self.open(tmp_file, group='data/2') as actual2: assert_identical(data2, actual2) + def test_encoding_kwarg_vlen_string(self): + for input_strings in [ + [b'foo', b'bar', b'baz'], + [u'foo', u'bar', u'baz'], + ]: + original = Dataset({'x': input_strings}) + expected = Dataset({'x': [u'foo', u'bar', u'baz']}) + kwargs = dict(encoding={'x': {'dtype': str}}) + with self.roundtrip(original, save_kwargs=kwargs) as actual: + assert actual['x'].encoding['dtype'] is str + assert_identical(actual, expected) + def test_roundtrip_string_with_fill_value_vlen(self): values = np.array([u'ab', u'cdef', np.nan], dtype=object) expected = Dataset({'x': ('t', values)}) @@ -1054,6 +1079,23 @@ def test_compression_encoding(self): with self.roundtrip(expected) as actual: assert_equal(expected, actual) + def test_encoding_kwarg_compression(self): + ds = Dataset({'x': np.arange(10.0)}) + encoding = dict(dtype='f4', zlib=True, complevel=9, fletcher32=True, + chunksizes=(5,), shuffle=True) + kwargs = dict(encoding=dict(x=encoding)) + + with self.roundtrip(ds, save_kwargs=kwargs) as actual: + assert_equal(actual, ds) + self.assertEqual(actual.x.encoding['dtype'], 'f4') + self.assertEqual(actual.x.encoding['zlib'], True) + self.assertEqual(actual.x.encoding['complevel'], 9) + self.assertEqual(actual.x.encoding['fletcher32'], True) + self.assertEqual(actual.x.encoding['chunksizes'], (5,)) + self.assertEqual(actual.x.encoding['shuffle'], True) + + self.assertEqual(ds.x.encoding, {}) + def test_encoding_chunksizes_unlimited(self): # regression test for GH1225 ds = Dataset({'x': [1, 2, 3], 'y': ('x', [2, 3, 4])}) @@ -1117,7 +1159,7 @@ def test_already_open_dataset(self): expected = Dataset({'x': ((), 42)}) assert_identical(expected, ds) - def test_variable_len_strings(self): + def test_read_variable_len_strings(self): with create_tmp_file() as tmp_file: values = np.array(['foo', 'bar', 'baz'], dtype=object) @@ -1410,6 +1452,10 @@ def test_group(self): open_kwargs={'group': group}) as actual: assert_identical(original, actual) + def test_encoding_kwarg_fixed_width_string(self): + # not relevant for zarr, since we don't use EncodedStringCoder + pass + # TODO: someone who understand caching figure out whether chaching # makes sense for Zarr backend @pytest.mark.xfail(reason="Zarr caching not implemented") @@ -1579,6 +1625,13 @@ def create_store(self): tmp_file, mode='w', format='NETCDF3_CLASSIC') as store: yield store + def test_encoding_kwarg_vlen_string(self): + original = Dataset({'x': [u'foo', u'bar', u'baz']}) + kwargs = dict(encoding={'x': {'dtype': str}}) + with raises_regex(ValueError, 'encoding dtype=str for vlen'): + with self.roundtrip(original, save_kwargs=kwargs): + pass + class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest): autoclose = True diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index acc1c978579..5ed482ed2bd 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -272,7 +272,14 @@ def test_roundtrip_coordinates(self): 'CFEncodedInMemoryStore') def test_invalid_dataarray_names_raise(self): + # only relevant for on-disk file formats pass def test_encoding_kwarg(self): + # we haven't bothered to raise errors yet for unexpected encodings in + # this test dummy + pass + + def test_encoding_kwarg_fixed_width_string(self): + # CFEncodedInMemoryStore doesn't support explicit string encodings. pass From cf19528d6d2baf988ad34e024cae28361c9fd693 Mon Sep 17 00:00:00 2001 From: barronh Date: Fri, 1 Jun 2018 00:21:43 -0400 Subject: [PATCH 43/61] Added PNC backend to xarray (#1905) * Added PNC backend to xarray PNC is used for GEOS-Chem, CAMx, CMAQ and other atmospheric data formats that have their own file formats and meta-data conventions. It can provide a CF compliant netCDF-like interface. * Added whats-new documentation * Updating pnc_ to remove DunderArrayMixin dependency * Adding basic tests for pnc Right now, pnc is simply being tested as a reader for NetCDF3 files * Updating for flake8 compliance * flake does not like unused e * Updating pnc to PseudoNetCDF * Remove outer except * Updating pnc to PseudoNetCDF * Added open and updated init Based on shoyer review * Updated indexing and test fix Indexing supports #1899 * Added PseudoNetCDF to doc/io.rst * Changing test subtype * Changing test subtype removing pdb * pnc test case requires netcdf3only For now, pnc is only supporting the classic data model * adding backend_kwargs default as dict This ensures **mapping is possible. * Upgrading tests to CFEncodedDataTest Some tests are bypassed. PseudoNetCDF string treatment is not currently compatible with xarray. This will be addressed soon. * Not currently supporting autoclose I do not fully understand the usecase, so I have not implemented these tests. * Minor updates for flake8 * Explicit skipping Using pytest.mark.skip to skip unsupported tests * removing trailing whitespace from pytest skip * Adding pip support * Addressing comments * Bypassing pickle, mask/scale, and object These tests cause errors that do not affect desired backend performance. * Added uamiv test PseudoNetCDF reads other formats. This adds a test of uamiv to the standard test for a backend and skips mask/scale, object, and boolean tests * Adding support for autoclose ensure open must be called before accessing variable data * Adding bakcend_kwargs to all backends Most backends currently take no keywords, so an empty ditionary is appropriate. * Small tweaks to PNC backend * remove warning and update whats-new * Separating isntall and io pnc doc and updating whats new * fixing line length in test * Tests now use non-netcdf files * Removing unknown meta-data netcdf support. * flake8 cleanup * Using python 2 and 3 compat testing * Disabling mask_and_scale by default prevents inadvertent double scaling in PNC formats * consistent with 3.0.0 Updates in 3.0.1 will fix close in uamiv. * Updating readers and line length * Updating readers and line length * Updating readers and line length * Adding open_mfdataset test Testing by opening same file twice and stacking it. * Using conda version of PseudoNetCDF * Removing xfail for netcdf Mask and scale with PseudoNetCDF and NetCDF4 is not supported, but not prevented. * Moving pseudonetcdf to v0.15 * Updating what's new * Fixing open_dataarray CF options mask_and_scale is None (diagnosed by open_dataset) and decode_cf should be True --- ci/requirements-py36.yml | 1 + doc/installing.rst | 7 +- doc/io.rst | 23 ++- doc/whats-new.rst | 4 + xarray/backends/__init__.py | 2 + xarray/backends/api.py | 55 ++++++-- xarray/backends/pseudonetcdf_.py | 101 ++++++++++++++ xarray/tests/__init__.py | 1 + xarray/tests/data/example.ict | 31 +++++ xarray/tests/data/example.uamiv | Bin 0 -> 608 bytes xarray/tests/test_backends.py | 232 ++++++++++++++++++++++++++++++- 11 files changed, 440 insertions(+), 17 deletions(-) create mode 100644 xarray/backends/pseudonetcdf_.py create mode 100644 xarray/tests/data/example.ict create mode 100644 xarray/tests/data/example.uamiv diff --git a/ci/requirements-py36.yml b/ci/requirements-py36.yml index 0790f20764d..fd63fe26130 100644 --- a/ci/requirements-py36.yml +++ b/ci/requirements-py36.yml @@ -20,6 +20,7 @@ dependencies: - rasterio - bottleneck - zarr + - pseudonetcdf>=3.0.1 - pip: - coveralls - pytest-cov diff --git a/doc/installing.rst b/doc/installing.rst index bb42129deea..33f01b8c770 100644 --- a/doc/installing.rst +++ b/doc/installing.rst @@ -28,6 +28,9 @@ For netCDF and IO - `cftime `__: recommended if you want to encode/decode datetimes for non-standard calendars or dates before year 1678 or after year 2262. +- `PseudoNetCDF `__: recommended + for accessing CAMx, GEOS-Chem (bpch), NOAA ARL files, ICARTT files + (ffi1001) and many other. For accelerating xarray ~~~~~~~~~~~~~~~~~~~~~~~ @@ -65,9 +68,9 @@ with its recommended dependencies using the conda command line tool:: .. _conda: http://conda.io/ -We recommend using the community maintained `conda-forge `__ channel if you need difficult\-to\-build dependencies such as cartopy or pynio:: +We recommend using the community maintained `conda-forge `__ channel if you need difficult\-to\-build dependencies such as cartopy, pynio or PseudoNetCDF:: - $ conda install -c conda-forge xarray cartopy pynio + $ conda install -c conda-forge xarray cartopy pynio pseudonetcdf New releases may also appear in conda-forge before being updated in the default channel. diff --git a/doc/io.rst b/doc/io.rst index 668416e714d..e92ecd01cb4 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -650,7 +650,26 @@ We recommend installing PyNIO via conda:: .. _PyNIO: https://www.pyngl.ucar.edu/Nio.shtml -.. _combining multiple files: +.. _io.PseudoNetCDF: + +Formats supported by PseudoNetCDF +--------------------------------- + +xarray can also read CAMx, BPCH, ARL PACKED BIT, and many other file +formats supported by PseudoNetCDF_, if PseudoNetCDF is installed. +PseudoNetCDF can also provide Climate Forecasting Conventions to +CMAQ files. In addition, PseudoNetCDF can automatically register custom +readers that subclass PseudoNetCDF.PseudoNetCDFFile. PseudoNetCDF can +identify readers heuristically, or format can be specified via a key in +`backend_kwargs`. + +To use PseudoNetCDF to read such files, supply +``engine='pseudonetcdf'`` to :py:func:`~xarray.open_dataset`. + +Add ``backend_kwargs={'format': ''}`` where `` +options are listed on the PseudoNetCDF page. + +.. _PseuodoNetCDF: http://github.com/barronh/PseudoNetCDF Formats supported by Pandas @@ -662,6 +681,8 @@ exporting your objects to pandas and using its broad range of `IO tools`_. .. _IO tools: http://pandas.pydata.org/pandas-docs/stable/io.html +.. _combining multiple files: + Combining multiple files ------------------------ diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c4c8db243d4..bfa24340bcd 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -41,6 +41,10 @@ Enhancements dask<0.17.4. (related to :issue:`2203`) By `Keisuke Fujii `_. + - :py:meth:`~DataArray.cumsum` and :py:meth:`~DataArray.cumprod` now support aggregation over multiple dimensions at the same time. This is the default behavior when dimensions are not specified (previously this raised an error). diff --git a/xarray/backends/__init__.py b/xarray/backends/__init__.py index d85893afb0b..47a2011a3af 100644 --- a/xarray/backends/__init__.py +++ b/xarray/backends/__init__.py @@ -10,6 +10,7 @@ from .pynio_ import NioDataStore from .scipy_ import ScipyDataStore from .h5netcdf_ import H5NetCDFStore +from .pseudonetcdf_ import PseudoNetCDFDataStore from .zarr import ZarrStore __all__ = [ @@ -21,4 +22,5 @@ 'ScipyDataStore', 'H5NetCDFStore', 'ZarrStore', + 'PseudoNetCDFDataStore', ] diff --git a/xarray/backends/api.py b/xarray/backends/api.py index c3b2aa59fcd..753f8394a7b 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -152,9 +152,10 @@ def _finalize_store(write, store): def open_dataset(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=True, decode_times=True, autoclose=False, + mask_and_scale=None, decode_times=True, autoclose=False, concat_characters=True, decode_coords=True, engine=None, - chunks=None, lock=None, cache=None, drop_variables=None): + chunks=None, lock=None, cache=None, drop_variables=None, + backend_kwargs=None): """Load and decode a dataset from a file or file-like object. Parameters @@ -178,7 +179,8 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, taken from variable attributes (if they exist). If the `_FillValue` or `missing_value` attribute contains multiple values a warning will be issued and all array values matching one of the multiple values will - be replaced by NA. + be replaced by NA. mask_and_scale defaults to True except for the + pseudonetcdf backend. decode_times : bool, optional If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. @@ -194,7 +196,7 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, decode_coords : bool, optional If True, decode the 'coordinates' attribute to identify coordinates in the resulting dataset. - engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio'}, optional + engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio', 'pseudonetcdf'}, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4'. @@ -219,6 +221,10 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or inconsistent values. + backend_kwargs: dictionary, optional + A dictionary of keyword arguments to pass on to the backend. This + may be useful when backend options would improve performance or + allow user control of dataset processing. Returns ------- @@ -229,6 +235,10 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, -------- open_mfdataset """ + + if mask_and_scale is None: + mask_and_scale = not engine == 'pseudonetcdf' + if not decode_cf: mask_and_scale = False decode_times = False @@ -238,6 +248,9 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, if cache is None: cache = chunks is None + if backend_kwargs is None: + backend_kwargs = {} + def maybe_decode_store(store, lock=False): ds = conventions.decode_cf( store, mask_and_scale=mask_and_scale, decode_times=decode_times, @@ -303,18 +316,26 @@ def maybe_decode_store(store, lock=False): if engine == 'netcdf4': store = backends.NetCDF4DataStore.open(filename_or_obj, group=group, - autoclose=autoclose) + autoclose=autoclose, + **backend_kwargs) elif engine == 'scipy': store = backends.ScipyDataStore(filename_or_obj, - autoclose=autoclose) + autoclose=autoclose, + **backend_kwargs) elif engine == 'pydap': - store = backends.PydapDataStore.open(filename_or_obj) + store = backends.PydapDataStore.open(filename_or_obj, + **backend_kwargs) elif engine == 'h5netcdf': store = backends.H5NetCDFStore(filename_or_obj, group=group, - autoclose=autoclose) + autoclose=autoclose, + **backend_kwargs) elif engine == 'pynio': store = backends.NioDataStore(filename_or_obj, - autoclose=autoclose) + autoclose=autoclose, + **backend_kwargs) + elif engine == 'pseudonetcdf': + store = backends.PseudoNetCDFDataStore.open( + filename_or_obj, autoclose=autoclose, **backend_kwargs) else: raise ValueError('unrecognized engine for open_dataset: %r' % engine) @@ -334,9 +355,10 @@ def maybe_decode_store(store, lock=False): def open_dataarray(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=True, decode_times=True, autoclose=False, + mask_and_scale=None, decode_times=True, autoclose=False, concat_characters=True, decode_coords=True, engine=None, - chunks=None, lock=None, cache=None, drop_variables=None): + chunks=None, lock=None, cache=None, drop_variables=None, + backend_kwargs=None): """Open an DataArray from a netCDF file containing a single data variable. This is designed to read netCDF files with only one data variable. If @@ -363,7 +385,8 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, taken from variable attributes (if they exist). If the `_FillValue` or `missing_value` attribute contains multiple values a warning will be issued and all array values matching one of the multiple values will - be replaced by NA. + be replaced by NA. mask_and_scale defaults to True except for the + pseudonetcdf backend. decode_times : bool, optional If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. @@ -403,6 +426,10 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or inconsistent values. + backend_kwargs: dictionary, optional + A dictionary of keyword arguments to pass on to the backend. This + may be useful when backend options would improve performance or + allow user control of dataset processing. Notes ----- @@ -417,13 +444,15 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, -------- open_dataset """ + dataset = open_dataset(filename_or_obj, group=group, decode_cf=decode_cf, mask_and_scale=mask_and_scale, decode_times=decode_times, autoclose=autoclose, concat_characters=concat_characters, decode_coords=decode_coords, engine=engine, chunks=chunks, lock=lock, cache=cache, - drop_variables=drop_variables) + drop_variables=drop_variables, + backend_kwargs=backend_kwargs) if len(dataset.data_vars) != 1: raise ValueError('Given file dataset contains more than one data ' diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py new file mode 100644 index 00000000000..c481bf848b9 --- /dev/null +++ b/xarray/backends/pseudonetcdf_.py @@ -0,0 +1,101 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import numpy as np + +from .. import Variable +from ..core.pycompat import OrderedDict +from ..core.utils import (FrozenOrderedDict, Frozen) +from ..core import indexing + +from .common import AbstractDataStore, DataStorePickleMixin, BackendArray + + +class PncArrayWrapper(BackendArray): + + def __init__(self, variable_name, datastore): + self.datastore = datastore + self.variable_name = variable_name + array = self.get_array() + self.shape = array.shape + self.dtype = np.dtype(array.dtype) + + def get_array(self): + self.datastore.assert_open() + return self.datastore.ds.variables[self.variable_name] + + def __getitem__(self, key): + key, np_inds = indexing.decompose_indexer( + key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR) + + with self.datastore.ensure_open(autoclose=True): + array = self.get_array()[key.tuple] # index backend array + + if len(np_inds.tuple) > 0: + # index the loaded np.ndarray + array = indexing.NumpyIndexingAdapter(array)[np_inds] + return array + + +class PseudoNetCDFDataStore(AbstractDataStore, DataStorePickleMixin): + """Store for accessing datasets via PseudoNetCDF + """ + @classmethod + def open(cls, filename, format=None, writer=None, + autoclose=False, **format_kwds): + from PseudoNetCDF import pncopen + opener = functools.partial(pncopen, filename, **format_kwds) + ds = opener() + mode = format_kwds.get('mode', 'r') + return cls(ds, mode=mode, writer=writer, opener=opener, + autoclose=autoclose) + + def __init__(self, pnc_dataset, mode='r', writer=None, opener=None, + autoclose=False): + + if autoclose and opener is None: + raise ValueError('autoclose requires an opener') + + self._ds = pnc_dataset + self._autoclose = autoclose + self._isopen = True + self._opener = opener + self._mode = mode + super(PseudoNetCDFDataStore, self).__init__() + + def open_store_variable(self, name, var): + with self.ensure_open(autoclose=False): + data = indexing.LazilyOuterIndexedArray( + PncArrayWrapper(name, self) + ) + attrs = OrderedDict((k, getattr(var, k)) for k in var.ncattrs()) + return Variable(var.dimensions, data, attrs) + + def get_variables(self): + with self.ensure_open(autoclose=False): + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.items()) + + def get_attrs(self): + with self.ensure_open(autoclose=True): + return Frozen(dict([(k, getattr(self.ds, k)) + for k in self.ds.ncattrs()])) + + def get_dimensions(self): + with self.ensure_open(autoclose=True): + return Frozen(self.ds.dimensions) + + def get_encoding(self): + encoding = {} + encoding['unlimited_dims'] = set( + [k for k in self.ds.dimensions + if self.ds.dimensions[k].isunlimited()]) + return encoding + + def close(self): + if self._isopen: + self.ds.close() + self._isopen = False diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 3acd26235ce..e93d9a80145 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -68,6 +68,7 @@ def _importorskip(modname, minversion=None): has_netCDF4, requires_netCDF4 = _importorskip('netCDF4') has_h5netcdf, requires_h5netcdf = _importorskip('h5netcdf') has_pynio, requires_pynio = _importorskip('Nio') +has_pseudonetcdf, requires_pseudonetcdf = _importorskip('PseudoNetCDF') has_cftime, requires_cftime = _importorskip('cftime') has_dask, requires_dask = _importorskip('dask') has_bottleneck, requires_bottleneck = _importorskip('bottleneck') diff --git a/xarray/tests/data/example.ict b/xarray/tests/data/example.ict new file mode 100644 index 00000000000..bc04888fb80 --- /dev/null +++ b/xarray/tests/data/example.ict @@ -0,0 +1,31 @@ +27, 1001 +Henderson, Barron +U.S. EPA +Example file with artificial data +JUST_A_TEST +1, 1 +2018, 04, 27, 2018, 04, 27 +0 +Start_UTC +7 +1, 1, 1, 1, 1 +-9999, -9999, -9999, -9999, -9999 +lat, degrees_north +lon, degrees_east +elev, meters +TEST_ppbv, ppbv +TESTM_ppbv, ppbv +0 +8 +ULOD_FLAG: -7777 +ULOD_VALUE: N/A +LLOD_FLAG: -8888 +LLOD_VALUE: N/A, N/A, N/A, N/A, 0.025 +OTHER_COMMENTS: www-air.larc.nasa.gov/missions/etc/IcarttDataFormat.htm +REVISION: R0 +R0: No comments for this revision. +Start_UTC, lat, lon, elev, TEST_ppbv, TESTM_ppbv +43200, 41.00000, -71.00000, 5, 1.2345, 2.220 +46800, 42.00000, -72.00000, 15, 2.3456, -9999 +50400, 42.00000, -73.00000, 20, 3.4567, -7777 +50400, 42.00000, -74.00000, 25, 4.5678, -8888 \ No newline at end of file diff --git a/xarray/tests/data/example.uamiv b/xarray/tests/data/example.uamiv new file mode 100644 index 0000000000000000000000000000000000000000..fcedcd53097122839b5b94d1fabd2cb70d7c003e GIT binary patch literal 608 zcmb8rv1$TA5XSL2h+tviBU~jm38#rx0dEtcl_(OdQiP~rL`g_OlET6=WlBpQ#a5rf zN6G)YTSY{W4E%29cK2pS&4S2zd|YF)5KZjoR;gEOedvC#K<_&avzwN`8~gXTIF xI-B-6oH6M=RsVnVGX1`ok76FN>IIhAm^lN}xe)vVE=C)Vc*P7q_{H25(?3@UPTv3k literal 0 HcmV?d00001 diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 1e7a09fa55a..0e6151b2db5 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -32,7 +32,7 @@ assert_identical, has_dask, has_netCDF4, has_scipy, network, raises_regex, requires_dask, requires_h5netcdf, requires_netCDF4, requires_pathlib, requires_pydap, requires_pynio, requires_rasterio, requires_scipy, - requires_scipy_or_netCDF4, requires_zarr, + requires_scipy_or_netCDF4, requires_zarr, requires_pseudonetcdf, requires_cftime) from .test_dataset import create_test_data @@ -63,6 +63,13 @@ def open_example_dataset(name, *args, **kwargs): *args, **kwargs) +def open_example_mfdataset(names, *args, **kwargs): + return open_mfdataset( + [os.path.join(os.path.dirname(__file__), 'data', name) + for name in names], + *args, **kwargs) + + def create_masked_and_scaled_data(): x = np.array([np.nan, np.nan, 10, 10.1, 10.2], dtype=np.float32) encoding = {'_FillValue': -1, 'add_offset': 10, @@ -2483,6 +2490,229 @@ class PyNioTestAutocloseTrue(PyNioTest): autoclose = True +@requires_pseudonetcdf +class PseudoNetCDFFormatTest(TestCase): + autoclose = True + + def open(self, path, **kwargs): + return open_dataset(path, engine='pseudonetcdf', + autoclose=self.autoclose, + **kwargs) + + @contextlib.contextmanager + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as path: + self.save(data, path, **save_kwargs) + with self.open(path, **open_kwargs) as ds: + yield ds + + def test_ict_format(self): + """ + Open a CAMx file and test data variables + """ + ictfile = open_example_dataset('example.ict', + engine='pseudonetcdf', + autoclose=False, + backend_kwargs={'format': 'ffi1001'}) + stdattr = { + 'fill_value': -9999.0, + 'missing_value': -9999, + 'scale': 1, + 'llod_flag': -8888, + 'llod_value': 'N/A', + 'ulod_flag': -7777, + 'ulod_value': 'N/A' + } + + def myatts(**attrs): + outattr = stdattr.copy() + outattr.update(attrs) + return outattr + + input = { + 'coords': {}, + 'attrs': { + 'fmt': '1001', 'n_header_lines': 27, + 'PI_NAME': 'Henderson, Barron', + 'ORGANIZATION_NAME': 'U.S. EPA', + 'SOURCE_DESCRIPTION': 'Example file with artificial data', + 'MISSION_NAME': 'JUST_A_TEST', + 'VOLUME_INFO': '1, 1', + 'SDATE': '2018, 04, 27', 'WDATE': '2018, 04, 27', + 'TIME_INTERVAL': '0', + 'INDEPENDENT_VARIABLE': 'Start_UTC', + 'ULOD_FLAG': '-7777', 'ULOD_VALUE': 'N/A', + 'LLOD_FLAG': '-8888', + 'LLOD_VALUE': ('N/A, N/A, N/A, N/A, 0.025'), + 'OTHER_COMMENTS': ('www-air.larc.nasa.gov/missions/etc/' + + 'IcarttDataFormat.htm'), + 'REVISION': 'R0', + 'R0': 'No comments for this revision.', + 'TFLAG': 'Start_UTC' + }, + 'dims': {'POINTS': 4}, + 'data_vars': { + 'Start_UTC': { + 'data': [43200.0, 46800.0, 50400.0, 50400.0], + 'dims': ('POINTS',), + 'attrs': myatts( + units='Start_UTC', + standard_name='Start_UTC', + ) + }, + 'lat': { + 'data': [41.0, 42.0, 42.0, 42.0], + 'dims': ('POINTS',), + 'attrs': myatts( + units='degrees_north', + standard_name='lat', + ) + }, + 'lon': { + 'data': [-71.0, -72.0, -73.0, -74.], + 'dims': ('POINTS',), + 'attrs': myatts( + units='degrees_east', + standard_name='lon', + ) + }, + 'elev': { + 'data': [5.0, 15.0, 20.0, 25.0], + 'dims': ('POINTS',), + 'attrs': myatts( + units='meters', + standard_name='elev', + ) + }, + 'TEST_ppbv': { + 'data': [1.2345, 2.3456, 3.4567, 4.5678], + 'dims': ('POINTS',), + 'attrs': myatts( + units='ppbv', + standard_name='TEST_ppbv', + ) + }, + 'TESTM_ppbv': { + 'data': [2.22, -9999.0, -7777.0, -8888.0], + 'dims': ('POINTS',), + 'attrs': myatts( + units='ppbv', + standard_name='TESTM_ppbv', + llod_value=0.025 + ) + } + } + } + chkfile = Dataset.from_dict(input) + assert_identical(ictfile, chkfile) + + def test_ict_format_write(self): + fmtkw = {'format': 'ffi1001'} + expected = open_example_dataset('example.ict', + engine='pseudonetcdf', + autoclose=False, + backend_kwargs=fmtkw) + with self.roundtrip(expected, save_kwargs=fmtkw, + open_kwargs={'backend_kwargs': fmtkw}) as actual: + assert_identical(expected, actual) + + def test_uamiv_format_read(self): + """ + Open a CAMx file and test data variables + """ + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=UserWarning, + message=('IOAPI_ISPH is assumed to be ' + + '6370000.; consistent with WRF')) + camxfile = open_example_dataset('example.uamiv', + engine='pseudonetcdf', + autoclose=True, + backend_kwargs={'format': 'uamiv'}) + data = np.arange(20, dtype='f').reshape(1, 1, 4, 5) + expected = xr.Variable(('TSTEP', 'LAY', 'ROW', 'COL'), data, + dict(units='ppm', long_name='O3'.ljust(16), + var_desc='O3'.ljust(80))) + actual = camxfile.variables['O3'] + assert_allclose(expected, actual) + + data = np.array(['2002-06-03'], 'datetime64[ns]') + expected = xr.Variable(('TSTEP',), data, + dict(bounds='time_bounds', + long_name=('synthesized time coordinate ' + + 'from SDATE, STIME, STEP ' + + 'global attributes'))) + actual = camxfile.variables['time'] + assert_allclose(expected, actual) + camxfile.close() + + def test_uamiv_format_mfread(self): + """ + Open a CAMx file and test data variables + """ + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=UserWarning, + message=('IOAPI_ISPH is assumed to be ' + + '6370000.; consistent with WRF')) + camxfile = open_example_mfdataset( + ['example.uamiv', + 'example.uamiv'], + engine='pseudonetcdf', + autoclose=True, + concat_dim='TSTEP', + backend_kwargs={'format': 'uamiv'}) + + data1 = np.arange(20, dtype='f').reshape(1, 1, 4, 5) + data = np.concatenate([data1] * 2, axis=0) + expected = xr.Variable(('TSTEP', 'LAY', 'ROW', 'COL'), data, + dict(units='ppm', long_name='O3'.ljust(16), + var_desc='O3'.ljust(80))) + actual = camxfile.variables['O3'] + assert_allclose(expected, actual) + + data1 = np.array(['2002-06-03'], 'datetime64[ns]') + data = np.concatenate([data1] * 2, axis=0) + expected = xr.Variable(('TSTEP',), data, + dict(bounds='time_bounds', + long_name=('synthesized time coordinate ' + + 'from SDATE, STIME, STEP ' + + 'global attributes'))) + actual = camxfile.variables['time'] + assert_allclose(expected, actual) + camxfile.close() + + def test_uamiv_format_write(self): + fmtkw = {'format': 'uamiv'} + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=UserWarning, + message=('IOAPI_ISPH is assumed to be ' + + '6370000.; consistent with WRF')) + expected = open_example_dataset('example.uamiv', + engine='pseudonetcdf', + autoclose=False, + backend_kwargs=fmtkw) + with self.roundtrip(expected, + save_kwargs=fmtkw, + open_kwargs={'backend_kwargs': fmtkw}) as actual: + assert_identical(expected, actual) + + def save(self, dataset, path, **save_kwargs): + import PseudoNetCDF as pnc + pncf = pnc.PseudoNetCDFFile() + pncf.dimensions = {k: pnc.PseudoNetCDFDimension(pncf, k, v) + for k, v in dataset.dims.items()} + pncf.variables = {k: pnc.PseudoNetCDFVariable(pncf, k, v.dtype.char, + v.dims, + values=v.data[...], + **v.attrs) + for k, v in dataset.variables.items()} + for pk, pv in dataset.attrs.items(): + setattr(pncf, pk, pv) + + pnc.pncwrite(pncf, path, **save_kwargs) + + @requires_rasterio @contextlib.contextmanager def create_tmp_geotiff(nx=4, ny=3, nz=3, From ad47ced88c1c99fd961617943e02613b67c9cea9 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 31 May 2018 22:17:25 -0700 Subject: [PATCH 44/61] Release v0.10.5 --- doc/whats-new.rst | 54 ++++++++++++++++++++++++----------------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index bfa24340bcd..4e4ed20a093 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -28,36 +28,25 @@ What's New .. _whats-new.0.10.5: -v0.10.5 (unreleased) --------------------- +v0.10.5 (31 May 2018) +--------------------- -Documentation -~~~~~~~~~~~~~ +The minor release includes a number of bug-fixes and backwards compatible +enhancements. Enhancements ~~~~~~~~~~~~ -- `:py:meth:`~DataArray.dot` and :py:func:`~dot` are partly supported with older - dask<0.17.4. (related to :issue:`2203`) - By `Keisuke Fujii `_. -- :py:meth:`~DataArray.cumsum` and :py:meth:`~DataArray.cumprod` now support - aggregation over multiple dimensions at the same time. This is the default - behavior when dimensions are not specified (previously this raised an error). - By `Stephan Hoyer `_ - -- Xarray now uses `Versioneer `__ - to manage its version strings. (:issue:`1300`). - By `Joe Hamman `_. - -- `:py:class:`Dataset`s align `:py:class:`DataArray`s to coords that are explicitly - passed into the constructor, where previously an error would be raised. +- The :py:class:`Dataset` constructor now aligns :py:class:`DataArray` + arguments in ``data_vars`` to indexes set explicitly in ``coords``, + where previously an error would be raised. (:issue:`674`) - By `Maximilian Roos `_. - :py:meth:`~DataArray.sel`, :py:meth:`~DataArray.isel` & :py:meth:`~DataArray.reindex`, (and their :py:class:`Dataset` counterparts) now support supplying a ``dict`` @@ -67,11 +56,24 @@ Enhancements not strings. By `Maximilian Roos `_. -- :py:meth:`~DataArray.rename` now supports supplying `kwargs`, as an +- :py:meth:`~DataArray.rename` now supports supplying ``**kwargs``, as an alternative to the existing approach of supplying a ``dict`` as the first argument. By `Maximilian Roos `_. +- :py:meth:`~DataArray.cumsum` and :py:meth:`~DataArray.cumprod` now support + aggregation over multiple dimensions at the same time. This is the default + behavior when dimensions are not specified (previously this raised an error). + By `Stephan Hoyer `_ + +- :py:meth:`~DataArray.dot` and :py:func:`~dot` are partly supported with older + dask<0.17.4. (related to :issue:`2203`) + By `Keisuke Fujii `_. + +- Xarray now uses `Versioneer `__ + to manage its version strings. (:issue:`1300`). + By `Joe Hamman `_. + Bug fixes ~~~~~~~~~ @@ -114,13 +116,13 @@ Bug fixes (:issue:`2153`). By `Stephan Hoyer `_ -- Fix Dataset.to_netcdf() cannot create group with engine="h5netcdf" +- Fix ``Dataset.to_netcdf()`` cannot create group with ``engine="h5netcdf"`` (:issue:`2177`). By `Stephan Hoyer `_ .. _whats-new.0.10.4: -v0.10.4 (May 16, 2018) +v0.10.4 (16 May 2018) ---------------------- The minor release includes a number of bug-fixes and backwards compatible @@ -208,7 +210,7 @@ Bug fixes .. _whats-new.0.10.3: -v0.10.3 (April 13, 2018) +v0.10.3 (13 April 2018) ------------------------ The minor release includes a number of bug-fixes and backwards compatible enhancements. From a3cf251cc04ed5731912171a6c2b63f8927e610e Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 31 May 2018 22:20:52 -0700 Subject: [PATCH 45/61] Add whats-new for v0.10.6 --- doc/whats-new.rst | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4e4ed20a093..730ba20b850 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -26,6 +26,20 @@ What's New - `Tips on porting to Python 3 `__ +.. _whats-new.0.10.6: + +v0.10.6 (unreleased) +-------------------- + +Documentation +~~~~~~~~~~~~~ + +Enhancements +~~~~~~~~~~~~ + +Bug fixes +~~~~~~~~~ + .. _whats-new.0.10.5: v0.10.5 (31 May 2018) From 1c9b4b2e556b81f1e668ae7aa3aaea8aa91b7983 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 31 May 2018 22:45:58 -0700 Subject: [PATCH 46/61] Fix versioneer, release v0.10.6 --- doc/whats-new.rst | 17 +---------------- setup.cfg | 2 +- xarray/_version.py | 2 +- 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 730ba20b850..ad88289e43f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,24 +25,9 @@ What's New - `Python 3 Statement `__ - `Tips on porting to Python 3 `__ - .. _whats-new.0.10.6: -v0.10.6 (unreleased) --------------------- - -Documentation -~~~~~~~~~~~~~ - -Enhancements -~~~~~~~~~~~~ - -Bug fixes -~~~~~~~~~ - -.. _whats-new.0.10.5: - -v0.10.5 (31 May 2018) +v0.10.6 (31 May 2018) --------------------- The minor release includes a number of bug-fixes and backwards compatible diff --git a/setup.cfg b/setup.cfg index 4dd1bffe043..17f24b3f1ce 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,7 +22,7 @@ VCS = git style = pep440 versionfile_source = xarray/_version.py versionfile_build = xarray/_version.py -tag_prefix = +tag_prefix = v parentdir_prefix = xarray- [aliases] diff --git a/xarray/_version.py b/xarray/_version.py index 2fa32b69798..df4ee95ade4 100644 --- a/xarray/_version.py +++ b/xarray/_version.py @@ -41,7 +41,7 @@ def get_config(): cfg = VersioneerConfig() cfg.VCS = "git" cfg.style = "pep440" - cfg.tag_prefix = "" + cfg.tag_prefix = "v" cfg.parentdir_prefix = "xarray-" cfg.versionfile_source = "xarray/_version.py" cfg.verbose = False From 16ef0cf508d99b4154f41004992e648eb6cd6eb9 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 31 May 2018 22:51:57 -0700 Subject: [PATCH 47/61] Add whats-new for v0.10.7 --- doc/whats-new.rst | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ad88289e43f..ea1c2114a41 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,6 +25,20 @@ What's New - `Python 3 Statement `__ - `Tips on porting to Python 3 `__ +.. _whats-new.0.10.7: + +v0.10.7 (unreleased) +-------------------- + +Documentation +~~~~~~~~~~~~~ + +Enhancements +~~~~~~~~~~~~ + +Bug fixes +~~~~~~~~~ + .. _whats-new.0.10.6: v0.10.6 (31 May 2018) From 1c37d9ce526fecb9fdab2c82b3f46be06f55a128 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 1 Jun 2018 12:15:49 -0400 Subject: [PATCH 48/61] Remove height=12in from facetgrid example plots. (#2210) --- doc/plotting.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/plotting.rst b/doc/plotting.rst index 28fbe7062a6..b10f0e7fc64 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -436,7 +436,7 @@ arguments to the xarray plotting methods/functions. This returns a .. ipython:: python - @savefig plot_facet_dataarray.png height=12in + @savefig plot_facet_dataarray.png g_simple = t.plot(x='lon', y='lat', col='time', col_wrap=3) 4 dimensional @@ -454,7 +454,7 @@ one were much hotter. # This is a 4d array t4d.coords - @savefig plot_facet_4d.png height=12in + @savefig plot_facet_4d.png t4d.plot(x='lon', y='lat', col='time', row='fourth_dim') Other features @@ -468,7 +468,7 @@ Faceted plotting supports other arguments common to xarray 2d plots. hasoutliers[0, 0, 0] = -100 hasoutliers[-1, -1, -1] = 400 - @savefig plot_facet_robust.png height=12in + @savefig plot_facet_robust.png g = hasoutliers.plot.pcolormesh('lon', 'lat', col='time', col_wrap=3, robust=True, cmap='viridis') @@ -509,7 +509,7 @@ they have been plotted. bottomright = g.axes[-1, -1] bottomright.annotate('bottom right', (240, 40)) - @savefig plot_facet_iterator.png height=12in + @savefig plot_facet_iterator.png plt.show() TODO: add an example of using the ``map`` method to plot dataset variables From 1e6984b247a08c5c98baa808dcc7552eb1c372a0 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 1 Jun 2018 20:10:25 -0400 Subject: [PATCH 49/61] Plot labels use CF convention information if available. (#2151) * Plot labels use CF convention information if available. Uses attrs long_name/standard_name, units if available. * Follow review feedback. * More informative docs. * Minor edit. --- doc/plotting.rst | 15 +++++- doc/whats-new.rst | 2 + xarray/plot/facetgrid.py | 27 +++++++---- xarray/plot/plot.py | 26 ++++++----- xarray/plot/utils.py | 22 +++++++++ xarray/tests/test_plot.py | 96 +++++++++++++++++++++++++++++---------- 6 files changed, 142 insertions(+), 46 deletions(-) diff --git a/doc/plotting.rst b/doc/plotting.rst index b10f0e7fc64..fa364d4838e 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -66,6 +66,13 @@ For these examples we'll use the North American air temperature dataset. # Convert to celsius air = airtemps.air - 273.15 + # copy attributes to get nice figure labels and change Kelvin to Celsius + air.attrs = airtemps.air.attrs + air.attrs['units'] = 'deg C' + +.. note:: + Until :issue:`1614` is solved, you might need to copy over the metadata in ``attrs`` to get informative figure labels (as was done above). + One Dimension ------------- @@ -73,7 +80,7 @@ One Dimension Simple Example ~~~~~~~~~~~~~~ -xarray uses the coordinate name to label the x axis. +The simplest way to make a plot is to call the :py:func:`xarray.DataArray.plot()` method. .. ipython:: python @@ -82,6 +89,12 @@ xarray uses the coordinate name to label the x axis. @savefig plotting_1d_simple.png width=4in air1d.plot() +xarray uses the coordinate name along with metadata ``attrs.long_name``, ``attrs.standard_name``, ``DataArray.name`` and ``attrs.units`` (if available) to label the axes. The names ``long_name``, ``standard_name`` and ``units`` are copied from the `CF-conventions spec `_. When choosing names, the order of precedence is ``long_name``, ``standard_name`` and finally ``DataArray.name``. The y-axis label in the above plot was constructed from the ``long_name`` and ``units`` attributes of ``air1d``. + +.. ipython:: python + + air1d.attrs + Additional Arguments ~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ea1c2114a41..749726ee190 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -35,6 +35,8 @@ Documentation Enhancements ~~~~~~~~~~~~ +- Plot labels now make use of metadata that follow CF conventions. + By `Deepak Cherian `_ and `Ryan Abernathey `_. Bug fixes ~~~~~~~~~ diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 5abae214c9f..361b47262c9 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -9,7 +9,8 @@ from ..core.formatting import format_item from ..core.pycompat import getargspec from .utils import ( - _determine_cmap_params, _infer_xy_labels, import_matplotlib_pyplot) + _determine_cmap_params, _infer_xy_labels, import_matplotlib_pyplot, + label_from_attrs) # Overrides axes.labelsize, xtick.major.size, ytick.major.size # from mpl.rcParams @@ -282,8 +283,8 @@ def add_colorbar(self, **kwargs): kwargs = kwargs.copy() if self._cmap_extend is not None: kwargs.setdefault('extend', self._cmap_extend) - if getattr(self.data, 'name', None) is not None: - kwargs.setdefault('label', self.data.name) + if 'label' not in kwargs: + kwargs.setdefault('label', label_from_attrs(self.data)) self.cbar = self.fig.colorbar(self._mappables[-1], ax=list(self.axes.flat), **kwargs) @@ -292,17 +293,25 @@ def add_colorbar(self, **kwargs): def set_axis_labels(self, x_var=None, y_var=None): """Set axis labels on the left column and bottom row of the grid.""" if x_var is not None: - self._x_var = x_var - self.set_xlabels(x_var) + if x_var in self.data.coords: + self._x_var = x_var + self.set_xlabels(label_from_attrs(self.data[x_var])) + else: + # x_var is a string + self.set_xlabels(x_var) + if y_var is not None: - self._y_var = y_var - self.set_ylabels(y_var) + if y_var in self.data.coords: + self._y_var = y_var + self.set_ylabels(label_from_attrs(self.data[y_var])) + else: + self.set_ylabels(y_var) return self def set_xlabels(self, label=None, **kwargs): """Label the x axis on the bottom row of the grid.""" if label is None: - label = self._x_var + label = label_from_attrs(self.data[self._x_var]) for ax in self._bottom_axes: ax.set_xlabel(label, **kwargs) return self @@ -310,7 +319,7 @@ def set_xlabels(self, label=None, **kwargs): def set_ylabels(self, label=None, **kwargs): """Label the y axis on the left column of the grid.""" if label is None: - label = self._y_var + label = label_from_attrs(self.data[self._y_var]) for ax in self._left_axes: ax.set_ylabel(label, **kwargs) return self diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index ee1df611d3b..f49ec5c52d9 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -20,7 +20,7 @@ from .facetgrid import FacetGrid from .utils import ( ROBUST_PERCENTILE, _determine_cmap_params, _infer_xy_labels, get_axis, - import_matplotlib_pyplot) + import_matplotlib_pyplot, label_from_attrs) def _valid_numpy_subdtype(x, numpy_types): @@ -240,14 +240,10 @@ def line(darray, *args, **kwargs): if (x is None and y is None) or x == dim: xplt = darray.coords[dim] yplt = darray - xlabel = dim - ylabel = darray.name else: yplt = darray.coords[dim] xplt = darray - xlabel = darray.name - ylabel = dim else: if x is None and y is None and hue is None: @@ -265,6 +261,12 @@ def line(darray, *args, **kwargs): xplt = darray.transpose(ylabel, huelabel) yplt = darray.coords[ylabel] + huecoords = darray[huelabel] + huelabel = label_from_attrs(huecoords) + + xlabel = label_from_attrs(xplt) + ylabel = label_from_attrs(yplt) + _ensure_plottable(xplt) primitive = ax.plot(xplt, yplt, *args, **kwargs) @@ -279,7 +281,7 @@ def line(darray, *args, **kwargs): if darray.ndim == 2 and add_legend: ax.legend(handles=primitive, - labels=list(darray.coords[huelabel].values), + labels=list(huecoords.values), title=huelabel) # Rotate dates on xlabels @@ -333,8 +335,8 @@ def hist(darray, figsize=None, size=None, aspect=None, ax=None, **kwargs): ax.set_ylabel('Count') - if darray.name is not None: - ax.set_title('Histogram of {0}'.format(darray.name)) + ax.set_title('Histogram') + ax.set_xlabel(label_from_attrs(darray)) return primitive @@ -652,8 +654,8 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, # Label the plot with metadata if add_labels: - ax.set_xlabel(xlab) - ax.set_ylabel(ylab) + ax.set_xlabel(label_from_attrs(darray[xlab])) + ax.set_ylabel(label_from_attrs(darray[ylab])) ax.set_title(darray._title_for_slice()) if add_colorbar: @@ -664,8 +666,8 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, else: cbar_kwargs.setdefault('cax', cbar_ax) cbar = plt.colorbar(primitive, **cbar_kwargs) - if darray.name and add_labels and 'label' not in cbar_kwargs: - cbar.set_label(darray.name, rotation=90) + if add_labels and 'label' not in cbar_kwargs: + cbar.set_label(label_from_attrs(darray), rotation=90) elif cbar_ax is not None or cbar_kwargs is not None: # inform the user about keywords which aren't used raise ValueError("cbar_ax and cbar_kwargs can't be used with " diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 7ba48819518..6846c553b8b 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd import pkg_resources +import textwrap from ..core.pycompat import basestring from ..core.utils import is_scalar @@ -354,3 +355,24 @@ def get_axis(figsize, size, aspect, ax): ax = plt.gca() return ax + + +def label_from_attrs(da): + ''' Makes informative labels if variable metadata (attrs) follows + CF conventions. ''' + + if da.attrs.get('long_name'): + name = da.attrs['long_name'] + elif da.attrs.get('standard_name'): + name = da.attrs['standard_name'] + elif da.name is not None: + name = da.name + else: + name = '' + + if da.attrs.get('units'): + units = ' [{}]'.format(da.attrs['units']) + else: + units = '' + + return '\n'.join(textwrap.wrap(name + units, 30)) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 70ed1156643..db1fb2fd081 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -13,7 +13,7 @@ from xarray.plot.plot import _infer_interval_breaks from xarray.plot.utils import ( _build_discrete_cmap, _color_palette, _determine_cmap_params, - import_seaborn) + import_seaborn, label_from_attrs) from . import ( TestCase, assert_array_equal, assert_equal, raises_regex, @@ -89,6 +89,28 @@ class TestPlot(PlotTestCase): def setUp(self): self.darray = DataArray(easy_array((2, 3, 4))) + def test_label_from_attrs(self): + da = self.darray.copy() + assert '' == label_from_attrs(da) + + da.name = 'a' + da.attrs['units'] = 'a_units' + da.attrs['long_name'] = 'a_long_name' + da.attrs['standard_name'] = 'a_standard_name' + assert 'a_long_name [a_units]' == label_from_attrs(da) + + da.attrs.pop('long_name') + assert 'a_standard_name [a_units]' == label_from_attrs(da) + da.attrs.pop('units') + assert 'a_standard_name' == label_from_attrs(da) + + da.attrs['units'] = 'a_units' + da.attrs.pop('standard_name') + assert 'a [a_units]' == label_from_attrs(da) + + da.attrs.pop('units') + assert 'a' == label_from_attrs(da) + def test1d(self): self.darray[:, 0, 0].plot() @@ -303,10 +325,11 @@ def setUp(self): d = [0, 1.1, 0, 2] self.darray = DataArray( d, coords={'period': range(len(d))}, dims='period') + self.darray.period.attrs['units'] = 's' def test_xlabel_is_index_name(self): self.darray.plot() - assert 'period' == plt.gca().get_xlabel() + assert 'period [s]' == plt.gca().get_xlabel() def test_no_label_name_on_x_axis(self): self.darray.plot(y='period') @@ -318,13 +341,15 @@ def test_no_label_name_on_y_axis(self): def test_ylabel_is_data_name(self): self.darray.name = 'temperature' + self.darray.attrs['units'] = 'degrees_Celsius' self.darray.plot() - assert self.darray.name == plt.gca().get_ylabel() + assert 'temperature [degrees_Celsius]' == plt.gca().get_ylabel() def test_xlabel_is_data_name(self): self.darray.name = 'temperature' + self.darray.attrs['units'] = 'degrees_Celsius' self.darray.plot(y='period') - self.assertEqual(self.darray.name, plt.gca().get_xlabel()) + assert 'temperature [degrees_Celsius]' == plt.gca().get_xlabel() def test_format_string(self): self.darray.plot.line('ro') @@ -374,19 +399,20 @@ def setUp(self): def test_3d_array(self): self.darray.plot.hist() - def test_title_no_name(self): - self.darray.plot.hist() - assert '' == plt.gca().get_title() - - def test_title_uses_name(self): + def test_xlabel_uses_name(self): self.darray.name = 'testpoints' + self.darray.attrs['units'] = 'testunits' self.darray.plot.hist() - assert self.darray.name in plt.gca().get_title() + assert 'testpoints [testunits]' == plt.gca().get_xlabel() def test_ylabel_is_count(self): self.darray.plot.hist() assert 'Count' == plt.gca().get_ylabel() + def test_title_is_histogram(self): + self.darray.plot.hist() + assert 'Histogram' == plt.gca().get_title() + def test_can_pass_in_kwargs(self): nbins = 5 self.darray.plot.hist(bins=nbins) @@ -654,7 +680,10 @@ class Common2dMixin: """ def setUp(self): - da = DataArray(easy_array((10, 15), start=-1), dims=['y', 'x']) + da = DataArray(easy_array((10, 15), start=-1), + dims=['y', 'x'], + coords={'y': np.arange(10), + 'x': np.arange(15)}) # add 2d coords ds = da.to_dataset(name='testvar') x, y = np.meshgrid(da.x.values, da.y.values) @@ -663,12 +692,21 @@ def setUp(self): ds.set_coords(['x2d', 'y2d'], inplace=True) # set darray and plot method self.darray = ds.testvar + + # Add CF-compliant metadata + self.darray.attrs['long_name'] = 'a_long_name' + self.darray.attrs['units'] = 'a_units' + self.darray.x.attrs['long_name'] = 'x_long_name' + self.darray.x.attrs['units'] = 'x_units' + self.darray.y.attrs['long_name'] = 'y_long_name' + self.darray.y.attrs['units'] = 'y_units' + self.plotmethod = getattr(self.darray.plot, self.plotfunc.__name__) def test_label_names(self): self.plotmethod() - assert 'x' == plt.gca().get_xlabel() - assert 'y' == plt.gca().get_ylabel() + assert 'x_long_name [x_units]' == plt.gca().get_xlabel() + assert 'y_long_name [y_units]' == plt.gca().get_ylabel() def test_1d_raises_valueerror(self): with raises_regex(ValueError, r'DataArray must be 2d'): @@ -761,19 +799,19 @@ def test_diverging_color_limits(self): def test_xy_strings(self): self.plotmethod('y', 'x') ax = plt.gca() - assert 'y' == ax.get_xlabel() - assert 'x' == ax.get_ylabel() + assert 'y_long_name [y_units]' == ax.get_xlabel() + assert 'x_long_name [x_units]' == ax.get_ylabel() def test_positional_coord_string(self): self.plotmethod(y='x') ax = plt.gca() - assert 'x' == ax.get_ylabel() - assert 'y' == ax.get_xlabel() + assert 'x_long_name [x_units]' == ax.get_ylabel() + assert 'y_long_name [y_units]' == ax.get_xlabel() self.plotmethod(x='x') ax = plt.gca() - assert 'x' == ax.get_xlabel() - assert 'y' == ax.get_ylabel() + assert 'x_long_name [x_units]' == ax.get_xlabel() + assert 'y_long_name [y_units]' == ax.get_ylabel() def test_bad_x_string_exception(self): with raises_regex(ValueError, 'x and y must be coordinate variables'): @@ -797,7 +835,7 @@ def test_non_linked_coords(self): # Normal case, without transpose self.plotfunc(self.darray, x='x', y='newy') ax = plt.gca() - assert 'x' == ax.get_xlabel() + assert 'x_long_name [x_units]' == ax.get_xlabel() assert 'newy' == ax.get_ylabel() # ax limits might change between plotfuncs # simply ensure that these high coords were passed over @@ -812,7 +850,7 @@ def test_non_linked_coords_transpose(self): self.plotfunc(self.darray, x='newy', y='x') ax = plt.gca() assert 'newy' == ax.get_xlabel() - assert 'x' == ax.get_ylabel() + assert 'x_long_name [x_units]' == ax.get_ylabel() # ax limits might change between plotfuncs # simply ensure that these high coords were passed over assert np.min(ax.get_xlim()) > 100. @@ -826,19 +864,29 @@ def test_default_title(self): assert 'c = 1, d = foo' == title or 'd = foo, c = 1' == title def test_colorbar_default_label(self): - self.darray.name = 'testvar' self.plotmethod(add_colorbar=True) - assert self.darray.name in text_in_fig() + assert ('a_long_name [a_units]' in text_in_fig()) def test_no_labels(self): self.darray.name = 'testvar' + self.darray.attrs['units'] = 'test_units' self.plotmethod(add_labels=False) alltxt = text_in_fig() - for string in ['x', 'y', 'testvar']: + for string in ['x_long_name [x_units]', + 'y_long_name [y_units]', + 'testvar [test_units]']: assert string not in alltxt def test_colorbar_kwargs(self): # replace label + self.darray.attrs.pop('long_name') + self.darray.attrs['units'] = 'test_units' + # check default colorbar label + self.plotmethod(add_colorbar=True) + alltxt = text_in_fig() + assert 'testvar [test_units]' in alltxt + self.darray.attrs.pop('units') + self.darray.name = 'testvar' self.plotmethod(add_colorbar=True, cbar_kwargs={'label': 'MyLabel'}) alltxt = text_in_fig() From 69c9c45bf7a9d572200c4649605a5875e96b650c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Sat, 2 Jun 2018 13:15:32 +0100 Subject: [PATCH 50/61] Trivial docs fix (#2212) --- doc/io.rst | 2 +- doc/whats-new.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/io.rst b/doc/io.rst index e92ecd01cb4..7f7e7a2a66a 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -669,7 +669,7 @@ To use PseudoNetCDF to read such files, supply Add ``backend_kwargs={'format': ''}`` where `` options are listed on the PseudoNetCDF page. -.. _PseuodoNetCDF: http://github.com/barronh/PseudoNetCDF +.. _PseudoNetCDF: http://github.com/barronh/PseudoNetCDF Formats supported by Pandas diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 749726ee190..5e5da295186 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -81,7 +81,7 @@ Enhancements behavior when dimensions are not specified (previously this raised an error). By `Stephan Hoyer `_ -- :py:meth:`~DataArray.dot` and :py:func:`~dot` are partly supported with older +- :py:meth:`DataArray.dot` and :py:func:`dot` are partly supported with older dask<0.17.4. (related to :issue:`2203`) By `Keisuke Fujii `_. From bc52f8aa64833d8c97f9ef5253b6a78c7033f521 Mon Sep 17 00:00:00 2001 From: Yohai Bar Sinai <6164157+yohai@users.noreply.github.com> Date: Mon, 4 Jun 2018 11:54:44 -0400 Subject: [PATCH 51/61] ENH: added FacetGrid functionality to line plots (#2107) * ENH: added FacetGrid functionality to line plots a) plot.line can now accept also 'row' and 'col' keywords. b) If 'hue' is passed as a keyword to DataArray.plot() it generates a line plot FacetGrid. c) Line plots are automatically generated if the number of dimensions after faceting along row and/or col is one. * minor formatting issues * minor formatting issues * fix kwargs bug * added tests and refactoring line_legend * add documentation * added tests * minor formatting * Fix merge. All tests pass now. --- doc/plotting.rst | 15 ++- xarray/plot/facetgrid.py | 67 +++++++++++++- xarray/plot/plot.py | 187 ++++++++++++++++++++++++-------------- xarray/tests/test_plot.py | 66 ++++++++++++++ 4 files changed, 263 insertions(+), 72 deletions(-) diff --git a/doc/plotting.rst b/doc/plotting.rst index fa364d4838e..54fa2f57ac8 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -208,7 +208,11 @@ It is required to explicitly specify either 2. ``hue``: the dimension you want to represent by multiple lines. Thus, we could have made the previous plot by specifying ``hue='lat'`` instead of ``x='time'``. -If required, the automatic legend can be turned off using ``add_legend=False``. +If required, the automatic legend can be turned off using ``add_legend=False``. Alternatively, +``hue`` can be passed directly to :py:func:`xarray.plot` as `air.isel(lon=10, lat=[19,21,22]).plot(hue='lat')`. + + + Dimension along y-axis ~~~~~~~~~~~~~~~~~~~~~~ @@ -218,7 +222,7 @@ It is also possible to make line plots such that the data are on the x-axis and .. ipython:: python @savefig plotting_example_xy_kwarg.png - air.isel(time=10, lon=[10, 11]).plot.line(y='lat', hue='lon') + air.isel(time=10, lon=[10, 11]).plot(y='lat', hue='lon') Changing Axes Direction ----------------------- @@ -452,6 +456,13 @@ arguments to the xarray plotting methods/functions. This returns a @savefig plot_facet_dataarray.png g_simple = t.plot(x='lon', y='lat', col='time', col_wrap=3) +Faceting also works for line plots. + +.. ipython:: python + + @savefig plot_facet_dataarray_line.png + g_simple_line = t.isel(lat=slice(0,None,4)).plot(x='lon', hue='lat', col='time', col_wrap=3) + 4 dimensional ~~~~~~~~~~~~~ diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 361b47262c9..771f0879408 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -5,7 +5,6 @@ import warnings import numpy as np - from ..core.formatting import format_item from ..core.pycompat import getargspec from .utils import ( @@ -267,6 +266,44 @@ def map_dataarray(self, func, x, y, **kwargs): return self + def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs): + """ + Apply a line plot to a 2d facet subset of the data. + + Parameters + ---------- + x, y, hue: string + dimension names for the axes and hues of each facet + + Returns + ------- + self : FacetGrid object + + """ + from .plot import line, _infer_line_data + + add_legend = kwargs.pop('add_legend', True) + kwargs['add_legend'] = False + + for d, ax in zip(self.name_dicts.flat, self.axes.flat): + # None is the sentinel value + if d is not None: + subset = self.data.loc[d] + mappable = line(subset, x=x, y=y, hue=hue, + ax=ax, _labels=False, + **kwargs) + self._mappables.append(mappable) + _, _, _, xlabel, ylabel, huelabel = _infer_line_data( + darray=self.data.loc[self.name_dicts.flat[0]], + x=x, y=y, hue=hue) + + self._finalize_grid(xlabel, ylabel) + + if add_legend and huelabel: + self.add_line_legend(huelabel) + + return self + def _finalize_grid(self, *axlabels): """Finalize the annotations and layout.""" self.set_axis_labels(*axlabels) @@ -277,6 +314,34 @@ def _finalize_grid(self, *axlabels): if namedict is None: ax.set_visible(False) + def add_line_legend(self, huelabel): + figlegend = self.fig.legend( + handles=self._mappables[-1], + labels=list(self.data.coords[huelabel].values), + title=huelabel, + loc="center right") + + # Draw the plot to set the bounding boxes correctly + self.fig.draw(self.fig.canvas.get_renderer()) + + # Calculate and set the new width of the figure so the legend fits + legend_width = figlegend.get_window_extent().width / self.fig.dpi + figure_width = self.fig.get_figwidth() + self.fig.set_figwidth(figure_width + legend_width) + + # Draw the plot again to get the new transformations + self.fig.draw(self.fig.canvas.get_renderer()) + + # Now calculate how much space we need on the right side + legend_width = figlegend.get_window_extent().width / self.fig.dpi + space_needed = legend_width / (figure_width + legend_width) + 0.02 + # margin = .01 + # _space_needed = margin + space_needed + right = 1 - space_needed + + # Place the subplot axes to give space for the legend + self.fig.subplots_adjust(right=right) + def add_colorbar(self, **kwargs): """Draw a colorbar """ diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index f49ec5c52d9..6322fc09d92 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -83,8 +83,32 @@ def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None, return g.map_dataarray(plotfunc, x, y, **kwargs) -def plot(darray, row=None, col=None, col_wrap=None, ax=None, rtol=0.01, - subplot_kws=None, **kwargs): +def _line_facetgrid(darray, row=None, col=None, hue=None, + col_wrap=None, sharex=True, sharey=True, aspect=None, + size=None, subplot_kws=None, **kwargs): + """ + Convenience method to call xarray.plot.FacetGrid for line plots + kwargs are the arguments to pyplot.plot() + """ + ax = kwargs.pop('ax', None) + figsize = kwargs.pop('figsize', None) + if ax is not None: + raise ValueError("Can't use axes when making faceted plots.") + if aspect is None: + aspect = 1 + if size is None: + size = 3 + elif figsize is not None: + raise ValueError('cannot provide both `figsize` and `size` arguments') + + g = FacetGrid(data=darray, col=col, row=row, col_wrap=col_wrap, + sharex=sharex, sharey=sharey, figsize=figsize, + aspect=aspect, size=size, subplot_kws=subplot_kws) + return g.map_dataarray_line(hue=hue, **kwargs) + + +def plot(darray, row=None, col=None, col_wrap=None, ax=None, hue=None, + rtol=0.01, subplot_kws=None, **kwargs): """ Default plot of DataArray using matplotlib.pyplot. @@ -106,6 +130,8 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, rtol=0.01, If passed, make row faceted plots on this dimension name col : string, optional If passed, make column faceted plots on this dimension name + hue : string, optional + If passed, make faceted line plots with hue on this dimension name col_wrap : integer, optional Use together with ``col`` to wrap faceted plots ax : matplotlib axes, optional @@ -129,26 +155,28 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, rtol=0.01, plot_dims = set(darray.dims) plot_dims.discard(row) plot_dims.discard(col) + plot_dims.discard(hue) ndims = len(plot_dims) - error_msg = ('Only 2d plots are supported for facets in xarray. ' + error_msg = ('Only 1d and 2d plots are supported for facets in xarray. ' 'See the package `Seaborn` for more options.') - if ndims == 1: + if ndims in [1, 2]: if row or col: - raise ValueError(error_msg) - plotfunc = line - elif ndims == 2: - # Only 2d can FacetGrid - kwargs['row'] = row - kwargs['col'] = col - kwargs['col_wrap'] = col_wrap - kwargs['subplot_kws'] = subplot_kws - - plotfunc = pcolormesh + kwargs['row'] = row + kwargs['col'] = col + kwargs['col_wrap'] = col_wrap + kwargs['subplot_kws'] = subplot_kws + if ndims == 1: + plotfunc = line + kwargs['hue'] = hue + elif ndims == 2: + if hue: + raise ValueError('hue is not compatible with 2d data') + plotfunc = pcolormesh else: - if row or col: + if row or col or hue: raise ValueError(error_msg) plotfunc = hist @@ -157,6 +185,61 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, rtol=0.01, return plotfunc(darray, **kwargs) +def _infer_line_data(darray, x, y, hue): + error_msg = ('must be either None or one of ({0:s})' + .format(', '.join([repr(dd) for dd in darray.dims]))) + ndims = len(darray.dims) + + if x is not None and x not in darray.dims: + raise ValueError('x ' + error_msg) + + if y is not None and y not in darray.dims: + raise ValueError('y ' + error_msg) + + if x is not None and y is not None: + raise ValueError('You cannot specify both x and y kwargs' + 'for line plots.') + + if ndims == 1: + dim, = darray.dims # get the only dimension name + huename = None + hueplt = None + huelabel = '' + + if (x is None and y is None) or x == dim: + xplt = darray.coords[dim] + yplt = darray + + else: + yplt = darray.coords[dim] + xplt = darray + + else: + if x is None and y is None and hue is None: + raise ValueError('For 2D inputs, please' + 'specify either hue, x or y.') + + if y is None: + xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) + yname = darray.name + xplt = darray.coords[xname] + yplt = darray.transpose(xname, huename) + + else: + yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) + xname = darray.name + xplt = darray.transpose(yname, huename) + yplt = darray.coords[yname] + + hueplt = darray.coords[huename] + huelabel = label_from_attrs(darray[huename]) + + xlabel = label_from_attrs(xplt) + ylabel = label_from_attrs(yplt) + + return xplt, yplt, hueplt, xlabel, ylabel, huelabel + + # This function signature should not change so that it can use # matplotlib format strings def line(darray, *args, **kwargs): @@ -182,8 +265,7 @@ def line(darray, *args, **kwargs): Axis on which to plot this figure. By default, use the current axis. Mutually exclusive with ``size`` and ``figsize``. hue : string, optional - Coordinate for which you want multiple lines plotted - (2D DataArrays only). + Coordinate for which you want multiple lines plotted. x, y : string, optional Coordinates for x, y axis. Only one of these may be specified. The other coordinate plots values from the DataArray on which this @@ -201,6 +283,15 @@ def line(darray, *args, **kwargs): """ + # Handle facetgrids first + row = kwargs.pop('row', None) + col = kwargs.pop('col', None) + if row or col: + allargs = locals().copy() + allargs.update(allargs.pop('kwargs')) + allargs.update(allargs.pop('args')) + return _line_facetgrid(**allargs) + ndims = len(darray.dims) if ndims > 2: raise ValueError('Line plots are for 1- or 2-dimensional DataArrays. ' @@ -218,70 +309,28 @@ def line(darray, *args, **kwargs): xincrease = kwargs.pop('xincrease', True) yincrease = kwargs.pop('yincrease', True) add_legend = kwargs.pop('add_legend', True) + _labels = kwargs.pop('_labels', True) ax = get_axis(figsize, size, aspect, ax) - - error_msg = ('must be either None or one of ({0:s})' - .format(', '.join([repr(dd) for dd in darray.dims]))) - - if x is not None and x not in darray.dims: - raise ValueError('x ' + error_msg) - - if y is not None and y not in darray.dims: - raise ValueError('y ' + error_msg) - - if x is not None and y is not None: - raise ValueError('You cannot specify both x and y kwargs' - 'for line plots.') - - if ndims == 1: - dim, = darray.dims # get the only dimension name - - if (x is None and y is None) or x == dim: - xplt = darray.coords[dim] - yplt = darray - - else: - yplt = darray.coords[dim] - xplt = darray - - else: - if x is None and y is None and hue is None: - raise ValueError('For 2D inputs, please specify either hue or x.') - - if y is None: - xlabel, huelabel = _infer_xy_labels(darray=darray, x=x, y=hue) - ylabel = darray.name - xplt = darray.coords[xlabel] - yplt = darray.transpose(xlabel, huelabel) - - else: - ylabel, huelabel = _infer_xy_labels(darray=darray, x=y, y=hue) - xlabel = darray.name - xplt = darray.transpose(ylabel, huelabel) - yplt = darray.coords[ylabel] - - huecoords = darray[huelabel] - huelabel = label_from_attrs(huecoords) - - xlabel = label_from_attrs(xplt) - ylabel = label_from_attrs(yplt) + xplt, yplt, hueplt, xlabel, ylabel, huelabel = \ + _infer_line_data(darray, x, y, hue) _ensure_plottable(xplt) primitive = ax.plot(xplt, yplt, *args, **kwargs) - if xlabel is not None: - ax.set_xlabel(xlabel) + if _labels: + if xlabel is not None: + ax.set_xlabel(xlabel) - if ylabel is not None: - ax.set_ylabel(ylabel) + if ylabel is not None: + ax.set_ylabel(ylabel) - ax.set_title(darray._title_for_slice()) + ax.set_title(darray._title_for_slice()) if darray.ndim == 2 and add_legend: ax.legend(handles=primitive, - labels=list(huecoords.values), + labels=list(hueplt.values), title=huelabel) # Rotate dates on xlabels diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index db1fb2fd081..cdb515ba92e 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1521,6 +1521,72 @@ def test_default_labels(self): assert substring_in_axes(label, ax) +class TestFacetedLinePlots(PlotTestCase): + def setUp(self): + self.darray = DataArray(np.random.randn(10, 6, 3, 4), + dims=['hue', 'x', 'col', 'row'], + coords=[range(10), range(6), + range(3), ['A', 'B', 'C', 'C++']], + name='Cornelius Ortega the 1st') + + def test_facetgrid_shape(self): + g = self.darray.plot(row='row', col='col', hue='hue') + assert g.axes.shape == (len(self.darray.row), len(self.darray.col)) + + g = self.darray.plot(row='col', col='row', hue='hue') + assert g.axes.shape == (len(self.darray.col), len(self.darray.row)) + + def test_default_labels(self): + g = self.darray.plot(row='row', col='col', hue='hue') + # Rightmost column should be labeled + for label, ax in zip(self.darray.coords['row'].values, g.axes[:, -1]): + assert substring_in_axes(label, ax) + + # Top row should be labeled + for label, ax in zip(self.darray.coords['col'].values, g.axes[0, :]): + assert substring_in_axes(str(label), ax) + + # Leftmost column should have array name + for ax in g.axes[:, 0]: + assert substring_in_axes(self.darray.name, ax) + + def test_test_empty_cell(self): + g = self.darray.isel(row=1).drop('row').plot(col='col', + hue='hue', + col_wrap=2) + bottomright = g.axes[-1, -1] + assert not bottomright.has_data() + assert not bottomright.get_visible() + + def test_set_axis_labels(self): + g = self.darray.plot(row='row', col='col', hue='hue') + g.set_axis_labels('longitude', 'latitude') + alltxt = text_in_fig() + + assert 'longitude' in alltxt + assert 'latitude' in alltxt + + def test_both_x_and_y(self): + with pytest.raises(ValueError): + self.darray.plot.line(row='row', col='col', + x='x', y='hue') + + def test_axes_in_faceted_plot(self): + with pytest.raises(ValueError): + self.darray.plot.line(row='row', col='col', + x='x', ax=plt.axes()) + + def test_figsize_and_size(self): + with pytest.raises(ValueError): + self.darray.plot.line(row='row', col='col', + x='x', size=3, figsize=4) + + def test_wrong_num_of_dimensions(self): + with pytest.raises(ValueError): + self.darray.plot(row='row', hue='hue') + self.darray.plot.line(row='row', hue='hue') + + class TestDatetimePlot(PlotTestCase): def setUp(self): ''' From 21a9f3d7e3a5dd729aeafd08dda966c365520965 Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Thu, 7 Jun 2018 14:02:55 -0400 Subject: [PATCH 52/61] Feature/pickle rasterio (#2131) * add regression test * add PickleByReconstructionWrapper * docs * load in context manager * add distributed integration test * add test_pickle_reconstructor * drop lazy opening/caching and use partial function for open * stop using clever getattr hack * allow_cleanup_failure=ON_WINDOWS in tests for windows * whats new fix * fix bug in multiple pickles * fix for windows --- doc/whats-new.rst | 4 +++ xarray/backends/common.py | 28 ++++++++++++++++ xarray/backends/rasterio_.py | 56 +++++++++++++++++--------------- xarray/tests/test_backends.py | 40 +++++++++++++++++++++-- xarray/tests/test_distributed.py | 18 ++++++++-- 5 files changed, 114 insertions(+), 32 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 5e5da295186..980f996cb6d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -41,6 +41,9 @@ Enhancements Bug fixes ~~~~~~~~~ +- Fixed a bug in ``rasterio`` backend which prevented use with ``distributed``. + The ``rasterio`` backend now returns pickleable objects (:issue:`2021`). + .. _whats-new.0.10.6: v0.10.6 (31 May 2018) @@ -220,6 +223,7 @@ Bug fixes By `Deepak Cherian `_. - Colorbar limits are now determined by excluding ±Infs too. By `Deepak Cherian `_. + By `Joe Hamman `_. - Fixed ``to_iris`` to maintain lazy dask array after conversion (:issue:`2046`). By `Alex Hilson `_ and `Stephan Hoyer `_. diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 2961838e85f..d5eccd9be52 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -8,6 +8,7 @@ import traceback import warnings from collections import Mapping, OrderedDict +from functools import partial import numpy as np @@ -507,3 +508,30 @@ def assert_open(self): if not self._isopen: raise AssertionError('internal failure: file must be open ' 'if `autoclose=True` is used.') + + +class PickleByReconstructionWrapper(object): + + def __init__(self, opener, file, mode='r', **kwargs): + self.opener = partial(opener, file, mode=mode, **kwargs) + self.mode = mode + self._ds = None + + @property + def value(self): + self._ds = self.opener() + return self._ds + + def __getstate__(self): + state = self.__dict__.copy() + del state['_ds'] + if self.mode == 'w': + # file has already been created, don't override when restoring + state['mode'] = 'a' + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + def close(self): + self._ds.close() diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 8c0764c3ec9..0f19a1b51be 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -8,7 +8,7 @@ from .. import DataArray from ..core import indexing from ..core.utils import is_scalar -from .common import BackendArray +from .common import BackendArray, PickleByReconstructionWrapper try: from dask.utils import SerializableLock as Lock @@ -25,15 +25,15 @@ class RasterioArrayWrapper(BackendArray): """A wrapper around rasterio dataset objects""" - def __init__(self, rasterio_ds): - self.rasterio_ds = rasterio_ds - self._shape = (rasterio_ds.count, rasterio_ds.height, - rasterio_ds.width) + def __init__(self, riods): + self.riods = riods + self._shape = (riods.value.count, riods.value.height, + riods.value.width) self._ndims = len(self.shape) @property def dtype(self): - dtypes = self.rasterio_ds.dtypes + dtypes = self.riods.value.dtypes if not np.all(np.asarray(dtypes) == dtypes[0]): raise ValueError('All bands should have the same dtype') return np.dtype(dtypes[0]) @@ -105,7 +105,7 @@ def _get_indexer(self, key): def __getitem__(self, key): band_key, window, squeeze_axis, np_inds = self._get_indexer(key) - out = self.rasterio_ds.read(band_key, window=tuple(window)) + out = self.riods.value.read(band_key, window=tuple(window)) if squeeze_axis: out = np.squeeze(out, axis=squeeze_axis) return indexing.NumpyIndexingAdapter(out)[np_inds] @@ -194,7 +194,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, """ import rasterio - riods = rasterio.open(filename, mode='r') + + riods = PickleByReconstructionWrapper(rasterio.open, filename, mode='r') if cache is None: cache = chunks is None @@ -202,20 +203,20 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, coords = OrderedDict() # Get bands - if riods.count < 1: + if riods.value.count < 1: raise ValueError('Unknown dims') - coords['band'] = np.asarray(riods.indexes) + coords['band'] = np.asarray(riods.value.indexes) # Get coordinates if LooseVersion(rasterio.__version__) < '1.0': - transform = riods.affine + transform = riods.value.affine else: - transform = riods.transform + transform = riods.value.transform if transform.is_rectilinear: # 1d coordinates parse = True if parse_coordinates is None else parse_coordinates if parse: - nx, ny = riods.width, riods.height + nx, ny = riods.value.width, riods.value.height # xarray coordinates are pixel centered x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform _, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform @@ -238,41 +239,42 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, # For serialization store as tuple of 6 floats, the last row being # always (0, 0, 1) per definition (see https://github.com/sgillies/affine) attrs['transform'] = tuple(transform)[:6] - if hasattr(riods, 'crs') and riods.crs: + if hasattr(riods.value, 'crs') and riods.value.crs: # CRS is a dict-like object specific to rasterio # If CRS is not None, we convert it back to a PROJ4 string using # rasterio itself - attrs['crs'] = riods.crs.to_string() - if hasattr(riods, 'res'): + attrs['crs'] = riods.value.crs.to_string() + if hasattr(riods.value, 'res'): # (width, height) tuple of pixels in units of CRS - attrs['res'] = riods.res - if hasattr(riods, 'is_tiled'): + attrs['res'] = riods.value.res + if hasattr(riods.value, 'is_tiled'): # Is the TIF tiled? (bool) # We cast it to an int for netCDF compatibility - attrs['is_tiled'] = np.uint8(riods.is_tiled) + attrs['is_tiled'] = np.uint8(riods.value.is_tiled) with warnings.catch_warnings(): - # casting riods.transform to a tuple makes this future proof + # casting riods.value.transform to a tuple makes this future proof warnings.simplefilter('ignore', FutureWarning) - if hasattr(riods, 'transform'): + if hasattr(riods.value, 'transform'): # Affine transformation matrix (tuple of floats) # Describes coefficients mapping pixel coordinates to CRS - attrs['transform'] = tuple(riods.transform) - if hasattr(riods, 'nodatavals'): + attrs['transform'] = tuple(riods.value.transform) + if hasattr(riods.value, 'nodatavals'): # The nodata values for the raster bands attrs['nodatavals'] = tuple([np.nan if nodataval is None else nodataval - for nodataval in riods.nodatavals]) + for nodataval in riods.value.nodatavals]) # Parse extra metadata from tags, if supported parsers = {'ENVI': _parse_envi} - driver = riods.driver + driver = riods.value.driver if driver in parsers: - meta = parsers[driver](riods.tags(ns=driver)) + meta = parsers[driver](riods.value.tags(ns=driver)) for k, v in meta.items(): # Add values as coordinates if they match the band count, # as attributes otherwise - if isinstance(v, (list, np.ndarray)) and len(v) == riods.count: + if (isinstance(v, (list, np.ndarray)) and + len(v) == riods.value.count): coords[k] = ('band', np.asarray(v)) else: attrs[k] = v diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 0e6151b2db5..df7ed66f4fd 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -19,7 +19,8 @@ from xarray import ( DataArray, Dataset, backends, open_dataarray, open_dataset, open_mfdataset, save_mfdataset) -from xarray.backends.common import robust_getitem +from xarray.backends.common import (robust_getitem, + PickleByReconstructionWrapper) from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding from xarray.backends.pydap_ import PydapDataStore from xarray.core import indexing @@ -2724,7 +2725,8 @@ def create_tmp_geotiff(nx=4, ny=3, nz=3, # yields a temporary geotiff file and a corresponding expected DataArray import rasterio from rasterio.transform import from_origin - with create_tmp_file(suffix='.tif') as tmp_file: + with create_tmp_file(suffix='.tif', + allow_cleanup_failure=ON_WINDOWS) as tmp_file: # allow 2d or 3d shapes if nz == 1: data_shape = ny, nx @@ -2996,6 +2998,14 @@ def test_chunks(self): ex = expected.sel(band=1).mean(dim='x') assert_allclose(ac, ex) + def test_pickle_rasterio(self): + # regression test for https://github.com/pydata/xarray/issues/2121 + with create_tmp_geotiff() as (tmp_file, expected): + with xr.open_rasterio(tmp_file) as rioda: + temp = pickle.dumps(rioda) + with pickle.loads(temp) as actual: + assert_equal(actual, rioda) + def test_ENVI_tags(self): rasterio = pytest.importorskip('rasterio', minversion='1.0a') from rasterio.transform import from_origin @@ -3260,3 +3270,29 @@ def test_dataarray_to_netcdf_no_name_pathlib(self): with open_dataarray(tmp) as loaded_da: assert_identical(original_da, loaded_da) + + +def test_pickle_reconstructor(): + + lines = ['foo bar spam eggs'] + + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp: + with open(tmp, 'w') as f: + f.writelines(lines) + + obj = PickleByReconstructionWrapper(open, tmp) + + assert obj.value.readlines() == lines + + p_obj = pickle.dumps(obj) + obj.value.close() # for windows + obj2 = pickle.loads(p_obj) + + assert obj2.value.readlines() == lines + + # roundtrip again to make sure we can fully restore the state + p_obj2 = pickle.dumps(obj2) + obj2.value.close() # for windows + obj3 = pickle.loads(p_obj2) + + assert obj3.value.readlines() == lines diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 0ac03327494..8679e892be4 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -17,13 +17,14 @@ from distributed.client import futures_of import xarray as xr -from xarray.tests.test_backends import ON_WINDOWS, create_tmp_file +from xarray.tests.test_backends import (ON_WINDOWS, create_tmp_file, + create_tmp_geotiff) from xarray.tests.test_dataset import create_test_data from xarray.backends.common import HDF5_LOCK, CombinedLock from . import ( - assert_allclose, has_h5netcdf, has_netCDF4, has_scipy, requires_zarr, - raises_regex) + assert_allclose, has_h5netcdf, has_netCDF4, requires_rasterio, has_scipy, + requires_zarr, raises_regex) # this is to stop isort throwing errors. May have been easier to just use # `isort:skip` in retrospect @@ -136,6 +137,17 @@ def test_dask_distributed_zarr_integration_test(loop): assert_allclose(original, computed) +@requires_rasterio +def test_dask_distributed_rasterio_integration_test(loop): + with create_tmp_geotiff() as (tmp_file, expected): + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: + da_tiff = xr.open_rasterio(tmp_file, chunks={'band': 1}) + assert isinstance(da_tiff.data, da.Array) + actual = da_tiff.compute() + assert_allclose(actual, expected) + + @pytest.mark.skipif(distributed.__version__ <= '1.19.3', reason='Need recent distributed version to clean up get') @gen_cluster(client=True, timeout=None) From e39729928544204894e65c187d66c1a2b1900fea Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Fri, 8 Jun 2018 09:33:51 +0900 Subject: [PATCH 53/61] implement interp() (#2104) * Start working * interp1d for numpy backed array. * interp1d for dask backed array. * Support scalar interpolation. * more docs * flake8. Remove an unnecessary file. * Remove non-unicode characters * refactoring... * flake8. whats new * Make tests skip if scipy is not installed * skipif -> skip * move skip into every function * remove reuires_scipy * refactoring exceptions. * assert_equal -> assert_allclose * Remove unintended word. * More tests. More docs. * More docs. * Added a benchmark * doc. Remove *.png file. * add .load to benchmark with dask. * add assume_sorted kwarg. * Support dimension without coordinate * flake8 * More docs. test for attrs. * Updates based on comments * rename test * update docs * Add transpose for python 2 * More strict ordering * Cleanup * Update doc * Add skipif in tests * minor grammar/language edits in docs * Support dict arguments for interp. * update based on comments * Remove unused if-block * ValueError -> NotImpletedError. Doc improvement * Using OrderedSet * Drop object array after interpolation. * flake8 * Add keep_attrs keyword * flake8 (reverted from commit 6e0099963a50dc622204a690a0058b4db527b8ef) * flake8 * Remove keep_attrs keywords * Returns copy for not-interpolated variable. * Fix docs --- asv_bench/benchmarks/interp.py | 54 ++ .../advanced_selection_interpolation.svg | 731 ++++++++++++++++++ doc/api.rst | 2 + doc/index.rst | 2 + doc/indexing.rst | 2 +- doc/installing.rst | 1 + doc/interpolation.rst | 261 +++++++ doc/whats-new.rst | 10 + xarray/core/computation.py | 2 +- xarray/core/dataarray.py | 46 +- xarray/core/dataset.py | 89 ++- xarray/core/missing.py | 222 +++++- xarray/tests/test_interp.py | 432 +++++++++++ 13 files changed, 1840 insertions(+), 14 deletions(-) create mode 100644 asv_bench/benchmarks/interp.py create mode 100644 doc/_static/advanced_selection_interpolation.svg create mode 100644 doc/interpolation.rst create mode 100644 xarray/tests/test_interp.py diff --git a/asv_bench/benchmarks/interp.py b/asv_bench/benchmarks/interp.py new file mode 100644 index 00000000000..edec6df34dd --- /dev/null +++ b/asv_bench/benchmarks/interp.py @@ -0,0 +1,54 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np +import pandas as pd + +import xarray as xr + +from . import parameterized, randn, requires_dask + +nx = 3000 +long_nx = 30000000 +ny = 2000 +nt = 1000 +window = 20 + +randn_xy = randn((nx, ny), frac_nan=0.1) +randn_xt = randn((nx, nt)) +randn_t = randn((nt, )) +randn_long = randn((long_nx, ), frac_nan=0.1) + + +new_x_short = np.linspace(0.3 * nx, 0.7 * nx, 100) +new_x_long = np.linspace(0.3 * nx, 0.7 * nx, 1000) +new_y_long = np.linspace(0.1, 0.9, 1000) + + +class Interpolation(object): + def setup(self, *args, **kwargs): + self.ds = xr.Dataset( + {'var1': (('x', 'y'), randn_xy), + 'var2': (('x', 't'), randn_xt), + 'var3': (('t', ), randn_t)}, + coords={'x': np.arange(nx), + 'y': np.linspace(0, 1, ny), + 't': pd.date_range('1970-01-01', periods=nt, freq='D'), + 'x_coords': ('x', np.linspace(1.1, 2.1, nx))}) + + @parameterized(['method', 'is_short'], + (['linear', 'cubic'], [True, False])) + def time_interpolation(self, method, is_short): + new_x = new_x_short if is_short else new_x_long + self.ds.interp(x=new_x, method=method).load() + + @parameterized(['method'], + (['linear', 'nearest'])) + def time_interpolation_2d(self, method): + self.ds.interp(x=new_x_long, y=new_y_long, method=method).load() + + +class InterpolationDask(Interpolation): + def setup(self, *args, **kwargs): + requires_dask() + super(InterpolationDask, self).setup(**kwargs) + self.ds = self.ds.chunk({'t': 50}) diff --git a/doc/_static/advanced_selection_interpolation.svg b/doc/_static/advanced_selection_interpolation.svg new file mode 100644 index 00000000000..096563a604f --- /dev/null +++ b/doc/_static/advanced_selection_interpolation.svg @@ -0,0 +1,731 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + y + x + + + + + z + + + + + + + + + + + + + + + + + + + + + + + + + + + + y + x + + + + + z + + + + + + + + + Advanced indexing + Advanced interpolation + + + + diff --git a/doc/api.rst b/doc/api.rst index a528496bb6a..cb44ef82c8f 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -110,6 +110,7 @@ Indexing Dataset.isel Dataset.sel Dataset.squeeze + Dataset.interp Dataset.reindex Dataset.reindex_like Dataset.set_index @@ -263,6 +264,7 @@ Indexing DataArray.isel DataArray.sel DataArray.squeeze + DataArray.interp DataArray.reindex DataArray.reindex_like DataArray.set_index diff --git a/doc/index.rst b/doc/index.rst index dc00c548b35..7528f3cb1fa 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -40,6 +40,7 @@ Documentation * :doc:`data-structures` * :doc:`indexing` +* :doc:`interpolation` * :doc:`computation` * :doc:`groupby` * :doc:`reshaping` @@ -57,6 +58,7 @@ Documentation data-structures indexing + interpolation computation groupby reshaping diff --git a/doc/indexing.rst b/doc/indexing.rst index cec438dd2e4..a44e64e4079 100644 --- a/doc/indexing.rst +++ b/doc/indexing.rst @@ -510,7 +510,7 @@ where three elements at ``(ix, iy) = ((0, 0), (1, 1), (6, 0))`` are selected and mapped along a new dimension ``z``. If you want to add a coordinate to the new dimension ``z``, -you can supply a :py:meth:`~xarray.DataArray` with a coordinate, +you can supply a :py:class:`~xarray.DataArray` with a coordinate, .. ipython:: python diff --git a/doc/installing.rst b/doc/installing.rst index 33f01b8c770..31fc109ee2e 100644 --- a/doc/installing.rst +++ b/doc/installing.rst @@ -35,6 +35,7 @@ For netCDF and IO For accelerating xarray ~~~~~~~~~~~~~~~~~~~~~~~ +- `scipy `__: necessary to enable the interpolation features for xarray objects - `bottleneck `__: speeds up NaN-skipping and rolling window aggregations by a large factor (1.1 or later) diff --git a/doc/interpolation.rst b/doc/interpolation.rst new file mode 100644 index 00000000000..c5fd5166aeb --- /dev/null +++ b/doc/interpolation.rst @@ -0,0 +1,261 @@ +.. _interp: + +Interpolating data +================== + +.. ipython:: python + :suppress: + + import numpy as np + import pandas as pd + import xarray as xr + np.random.seed(123456) + +xarray offers flexible interpolation routines, which have a similar interface +to our :ref:`indexing `. + +.. note:: + + ``interp`` requires `scipy` installed. + + +Scalar and 1-dimensional interpolation +-------------------------------------- + +Interpolating a :py:class:`~xarray.DataArray` works mostly like labeled +indexing of a :py:class:`~xarray.DataArray`, + +.. ipython:: python + + da = xr.DataArray(np.sin(0.3 * np.arange(12).reshape(4, 3)), + [('time', np.arange(4)), + ('space', [0.1, 0.2, 0.3])]) + # label lookup + da.sel(time=3) + + # interpolation + da.interp(time=3.5) + + +Similar to the indexing, :py:meth:`~xarray.DataArray.interp` also accepts an +array-like, which gives the interpolated result as an array. + +.. ipython:: python + + # label lookup + da.sel(time=[2, 3]) + + # interpolation + da.interp(time=[2.5, 3.5]) + +.. note:: + + Currently, our interpolation only works for regular grids. + Therefore, similarly to :py:meth:`~xarray.DataArray.sel`, + only 1D coordinates along a dimension can be used as the + original coordinate to be interpolated. + + +Multi-dimensional Interpolation +------------------------------- + +Like :py:meth:`~xarray.DataArray.sel`, :py:meth:`~xarray.DataArray.interp` +accepts multiple coordinates. In this case, multidimensional interpolation +is carried out. + +.. ipython:: python + + # label lookup + da.sel(time=2, space=0.1) + + # interpolation + da.interp(time=2.5, space=0.15) + +Array-like coordinates are also accepted: + +.. ipython:: python + + # label lookup + da.sel(time=[2, 3], space=[0.1, 0.2]) + + # interpolation + da.interp(time=[1.5, 2.5], space=[0.15, 0.25]) + + +Interpolation methods +--------------------- + +We use :py:func:`scipy.interpolate.interp1d` for 1-dimensional interpolation and +:py:func:`scipy.interpolate.interpn` for multi-dimensional interpolation. + +The interpolation method can be specified by the optional ``method`` argument. + +.. ipython:: python + + da = xr.DataArray(np.sin(np.linspace(0, 2 * np.pi, 10)), dims='x', + coords={'x': np.linspace(0, 1, 10)}) + + da.plot.line('o', label='original') + da.interp(x=np.linspace(0, 1, 100)).plot.line(label='linear (default)') + da.interp(x=np.linspace(0, 1, 100), method='cubic').plot.line(label='cubic') + @savefig interpolation_sample1.png width=4in + plt.legend() + +Additional keyword arguments can be passed to scipy's functions. + +.. ipython:: python + + # fill 0 for the outside of the original coordinates. + da.interp(x=np.linspace(-0.5, 1.5, 10), kwargs={'fill_value': 0.0}) + # extrapolation + da.interp(x=np.linspace(-0.5, 1.5, 10), kwargs={'fill_value': 'extrapolate'}) + + +Advanced Interpolation +---------------------- + +:py:meth:`~xarray.DataArray.interp` accepts :py:class:`~xarray.DataArray` +as similar to :py:meth:`~xarray.DataArray.sel`, which enables us more advanced interpolation. +Based on the dimension of the new coordinate passed to :py:meth:`~xarray.DataArray.interp`, the dimension of the result are determined. + +For example, if you want to interpolate a two dimensional array along a particular dimension, as illustrated below, +you can pass two 1-dimensional :py:class:`~xarray.DataArray` s with +a common dimension as new coordinate. + +.. image:: _static/advanced_selection_interpolation.svg + :height: 200px + :width: 400 px + :alt: advanced indexing and interpolation + :align: center + +For example: + +.. ipython:: python + + da = xr.DataArray(np.sin(0.3 * np.arange(20).reshape(5, 4)), + [('x', np.arange(5)), + ('y', [0.1, 0.2, 0.3, 0.4])]) + # advanced indexing + x = xr.DataArray([0, 2, 4], dims='z') + y = xr.DataArray([0.1, 0.2, 0.3], dims='z') + da.sel(x=x, y=y) + + # advanced interpolation + x = xr.DataArray([0.5, 1.5, 2.5], dims='z') + y = xr.DataArray([0.15, 0.25, 0.35], dims='z') + da.interp(x=x, y=y) + +where values on the original coordinates +``(x, y) = ((0.5, 0.15), (1.5, 0.25), (2.5, 0.35))`` are obtained by the +2-dimensional interpolation and mapped along a new dimension ``z``. + +If you want to add a coordinate to the new dimension ``z``, you can supply +:py:class:`~xarray.DataArray` s with a coordinate, + +.. ipython:: python + + x = xr.DataArray([0.5, 1.5, 2.5], dims='z', coords={'z': ['a', 'b','c']}) + y = xr.DataArray([0.15, 0.25, 0.35], dims='z', + coords={'z': ['a', 'b','c']}) + da.interp(x=x, y=y) + +For the details of the advanced indexing, +see :ref:`more advanced indexing `. + + +Interpolating arrays with NaN +----------------------------- + +Our :py:meth:`~xarray.DataArray.interp` works with arrays with NaN +the same way that +`scipy.interpolate.interp1d `_ and +`scipy.interpolate.interpn `_ do. +``linear`` and ``nearest`` methods return arrays including NaN, +while other methods such as ``cubic`` or ``quadratic`` return all NaN arrays. + +.. ipython:: python + + da = xr.DataArray([0, 2, np.nan, 3, 3.25], dims='x', + coords={'x': range(5)}) + da.interp(x=[0.5, 1.5, 2.5]) + da.interp(x=[0.5, 1.5, 2.5], method='cubic') + +To avoid this, you can drop NaN by :py:meth:`~xarray.DataArray.dropna`, and +then make the interpolation + +.. ipython:: python + + dropped = da.dropna('x') + dropped + dropped.interp(x=[0.5, 1.5, 2.5], method='cubic') + +If NaNs are distributed rondomly in your multidimensional array, +dropping all the columns containing more than one NaNs by +:py:meth:`~xarray.DataArray.dropna` may lose a significant amount of information. +In such a case, you can fill NaN by :py:meth:`~xarray.DataArray.interpolate_na`, +which is similar to :py:meth:`pandas.Series.interpolate`. + +.. ipython:: python + + filled = da.interpolate_na(dim='x') + filled + +This fills NaN by interpolating along the specified dimension. +After filling NaNs, you can interpolate: + +.. ipython:: python + + filled.interp(x=[0.5, 1.5, 2.5], method='cubic') + +For the details of :py:meth:`~xarray.DataArray.interpolate_na`, +see :ref:`Missing values `. + + +Example +------- + +Let's see how :py:meth:`~xarray.DataArray.interp` works on real data. + +.. ipython:: python + + # Raw data + ds = xr.tutorial.load_dataset('air_temperature') + fig, axes = plt.subplots(ncols=2, figsize=(10, 4)) + ds.air.isel(time=0).plot(ax=axes[0]) + axes[0].set_title('Raw data') + + # Interpolated data + new_lon = np.linspace(ds.lon[0], ds.lon[-1], ds.dims['lon'] * 4) + new_lat = np.linspace(ds.lat[0], ds.lat[-1], ds.dims['lat'] * 4) + dsi = ds.interp(lat=new_lat, lon=new_lon) + dsi.air.isel(time=0).plot(ax=axes[1]) + @savefig interpolation_sample3.png width=8in + axes[1].set_title('Interpolated data') + +Our advanced interpolation can be used to remap the data to the new coordinate. +Consider the new coordinates x and z on the two dimensional plane. +The remapping can be done as follows + +.. ipython:: python + + # new coordinate + x = np.linspace(240, 300, 100) + z = np.linspace(20, 70, 100) + # relation between new and original coordinates + lat = xr.DataArray(z, dims=['z'], coords={'z': z}) + lon = xr.DataArray((x[:, np.newaxis]-270)/np.cos(z*np.pi/180)+270, + dims=['x', 'z'], coords={'x': x, 'z': z}) + + fig, axes = plt.subplots(ncols=2, figsize=(10, 4)) + ds.air.isel(time=0).plot(ax=axes[0]) + # draw the new coordinate on the original coordinates. + for idx in [0, 33, 66, 99]: + axes[0].plot(lon.isel(x=idx), lat, '--k') + for idx in [0, 33, 66, 99]: + axes[0].plot(*xr.broadcast(lon.isel(z=idx), lat.isel(z=idx)), '--k') + axes[0].set_title('Raw data') + + dsi = ds.interp(lon=lon, lat=lat) + dsi.air.isel(time=0).plot(ax=axes[1]) + @savefig interpolation_sample4.png width=8in + axes[1].set_title('Remapped data') diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 980f996cb6d..44f829874ac 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,15 @@ Enhancements - Plot labels now make use of metadata that follow CF conventions. By `Deepak Cherian `_ and `Ryan Abernathey `_. +- :py:meth:`~xarray.DataArray.interp` and :py:meth:`~xarray.Dataset.interp` + methods are newly added. + See :ref:`interpolating values with interp` for the detail. + (:issue:`2079`) + By `Keisuke Fujii `_. + +- `:py:meth:`~DataArray.dot` and :py:func:`~dot` are partly supported with older + dask<0.17.4. (related to :issue:`2203`) + By `Keisuke Fujii `_. + Bug fixes ~~~~~~~~~ diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 6a49610cb7b..9b251bb2c4b 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -10,7 +10,7 @@ import numpy as np -from . import duck_array_ops, utils, dtypes +from . import duck_array_ops, utils from .alignment import deep_align from .merge import expand_and_merge_variables from .pycompat import OrderedDict, dask_array_type, basestring diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index fd2b49cc08a..4129a3c5f26 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -906,10 +906,54 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True, indexers=indexers, method=method, tolerance=tolerance, copy=copy) return self._from_temp_dataset(ds) + def interp(self, coords=None, method='linear', assume_sorted=False, + kwargs={}, **coords_kwargs): + """ Multidimensional interpolation of variables. + + coords : dict, optional + Mapping from dimension names to the new coordinates. + new coordinate can be an scalar, array-like or DataArray. + If DataArrays are passed as new coordates, their dimensions are + used for the broadcasting. + method: {'linear', 'nearest'} for multidimensional array, + {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} + for 1-dimensional array. + assume_sorted: boolean, optional + If False, values of x can be in any order and they are sorted + first. If True, x has to be an array of monotonically increasing + values. + kwargs: dictionary + Additional keyword passed to scipy's interpolator. + **coords_kwarg : {dim: coordinate, ...}, optional + The keyword arguments form of ``coords``. + One of coords or coords_kwargs must be provided. + + Returns + ------- + interpolated: xr.DataArray + New dataarray on the new coordinates. + + Note + ---- + scipy is required. + + See Also + -------- + scipy.interpolate.interp1d + scipy.interpolate.interpn + """ + if self.dtype.kind not in 'uifc': + raise TypeError('interp only works for a numeric type array. ' + 'Given {}.'.format(self.dtype)) + + ds = self._to_temp_dataset().interp( + coords, method=method, kwargs=kwargs, assume_sorted=assume_sorted, + **coords_kwargs) + return self._from_temp_dataset(ds) + def rename(self, new_name_or_name_dict=None, **names): """Returns a new DataArray with renamed coordinates or a new name. - Parameters ---------- new_name_or_name_dict : str or dict-like, optional diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 08f5f70d72b..90712c953da 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1318,7 +1318,7 @@ def _validate_indexers(self, indexers): # all indexers should be int, slice, np.ndarrays, or Variable indexers_list = [] for k, v in iteritems(indexers): - if isinstance(v, integer_types + (slice, Variable)): + if isinstance(v, (slice, Variable)): pass elif isinstance(v, DataArray): v = v.variable @@ -1328,6 +1328,14 @@ def _validate_indexers(self, indexers): raise TypeError('cannot use a Dataset as an indexer') else: v = np.asarray(v) + if v.ndim == 0: + v = as_variable(v) + elif v.ndim == 1: + v = as_variable((k, v)) + else: + raise IndexError( + "Unlabeled multi-dimensional array cannot be " + "used for indexing: {}".format(k)) indexers_list.append((k, v)) return indexers_list @@ -1806,6 +1814,85 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True, coord_names.update(indexers) return self._replace_vars_and_dims(variables, coord_names) + def interp(self, coords=None, method='linear', assume_sorted=False, + kwargs={}, **coords_kwargs): + """ Multidimensional interpolation of Dataset. + + Parameters + ---------- + coords : dict, optional + Mapping from dimension names to the new coordinates. + New coordinate can be a scalar, array-like or DataArray. + If DataArrays are passed as new coordates, their dimensions are + used for the broadcasting. + method: string, optional. + {'linear', 'nearest'} for multidimensional array, + {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} + for 1-dimensional array. 'linear' is used by default. + assume_sorted: boolean, optional + If False, values of coordinates that are interpolated over can be + in any order and they are sorted first. If True, interpolated + coordinates are assumed to be an array of monotonically increasing + values. + kwargs: dictionary, optional + Additional keyword passed to scipy's interpolator. + **coords_kwarg : {dim: coordinate, ...}, optional + The keyword arguments form of ``coords``. + One of coords or coords_kwargs must be provided. + + Returns + ------- + interpolated: xr.Dataset + New dataset on the new coordinates. + + Note + ---- + scipy is required. + + See Also + -------- + scipy.interpolate.interp1d + scipy.interpolate.interpn + """ + from . import missing + + coords = either_dict_or_kwargs(coords, coords_kwargs, 'rename') + indexers = OrderedDict(self._validate_indexers(coords)) + + obj = self if assume_sorted else self.sortby([k for k in coords]) + + def maybe_variable(obj, k): + # workaround to get variable for dimension without coordinate. + try: + return obj._variables[k] + except KeyError: + return as_variable((k, range(obj.dims[k]))) + + variables = OrderedDict() + for name, var in iteritems(obj._variables): + if name not in indexers: + if var.dtype.kind in 'uifc': + var_indexers = {k: (maybe_variable(obj, k), v) for k, v + in indexers.items() if k in var.dims} + variables[name] = missing.interp( + var, var_indexers, method, **kwargs) + elif all(d not in indexers for d in var.dims): + # keep unrelated object array + variables[name] = var + + coord_names = set(variables).intersection(obj._coord_names) + selected = obj._replace_vars_and_dims(variables, + coord_names=coord_names) + # attach indexer as coordinate + variables.update(indexers) + # Extract coordinates from indexers + coord_vars = selected._get_indexers_coordinates(coords) + variables.update(coord_vars) + coord_names = (set(variables) + .intersection(obj._coord_names) + .union(coord_vars)) + return obj._replace_vars_and_dims(variables, coord_names=coord_names) + def rename(self, name_dict=None, inplace=False, **names): """Returns a new object with renamed variables and dimensions. diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 0da6750f5bc..e10f37d58d8 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -10,7 +10,9 @@ from .computation import apply_ufunc from .npcompat import flip from .pycompat import iteritems -from .utils import is_scalar +from .utils import is_scalar, OrderedSet +from .variable import Variable, broadcast_variables +from .duck_array_ops import dask_array_type class BaseInterpolator(object): @@ -203,7 +205,8 @@ def interp_na(self, dim=None, use_coordinate=True, method='linear', limit=None, # method index = get_clean_interp_index(self, dim, use_coordinate=use_coordinate, **kwargs) - interpolator = _get_interpolator(method, **kwargs) + interp_class, kwargs = _get_interpolator(method, **kwargs) + interpolator = partial(func_interpolate_na, interp_class, **kwargs) arr = apply_ufunc(interpolator, index, self, input_core_dims=[[dim], [dim]], @@ -219,7 +222,7 @@ def interp_na(self, dim=None, use_coordinate=True, method='linear', limit=None, return arr -def wrap_interpolator(interpolator, x, y, **kwargs): +def func_interpolate_na(interpolator, x, y, **kwargs): '''helper function to apply interpolation along 1 dimension''' # it would be nice if this wasn't necessary, works around: # "ValueError: assignment destination is read-only" in assignment below @@ -281,29 +284,41 @@ def bfill(arr, dim=None, limit=None): kwargs=dict(n=_limit, axis=axis)).transpose(*arr.dims) -def _get_interpolator(method, **kwargs): +def _get_interpolator(method, vectorizeable_only=False, **kwargs): '''helper function to select the appropriate interpolator class - returns a partial of wrap_interpolator + returns interpolator class and keyword arguments for the class ''' interp1d_methods = ['linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'polynomial'] valid_methods = interp1d_methods + ['barycentric', 'krog', 'pchip', 'spline', 'akima'] + has_scipy = True + try: + from scipy import interpolate + except ImportError: + has_scipy = False + + # prioritize scipy.interpolate if (method == 'linear' and not - kwargs.get('fill_value', None) == 'extrapolate'): + kwargs.get('fill_value', None) == 'extrapolate' and + not vectorizeable_only): kwargs.update(method=method) interp_class = NumpyInterpolator + elif method in valid_methods: - try: - from scipy import interpolate - except ImportError: + if not has_scipy: raise ImportError( 'Interpolation with method `%s` requires scipy' % method) + if method in interp1d_methods: kwargs.update(method=method) interp_class = ScipyInterpolator + elif vectorizeable_only: + raise ValueError('{} is not a vectorizeable interpolator. ' + 'Available methods are {}'.format( + method, interp1d_methods)) elif method == 'barycentric': interp_class = interpolate.BarycentricInterpolator elif method == 'krog': @@ -320,7 +335,30 @@ def _get_interpolator(method, **kwargs): else: raise ValueError('%s is not a valid interpolator' % method) - return partial(wrap_interpolator, interp_class, **kwargs) + return interp_class, kwargs + + +def _get_interpolator_nd(method, **kwargs): + '''helper function to select the appropriate interpolator class + + returns interpolator class and keyword arguments for the class + ''' + valid_methods = ['linear', 'nearest'] + + try: + from scipy import interpolate + except ImportError: + raise ImportError( + 'Interpolation with method `%s` requires scipy' % method) + + if method in valid_methods: + kwargs.update(method=method) + interp_class = interpolate.interpn + else: + raise ValueError('%s is not a valid interpolator for interpolating ' + 'over multiple dimensions.' % method) + + return interp_class, kwargs def _get_valid_fill_mask(arr, dim, limit): @@ -332,3 +370,167 @@ def _get_valid_fill_mask(arr, dim, limit): return (arr.isnull().rolling(min_periods=1, **kw) .construct(new_dim, fill_value=False) .sum(new_dim, skipna=False)) <= limit + + +def _assert_single_chunk(var, axes): + for axis in axes: + if len(var.chunks[axis]) > 1 or var.chunks[axis][0] < var.shape[axis]: + raise NotImplementedError( + 'Chunking along the dimension to be interpolated ' + '({}) is not yet supported.'.format(axis)) + + +def _localize(var, indexes_coords): + """ Speed up for linear and nearest neighbor method. + Only consider a subspace that is needed for the interpolation + """ + indexes = {} + for dim, [x, new_x] in indexes_coords.items(): + index = x.to_index() + imin = index.get_loc(np.min(new_x.values), method='nearest') + imax = index.get_loc(np.max(new_x.values), method='nearest') + + indexes[dim] = slice(max(imin - 2, 0), imax + 2) + indexes_coords[dim] = (x[indexes[dim]], new_x) + return var.isel(**indexes), indexes_coords + + +def interp(var, indexes_coords, method, **kwargs): + """ Make an interpolation of Variable + + Parameters + ---------- + var: Variable + index_coords: + Mapping from dimension name to a pair of original and new coordinates. + Original coordinates should be sorted in strictly ascending order. + Note that all the coordinates should be Variable objects. + method: string + One of {'linear', 'nearest', 'zero', 'slinear', 'quadratic', + 'cubic'}. For multidimensional interpolation, only + {'linear', 'nearest'} can be used. + **kwargs: + keyword arguments to be passed to scipy.interpolate + + Returns + ------- + Interpolated Variable + + See Also + -------- + DataArray.interp + Dataset.interp + """ + if not indexes_coords: + return var.copy() + + # simple speed up for the local interpolation + if method in ['linear', 'nearest']: + var, indexes_coords = _localize(var, indexes_coords) + + # default behavior + kwargs['bounds_error'] = kwargs.get('bounds_error', False) + + # target dimensions + dims = list(indexes_coords) + x, new_x = zip(*[indexes_coords[d] for d in dims]) + destination = broadcast_variables(*new_x) + + # transpose to make the interpolated axis to the last position + broadcast_dims = [d for d in var.dims if d not in dims] + original_dims = broadcast_dims + dims + new_dims = broadcast_dims + list(destination[0].dims) + interped = interp_func(var.transpose(*original_dims).data, + x, destination, method, kwargs) + + result = Variable(new_dims, interped, attrs=var.attrs) + + # dimension of the output array + out_dims = OrderedSet() + for d in var.dims: + if d in dims: + out_dims.update(indexes_coords[d][1].dims) + else: + out_dims.add(d) + return result.transpose(*tuple(out_dims)) + + +def interp_func(var, x, new_x, method, kwargs): + """ + multi-dimensional interpolation for array-like. Interpolated axes should be + located in the last position. + + Parameters + ---------- + var: np.ndarray or dask.array.Array + Array to be interpolated. The final dimension is interpolated. + x: a list of 1d array. + Original coordinates. Should not contain NaN. + new_x: a list of 1d array + New coordinates. Should not contain NaN. + method: string + {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for + 1-dimensional itnterpolation. + {'linear', 'nearest'} for multidimensional interpolation + **kwargs: + Optional keyword arguments to be passed to scipy.interpolator + + Returns + ------- + interpolated: array + Interpolated array + + Note + ---- + This requiers scipy installed. + + See Also + -------- + scipy.interpolate.interp1d + """ + if not x: + return var.copy() + + if len(x) == 1: + func, kwargs = _get_interpolator(method, vectorizeable_only=True, + **kwargs) + else: + func, kwargs = _get_interpolator_nd(method, **kwargs) + + if isinstance(var, dask_array_type): + import dask.array as da + + _assert_single_chunk(var, range(var.ndim - len(x), var.ndim)) + chunks = var.chunks[:-len(x)] + new_x[0].shape + drop_axis = range(var.ndim - len(x), var.ndim) + new_axis = range(var.ndim - len(x), var.ndim - len(x) + new_x[0].ndim) + return da.map_blocks(_interpnd, var, x, new_x, func, kwargs, + dtype=var.dtype, chunks=chunks, + new_axis=new_axis, drop_axis=drop_axis) + + return _interpnd(var, x, new_x, func, kwargs) + + +def _interp1d(var, x, new_x, func, kwargs): + # x, new_x are tuples of size 1. + x, new_x = x[0], new_x[0] + rslt = func(x, var, assume_sorted=True, **kwargs)(np.ravel(new_x)) + if new_x.ndim > 1: + return rslt.reshape(var.shape[:-1] + new_x.shape) + if new_x.ndim == 0: + return rslt[..., -1] + return rslt + + +def _interpnd(var, x, new_x, func, kwargs): + if len(x) == 1: + return _interp1d(var, x, new_x, func, kwargs) + + # move the interpolation axes to the start position + var = var.transpose(range(-len(x), var.ndim - len(x))) + # stack new_x to 1 vector, with reshape + xi = np.stack([x1.values.ravel() for x1 in new_x], axis=-1) + rslt = func(x, var, xi, **kwargs) + # move back the interpolation axes to the last position + rslt = rslt.transpose(range(-rslt.ndim + 1, 1)) + return rslt.reshape(rslt.shape[:-1] + new_x[0].shape) diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py new file mode 100644 index 00000000000..592854a4d1b --- /dev/null +++ b/xarray/tests/test_interp.py @@ -0,0 +1,432 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np +import pytest + +import xarray as xr +from xarray.tests import assert_allclose, assert_equal, requires_scipy +from . import has_dask, has_scipy +from .test_dataset import create_test_data + +try: + import scipy +except ImportError: + pass + + +def get_example_data(case): + x = np.linspace(0, 1, 100) + y = np.linspace(0, 0.1, 30) + data = xr.DataArray( + np.sin(x[:, np.newaxis]) * np.cos(y), dims=['x', 'y'], + coords={'x': x, 'y': y, 'x2': ('x', x**2)}) + + if case == 0: + return data + elif case == 1: + return data.chunk({'y': 3}) + elif case == 2: + return data.chunk({'x': 25, 'y': 3}) + elif case == 3: + x = np.linspace(0, 1, 100) + y = np.linspace(0, 0.1, 30) + z = np.linspace(0.1, 0.2, 10) + return xr.DataArray( + np.sin(x[:, np.newaxis, np.newaxis]) * np.cos( + y[:, np.newaxis]) * z, + dims=['x', 'y', 'z'], + coords={'x': x, 'y': y, 'x2': ('x', x**2), 'z': z}) + elif case == 4: + return get_example_data(3).chunk({'z': 5}) + + +def test_keywargs(): + if not has_scipy: + pytest.skip('scipy is not installed.') + + da = get_example_data(0) + assert_equal(da.interp(x=[0.5, 0.8]), da.interp({'x': [0.5, 0.8]})) + + +@pytest.mark.parametrize('method', ['linear', 'cubic']) +@pytest.mark.parametrize('dim', ['x', 'y']) +@pytest.mark.parametrize('case', [0, 1]) +def test_interpolate_1d(method, dim, case): + if not has_scipy: + pytest.skip('scipy is not installed.') + + if not has_dask and case in [1]: + pytest.skip('dask is not installed in the environment.') + + da = get_example_data(case) + xdest = np.linspace(0.0, 0.9, 80) + + if dim == 'y' and case == 1: + with pytest.raises(NotImplementedError): + actual = da.interp(method=method, **{dim: xdest}) + pytest.skip('interpolation along chunked dimension is ' + 'not yet supported') + + actual = da.interp(method=method, **{dim: xdest}) + + # scipy interpolation for the reference + def func(obj, new_x): + return scipy.interpolate.interp1d( + da[dim], obj.data, axis=obj.get_axis_num(dim), bounds_error=False, + fill_value=np.nan, kind=method)(new_x) + + if dim == 'x': + coords = {'x': xdest, 'y': da['y'], 'x2': ('x', func(da['x2'], xdest))} + else: # y + coords = {'x': da['x'], 'y': xdest, 'x2': da['x2']} + + expected = xr.DataArray(func(da, xdest), dims=['x', 'y'], coords=coords) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize('method', ['cubic', 'zero']) +def test_interpolate_1d_methods(method): + if not has_scipy: + pytest.skip('scipy is not installed.') + + da = get_example_data(0) + dim = 'x' + xdest = np.linspace(0.0, 0.9, 80) + + actual = da.interp(method=method, **{dim: xdest}) + + # scipy interpolation for the reference + def func(obj, new_x): + return scipy.interpolate.interp1d( + da[dim], obj.data, axis=obj.get_axis_num(dim), bounds_error=False, + fill_value=np.nan, kind=method)(new_x) + + coords = {'x': xdest, 'y': da['y'], 'x2': ('x', func(da['x2'], xdest))} + expected = xr.DataArray(func(da, xdest), dims=['x', 'y'], coords=coords) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize('use_dask', [False, True]) +def test_interpolate_vectorize(use_dask): + if not has_scipy: + pytest.skip('scipy is not installed.') + + if not has_dask and use_dask: + pytest.skip('dask is not installed in the environment.') + + # scipy interpolation for the reference + def func(obj, dim, new_x): + shape = [s for i, s in enumerate(obj.shape) + if i != obj.get_axis_num(dim)] + for s in new_x.shape[::-1]: + shape.insert(obj.get_axis_num(dim), s) + + return scipy.interpolate.interp1d( + da[dim], obj.data, axis=obj.get_axis_num(dim), + bounds_error=False, fill_value=np.nan)(new_x).reshape(shape) + + da = get_example_data(0) + if use_dask: + da = da.chunk({'y': 5}) + + # xdest is 1d but has different dimension + xdest = xr.DataArray(np.linspace(0.1, 0.9, 30), dims='z', + coords={'z': np.random.randn(30), + 'z2': ('z', np.random.randn(30))}) + + actual = da.interp(x=xdest, method='linear') + + expected = xr.DataArray(func(da, 'x', xdest), dims=['z', 'y'], + coords={'z': xdest['z'], 'z2': xdest['z2'], + 'y': da['y'], + 'x': ('z', xdest.values), + 'x2': ('z', func(da['x2'], 'x', xdest))}) + assert_allclose(actual, expected.transpose('z', 'y')) + + # xdest is 2d + xdest = xr.DataArray(np.linspace(0.1, 0.9, 30).reshape(6, 5), + dims=['z', 'w'], + coords={'z': np.random.randn(6), + 'w': np.random.randn(5), + 'z2': ('z', np.random.randn(6))}) + + actual = da.interp(x=xdest, method='linear') + + expected = xr.DataArray( + func(da, 'x', xdest), + dims=['z', 'w', 'y'], + coords={'z': xdest['z'], 'w': xdest['w'], 'z2': xdest['z2'], + 'y': da['y'], 'x': (('z', 'w'), xdest), + 'x2': (('z', 'w'), func(da['x2'], 'x', xdest))}) + assert_allclose(actual, expected.transpose('z', 'w', 'y')) + + +@pytest.mark.parametrize('case', [3, 4]) +def test_interpolate_nd(case): + if not has_scipy: + pytest.skip('scipy is not installed.') + + if not has_dask and case == 4: + pytest.skip('dask is not installed in the environment.') + + da = get_example_data(case) + + # grid -> grid + xdest = np.linspace(0.1, 1.0, 11) + ydest = np.linspace(0.0, 0.2, 10) + actual = da.interp(x=xdest, y=ydest, method='linear') + + # linear interpolation is separateable + expected = da.interp(x=xdest, method='linear') + expected = expected.interp(y=ydest, method='linear') + assert_allclose(actual.transpose('x', 'y', 'z'), + expected.transpose('x', 'y', 'z')) + + # grid -> 1d-sample + xdest = xr.DataArray(np.linspace(0.1, 1.0, 11), dims='y') + ydest = xr.DataArray(np.linspace(0.0, 0.2, 11), dims='y') + actual = da.interp(x=xdest, y=ydest, method='linear') + + # linear interpolation is separateable + expected_data = scipy.interpolate.RegularGridInterpolator( + (da['x'], da['y']), da.transpose('x', 'y', 'z').values, + method='linear', bounds_error=False, + fill_value=np.nan)(np.stack([xdest, ydest], axis=-1)) + expected = xr.DataArray( + expected_data, dims=['y', 'z'], + coords={'z': da['z'], 'y': ydest, 'x': ('y', xdest.values), + 'x2': da['x2'].interp(x=xdest)}) + assert_allclose(actual.transpose('y', 'z'), expected) + + # reversed order + actual = da.interp(y=ydest, x=xdest, method='linear') + assert_allclose(actual.transpose('y', 'z'), expected) + + +@pytest.mark.parametrize('method', ['linear']) +@pytest.mark.parametrize('case', [0, 1]) +def test_interpolate_scalar(method, case): + if not has_scipy: + pytest.skip('scipy is not installed.') + + if not has_dask and case in [1]: + pytest.skip('dask is not installed in the environment.') + + da = get_example_data(case) + xdest = 0.4 + + actual = da.interp(x=xdest, method=method) + + # scipy interpolation for the reference + def func(obj, new_x): + return scipy.interpolate.interp1d( + da['x'], obj.data, axis=obj.get_axis_num('x'), bounds_error=False, + fill_value=np.nan)(new_x) + + coords = {'x': xdest, 'y': da['y'], 'x2': func(da['x2'], xdest)} + expected = xr.DataArray(func(da, xdest), dims=['y'], coords=coords) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize('method', ['linear']) +@pytest.mark.parametrize('case', [3, 4]) +def test_interpolate_nd_scalar(method, case): + if not has_scipy: + pytest.skip('scipy is not installed.') + + if not has_dask and case in [4]: + pytest.skip('dask is not installed in the environment.') + + da = get_example_data(case) + xdest = 0.4 + ydest = 0.05 + + actual = da.interp(x=xdest, y=ydest, method=method) + # scipy interpolation for the reference + expected_data = scipy.interpolate.RegularGridInterpolator( + (da['x'], da['y']), da.transpose('x', 'y', 'z').values, + method='linear', bounds_error=False, + fill_value=np.nan)(np.stack([xdest, ydest], axis=-1)) + + coords = {'x': xdest, 'y': ydest, 'x2': da['x2'].interp(x=xdest), + 'z': da['z']} + expected = xr.DataArray(expected_data[0], dims=['z'], coords=coords) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize('use_dask', [True, False]) +def test_nans(use_dask): + if not has_scipy: + pytest.skip('scipy is not installed.') + + da = xr.DataArray([0, 1, np.nan, 2], dims='x', coords={'x': range(4)}) + + if not has_dask and use_dask: + pytest.skip('dask is not installed in the environment.') + da = da.chunk() + + actual = da.interp(x=[0.5, 1.5]) + # not all values are nan + assert actual.count() > 0 + + +@pytest.mark.parametrize('use_dask', [True, False]) +def test_errors(use_dask): + if not has_scipy: + pytest.skip('scipy is not installed.') + + # akima and spline are unavailable + da = xr.DataArray([0, 1, np.nan, 2], dims='x', coords={'x': range(4)}) + if not has_dask and use_dask: + pytest.skip('dask is not installed in the environment.') + da = da.chunk() + + for method in ['akima', 'spline']: + with pytest.raises(ValueError): + da.interp(x=[0.5, 1.5], method=method) + + # not sorted + if use_dask: + da = get_example_data(3) + else: + da = get_example_data(1) + + result = da.interp(x=[-1, 1, 3], kwargs={'fill_value': 0.0}) + assert not np.isnan(result.values).any() + result = da.interp(x=[-1, 1, 3]) + assert np.isnan(result.values).any() + + # invalid method + with pytest.raises(ValueError): + da.interp(x=[2, 0], method='boo') + with pytest.raises(ValueError): + da.interp(x=[2, 0], y=2, method='cubic') + with pytest.raises(ValueError): + da.interp(y=[2, 0], method='boo') + + # object-type DataArray cannot be interpolated + da = xr.DataArray(['a', 'b', 'c'], dims='x', coords={'x': [0, 1, 2]}) + with pytest.raises(TypeError): + da.interp(x=0) + + +@requires_scipy +def test_dtype(): + ds = xr.Dataset({'var1': ('x', [0, 1, 2]), 'var2': ('x', ['a', 'b', 'c'])}, + coords={'x': [0.1, 0.2, 0.3], 'z': ('x', ['a', 'b', 'c'])}) + actual = ds.interp(x=[0.15, 0.25]) + assert 'var1' in actual + assert 'var2' not in actual + # object array should be dropped + assert 'z' not in actual.coords + + +@requires_scipy +def test_sorted(): + # unsorted non-uniform gridded data + x = np.random.randn(100) + y = np.random.randn(30) + z = np.linspace(0.1, 0.2, 10) * 3.0 + da = xr.DataArray( + np.cos(x[:, np.newaxis, np.newaxis]) * np.cos( + y[:, np.newaxis]) * z, + dims=['x', 'y', 'z'], + coords={'x': x, 'y': y, 'x2': ('x', x**2), 'z': z}) + + x_new = np.linspace(0, 1, 30) + y_new = np.linspace(0, 1, 20) + + da_sorted = da.sortby('x') + assert_allclose(da.interp(x=x_new), + da_sorted.interp(x=x_new, assume_sorted=True)) + da_sorted = da.sortby(['x', 'y']) + assert_allclose(da.interp(x=x_new, y=y_new), + da_sorted.interp(x=x_new, y=y_new, assume_sorted=True)) + + with pytest.raises(ValueError): + da.interp(x=[0, 1, 2], assume_sorted=True) + + +@requires_scipy +def test_dimension_wo_coords(): + da = xr.DataArray(np.arange(12).reshape(3, 4), dims=['x', 'y'], + coords={'y': [0, 1, 2, 3]}) + da_w_coord = da.copy() + da_w_coord['x'] = np.arange(3) + + assert_equal(da.interp(x=[0.1, 0.2, 0.3]), + da_w_coord.interp(x=[0.1, 0.2, 0.3])) + assert_equal(da.interp(x=[0.1, 0.2, 0.3], y=[0.5]), + da_w_coord.interp(x=[0.1, 0.2, 0.3], y=[0.5])) + + +@requires_scipy +def test_dataset(): + ds = create_test_data() + ds.attrs['foo'] = 'var' + ds['var1'].attrs['buz'] = 'var2' + new_dim2 = xr.DataArray([0.11, 0.21, 0.31], dims='z') + interpolated = ds.interp(dim2=new_dim2) + + assert_allclose(interpolated['var1'], ds['var1'].interp(dim2=new_dim2)) + assert interpolated['var3'].equals(ds['var3']) + + # make sure modifying interpolated does not affect the original dataset + interpolated['var1'][:, 1] = 1.0 + interpolated['var2'][:, 1] = 1.0 + interpolated['var3'][:, 1] = 1.0 + + assert not interpolated['var1'].equals(ds['var1']) + assert not interpolated['var2'].equals(ds['var2']) + assert not interpolated['var3'].equals(ds['var3']) + # attrs should be kept + assert interpolated.attrs['foo'] == 'var' + assert interpolated['var1'].attrs['buz'] == 'var2' + + +@pytest.mark.parametrize('case', [0, 3]) +def test_interpolate_dimorder(case): + """ Make sure the resultant dimension order is consistent with .sel() """ + if not has_scipy: + pytest.skip('scipy is not installed.') + + da = get_example_data(case) + + new_x = xr.DataArray([0, 1, 2], dims='x') + assert da.interp(x=new_x).dims == da.sel(x=new_x, method='nearest').dims + + new_y = xr.DataArray([0, 1, 2], dims='y') + actual = da.interp(x=new_x, y=new_y).dims + expected = da.sel(x=new_x, y=new_y, method='nearest').dims + assert actual == expected + # reversed order + actual = da.interp(y=new_y, x=new_x).dims + expected = da.sel(y=new_y, x=new_x, method='nearest').dims + assert actual == expected + + new_x = xr.DataArray([0, 1, 2], dims='a') + assert da.interp(x=new_x).dims == da.sel(x=new_x, method='nearest').dims + assert da.interp(y=new_x).dims == da.sel(y=new_x, method='nearest').dims + new_y = xr.DataArray([0, 1, 2], dims='a') + actual = da.interp(x=new_x, y=new_y).dims + expected = da.sel(x=new_x, y=new_y, method='nearest').dims + assert actual == expected + + new_x = xr.DataArray([[0], [1], [2]], dims=['a', 'b']) + assert da.interp(x=new_x).dims == da.sel(x=new_x, method='nearest').dims + assert da.interp(y=new_x).dims == da.sel(y=new_x, method='nearest').dims + + if case == 3: + new_x = xr.DataArray([[0], [1], [2]], dims=['a', 'b']) + new_z = xr.DataArray([[0], [1], [2]], dims=['a', 'b']) + actual = da.interp(x=new_x, z=new_z).dims + expected = da.sel(x=new_x, z=new_z, method='nearest').dims + assert actual == expected + + actual = da.interp(z=new_z, x=new_x).dims + expected = da.sel(z=new_z, x=new_x, method='nearest').dims + assert actual == expected + + actual = da.interp(x=0.5, z=new_z).dims + expected = da.sel(x=0.5, z=new_z, method='nearest').dims + assert actual == expected From 98e6a4b84dd2cf4296a3e0aa9710bb79411354e4 Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Fri, 8 Jun 2018 10:31:18 +0900 Subject: [PATCH 54/61] reduce memory consumption. (#2220) --- doc/interpolation.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/interpolation.rst b/doc/interpolation.rst index c5fd5166aeb..98cd89afbd4 100644 --- a/doc/interpolation.rst +++ b/doc/interpolation.rst @@ -219,16 +219,16 @@ Let's see how :py:meth:`~xarray.DataArray.interp` works on real data. .. ipython:: python # Raw data - ds = xr.tutorial.load_dataset('air_temperature') + ds = xr.tutorial.load_dataset('air_temperature').isel(time=0) fig, axes = plt.subplots(ncols=2, figsize=(10, 4)) - ds.air.isel(time=0).plot(ax=axes[0]) + ds.air.plot(ax=axes[0]) axes[0].set_title('Raw data') # Interpolated data new_lon = np.linspace(ds.lon[0], ds.lon[-1], ds.dims['lon'] * 4) new_lat = np.linspace(ds.lat[0], ds.lat[-1], ds.dims['lat'] * 4) dsi = ds.interp(lat=new_lat, lon=new_lon) - dsi.air.isel(time=0).plot(ax=axes[1]) + dsi.air.plot(ax=axes[1]) @savefig interpolation_sample3.png width=8in axes[1].set_title('Interpolated data') @@ -247,7 +247,7 @@ The remapping can be done as follows dims=['x', 'z'], coords={'x': x, 'z': z}) fig, axes = plt.subplots(ncols=2, figsize=(10, 4)) - ds.air.isel(time=0).plot(ax=axes[0]) + ds.air.plot(ax=axes[0]) # draw the new coordinate on the original coordinates. for idx in [0, 33, 66, 99]: axes[0].plot(lon.isel(x=idx), lat, '--k') @@ -256,6 +256,6 @@ The remapping can be done as follows axes[0].set_title('Raw data') dsi = ds.interp(lon=lon, lat=lat) - dsi.air.isel(time=0).plot(ax=axes[1]) + dsi.air.plot(ax=axes[1]) @savefig interpolation_sample4.png width=8in axes[1].set_title('Remapped data') From 43c189806264e15cbcae9a37d6f22e2b3e609348 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 7 Jun 2018 21:08:53 -0700 Subject: [PATCH 55/61] DOC: misc fixes to whats-new for 0.10.7 (#2221) --- doc/whats-new.rst | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 44f829874ac..94c3164247f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -35,18 +35,21 @@ Documentation Enhancements ~~~~~~~~~~~~ -- Plot labels now make use of metadata that follow CF conventions. + +- Plot labels now make use of metadata that follow CF conventions + (:issue:`2135`). By `Deepak Cherian `_ and `Ryan Abernathey `_. +- Line plots now support facetting with ``row`` and ``col`` arguments + (:issue:`2107`). + By `Yohai Bar Sinai `_. + - :py:meth:`~xarray.DataArray.interp` and :py:meth:`~xarray.Dataset.interp` methods are newly added. See :ref:`interpolating values with interp` for the detail. (:issue:`2079`) By `Keisuke Fujii `_. -- `:py:meth:`~DataArray.dot` and :py:func:`~dot` are partly supported with older - dask<0.17.4. (related to :issue:`2203`) - By `Keisuke Fujii Date: Thu, 7 Jun 2018 21:35:13 -0700 Subject: [PATCH 56/61] Release v0.10.7 --- doc/whats-new.rst | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 94c3164247f..4e5df777e3a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -27,11 +27,8 @@ What's New .. _whats-new.0.10.7: -v0.10.7 (unreleased) --------------------- - -Documentation -~~~~~~~~~~~~~ +v0.10.7 (7 June 2018) +--------------------- Enhancements ~~~~~~~~~~~~ From 6c3abedf906482111b06207b9016ea8493c42713 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 7 Jun 2018 21:38:07 -0700 Subject: [PATCH 57/61] Revert to dev version for v0.10.8 --- doc/whats-new.rst | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4e5df777e3a..0dc92fbce58 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,6 +25,20 @@ What's New - `Python 3 Statement `__ - `Tips on porting to Python 3 `__ +.. _whats-new.0.10.8: + +v0.10.8 (unreleased) +-------------------- + +Documentation +~~~~~~~~~~~~~ + +Enhancements +~~~~~~~~~~~~ + +Bug fixes +~~~~~~~~~ + .. _whats-new.0.10.7: v0.10.7 (7 June 2018) From 66be9c5db7d86ea385c3a4cd4295bfce67e3f25b Mon Sep 17 00:00:00 2001 From: Joe Hamman Date: Tue, 12 Jun 2018 22:51:35 -0700 Subject: [PATCH 58/61] fix zarr chunking bug (#2228) --- doc/whats-new.rst | 7 ++++++- xarray/backends/zarr.py | 30 ++++++++++++------------------ xarray/tests/test_backends.py | 11 +++++++++++ 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0dc92fbce58..5871b8bb0a3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -39,6 +39,10 @@ Enhancements Bug fixes ~~~~~~~~~ +- Fixed a bug in ``zarr`` backend which prevented use with datasets with + incomplete chunks in multiple dimensions (:issue:`2225`). + By `Joe Hamman `_. + .. _whats-new.0.10.7: v0.10.7 (7 June 2018) @@ -60,12 +64,13 @@ Enhancements See :ref:`interpolating values with interp` for the detail. (:issue:`2079`) By `Keisuke Fujii `_. - + Bug fixes ~~~~~~~~~ - Fixed a bug in ``rasterio`` backend which prevented use with ``distributed``. The ``rasterio`` backend now returns pickleable objects (:issue:`2021`). + By `Joe Hamman `_. .. _whats-new.0.10.6: diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 343690eaabd..c5043ce8a47 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -78,24 +78,18 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim): # while dask chunks can be variable sized # http://dask.pydata.org/en/latest/array-design.html#chunks if var_chunks and enc_chunks is None: - all_var_chunks = list(product(*var_chunks)) - first_var_chunk = all_var_chunks[0] - # all but the last chunk have to match exactly - for this_chunk in all_var_chunks[:-1]: - if this_chunk != first_var_chunk: - raise ValueError( - "Zarr requires uniform chunk sizes excpet for final chunk." - " Variable %r has incompatible chunks. Consider " - "rechunking using `chunk()`." % (var_chunks,)) - # last chunk is allowed to be smaller - last_var_chunk = all_var_chunks[-1] - for len_first, len_last in zip(first_var_chunk, last_var_chunk): - if len_last > len_first: - raise ValueError( - "Final chunk of Zarr array must be smaller than first. " - "Variable %r has incompatible chunks. Consider rechunking " - "using `chunk()`." % var_chunks) - return first_var_chunk + if any(len(set(chunks[:-1])) > 1 for chunks in var_chunks): + raise ValueError( + "Zarr requires uniform chunk sizes excpet for final chunk." + " Variable %r has incompatible chunks. Consider " + "rechunking using `chunk()`." % (var_chunks,)) + if any((chunks[0] < chunks[-1]) for chunks in var_chunks): + raise ValueError( + "Final chunk of Zarr array must be smaller than first. " + "Variable %r has incompatible chunks. Consider rechunking " + "using `chunk()`." % var_chunks) + # return the first chunk for each dimension + return tuple(chunk[0] for chunk in var_chunks) # from here on, we are dealing with user-specified chunks in encoding # zarr allows chunks to be an integer, in which case it uses the same chunk diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index df7ed66f4fd..e83b80a6dd8 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1330,6 +1330,17 @@ def test_auto_chunk(self): # chunk size should be the same as original self.assertEqual(v.chunks, original[k].chunks) + def test_write_uneven_dask_chunks(self): + # regression for GH#2225 + original = create_test_data().chunk({'dim1': 3, 'dim2': 4, 'dim3': 3}) + + with self.roundtrip( + original, open_kwargs={'auto_chunk': True}) as actual: + for k, v in actual.data_vars.items(): + print(k) + assert v.chunks == actual[k].chunks + + def test_chunk_encoding(self): # These datasets have no dask chunks. All chunking specified in # encoding From 59ad782f29a0f4766bac7802be6650be61f018b8 Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Wed, 20 Jun 2018 10:39:23 +0900 Subject: [PATCH 59/61] implement interp_like (#2222) * implement interp_like * flake8 * interp along datetime * Support datetime coordinate * Using reindex for object coordinate. * Update based on the comments --- doc/api.rst | 2 ++ doc/indexing.rst | 8 ++++++ doc/interpolation.rst | 27 +++++++++++++++++- doc/whats-new.rst | 6 ++++ xarray/core/dataarray.py | 48 +++++++++++++++++++++++++++++++ xarray/core/dataset.py | 57 +++++++++++++++++++++++++++++++++++++ xarray/core/missing.py | 22 ++++++++++++++ xarray/tests/test_interp.py | 48 +++++++++++++++++++++++++++++++ 8 files changed, 217 insertions(+), 1 deletion(-) diff --git a/doc/api.rst b/doc/api.rst index cb44ef82c8f..927c0aa072c 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -111,6 +111,7 @@ Indexing Dataset.sel Dataset.squeeze Dataset.interp + Dataset.interp_like Dataset.reindex Dataset.reindex_like Dataset.set_index @@ -265,6 +266,7 @@ Indexing DataArray.sel DataArray.squeeze DataArray.interp + DataArray.interp_like DataArray.reindex DataArray.reindex_like DataArray.set_index diff --git a/doc/indexing.rst b/doc/indexing.rst index a44e64e4079..c05bf9994fc 100644 --- a/doc/indexing.rst +++ b/doc/indexing.rst @@ -193,6 +193,14 @@ Indexing axes with monotonic decreasing labels also works, as long as the reversed_da.loc[3.1:0.9] +.. note:: + + If you want to interpolate along coordinates rather than looking up the + nearest neighbors, use :py:meth:`~xarray.Dataset.interp` and + :py:meth:`~xarray.Dataset.interp_like`. + See :ref:`interpolation ` for the details. + + Dataset indexing ---------------- diff --git a/doc/interpolation.rst b/doc/interpolation.rst index 98cd89afbd4..cd1c078fb2d 100644 --- a/doc/interpolation.rst +++ b/doc/interpolation.rst @@ -34,7 +34,7 @@ indexing of a :py:class:`~xarray.DataArray`, da.sel(time=3) # interpolation - da.interp(time=3.5) + da.interp(time=2.5) Similar to the indexing, :py:meth:`~xarray.DataArray.interp` also accepts an @@ -82,6 +82,31 @@ Array-like coordinates are also accepted: da.interp(time=[1.5, 2.5], space=[0.15, 0.25]) +:py:meth:`~xarray.DataArray.interp_like` method is a useful shortcut. This +method interpolates an xarray object onto the coordinates of another xarray +object. For example, if we want to compute the difference between +two :py:class:`~xarray.DataArray` s (``da`` and ``other``) staying on slightly +different coordinates, + +.. ipython:: python + + other = xr.DataArray(np.sin(0.4 * np.arange(9).reshape(3, 3)), + [('time', [0.9, 1.9, 2.9]), + ('space', [0.15, 0.25, 0.35])]) + +it might be a good idea to first interpolate ``da`` so that it will stay on the +same coordinates of ``other``, and then subtract it. +:py:meth:`~xarray.DataArray.interp_like` can be used for such a case, + +.. ipython:: python + + # interpolate da along other's coordinates + interpolated = da.interp_like(other) + interpolated + +It is now possible to safely compute the difference ``other - interpolated``. + + Interpolation methods --------------------- diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 5871b8bb0a3..55bd0d974f4 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,6 +36,12 @@ Documentation Enhancements ~~~~~~~~~~~~ +- :py:meth:`~xarray.DataArray.interp_like` and + :py:meth:`~xarray.Dataset.interp_like` methods are newly added. + (:issue:`2218`) + By `Keisuke Fujii `_. + + Bug fixes ~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 4129a3c5f26..35def72c64a 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -951,6 +951,54 @@ def interp(self, coords=None, method='linear', assume_sorted=False, **coords_kwargs) return self._from_temp_dataset(ds) + def interp_like(self, other, method='linear', assume_sorted=False, + kwargs={}): + """Interpolate this object onto the coordinates of another object, + filling out of range values with NaN. + + Parameters + ---------- + other : Dataset or DataArray + Object with an 'indexes' attribute giving a mapping from dimension + names to an 1d array-like, which provides coordinates upon + which to index the variables in this dataset. + method: string, optional. + {'linear', 'nearest'} for multidimensional array, + {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} + for 1-dimensional array. 'linear' is used by default. + assume_sorted: boolean, optional + If False, values of coordinates that are interpolated over can be + in any order and they are sorted first. If True, interpolated + coordinates are assumed to be an array of monotonically increasing + values. + kwargs: dictionary, optional + Additional keyword passed to scipy's interpolator. + + Returns + ------- + interpolated: xr.DataArray + Another dataarray by interpolating this dataarray's data along the + coordinates of the other object. + + Note + ---- + scipy is required. + If the dataarray has object-type coordinates, reindex is used for these + coordinates instead of the interpolation. + + See Also + -------- + DataArray.interp + DataArray.reindex_like + """ + if self.dtype.kind not in 'uifc': + raise TypeError('interp only works for a numeric type array. ' + 'Given {}.'.format(self.dtype)) + + ds = self._to_temp_dataset().interp_like( + other, method=method, kwargs=kwargs, assume_sorted=assume_sorted) + return self._from_temp_dataset(ds) + def rename(self, new_name_or_name_dict=None, **names): """Returns a new DataArray with renamed coordinates or a new name. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 90712c953da..8e039572237 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1893,6 +1893,63 @@ def maybe_variable(obj, k): .union(coord_vars)) return obj._replace_vars_and_dims(variables, coord_names=coord_names) + def interp_like(self, other, method='linear', assume_sorted=False, + kwargs={}): + """Interpolate this object onto the coordinates of another object, + filling the out of range values with NaN. + + Parameters + ---------- + other : Dataset or DataArray + Object with an 'indexes' attribute giving a mapping from dimension + names to an 1d array-like, which provides coordinates upon + which to index the variables in this dataset. + method: string, optional. + {'linear', 'nearest'} for multidimensional array, + {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} + for 1-dimensional array. 'linear' is used by default. + assume_sorted: boolean, optional + If False, values of coordinates that are interpolated over can be + in any order and they are sorted first. If True, interpolated + coordinates are assumed to be an array of monotonically increasing + values. + kwargs: dictionary, optional + Additional keyword passed to scipy's interpolator. + + Returns + ------- + interpolated: xr.Dataset + Another dataset by interpolating this dataset's data along the + coordinates of the other object. + + Note + ---- + scipy is required. + If the dataset has object-type coordinates, reindex is used for these + coordinates instead of the interpolation. + + See Also + -------- + Dataset.interp + Dataset.reindex_like + """ + coords = alignment.reindex_like_indexers(self, other) + + numeric_coords = OrderedDict() + object_coords = OrderedDict() + for k, v in coords.items(): + if v.dtype.kind in 'uifcMm': + numeric_coords[k] = v + else: + object_coords[k] = v + + ds = self + if object_coords: + # We do not support interpolation along object coordinate. + # reindex instead. + ds = self.reindex(object_coords) + return ds.interp(numeric_coords, method, assume_sorted, kwargs) + def rename(self, name_dict=None, inplace=False, **names): """Returns a new object with renamed variables and dimensions. diff --git a/xarray/core/missing.py b/xarray/core/missing.py index e10f37d58d8..743627bb381 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -395,6 +395,26 @@ def _localize(var, indexes_coords): return var.isel(**indexes), indexes_coords +def _floatize_x(x, new_x): + """ Make x and new_x float. + This is particulary useful for datetime dtype. + x, new_x: tuple of np.ndarray + """ + x = list(x) + new_x = list(new_x) + for i in range(len(x)): + if x[i].dtype.kind in 'Mm': + # Scipy casts coordinates to np.float64, which is not accurate + # enough for datetime64 (uses 64bit integer). + # We assume that the most of the bits are used to represent the + # offset (min(x)) and the variation (x - min(x)) can be + # represented by float. + xmin = np.min(x[i]) + x[i] = (x[i] - xmin).astype(np.float64) + new_x[i] = (new_x[i] - xmin).astype(np.float64) + return x, new_x + + def interp(var, indexes_coords, method, **kwargs): """ Make an interpolation of Variable @@ -523,6 +543,8 @@ def _interp1d(var, x, new_x, func, kwargs): def _interpnd(var, x, new_x, func, kwargs): + x, new_x = _floatize_x(x, new_x) + if len(x) == 1: return _interp1d(var, x, new_x, func, kwargs) diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 592854a4d1b..69a4644bc97 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function import numpy as np +import pandas as pd import pytest import xarray as xr @@ -430,3 +431,50 @@ def test_interpolate_dimorder(case): actual = da.interp(x=0.5, z=new_z).dims expected = da.sel(x=0.5, z=new_z, method='nearest').dims assert actual == expected + + +@requires_scipy +def test_interp_like(): + ds = create_test_data() + ds.attrs['foo'] = 'var' + ds['var1'].attrs['buz'] = 'var2' + + other = xr.DataArray(np.random.randn(3), dims=['dim2'], + coords={'dim2': [0, 1, 2]}) + interpolated = ds.interp_like(other) + + assert_allclose(interpolated['var1'], + ds['var1'].interp(dim2=other['dim2'])) + assert_allclose(interpolated['var1'], + ds['var1'].interp_like(other)) + assert interpolated['var3'].equals(ds['var3']) + + # attrs should be kept + assert interpolated.attrs['foo'] == 'var' + assert interpolated['var1'].attrs['buz'] == 'var2' + + other = xr.DataArray(np.random.randn(3), dims=['dim3'], + coords={'dim3': ['a', 'b', 'c']}) + + actual = ds.interp_like(other) + expected = ds.reindex_like(other) + assert_allclose(actual, expected) + + +@requires_scipy +def test_datetime(): + da = xr.DataArray(np.random.randn(24), dims='time', + coords={'time': pd.date_range('2000-01-01', periods=24)}) + + x_new = pd.date_range('2000-01-02', periods=3) + actual = da.interp(time=x_new) + expected = da.isel(time=[1, 2, 3]) + assert_allclose(actual, expected) + + x_new = np.array([np.datetime64('2000-01-01T12:00'), + np.datetime64('2000-01-02T12:00')]) + actual = da.interp(time=x_new) + assert_allclose(actual.isel(time=0).drop('time'), + 0.5 * (da.isel(time=0) + da.isel(time=1))) + assert_allclose(actual.isel(time=1).drop('time'), + 0.5 * (da.isel(time=1) + da.isel(time=2))) From 73b476e4db6631b2203954dd5b138cb650e4fb8c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 20 Jun 2018 09:26:36 -0700 Subject: [PATCH 60/61] Bugfix for faceting line plots. (#2229) * Bugfix for faceting line plots. * Make add_legend public and usable. --- xarray/plot/facetgrid.py | 16 +++++++++------- xarray/tests/test_plot.py | 6 ++++++ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 771f0879408..a0d7c4dd5e2 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -293,14 +293,16 @@ def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs): ax=ax, _labels=False, **kwargs) self._mappables.append(mappable) - _, _, _, xlabel, ylabel, huelabel = _infer_line_data( + _, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data( darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y, hue=hue) + self._hue_var = hueplt + self._hue_label = huelabel self._finalize_grid(xlabel, ylabel) - if add_legend and huelabel: - self.add_line_legend(huelabel) + if add_legend and hueplt is not None and huelabel is not None: + self.add_legend() return self @@ -314,12 +316,12 @@ def _finalize_grid(self, *axlabels): if namedict is None: ax.set_visible(False) - def add_line_legend(self, huelabel): + def add_legend(self, **kwargs): figlegend = self.fig.legend( handles=self._mappables[-1], - labels=list(self.data.coords[huelabel].values), - title=huelabel, - loc="center right") + labels=list(self._hue_var.values), + title=self._hue_label, + loc="center right", **kwargs) # Draw the plot to set the bounding boxes correctly self.fig.draw(self.fig.canvas.get_renderer()) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index cdb515ba92e..15729f25e22 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1529,6 +1529,12 @@ def setUp(self): range(3), ['A', 'B', 'C', 'C++']], name='Cornelius Ortega the 1st') + self.darray.hue.name = 'huename' + self.darray.hue.attrs['units'] = 'hunits' + self.darray.x.attrs['units'] = 'xunits' + self.darray.col.attrs['units'] = 'colunits' + self.darray.row.attrs['units'] = 'rowunits' + def test_facetgrid_shape(self): g = self.darray.plot(row='row', col='col', hue='hue') assert g.axes.shape == (len(self.darray.row), len(self.darray.col)) From 9491318e29b478234e6f96c3547d724504b4a1bb Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+maxim-lian@users.noreply.github.com> Date: Fri, 22 Jun 2018 18:00:05 -0400 Subject: [PATCH 61/61] no mode arg to open_zarr (#2246) --- doc/io.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/io.rst b/doc/io.rst index 7f7e7a2a66a..093ee773e15 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -603,7 +603,7 @@ pass to xarray:: # write to the bucket ds.to_zarr(store=gcsmap) # read it back - ds_gcs = xr.open_zarr(gcsmap, mode='r') + ds_gcs = xr.open_zarr(gcsmap) .. _Zarr: http://zarr.readthedocs.io/ .. _Amazon S3: https://aws.amazon.com/s3/