diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 07fe7ef6256..b9a8d773c5a 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -22,7 +22,7 @@ jobs: fetch-depth: 0 - name: Set up conda environment - uses: mamba-org/provision-with-micromamba@v14 + uses: mamba-org/provision-with-micromamba@v15 with: environment-file: ${{env.CONDA_ENV_FILE}} environment-name: xarray-tests diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index bdae56ae6db..a2e3734f534 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -53,7 +53,7 @@ jobs: echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV - name: Setup micromamba - uses: mamba-org/provision-with-micromamba@v14 + uses: mamba-org/provision-with-micromamba@v15 with: environment-file: ${{env.CONDA_ENV_FILE}} environment-name: xarray-tests @@ -100,7 +100,7 @@ jobs: run: | echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV - name: Setup micromamba - uses: mamba-org/provision-with-micromamba@v14 + uses: mamba-org/provision-with-micromamba@v15 with: environment-file: ${{env.CONDA_ENV_FILE}} environment-name: xarray-tests @@ -156,7 +156,7 @@ jobs: run: | echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV - name: Setup micromamba - uses: mamba-org/provision-with-micromamba@v14 + uses: mamba-org/provision-with-micromamba@v15 with: environment-file: ${{env.CONDA_ENV_FILE}} environment-name: xarray-tests @@ -212,7 +212,7 @@ jobs: fetch-depth: 0 # Fetch all history for all branches and tags. - name: Setup micromamba - uses: mamba-org/provision-with-micromamba@v14 + uses: mamba-org/provision-with-micromamba@v15 with: environment-name: xarray-tests environment-file: false diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 64aae44e56c..2d190efc14c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -42,7 +42,7 @@ jobs: matrix: os: ["ubuntu-latest", "macos-latest", "windows-latest"] # Bookend python versions - python-version: ["3.8", "3.10"] + python-version: ["3.8", "3.10", "3.11"] env: [""] include: # Minimum python version: @@ -67,7 +67,13 @@ jobs: run: | echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV - if [[ ${{ matrix.os }} == windows* ]] ; + if [[ "${{matrix.python-version}}" == "3.11" ]]; then + if [[ ${{matrix.os}} == windows* ]]; then + echo "CONDA_ENV_FILE=ci/requirements/environment-windows-py311.yml" >> $GITHUB_ENV + else + echo "CONDA_ENV_FILE=ci/requirements/environment-py311.yml" >> $GITHUB_ENV + fi + elif [[ ${{ matrix.os }} == windows* ]] ; then echo "CONDA_ENV_FILE=ci/requirements/environment-windows.yml" >> $GITHUB_ENV elif [[ "${{ matrix.env }}" != "" ]] ; @@ -86,7 +92,7 @@ jobs: echo "PYTHON_VERSION=${{ matrix.python-version }}" >> $GITHUB_ENV - name: Setup micromamba - uses: mamba-org/provision-with-micromamba@v14 + uses: mamba-org/provision-with-micromamba@v15 with: environment-file: ${{ env.CONDA_ENV_FILE }} environment-name: xarray-tests diff --git a/.github/workflows/upstream-dev-ci.yaml b/.github/workflows/upstream-dev-ci.yaml index 68bd0c15067..c7ccee73414 100644 --- a/.github/workflows/upstream-dev-ci.yaml +++ b/.github/workflows/upstream-dev-ci.yaml @@ -54,7 +54,7 @@ jobs: with: fetch-depth: 0 # Fetch all history for all branches and tags. - name: Set up conda environment - uses: mamba-org/provision-with-micromamba@v14 + uses: mamba-org/provision-with-micromamba@v15 with: environment-file: ci/requirements/environment.yml environment-name: xarray-tests diff --git a/.readthedocs.yaml b/.readthedocs.yaml index cf263cc666a..db2e1cd0b9a 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,9 +1,14 @@ version: 2 build: - os: ubuntu-20.04 + os: ubuntu-22.04 tools: python: mambaforge-4.10 + jobs: + post_checkout: + - (git --no-pager log --pretty="tformat:%s" -1 | grep -vqF "[skip-rtd]") || exit 183 + pre_install: + - git update-index --assume-unchanged doc/conf.py ci/requirements/doc.yml conda: environment: ci/requirements/doc.yml diff --git a/HOW_TO_RELEASE.md b/HOW_TO_RELEASE.md index f647263a3a7..b59212f15ad 100644 --- a/HOW_TO_RELEASE.md +++ b/HOW_TO_RELEASE.md @@ -92,7 +92,15 @@ upstream https://github.com/pydata/xarray (push) ``` You're done pushing to main! -13. Issue the release announcement to mailing lists & Twitter. For bug fix releases, I +13. Update the version available on pyodide: + - Open the PyPI page for [Xarray downloads](https://pypi.org/project/xarray/#files) + - Clone the [pyodide repository](https://github.com/pyodide/pyodide). + - Edit `packages/xarray/meta.yaml` to update the + - link to the wheel (under "Built Distribution" on the PyPI page) + - SHA256 hash (Click "Show Hashes" next to the link to the wheel) + - Open a pull request to pyodide + +14. Issue the release announcement to mailing lists & Twitter. For bug fix releases, I usually only email xarray@googlegroups.com. For major/feature releases, I will email a broader list (no more than once every 3-6 months): - pydata@googlegroups.com diff --git a/asv_bench/benchmarks/pandas.py b/asv_bench/benchmarks/pandas.py index 8aaa515d417..2a296ecc4d0 100644 --- a/asv_bench/benchmarks/pandas.py +++ b/asv_bench/benchmarks/pandas.py @@ -3,7 +3,7 @@ import xarray as xr -from . import parameterized +from . import parameterized, requires_dask class MultiIndexSeries: @@ -24,3 +24,38 @@ def setup(self, dtype, subset): @parameterized(["dtype", "subset"], ([int, float], [True, False])) def time_from_series(self, dtype, subset): xr.DataArray.from_series(self.series) + + +class ToDataFrame: + def setup(self, *args, **kwargs): + xp = kwargs.get("xp", np) + random_kws = kwargs.get("random_kws", {}) + method = kwargs.get("method", "to_dataframe") + + dim1 = 10_000 + dim2 = 10_000 + ds = xr.Dataset( + { + "x": xr.DataArray( + data=xp.random.random((dim1, dim2), **random_kws), + dims=["dim1", "dim2"], + coords={"dim1": np.arange(0, dim1), "dim2": np.arange(0, dim2)}, + ) + } + ) + self.to_frame = getattr(ds, method) + + def time_to_dataframe(self): + self.to_frame() + + def peakmem_to_dataframe(self): + self.to_frame() + + +class ToDataFrameDask(ToDataFrame): + def setup(self, *args, **kwargs): + requires_dask() + + import dask.array as da + + super().setup(xp=da, random_kws=dict(chunks=5000), method="to_dask_dataframe") diff --git a/ci/requirements/all-but-dask.yml b/ci/requirements/all-but-dask.yml index d30c94348d0..ce819640c76 100644 --- a/ci/requirements/all-but-dask.yml +++ b/ci/requirements/all-but-dask.yml @@ -23,7 +23,8 @@ dependencies: - nc-time-axis - netcdf4 - numba - - numpy + - numbagg + - numpy<1.24 - packaging - pandas - pint @@ -41,5 +42,3 @@ dependencies: - toolz - typing_extensions - zarr - - pip: - - numbagg diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index 5b39a4d7d19..35fdcb8cdb7 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -18,7 +18,7 @@ dependencies: - nbsphinx - netcdf4>=1.5 - numba - - numpy>=1.20 + - numpy>=1.20,<1.24 - packaging>=21.0 - pandas>=1.3 - pooch diff --git a/ci/requirements/environment-py311.yml b/ci/requirements/environment-py311.yml new file mode 100644 index 00000000000..e23fa44c683 --- /dev/null +++ b/ci/requirements/environment-py311.yml @@ -0,0 +1,48 @@ +name: xarray-tests +channels: + - conda-forge + - nodefaults +dependencies: + - aiobotocore + - boto3 + - bottleneck + - cartopy + # - cdms2 + - cfgrib + - cftime + - dask-core + - distributed + - flox + - fsspec!=2021.7.0 + - h5netcdf + - h5py + - hdf5 + - hypothesis + - iris + - lxml # Optional dep of pydap + - matplotlib-base + - nc-time-axis + - netcdf4 + # - numba + # - numbagg + - numexpr + - numpy + - packaging + - pandas + - pint + - pip + - pooch + - pre-commit + - pseudonetcdf + - pydap + - pytest + - pytest-cov + - pytest-env + - pytest-xdist + - rasterio + - scipy + - seaborn + # - sparse + - toolz + - typing_extensions + - zarr diff --git a/ci/requirements/environment-windows-py311.yml b/ci/requirements/environment-windows-py311.yml new file mode 100644 index 00000000000..3fc207dc609 --- /dev/null +++ b/ci/requirements/environment-windows-py311.yml @@ -0,0 +1,44 @@ +name: xarray-tests +channels: + - conda-forge +dependencies: + - boto3 + - bottleneck + - cartopy + # - cdms2 # Not available on Windows + # - cfgrib # Causes Python interpreter crash on Windows: https://github.com/pydata/xarray/pull/3340 + - cftime + - dask-core + - distributed + - flox + - fsspec!=2021.7.0 + - h5netcdf + - h5py + - hdf5 + - hypothesis + - iris + - lxml # Optional dep of pydap + - matplotlib-base + - nc-time-axis + - netcdf4 + # - numba + # - numbagg + - numpy + - packaging + - pandas + - pint + - pip + - pre-commit + - pseudonetcdf + - pydap + - pytest + - pytest-cov + - pytest-env + - pytest-xdist + - rasterio + - scipy + - seaborn + # - sparse + - toolz + - typing_extensions + - zarr diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 22bfa3543d3..0941af474f7 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -22,7 +22,8 @@ dependencies: - nc-time-axis - netcdf4 - numba - - numpy + - numbagg + - numpy<1.24 - packaging - pandas - pint @@ -41,5 +42,3 @@ dependencies: - toolz - typing_extensions - zarr - - pip: - - numbagg diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index d6bc8466c76..e87e69138ee 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -24,8 +24,9 @@ dependencies: - nc-time-axis - netcdf4 - numba + - numbagg - numexpr - - numpy + - numpy<1.24 - packaging - pandas - pint @@ -45,5 +46,3 @@ dependencies: - toolz - typing_extensions - zarr - - pip: - - numbagg diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8af7a258f5a..b0f6a07841b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -15,9 +15,9 @@ What's New np.random.seed(123456) -.. _whats-new.2022.12.1: +.. _whats-new.2023.01.1: -v2022.12.1 (unreleased) +v2023.01.1 (unreleased) ----------------------- New Features @@ -27,14 +27,42 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ -- :py:meth:`CFTimeIndex.get_loc` has removed the ``method`` and ``tolerance`` keyword arguments. - Use ``.get_indexer([key], method=..., tolerance=...)`` instead (:pull:`7361`). - By `Matthew Roeschke `_. Deprecations ~~~~~~~~~~~~ +Bug fixes +~~~~~~~~~ + +- :py:func:`xarray.concat` can now concatenate variables present in some datasets but + not others (:issue:`508`, :pull:`7400`). + By `Kai Mühlbauer `_ and `Scott Chamberlin `_. + +Documentation +~~~~~~~~~~~~~ + + +Internal Changes +~~~~~~~~~~~~~~~~ + + +.. _whats-new.2023.01.0: + +v2023.01.0 (Jan 17, 2023) +------------------------- + +This release includes a number of bug fixes. Thanks to the 14 contributors to this release: +Aron Gergely, Benoit Bovy, Deepak Cherian, Ian Carroll, Illviljan, Joe Hamman, Justus Magin, Mark Harfouche, +Matthew Roeschke, Paige Martin, Pierre, Sam Levang, Tom White, stefank0. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- :py:meth:`CFTimeIndex.get_loc` has removed the ``method`` and ``tolerance`` keyword arguments. + Use ``.get_indexer([key], method=..., tolerance=...)`` instead (:pull:`7361`). + By `Matthew Roeschke `_. + Bug fixes ~~~~~~~~~ @@ -51,10 +79,6 @@ Bug fixes - Preserve original dtype on accessing MultiIndex levels (:issue:`7250`, :pull:`7393`). By `Ian Carroll `_. -Documentation -~~~~~~~~~~~~~ - - Internal Changes ~~~~~~~~~~~~~~~~ - Add the pre-commit hook `absolufy-imports` to convert relative xarray imports to diff --git a/setup.cfg b/setup.cfg index 7919908e8ec..70b810307be 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,6 +67,7 @@ classifiers = Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 Topic :: Scientific/Engineering [options] diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 2eea2ecb3ee..95a2199f9c5 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -5,7 +5,7 @@ import pandas as pd from xarray.core import dtypes, utils -from xarray.core.alignment import align +from xarray.core.alignment import align, reindex_variables from xarray.core.duck_array_ops import lazy_array_equiv from xarray.core.indexes import Index, PandasIndex from xarray.core.merge import ( @@ -378,7 +378,9 @@ def process_subset_opt(opt, subset): elif opt == "all": concat_over.update( - set(getattr(datasets[0], subset)) - set(datasets[0].dims) + set().union( + *list(set(getattr(d, subset)) - set(d.dims) for d in datasets) + ) ) elif opt == "minimal": pass @@ -406,19 +408,26 @@ def process_subset_opt(opt, subset): # determine dimensional coordinate names and a dict mapping name to DataArray def _parse_datasets( - datasets: Iterable[T_Dataset], -) -> tuple[dict[Hashable, Variable], dict[Hashable, int], set[Hashable], set[Hashable]]: - + datasets: list[T_Dataset], +) -> tuple[ + dict[Hashable, Variable], + dict[Hashable, int], + set[Hashable], + set[Hashable], + list[Hashable], +]: dims: set[Hashable] = set() all_coord_names: set[Hashable] = set() data_vars: set[Hashable] = set() # list of data_vars dim_coords: dict[Hashable, Variable] = {} # maps dim name to variable dims_sizes: dict[Hashable, int] = {} # shared dimension sizes to expand variables + variables_order: dict[Hashable, Variable] = {} # variables in order of appearance for ds in datasets: dims_sizes.update(ds.dims) all_coord_names.update(ds.coords) data_vars.update(ds.data_vars) + variables_order.update(ds.variables) # preserves ordering of dimensions for dim in ds.dims: @@ -429,7 +438,7 @@ def _parse_datasets( dim_coords[dim] = ds.coords[dim].variable dims = dims | set(ds.dims) - return dim_coords, dims_sizes, all_coord_names, data_vars + return dim_coords, dims_sizes, all_coord_names, data_vars, list(variables_order) def _dataset_concat( @@ -439,7 +448,7 @@ def _dataset_concat( coords: str | list[str], compat: CompatOptions, positions: Iterable[Iterable[int]] | None, - fill_value: object = dtypes.NA, + fill_value: Any = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", ) -> T_Dataset: @@ -471,7 +480,9 @@ def _dataset_concat( align(*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value) ) - dim_coords, dims_sizes, coord_names, data_names = _parse_datasets(datasets) + dim_coords, dims_sizes, coord_names, data_names, vars_order = _parse_datasets( + datasets + ) dim_names = set(dim_coords) unlabeled_dims = dim_names - coord_names @@ -525,7 +536,7 @@ def _dataset_concat( # we've already verified everything is consistent; now, calculate # shared dimension sizes so we can expand the necessary variables - def ensure_common_dims(vars): + def ensure_common_dims(vars, concat_dim_lengths): # ensure each variable with the given name shares the same # dimensions and the same shape for all of them except along the # concat dimension @@ -553,16 +564,35 @@ def get_indexes(name): data = var.set_dims(dim).values yield PandasIndex(data, dim, coord_dtype=var.dtype) + # create concatenation index, needed for later reindexing + concat_index = list(range(sum(concat_dim_lengths))) + # stack up each variable and/or index to fill-out the dataset (in order) # n.b. this loop preserves variable order, needed for groupby. - for name in datasets[0].variables: + for name in vars_order: if name in concat_over and name not in result_indexes: - try: - vars = ensure_common_dims([ds[name].variable for ds in datasets]) - except KeyError: - raise ValueError(f"{name!r} is not present in all datasets.") - - # Try concatenate the indexes, concatenate the variables when no index + variables = [] + variable_index = [] + var_concat_dim_length = [] + for i, ds in enumerate(datasets): + if name in ds.variables: + variables.append(ds[name].variable) + # add to variable index, needed for reindexing + var_idx = [ + sum(concat_dim_lengths[:i]) + k + for k in range(concat_dim_lengths[i]) + ] + variable_index.extend(var_idx) + var_concat_dim_length.append(len(var_idx)) + else: + # raise if coordinate not in all datasets + if name in coord_names: + raise ValueError( + f"coordinate {name!r} not present in all datasets." + ) + vars = ensure_common_dims(variables, var_concat_dim_length) + + # Try to concatenate the indexes, concatenate the variables when no index # is found on all datasets. indexes: list[Index] = list(get_indexes(name)) if indexes: @@ -589,6 +619,15 @@ def get_indexes(name): combined_var = concat_vars( vars, dim, positions, combine_attrs=combine_attrs ) + # reindex if variable is not present in all datasets + if len(variable_index) < len(concat_index): + combined_var = reindex_variables( + variables={name: combined_var}, + dim_pos_indexers={ + dim: pd.Index(variable_index).get_indexer(concat_index) + }, + fill_value=fill_value, + )[name] result_vars[name] = combined_var elif name in result_vars: diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index e0e0038cd89..eb3cd7fccc8 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -1,7 +1,7 @@ from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable import numpy as np import pandas as pd @@ -23,6 +23,89 @@ from xarray.core.types import CombineAttrsOptions, JoinOptions +# helper method to create multiple tests datasets to concat +def create_concat_datasets( + num_datasets: int = 2, seed: int | None = None, include_day: bool = True +) -> list[Dataset]: + rng = np.random.default_rng(seed) + lat = rng.standard_normal(size=(1, 4)) + lon = rng.standard_normal(size=(1, 4)) + result = [] + variables = ["temperature", "pressure", "humidity", "precipitation", "cloud_cover"] + for i in range(num_datasets): + if include_day: + data_tuple = ( + ["x", "y", "day"], + rng.standard_normal(size=(1, 4, 2)), + ) + data_vars = {v: data_tuple for v in variables} + result.append( + Dataset( + data_vars=data_vars, + coords={ + "lat": (["x", "y"], lat), + "lon": (["x", "y"], lon), + "day": ["day" + str(i * 2 + 1), "day" + str(i * 2 + 2)], + }, + ) + ) + else: + data_tuple = ( + ["x", "y"], + rng.standard_normal(size=(1, 4)), + ) + data_vars = {v: data_tuple for v in variables} + result.append( + Dataset( + data_vars=data_vars, + coords={"lat": (["x", "y"], lat), "lon": (["x", "y"], lon)}, + ) + ) + + return result + + +# helper method to create multiple tests datasets to concat with specific types +def create_typed_datasets( + num_datasets: int = 2, seed: int | None = None +) -> list[Dataset]: + var_strings = ["a", "b", "c", "d", "e", "f", "g", "h"] + result = [] + rng = np.random.default_rng(seed) + lat = rng.standard_normal(size=(1, 4)) + lon = rng.standard_normal(size=(1, 4)) + for i in range(num_datasets): + result.append( + Dataset( + data_vars={ + "float": (["x", "y", "day"], rng.standard_normal(size=(1, 4, 2))), + "float2": (["x", "y", "day"], rng.standard_normal(size=(1, 4, 2))), + "string": ( + ["x", "y", "day"], + rng.choice(var_strings, size=(1, 4, 2)), + ), + "int": (["x", "y", "day"], rng.integers(0, 10, size=(1, 4, 2))), + "datetime64": ( + ["x", "y", "day"], + np.arange( + np.datetime64("2017-01-01"), np.datetime64("2017-01-09") + ).reshape(1, 4, 2), + ), + "timedelta64": ( + ["x", "y", "day"], + np.reshape([pd.Timedelta(days=i) for i in range(8)], [1, 4, 2]), + ), + }, + coords={ + "lat": (["x", "y"], lat), + "lon": (["x", "y"], lon), + "day": ["day" + str(i * 2 + 1), "day" + str(i * 2 + 2)], + }, + ) + ) + return result + + def test_concat_compat() -> None: ds1 = Dataset( { @@ -46,14 +129,324 @@ def test_concat_compat() -> None: for var in ["has_x", "no_x_y"]: assert "y" not in result[var].dims and "y" not in result[var].coords - with pytest.raises( - ValueError, match=r"coordinates in some datasets but not others" - ): + with pytest.raises(ValueError, match=r"'q' not present in all datasets"): concat([ds1, ds2], dim="q") - with pytest.raises(ValueError, match=r"'q' is not present in all datasets"): + with pytest.raises(ValueError, match=r"'q' not present in all datasets"): concat([ds2, ds1], dim="q") +def test_concat_missing_var() -> None: + datasets = create_concat_datasets(2, seed=123) + expected = concat(datasets, dim="day") + vars_to_drop = ["humidity", "precipitation", "cloud_cover"] + + expected = expected.drop_vars(vars_to_drop) + expected["pressure"][..., 2:] = np.nan + + datasets[0] = datasets[0].drop_vars(vars_to_drop) + datasets[1] = datasets[1].drop_vars(vars_to_drop + ["pressure"]) + actual = concat(datasets, dim="day") + + assert list(actual.data_vars.keys()) == ["temperature", "pressure"] + assert_identical(actual, expected) + + +def test_concat_missing_multiple_consecutive_var() -> None: + datasets = create_concat_datasets(3, seed=123) + expected = concat(datasets, dim="day") + vars_to_drop = ["humidity", "pressure"] + + expected["pressure"][..., :4] = np.nan + expected["humidity"][..., :4] = np.nan + + datasets[0] = datasets[0].drop_vars(vars_to_drop) + datasets[1] = datasets[1].drop_vars(vars_to_drop) + actual = concat(datasets, dim="day") + + assert list(actual.data_vars.keys()) == [ + "temperature", + "precipitation", + "cloud_cover", + "pressure", + "humidity", + ] + assert_identical(actual, expected) + + +def test_concat_all_empty() -> None: + ds1 = Dataset() + ds2 = Dataset() + expected = Dataset() + actual = concat([ds1, ds2], dim="new_dim") + + assert_identical(actual, expected) + + +def test_concat_second_empty() -> None: + ds1 = Dataset(data_vars={"a": ("y", [0.1])}, coords={"x": 0.1}) + ds2 = Dataset(coords={"x": 0.1}) + + expected = Dataset(data_vars={"a": ("y", [0.1, np.nan])}, coords={"x": 0.1}) + actual = concat([ds1, ds2], dim="y") + assert_identical(actual, expected) + + expected = Dataset( + data_vars={"a": ("y", [0.1, np.nan])}, coords={"x": ("y", [0.1, 0.1])} + ) + actual = concat([ds1, ds2], dim="y", coords="all") + assert_identical(actual, expected) + + # Check concatenating scalar data_var only present in ds1 + ds1["b"] = 0.1 + expected = Dataset( + data_vars={"a": ("y", [0.1, np.nan]), "b": ("y", [0.1, np.nan])}, + coords={"x": ("y", [0.1, 0.1])}, + ) + actual = concat([ds1, ds2], dim="y", coords="all", data_vars="all") + assert_identical(actual, expected) + + expected = Dataset( + data_vars={"a": ("y", [0.1, np.nan]), "b": 0.1}, coords={"x": 0.1} + ) + actual = concat([ds1, ds2], dim="y", coords="different", data_vars="different") + assert_identical(actual, expected) + + +def test_concat_multiple_missing_variables() -> None: + datasets = create_concat_datasets(2, seed=123) + expected = concat(datasets, dim="day") + vars_to_drop = ["pressure", "cloud_cover"] + + expected["pressure"][..., 2:] = np.nan + expected["cloud_cover"][..., 2:] = np.nan + + datasets[1] = datasets[1].drop_vars(vars_to_drop) + actual = concat(datasets, dim="day") + + # check the variables orders are the same + assert list(actual.data_vars.keys()) == [ + "temperature", + "pressure", + "humidity", + "precipitation", + "cloud_cover", + ] + + assert_identical(actual, expected) + + +@pytest.mark.parametrize("include_day", [True, False]) +def test_concat_multiple_datasets_missing_vars(include_day: bool) -> None: + vars_to_drop = [ + "temperature", + "pressure", + "humidity", + "precipitation", + "cloud_cover", + ] + + datasets = create_concat_datasets( + len(vars_to_drop), seed=123, include_day=include_day + ) + expected = concat(datasets, dim="day") + + for i, name in enumerate(vars_to_drop): + if include_day: + expected[name][..., i * 2 : (i + 1) * 2] = np.nan + else: + expected[name][i : i + 1, ...] = np.nan + + # set up the test data + datasets = [ds.drop_vars(varname) for ds, varname in zip(datasets, vars_to_drop)] + + actual = concat(datasets, dim="day") + + assert list(actual.data_vars.keys()) == [ + "pressure", + "humidity", + "precipitation", + "cloud_cover", + "temperature", + ] + assert_identical(actual, expected) + + +def test_concat_multiple_datasets_with_multiple_missing_variables() -> None: + vars_to_drop_in_first = ["temperature", "pressure"] + vars_to_drop_in_second = ["humidity", "precipitation", "cloud_cover"] + datasets = create_concat_datasets(2, seed=123) + expected = concat(datasets, dim="day") + for name in vars_to_drop_in_first: + expected[name][..., :2] = np.nan + for name in vars_to_drop_in_second: + expected[name][..., 2:] = np.nan + + # set up the test data + datasets[0] = datasets[0].drop_vars(vars_to_drop_in_first) + datasets[1] = datasets[1].drop_vars(vars_to_drop_in_second) + + actual = concat(datasets, dim="day") + + assert list(actual.data_vars.keys()) == [ + "humidity", + "precipitation", + "cloud_cover", + "temperature", + "pressure", + ] + assert_identical(actual, expected) + + +def test_concat_type_of_missing_fill() -> None: + datasets = create_typed_datasets(2, seed=123) + expected1 = concat(datasets, dim="day", fill_value=dtypes.NA) + expected2 = concat(datasets[::-1], dim="day", fill_value=dtypes.NA) + vars = ["float", "float2", "string", "int", "datetime64", "timedelta64"] + expected = [expected2, expected1] + for i, exp in enumerate(expected): + sl = slice(i * 2, (i + 1) * 2) + exp["float2"][..., sl] = np.nan + exp["datetime64"][..., sl] = np.nan + exp["timedelta64"][..., sl] = np.nan + var = exp["int"] * 1.0 + var[..., sl] = np.nan + exp["int"] = var + var = exp["string"].astype(object) + var[..., sl] = np.nan + exp["string"] = var + + # set up the test data + datasets[1] = datasets[1].drop_vars(vars[1:]) + + actual = concat(datasets, dim="day", fill_value=dtypes.NA) + + assert_identical(actual, expected[1]) + + # reversed + actual = concat(datasets[::-1], dim="day", fill_value=dtypes.NA) + + assert_identical(actual, expected[0]) + + +def test_concat_order_when_filling_missing() -> None: + vars_to_drop_in_first: list[str] = [] + # drop middle + vars_to_drop_in_second = ["humidity"] + datasets = create_concat_datasets(2, seed=123) + expected1 = concat(datasets, dim="day") + for name in vars_to_drop_in_second: + expected1[name][..., 2:] = np.nan + expected2 = concat(datasets[::-1], dim="day") + for name in vars_to_drop_in_second: + expected2[name][..., :2] = np.nan + + # set up the test data + datasets[0] = datasets[0].drop_vars(vars_to_drop_in_first) + datasets[1] = datasets[1].drop_vars(vars_to_drop_in_second) + + actual = concat(datasets, dim="day") + + assert list(actual.data_vars.keys()) == [ + "temperature", + "pressure", + "humidity", + "precipitation", + "cloud_cover", + ] + assert_identical(actual, expected1) + + actual = concat(datasets[::-1], dim="day") + + assert list(actual.data_vars.keys()) == [ + "temperature", + "pressure", + "precipitation", + "cloud_cover", + "humidity", + ] + assert_identical(actual, expected2) + + +@pytest.fixture +def concat_var_names() -> Callable: + # create var names list with one missing value + def get_varnames(var_cnt: int = 10, list_cnt: int = 10) -> list[list[str]]: + orig = [f"d{i:02d}" for i in range(var_cnt)] + var_names = [] + for i in range(0, list_cnt): + l1 = orig.copy() + var_names.append(l1) + return var_names + + return get_varnames + + +@pytest.fixture +def create_concat_ds() -> Callable: + def create_ds( + var_names: list[list[str]], + dim: bool = False, + coord: bool = False, + drop_idx: list[int] | None = None, + ) -> list[Dataset]: + out_ds = [] + ds = Dataset() + ds = ds.assign_coords({"x": np.arange(2)}) + ds = ds.assign_coords({"y": np.arange(3)}) + ds = ds.assign_coords({"z": np.arange(4)}) + for i, dsl in enumerate(var_names): + vlist = dsl.copy() + if drop_idx is not None: + vlist.pop(drop_idx[i]) + foo_data = np.arange(48, dtype=float).reshape(2, 2, 3, 4) + dsi = ds.copy() + if coord: + dsi = ds.assign({"time": (["time"], [i * 2, i * 2 + 1])}) + for k in vlist: + dsi = dsi.assign({k: (["time", "x", "y", "z"], foo_data.copy())}) + if not dim: + dsi = dsi.isel(time=0) + out_ds.append(dsi) + return out_ds + + return create_ds + + +@pytest.mark.parametrize("dim", [True, False]) +@pytest.mark.parametrize("coord", [True, False]) +def test_concat_fill_missing_variables( + concat_var_names, create_concat_ds, dim: bool, coord: bool +) -> None: + var_names = concat_var_names() + drop_idx = [0, 7, 6, 4, 4, 8, 0, 6, 2, 0] + + expected = concat( + create_concat_ds(var_names, dim=dim, coord=coord), dim="time", data_vars="all" + ) + for i, idx in enumerate(drop_idx): + if dim: + expected[var_names[0][idx]][i * 2 : i * 2 + 2] = np.nan + else: + expected[var_names[0][idx]][i] = np.nan + + concat_ds = create_concat_ds(var_names, dim=dim, coord=coord, drop_idx=drop_idx) + actual = concat(concat_ds, dim="time", data_vars="all") + + assert list(actual.data_vars.keys()) == [ + "d01", + "d02", + "d03", + "d04", + "d05", + "d06", + "d07", + "d08", + "d09", + "d00", + ] + assert_identical(actual, expected) + + class TestConcatDataset: @pytest.fixture def data(self) -> Dataset: @@ -86,10 +479,17 @@ def test_concat_merge_variables_present_in_some_datasets(self, data) -> None: split_data = [data.isel(dim1=slice(3)), data.isel(dim1=slice(3, None))] data0, data1 = deepcopy(split_data) data1["foo"] = ("bar", np.random.randn(10)) - actual = concat([data0, data1], "dim1") + actual = concat([data0, data1], "dim1", data_vars="minimal") expected = data.copy().assign(foo=data1.foo) assert_identical(expected, actual) + # expand foo + actual = concat([data0, data1], "dim1") + foo = np.ones((8, 10), dtype=data1.foo.dtype) * np.nan + foo[3:] = data1.foo.values[None, ...] + expected = data.copy().assign(foo=(["dim1", "bar"], foo)) + assert_identical(expected, actual) + def test_concat_2(self, data) -> None: dim = "dim2" datasets = [g for _, g in data.groupby(dim, squeeze=True)] @@ -776,7 +1176,7 @@ def test_concat_merge_single_non_dim_coord(): actual = concat([da1, da2], "x", coords=coords) assert_identical(actual, expected) - with pytest.raises(ValueError, match=r"'y' is not present in all datasets."): + with pytest.raises(ValueError, match=r"'y' not present in all datasets."): concat([da1, da2], dim="x", coords="all") da1 = DataArray([1, 2, 3], dims="x", coords={"x": [1, 2, 3], "y": 1}) @@ -784,7 +1184,7 @@ def test_concat_merge_single_non_dim_coord(): da3 = DataArray([7, 8, 9], dims="x", coords={"x": [7, 8, 9], "y": 1}) for coords in ["different", "all"]: with pytest.raises(ValueError, match=r"'y' not present in all datasets"): - concat([da1, da2, da3], dim="x") + concat([da1, da2, da3], dim="x", coords=coords) def test_concat_preserve_coordinate_order() -> None: