Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add defaults during concat 508 #3545

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ Bug fixes
(:issue:`3402`). By `Deepak Cherian <https://github.com/dcherian/>`_
- Allow appending datetime and bool data variables to zarr stores.
(:issue:`3480`). By `Akihiro Matsukawa <https://github.com/amatsukawa/>`_.
- Make :py:func:`~xarray.concat` more robust when concatenating variables present in some datasets but
not others (:issue:`508`). By `Scott Chamberlin <https://github.com/scottcha>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
104 changes: 82 additions & 22 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pandas as pd
from collections import OrderedDict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just plain dict should be fine now since we are python 3.6+

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I didn't realize that it was 3.6+ only. Will change to dict.


from . import dtypes, utils
from .alignment import align
from .common import full_like
from .duck_array_ops import lazy_array_equiv
from .merge import _VALID_COMPAT, unique_variable
from .variable import IndexVariable, Variable, as_variable
Expand All @@ -26,7 +28,7 @@ def concat(
xarray objects to concatenate together. Each object is expected to
consist of variables and coordinates with matching shapes except for
along the concatenated dimension.
dim : str or DataArray or pandas.Index
dim : str, DataArray, Variable, or pandas.Index
Name of the dimension to concatenate along. This can either be a new
dimension name, in which case it is added along axis=0, or an existing
dimension name, in which case the location of the dimension is
Expand Down Expand Up @@ -77,7 +79,8 @@ def concat(
to assign each dataset along the concatenated dimension. If not
supplied, objects are concatenated in the provided order.
fill_value : scalar, optional
Value to use for newly missing values
Value to use for newly missing values as well as to fill values where the
variable is not present in all datasets.
join : {'outer', 'inner', 'left', 'right', 'exact'}, optional
String indicating how to combine differing indexes
(excluding dim) in objects
Expand Down Expand Up @@ -129,6 +132,7 @@ def concat(
"can only concatenate xarray Dataset and DataArray "
"objects, got %s" % type(first_obj)
)

return f(objs, dim, data_vars, coords, compat, positions, fill_value, join)


Expand Down Expand Up @@ -261,21 +265,21 @@ def _parse_datasets(datasets):

dims = set()
all_coord_names = set()
data_vars = set() # list of data_vars
data_vars = {} # list of data_vars, using dict internally to maintain order
dim_coords = {} # maps dim name to variable
dims_sizes = {} # shared dimension sizes to expand variables

for ds in datasets:
dims_sizes.update(ds.dims)
all_coord_names.update(ds.coords)
data_vars.update(ds.data_vars)
data_vars.update(dict.fromkeys(ds.data_vars))

for dim in set(ds.dims) - dims:
if dim not in dim_coords:
dim_coords[dim] = ds.coords[dim].variable
dims = dims | set(ds.dims)

return dim_coords, dims_sizes, all_coord_names, data_vars
return dim_coords, dims_sizes, all_coord_names, list(data_vars.keys())


def _dataset_concat(
Expand Down Expand Up @@ -304,7 +308,7 @@ def _dataset_concat(
dim_names = set(dim_coords)
unlabeled_dims = dim_names - coord_names

both_data_and_coords = coord_names & data_names
both_data_and_coords = coord_names & set(data_names)
if both_data_and_coords:
raise ValueError(
"%r is a coordinate in some datasets but not others." % both_data_and_coords
Expand All @@ -323,7 +327,7 @@ def _dataset_concat(
)

# determine which variables to merge, and then merge them according to compat
variables_to_merge = (coord_names | data_names) - concat_over - dim_names
variables_to_merge = (coord_names | set(data_names)) - concat_over - dim_names

result_vars = {}
if variables_to_merge:
Expand Down Expand Up @@ -366,25 +370,81 @@ def ensure_common_dims(vars):
var = var.set_dims(common_dims, common_shape)
yield var

# stack up each variable to fill-out the dataset (in order)
# n.b. this loop preserves variable order, needed for groupby.
for k in datasets[0].variables:
if k in concat_over:
try:
vars = ensure_common_dims([ds.variables[k] for ds in datasets])
except KeyError:
raise ValueError("%r is not present in all datasets." % k)
# Find union of all data variables (preserving order)
# assumes all datasets are relatively in the same order
# and missing variables are inserted in the correct position
# if datasets have variables in drastically different orders
# the resulting order will be dependent on the order they are in the list
# passed to concat
data_var_order = list(datasets[0].data_vars)
data_var_order += [e for e in data_names if e not in data_var_order]

union_of_variables = OrderedDict.fromkeys(data_var_order)
union_of_coordinates = OrderedDict.fromkeys(coord_names)

# we don't want to fill coordinate variables so remove them
for k in union_of_coordinates.keys():
union_of_variables.pop(k, None)

# Cache a filled tmp variable with correct dims for filling missing variables
# doing this here allows us to concat with variables missing from any dataset
# only will run until it finds one protype for each variable in concat list
# we will also only fill defaults for data_vars not coordinates

# optimization to allow us to break when filling variable
def find_fill_variable_from_ds(variable_key, union_of_variables, datasets):
for ds in datasets:
if union_of_variables[variable_key] is not None:
continue

if variable_key not in ds.variables:
continue

v_fill_value = fill_value
dtype, v_fill_value = dtypes.get_fill_value_for_variable(
ds[variable_key], fill_value
)

union_of_variables[variable_key] = full_like(
ds.variables[variable_key], fill_value=v_fill_value, dtype=dtype
)
return

for v in union_of_variables.keys():
find_fill_variable_from_ds(v, union_of_variables, datasets)

# create the concat list filling in missing variables
filling_coordinates = False
while len(union_of_variables) > 0 or len(union_of_coordinates) > 0:
k = None
# get the variables in order
if len(union_of_variables) > 0:
k = union_of_variables.popitem(last=False)
elif len(union_of_coordinates) > 0:
filling_coordinates = True
k = union_of_coordinates.popitem()

if k[0] in concat_over:
variables = []
for ds in datasets:
if k[0] in ds.variables:
variables.append(ds.variables[k[0]])
else:
if filling_coordinates:
# in this case the coordinate is missing from a dataset
raise ValueError(
"Variables %r are coordinates in some datasets but not others."
% k[0]
)
# var is missing, fill with cached value
variables.append(k[1])

vars = ensure_common_dims(variables)
combined = concat_vars(vars, dim, positions)
assert isinstance(combined, Variable)
result_vars[k] = combined
result_vars[k[0]] = combined

result = Dataset(result_vars, attrs=result_attrs)
absent_coord_names = coord_names - set(result.variables)
if absent_coord_names:
raise ValueError(
"Variables %r are coordinates in some datasets but not others."
% absent_coord_names
)
result = result.set_coords(coord_names)
result.encoding = result_encoding

Expand Down
32 changes: 32 additions & 0 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from . import utils


# Use as a sentinel value to indicate a dtype appropriate NA value.
NA = utils.ReprObject("<NA>")

Expand Down Expand Up @@ -96,6 +97,37 @@ def get_fill_value(dtype):
return fill_value


def get_fill_value_for_variable(variable, fill_value=NA):
"""Return an appropriate fill value for this variable

Parameters
----------
variables : DataSet or DataArray
fill_value : a suggested fill value to evaluate and promote if necessary

Returns
-------
dtype : Promoted dtype for fill value.
new_fill_value : Missing value corresponding to this dtype.
"""
from .dataset import Dataset
from .dataarray import DataArray

if not (isinstance(variable, DataArray) or isinstance(variable, Dataset)):
raise TypeError(
"can only get fill value for xarray Dataset and DataArray "
"objects, got %s" % type(variable)
)

new_fill_value = fill_value
if fill_value is NA:
dtype, new_fill_value = maybe_promote(variable.dtype)
else:
dtype = variable.dtype

return dtype, new_fill_value


def get_pos_infinity(dtype):
"""Return an appropriate positive infinity for this dtype.

Expand Down
4 changes: 3 additions & 1 deletion xarray/tests/test_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,9 @@ def test_auto_combine(self):
auto_combine(objs)

objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [0]})]
with raises_regex(ValueError, "'y' is not present in all datasets"):
with raises_regex(
ValueError, ".* are coordinates in some datasets but not others"
):
auto_combine(objs)

def test_auto_combine_previously_failed(self):
Expand Down
Loading