Skip to content

Commit

Permalink
ENH: fill missing variables during concat by reindexing
Browse files Browse the repository at this point in the history
  • Loading branch information
kmuehlbauer committed Dec 22, 2022
1 parent 80c3e8e commit 3f6206f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 15 deletions.
62 changes: 51 additions & 11 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,9 @@ def process_subset_opt(opt, subset):

elif opt == "all":
concat_over.update(
set(getattr(datasets[0], subset)) - set(datasets[0].dims)
set().union(
*list((set(getattr(d, subset)) - set(d.dims) for d in datasets))
)
)
elif opt == "minimal":
pass
Expand Down Expand Up @@ -553,16 +555,35 @@ def get_indexes(name):
data = var.set_dims(dim).values
yield PandasIndex(data, dim, coord_dtype=var.dtype)

# preserve variable order for variables in first dataset
data_var_order = list(datasets[0].variables)
# append additional variables to the end
data_var_order += [e for e in data_names if e not in data_var_order]
# create concatenation index, needed for later reindexing
concat_index = list(range(sum(concat_dim_lengths)))

# stack up each variable and/or index to fill-out the dataset (in order)
# n.b. this loop preserves variable order, needed for groupby.
for name in datasets[0].variables:
for name in data_var_order:
if name in concat_over and name not in result_indexes:
try:
vars = ensure_common_dims([ds[name].variable for ds in datasets])
except KeyError:
raise ValueError(f"{name!r} is not present in all datasets.")

# Try concatenate the indexes, concatenate the variables when no index
variables = []
variable_index = []
for i, ds in enumerate(datasets):
if name in ds.variables:
variables.append(ds.variables[name])
# add to variable index, needed for reindexing
variable_index.extend(
[sum(concat_dim_lengths[:i]) + k for k in range(concat_dim_lengths[i])]
)
else:
# raise if coordinate not in all datasets
if name in coord_names:
raise ValueError(
f"coordinate {name!r} not present in all datasets."
)
vars = list(ensure_common_dims(variables))

# Try to concatenate the indexes, concatenate the variables when no index
# is found on all datasets.
indexes: list[Index] = list(get_indexes(name))
if indexes:
Expand All @@ -586,9 +607,28 @@ def get_indexes(name):
)
result_vars[k] = v
else:
combined_var = concat_vars(
vars, dim, positions, combine_attrs=combine_attrs
)
# if variable is only present in one dataset of multiple datasets,
# then do not concat
if len(variables) == 1 and len(datasets) > 1:
combined_var = variables[0]
# only concat if variable is in multiple datasets
# or if single dataset (GH1988)
else:
combined_var = concat_vars(
vars, dim, positions, combine_attrs=combine_attrs
)
# reindex if variable is not present in all datasets
if len(variable_index) < len(concat_index):
try:
fill = fill_value[name]
except (TypeError, KeyError):
fill = fill_value
combined_var = (
DataArray(data=combined_var, name=name)
.assign_coords({dim: variable_index})
.reindex({dim: concat_index}, fill_value=fill)
.variable
)
result_vars[name] = combined_var

elif name in result_vars:
Expand Down
6 changes: 2 additions & 4 deletions xarray/tests/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def test_concat_compat() -> None:
ValueError, match=r"coordinates in some datasets but not others"
):
concat([ds1, ds2], dim="q")
with pytest.raises(ValueError, match=r"'q' is not present in all datasets"):
concat([ds2, ds1], dim="q")


class TestConcatDataset:
Expand Down Expand Up @@ -776,15 +774,15 @@ def test_concat_merge_single_non_dim_coord():
actual = concat([da1, da2], "x", coords=coords)
assert_identical(actual, expected)

with pytest.raises(ValueError, match=r"'y' is not present in all datasets."):
with pytest.raises(ValueError, match=r"'y' not present in all datasets."):
concat([da1, da2], dim="x", coords="all")

da1 = DataArray([1, 2, 3], dims="x", coords={"x": [1, 2, 3], "y": 1})
da2 = DataArray([4, 5, 6], dims="x", coords={"x": [4, 5, 6]})
da3 = DataArray([7, 8, 9], dims="x", coords={"x": [7, 8, 9], "y": 1})
for coords in ["different", "all"]:
with pytest.raises(ValueError, match=r"'y' not present in all datasets"):
concat([da1, da2, da3], dim="x")
concat([da1, da2, da3], dim="x", coords=coords)


def test_concat_preserve_coordinate_order() -> None:
Expand Down

0 comments on commit 3f6206f

Please sign in to comment.