-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dask-friendly nan check in xr.corr() and xr.cov() #5284
Conversation
Hi @AndrewWilliams3142 , thanks for another PR! This looks good. Could we add a test like https://github.com/pydata/xarray/pull/4559/files#diff-74d2dc289aa601b2de094fb3a3b687fd65963401b51b95cc5e0afcd06cc4cb82R45? And maybe reference this or #4559 as an explanation of what's going on. |
xarray/core/computation.py
Outdated
else: | ||
_are_there_nans = _nan_check(valid_values.data) | ||
|
||
if not _are_there_nans: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @AndrewWilliams3142 .
Unfortunately I don't think you can avoid this for dask arrays. Fundamentally bool(dask_array)
will call compute. So this line if not _are_there_nans
will call compute.
I think the only solution here is to special case this optimization for non-dask DataArray, and always call .where(valid_values)
for dask arrays but maybe someone else has an idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I guess you could stick the
valid_values = ...
if valid_values.any():
...
in a helper function and map_blocks
that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, ok. Silly question but does it matter that _are_there_nans
is a scalar reduction of valid_values
? So valid_values
is still left uncomputed. Would this help the bottleneck at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No because the program needs to know valid_blocks.any()
to decide which branch of the if-else statement to execute. So it gets computed.
Hi @dcherian , just thinking about your suggestion for using da_a = xr.DataArray(
np.array([[1, 2, 3, 4], [1, 0.1, 0.2, 0.3], [2, 3.2, 0.6, 1.8]]),
dims=("space", "time"),
coords=[
("space", ["IA", "IL", "IN"]),
("time", pd.date_range("2000-01-01", freq="1D", periods=4)),
],
).chunk({'time':1})
da_b = xr.DataArray(
np.array([[0.2, 0.4, 0.6, 2], [15, 10, 5, 1], [1, 3.2, np.nan, 1.8]]),
dims=("space", "time"),
coords=[
("space", ["IA", "IL", "IN"]),
("time", pd.date_range("2000-01-01", freq="1D", periods=4)),
],
).chunk({'time':1})
print(da_a)
>>> <xarray.DataArray (space: 3, time: 4)>
array([[1. , 2. , 3. , 4. ],
[1. , 0.1, 0.2, 0.3],
[2. , 3.2, 0.6, 1.8]])
Coordinates:
* space (space) <U2 'IA' 'IL' 'IN'
* time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 2000-01-04
print(da_b)
>>> <xarray.DataArray (space: 3, time: 4)>
array([[ 0.2, 0.4, 0.6, 2. ],
[15. , 10. , 5. , 1. ],
[ 1. , 3.2, nan, 1.8]])
Coordinates:
* space (space) <U2 'IA' 'IL' 'IN'
* time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 2000-01-04
# Define function to use in map_blocks
def _get_valid_values(da, other):
da1, da2 = xr.align(da, other, join="inner", copy=False)
# 2. Ignore the nans
missing_vals = np.logical_or(da1.isnull(), da2.isnull())
if missing_vals.any():
da = da.where(missing_vals)
return da
else:
return da
# test
outp = da_a.map_blocks(_get_valid_values, args=[da_b])
print(outp.compute())
>>> <xarray.DataArray (space: 3, time: 4)>
array([[1. , 2. , nan, 4. ],
[1. , 0.1, nan, 0.3],
[2. , 3.2, 0.6, 1.8]])
Coordinates:
* time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 2000-01-04
* space (space) object 'IA' 'IL' 'IN' |
Well that was confusing! I discovered this with
For better performance, we should try a |
Thanks for that @dcherian ! I didn't know you could use print debugging on chunked operations like this! One thing actually: If I change
|
Hey both, I've added a test to check that dask doesn't compute when calling either @dcherian, regarding the |
Co-authored-by: Deepak Cherian <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @AndrewWilliams3142 !
Please feel free to add a whatsnew @AndrewWilliams3142 |
Thank you very much @AndrewWilliams3142 ! |
Was reading the discussion here and thought I'd draft a PR to implement some of the changes people suggested. It seems like the main problem is that currently in
computation.py
,if not valid_values.all()
is not a lazy operation, and so can be a bottleneck for very large dataarrays. To get around this, I've lifted some neat tricks from #4559 so thatvalid_values
remains a dask array.pre-commit run --all-files
whats-new.rst