-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Non lazy behavior for weighted average when using resampled data #4625
Comments
I fear it's the weight check 🤦, try commenting lines 105 to 121: xarray/xarray/core/weighted.py Line 105 in 255bc8e
|
Oh nooo. So would you suggest that in addition to #4559, we should have a kwarg to completely skip this? |
the issue seems to be just this: xarray/xarray/core/weighted.py Lines 116 to 118 in 180e76d
Also, the computation is still triggered, even if we remove the map_blocks call:
weights = weights.copy(data=weights.data) not sure why, though. |
The weighted fix in #4559 is correct, that's why with ProgressBar():
mean_func(ds) does not compute. This is more instructive: from xarray.tests import raise_if_dask_computes
with raise_if_dask_computes():
ds.resample(time='3AS').map(mean_func) ....
150
151 def _sum_of_weights(
~/work/python/xarray/xarray/core/computation.py in dot(dims, *arrays, **kwargs)
1483 output_core_dims=output_core_dims,
1484 join=join,
-> 1485 dask="allowed",
1486 )
1487 return result.transpose(*[d for d in all_dims if d in result.dims])
~/work/python/xarray/xarray/core/computation.py in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
1132 join=join,
1133 exclude_dims=exclude_dims,
-> 1134 keep_attrs=keep_attrs,
1135 )
1136 # feed Variables directly through apply_variable_ufunc
~/work/python/xarray/xarray/core/computation.py in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
266 else:
267 name = result_name(args)
--> 268 result_coords = build_output_coords(args, signature, exclude_dims)
269
270 data_vars = [getattr(a, "variable", a) for a in args]
~/work/python/xarray/xarray/core/computation.py in build_output_coords(args, signature, exclude_dims)
231 # TODO: save these merged indexes, instead of re-computing them later
232 merged_vars, unused_indexes = merge_coordinates_without_align(
--> 233 coords_list, exclude_dims=exclude_dims
234 )
235
~/work/python/xarray/xarray/core/merge.py in merge_coordinates_without_align(objects, prioritized, exclude_dims)
327 filtered = collected
328
--> 329 return merge_collected(filtered, prioritized)
330
331
~/work/python/xarray/xarray/core/merge.py in merge_collected(grouped, prioritized, compat)
227 variables = [variable for variable, _ in elements_list]
228 try:
--> 229 merged_vars[name] = unique_variable(name, variables, compat)
230 except MergeError:
231 if compat != "minimal":
~/work/python/xarray/xarray/core/merge.py in unique_variable(name, variables, compat, equals)
132 if equals is None:
133 # now compare values with minimum number of computes
--> 134 out = out.compute()
135 for var in variables[1:]:
136 equals = getattr(out, compat)(var)
~/work/python/xarray/xarray/core/variable.py in compute(self, **kwargs)
459 """
460 new = self.copy(deep=False)
--> 461 return new.load(**kwargs)
462
463 def __dask_tokenize__(self):
~/work/python/xarray/xarray/core/variable.py in load(self, **kwargs)
435 """
436 if is_duck_dask_array(self._data):
--> 437 self._data = as_compatible_data(self._data.compute(**kwargs))
438 elif not is_duck_array(self._data):
439 self._data = np.asarray(self._data)
~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
165 dask.base.compute
166 """
--> 167 (result,) = compute(self, traverse=False, **kwargs)
168 return result
169
~/miniconda3/envs/dcpy/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
450 postcomputes.append(x.__dask_postcompute__())
451
--> 452 results = schedule(dsk, keys, **kwargs)
453 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
454
~/work/python/xarray/xarray/tests/__init__.py in __call__(self, dsk, keys, **kwargs)
112 raise RuntimeError(
113 "Too many computes. Total: %d > max: %d."
--> 114 % (self.total_computes, self.max_computes)
115 )
116 return dask.get(dsk, keys, **kwargs)
RuntimeError: Too many computes. Total: 1 > max: 0. It looks like we're repeatedly checking ipdb> up
> /home/deepak/work/python/xarray/xarray/core/merge.py(229)merge_collected()
227 variables = [variable for variable, _ in elements_list]
228 try:
--> 229 merged_vars[name] = unique_variable(name, variables, compat)
230 except MergeError:
231 if compat != "minimal":
ipdb> name
'weights'
ipdb> variables
[<xarray.Variable (time: 1)>
dask.array<getitem, shape=(1,), dtype=float64, chunksize=(1,), chunktype=numpy.ndarray>, <xarray.Variable (time: 1)>
dask.array<copy, shape=(1,), dtype=float64, chunksize=(1,), chunktype=numpy.ndarray>]
ipdb> variables[0].data.name
'getitem-2a74b8ca20ae20100597e397404ba17b'
ipdb> variables[1].data.name
'copy-fff901a87f4a2293c750766c554aa68d' |
Ah this works (but we lose # simple customized weighted mean function
def mean_func(ds):
return ds.weighted(ds.weights.reset_coords(drop=True)).mean('time') Adding xarray/xarray/core/weighted.py Line 149 in 180e76d
dot compares the weights coord var on ds and weights to decide if it should keep it.
The new call to enc = weights.encoding
weights = DataArray(
weights.data.map_blocks(_weight_check, dtype=weights.dtype),
dims=weights.dims,
coords=weights.coords,
attrs=weights.attrs
)
weights.encoding = enc This works locally. |
Sweet. Ill try to apply this fix for my workflow now. Happy to submit a PR with the suggested changes to |
PRs are always welcome! |
Untested but specifying |
Do you have a suggestion how to test this? Should I write a test involving resample + weighted? |
Yes something like what you have with with raise_if_dask_computes():
ds.resample(time='3AS').map(mean_func) BUT something is wrong with my explanation above. The error is only triggered when the number of timesteps is not divisble by the resampling frequency. If you set |
Oh I remember that too, and I didn't understand it at all... |
So I have added a test in #4668 and it confirms that this behavior is only occurring if the resample interval is smaller or equal than the chunks. |
As @dcherian pointed out above |
I am trying to apply an averaging function to multi year chunks of monthly model data. At the core the function performs a weighted average (and then some coordinate manipulations). I am using
resample(time='1AS')
and then try tomap
my custom function onto the data (see example below). Without actually loading the data, this step is prohibitively long in my workflow (20-30min depending on the model).Is there a way to apply this step completely lazily, like in the case where a simple non-weighted
.mean()
is used?Using resample with a simple mean works without any computation being triggered:
But when I do the same step with my custom function, there are some computations showing up
I am quite sure these are the same kind of computations that make my real-world workflow so slow.
I also confirmed that this not happening when I do not use resample first
this does not trigger a computation either. So this must be somehow related to
resample
? I would be happy to dig deeper into this, if somebody with more knowledge could point me to the right place.Environment:
Output of xr.show_versions()
INSTALLED VERSIONS
commit: None
python: 3.8.6 | packaged by conda-forge | (default, Oct 7 2020, 19:08:05)
[GCC 7.5.0]
python-bits: 64
OS: Linux
OS-release: 3.10.0-1160.2.2.el7.x86_64
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_US.UTF-8
LOCALE: en_US.UTF-8
libhdf5: 1.10.6
libnetcdf: 4.7.4
xarray: 0.16.2.dev77+g1a4f7bd
pandas: 1.1.3
numpy: 1.19.2
scipy: 1.5.2
netCDF4: 1.5.4
pydap: None
h5netcdf: 0.8.1
h5py: 2.10.0
Nio: None
zarr: 2.4.0
cftime: 1.2.1
nc_time_axis: 1.2.0
PseudoNetCDF: None
rasterio: 1.1.3
cfgrib: None
iris: None
bottleneck: None
dask: 2.30.0
distributed: 2.30.0
matplotlib: 3.3.2
cartopy: 0.18.0
seaborn: None
numbagg: None
pint: 0.16.1
setuptools: 49.6.0.post20201009
pip: 20.2.4
conda: None
pytest: 6.1.2
IPython: 7.18.1
sphinx: None
The text was updated successfully, but these errors were encountered: