Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dask-friendly nan check in xr.corr() and xr.cov() #5284

Merged
merged 10 commits into from
May 27, 2021

Conversation

AndrewILWilliams
Copy link
Contributor

@AndrewILWilliams AndrewILWilliams commented May 9, 2021

Was reading the discussion here and thought I'd draft a PR to implement some of the changes people suggested. It seems like the main problem is that currently in computation.py, if not valid_values.all() is not a lazy operation, and so can be a bottleneck for very large dataarrays. To get around this, I've lifted some neat tricks from #4559 so that valid_values remains a dask array.


@max-sixty
Copy link
Collaborator

Hi @AndrewWilliams3142 , thanks for another PR!

This looks good. Could we add a test like https://github.com/pydata/xarray/pull/4559/files#diff-74d2dc289aa601b2de094fb3a3b687fd65963401b51b95cc5e0afcd06cc4cb82R45? And maybe reference this or #4559 as an explanation of what's going on.

else:
_are_there_nans = _nan_check(valid_values.data)

if not _are_there_nans:
Copy link
Contributor

@dcherian dcherian May 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @AndrewWilliams3142 .

Unfortunately I don't think you can avoid this for dask arrays. Fundamentally bool(dask_array) will call compute. So this line if not _are_there_nans will call compute.

I think the only solution here is to special case this optimization for non-dask DataArray, and always call .where(valid_values) for dask arrays but maybe someone else has an idea.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I guess you could stick the

valid_values = ...
if valid_values.any():
    ...

in a helper function and map_blocks that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, ok. Silly question but does it matter that _are_there_nans is a scalar reduction of valid_values? So valid_values is still left uncomputed. Would this help the bottleneck at all?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because the program needs to know valid_blocks.any() to decide which branch of the if-else statement to execute. So it gets computed.

@AndrewILWilliams
Copy link
Contributor Author

Hi @dcherian , just thinking about your suggestion for using map_blocks on the actual valid_values check. I've tested this and was wondering if you could maybe point to where I'm going wrong? It does mask out some of the values in a lazy way, but not the correct ones.

da_a = xr.DataArray(
        np.array([[1, 2, 3, 4], [1, 0.1, 0.2, 0.3], [2, 3.2, 0.6, 1.8]]),
        dims=("space", "time"),
        coords=[
            ("space", ["IA", "IL", "IN"]),
            ("time", pd.date_range("2000-01-01", freq="1D", periods=4)),
        ],
    ).chunk({'time':1})

da_b = xr.DataArray(
        np.array([[0.2, 0.4, 0.6, 2], [15, 10, 5, 1], [1, 3.2, np.nan, 1.8]]),
        dims=("space", "time"),
        coords=[
            ("space", ["IA", "IL", "IN"]),
            ("time", pd.date_range("2000-01-01", freq="1D", periods=4)),
        ],
    ).chunk({'time':1})

print(da_a)
>>> <xarray.DataArray (space: 3, time: 4)>
array([[1. , 2. , 3. , 4. ],
       [1. , 0.1, 0.2, 0.3],
       [2. , 3.2, 0.6, 1.8]])
Coordinates:
  * space    (space) <U2 'IA' 'IL' 'IN'
  * time     (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 2000-01-04

print(da_b)
>>> <xarray.DataArray (space: 3, time: 4)>
array([[ 0.2,  0.4,  0.6,  2. ],
       [15. , 10. ,  5. ,  1. ],
       [ 1. ,  3.2,  nan,  1.8]])
Coordinates:
  * space    (space) <U2 'IA' 'IL' 'IN'
  * time     (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 2000-01-04


# Define function to use in map_blocks
def _get_valid_values(da, other):
    da1, da2 = xr.align(da, other, join="inner", copy=False)
    
    # 2. Ignore the nans
    missing_vals = np.logical_or(da1.isnull(), da2.isnull())
    
    if missing_vals.any():
        da = da.where(missing_vals)
        return da
    else:
        return da


# test
outp = da_a.map_blocks(_get_valid_values, args=[da_b])


print(outp.compute())
>>> <xarray.DataArray (space: 3, time: 4)>
array([[1. , 2. , nan, 4. ],
       [1. , 0.1, nan, 0.3],
       [2. , 3.2, 0.6, 1.8]])
Coordinates:
  * time     (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 2000-01-04
  * space    (space) object 'IA' 'IL' 'IN'

@keewis
Copy link
Collaborator

keewis commented May 10, 2021

#4559 and #4668 fixed a similar issue using map_blocks, maybe you can use that as a reference?

@dcherian
Copy link
Contributor

Well that was confusing! if missing_vals.any(): will not be triggered if all the values in a block are valid. With .chunk({"time": 1}), some of the blocks are all valid.

I discovered this with print debugging and the single-threaded scheduler which loops over blocks in a for-loop

# Define function to use in map_blocks
def _get_valid_values(da, other):
    da1, da2 = xr.align(da, other, join="inner", copy=False)
    
    # 2. Ignore the nans
    missing_vals = np.logical_or(da1.isnull(), da2.isnull())
    print(missing_vals)
    
    if missing_vals.any():
        da = da.where(missing_vals)
    return da

da_a.map_blocks(_get_valid_values, args=[da_b]).compute(scheduler="sync")

For better performance, we should try a dask.array.map_blocks approach with duck_array_ops.isnull

@AndrewILWilliams
Copy link
Contributor Author

AndrewILWilliams commented May 11, 2021

Thanks for that @dcherian ! I didn't know you could use print debugging on chunked operations like this!

One thing actually: If I change da = da.where(missing_vals) to da = da.where(~missing_vals) then we get the results we'd expect. Do you think this fixes the problem?

def _get_valid_values(da, other):
    da1, da2 = xr.align(da, other, join="outer", copy=False)
    
    # 2. Ignore the nans
    missing_vals = np.logical_or(da1.isnull(), da2.isnull())
    
    if missing_vals.any():
        da = da.where(~missing_vals)
        return da
    else:
        return da
print(da_a.map_blocks(_get_valid_values, args=[da_b]).compute())
<xarray.DataArray (space: 3, time: 4)>
array([[1. , 2. , 3. , 4. ],
       [1. , 0.1, 0.2, 0.3],
       [2. , 3.2, nan, 1.8]])
Coordinates:
  * time     (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 2000-01-04
  * space    (space) object 'IA' 'IL' 'IN'
  * 

@AndrewILWilliams
Copy link
Contributor Author

Hey both, I've added a test to check that dask doesn't compute when calling either xr.corr() or xr.cov(), and also that the end result is still a dask array. Let me know if there's anything I've missed though! thanks for the help :)

@dcherian, regarding the apply_ufunc approach, I might leave that for now but as you said it can always be a future PR

@AndrewILWilliams AndrewILWilliams marked this pull request as ready for review May 26, 2021 13:40
Copy link
Contributor

@dcherian dcherian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @AndrewWilliams3142 !

@dcherian dcherian added the plan to merge Final call for comments label May 26, 2021
@max-sixty
Copy link
Collaborator

Please feel free to add a whatsnew @AndrewWilliams3142

@max-sixty
Copy link
Collaborator

Thank you very much @AndrewWilliams3142 !

@max-sixty max-sixty merged commit 3b81a86 into pydata:master May 27, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
plan to merge Final call for comments
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve performance of xarray.corr() on big datasets
4 participants