diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 2eea2ecb3ee..d709b296e28 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -378,7 +378,9 @@ def process_subset_opt(opt, subset): elif opt == "all": concat_over.update( - set(getattr(datasets[0], subset)) - set(datasets[0].dims) + set().union( + *list((set(getattr(d, subset)) - set(d.dims) for d in datasets)) + ) ) elif opt == "minimal": pass @@ -553,16 +555,35 @@ def get_indexes(name): data = var.set_dims(dim).values yield PandasIndex(data, dim, coord_dtype=var.dtype) + # preserve variable order for variables in first dataset + data_var_order = list(datasets[0].variables) + # append additional variables to the end + data_var_order += [e for e in data_names if e not in data_var_order] + # create concatenation index, needed for later reindexing + concat_index = list(range(sum(concat_dim_lengths))) + # stack up each variable and/or index to fill-out the dataset (in order) # n.b. this loop preserves variable order, needed for groupby. - for name in datasets[0].variables: + for name in data_var_order: if name in concat_over and name not in result_indexes: - try: - vars = ensure_common_dims([ds[name].variable for ds in datasets]) - except KeyError: - raise ValueError(f"{name!r} is not present in all datasets.") - - # Try concatenate the indexes, concatenate the variables when no index + variables = [] + variable_index = [] + for i, ds in enumerate(datasets): + if name in ds.variables: + variables.append(ds.variables[name]) + # add to variable index, needed for reindexing + variable_index.extend( + [sum(concat_dim_lengths[:i]) + k for k in range(concat_dim_lengths[i])] + ) + else: + # raise if coordinate not in all datasets + if name in coord_names: + raise ValueError( + f"coordinate {name!r} not present in all datasets." + ) + vars = list(ensure_common_dims(variables)) + + # Try to concatenate the indexes, concatenate the variables when no index # is found on all datasets. indexes: list[Index] = list(get_indexes(name)) if indexes: @@ -586,9 +607,28 @@ def get_indexes(name): ) result_vars[k] = v else: - combined_var = concat_vars( - vars, dim, positions, combine_attrs=combine_attrs - ) + # if variable is only present in one dataset of multiple datasets, + # then do not concat + if len(variables) == 1 and len(datasets) > 1: + combined_var = variables[0] + # only concat if variable is in multiple datasets + # or if single dataset (GH1988) + else: + combined_var = concat_vars( + vars, dim, positions, combine_attrs=combine_attrs + ) + # reindex if variable is not present in all datasets + if len(variable_index) < len(concat_index): + try: + fill = fill_value[name] + except (TypeError, KeyError): + fill = fill_value + combined_var = ( + DataArray(data=combined_var, name=name) + .assign_coords({dim: variable_index}) + .reindex({dim: concat_index}, fill_value=fill) + .variable + ) result_vars[name] = combined_var elif name in result_vars: diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index e0e0038cd89..186ad5b0865 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -50,8 +50,6 @@ def test_concat_compat() -> None: ValueError, match=r"coordinates in some datasets but not others" ): concat([ds1, ds2], dim="q") - with pytest.raises(ValueError, match=r"'q' is not present in all datasets"): - concat([ds2, ds1], dim="q") class TestConcatDataset: @@ -776,7 +774,7 @@ def test_concat_merge_single_non_dim_coord(): actual = concat([da1, da2], "x", coords=coords) assert_identical(actual, expected) - with pytest.raises(ValueError, match=r"'y' is not present in all datasets."): + with pytest.raises(ValueError, match=r"'y' not present in all datasets."): concat([da1, da2], dim="x", coords="all") da1 = DataArray([1, 2, 3], dims="x", coords={"x": [1, 2, 3], "y": 1}) @@ -784,7 +782,7 @@ def test_concat_merge_single_non_dim_coord(): da3 = DataArray([7, 8, 9], dims="x", coords={"x": [7, 8, 9], "y": 1}) for coords in ["different", "all"]: with pytest.raises(ValueError, match=r"'y' not present in all datasets"): - concat([da1, da2, da3], dim="x") + concat([da1, da2, da3], dim="x", coords=coords) def test_concat_preserve_coordinate_order() -> None: