-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add defaults during concat 508 #3545
Changes from 12 commits
baeebed
a96583b
af347e7
418c538
9e35c84
f7124a3
47f7e4d
df3693e
c21dcd4
4e01bd9
515b9c1
cf5b8bd
3bf3931
03f9b3b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
import pandas as pd | ||
from collections import OrderedDict | ||
|
||
from . import dtypes, utils | ||
from .alignment import align | ||
from .common import full_like | ||
from .duck_array_ops import lazy_array_equiv | ||
from .merge import _VALID_COMPAT, unique_variable | ||
from .variable import IndexVariable, Variable, as_variable | ||
|
@@ -26,7 +28,7 @@ def concat( | |
xarray objects to concatenate together. Each object is expected to | ||
consist of variables and coordinates with matching shapes except for | ||
along the concatenated dimension. | ||
dim : str or DataArray or pandas.Index | ||
dim : str, DataArray, Variable, or pandas.Index | ||
Name of the dimension to concatenate along. This can either be a new | ||
dimension name, in which case it is added along axis=0, or an existing | ||
dimension name, in which case the location of the dimension is | ||
|
@@ -77,7 +79,8 @@ def concat( | |
to assign each dataset along the concatenated dimension. If not | ||
supplied, objects are concatenated in the provided order. | ||
fill_value : scalar, optional | ||
Value to use for newly missing values | ||
Value to use for newly missing values as well as to fill values where the | ||
variable is not present in all datasets. | ||
join : {'outer', 'inner', 'left', 'right', 'exact'}, optional | ||
String indicating how to combine differing indexes | ||
(excluding dim) in objects | ||
|
@@ -129,6 +132,7 @@ def concat( | |
"can only concatenate xarray Dataset and DataArray " | ||
"objects, got %s" % type(first_obj) | ||
) | ||
|
||
return f(objs, dim, data_vars, coords, compat, positions, fill_value, join) | ||
|
||
|
||
|
@@ -366,25 +370,89 @@ def ensure_common_dims(vars): | |
var = var.set_dims(common_dims, common_shape) | ||
yield var | ||
|
||
# stack up each variable to fill-out the dataset (in order) | ||
# n.b. this loop preserves variable order, needed for groupby. | ||
for k in datasets[0].variables: | ||
if k in concat_over: | ||
try: | ||
vars = ensure_common_dims([ds.variables[k] for ds in datasets]) | ||
except KeyError: | ||
raise ValueError("%r is not present in all datasets." % k) | ||
# Find union of all data variables (preserving order) | ||
# assumes all datasets are relatively in the same order | ||
# and missing variables are inserted in the correct position | ||
# if datasets have variables in drastically different orders | ||
# the resulting order will be dependent on the order they are in the list | ||
# passed to concat | ||
union_of_variables = OrderedDict() | ||
union_of_coordinates = OrderedDict() | ||
for ds in datasets: | ||
var_list = list(ds.variables.keys()) | ||
|
||
_find_ordering_inplace(var_list, union_of_variables) | ||
|
||
# check that all datasets have the same coordinate set | ||
if len(union_of_coordinates) > 0: | ||
coord_set_diff = ( | ||
union_of_coordinates.keys() ^ ds.coords.keys() | ||
) & concat_over | ||
if len(coord_set_diff) > 0: | ||
raise ValueError( | ||
"Variables %r are coordinates in some datasets but not others." | ||
% coord_set_diff | ||
) | ||
|
||
union_of_coordinates = OrderedDict( | ||
union_of_coordinates.items() | OrderedDict.fromkeys(ds.coords).items() | ||
) | ||
|
||
# we don't want to fill coordinate variables so remove them | ||
for k in union_of_coordinates.keys(): | ||
union_of_variables.pop(k, None) | ||
|
||
# Cache a filled tmp variable with correct dims for filling missing variables | ||
# doing this here allows us to concat with variables missing from any dataset | ||
# only will run until it finds one protype for each variable in concat list | ||
# we will also only fill defaults for data_vars not coordinates | ||
|
||
# optimization to allow us to break when filling variable | ||
def find_fill_variable_from_ds(variable_key, union_of_variables, datasets): | ||
for ds in datasets: | ||
if union_of_variables[variable_key] is not None: | ||
continue | ||
|
||
if variable_key not in ds.variables: | ||
continue | ||
|
||
v_fill_value = fill_value | ||
dtype, v_fill_value = dtypes.get_fill_value_for_variable( | ||
ds[variable_key], fill_value | ||
) | ||
|
||
union_of_variables[variable_key] = full_like( | ||
ds[variable_key], fill_value=v_fill_value, dtype=dtype | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This need to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the feedback and the above test. I'll try to incorporate your suggested test as well as the rest of the pending comments in the next update. |
||
) | ||
return | ||
|
||
for v in union_of_variables.keys(): | ||
find_fill_variable_from_ds(v, union_of_variables, datasets) | ||
|
||
# create the concat list filling in missing variables | ||
while len(union_of_variables) > 0 or len(union_of_coordinates) > 0: | ||
k = None | ||
# get the variables in order | ||
if len(union_of_variables) > 0: | ||
k = union_of_variables.popitem(last=False) | ||
elif len(union_of_coordinates) > 0: | ||
k = union_of_coordinates.popitem() | ||
|
||
if k[0] in concat_over: | ||
variables = [] | ||
for ds in datasets: | ||
if k[0] in ds.variables: | ||
variables.append(ds.variables[k[0]]) | ||
else: | ||
# var is missing, fill with cached value | ||
variables.append(k[1]) | ||
|
||
vars = ensure_common_dims(variables) | ||
combined = concat_vars(vars, dim, positions) | ||
assert isinstance(combined, Variable) | ||
result_vars[k] = combined | ||
result_vars[k[0]] = combined | ||
|
||
result = Dataset(result_vars, attrs=result_attrs) | ||
absent_coord_names = coord_names - set(result.variables) | ||
if absent_coord_names: | ||
raise ValueError( | ||
"Variables %r are coordinates in some datasets but not others." | ||
% absent_coord_names | ||
) | ||
result = result.set_coords(coord_names) | ||
result.encoding = result_encoding | ||
|
||
|
@@ -397,6 +465,26 @@ def ensure_common_dims(vars): | |
return result | ||
|
||
|
||
def _find_ordering_inplace(l, union): | ||
# this logic maintains the order of the variable list and runs in | ||
# O(n^2) where n is number of variables in the uncommon worst case | ||
# where there are no missing variables this will be O(n) | ||
# could potentially be refactored to a more generic function to determine | ||
# a consistent ordering of variables if proper consideration were | ||
# given both to the runtime as well as to the user scenarios | ||
for i in range(0, len(l)): | ||
if l[i] not in union: | ||
# need to determine the correct place | ||
# first add the new item which will be at the end | ||
union[l[i]] = None | ||
union.move_to_end(l[i]) | ||
# move any items after this in the variables list to the end | ||
# this will only happen for missing variables | ||
for j in range(i + 1, len(l)): | ||
if l[j] in union: | ||
union.move_to_end(l[j]) | ||
|
||
|
||
def _dataarray_concat( | ||
arrays, | ||
dim, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just plain
dict
should be fine now since we are python 3.6+There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I didn't realize that it was 3.6+ only. Will change to dict.