Skip to content

Commit

Permalink
Merge pull request JiaweiZhuang#102 from brews/respect_dtype
Browse files Browse the repository at this point in the history
Regridded output now same dtype as input
  • Loading branch information
raphaeldussin authored Aug 12, 2021
2 parents 6cad2ee + 0b5a29d commit a85025c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 5 deletions.
8 changes: 5 additions & 3 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def regrid_dask(self, indata, skipna=False, na_thres=1.0):
outdata = da.map_blocks(
self._regrid_array,
indata,
dtype=float,
dtype=indata.dtype,
chunks=output_chunk_shape,
skipna=skipna,
na_thres=na_thres,
Expand All @@ -526,7 +526,7 @@ def regrid_dataarray(self, dr_in, keep_attrs=False, skipna=False, na_thres=1.0):
input_core_dims=[input_horiz_dims],
output_core_dims=[temp_horiz_dims],
dask='parallelized',
output_dtypes=[float],
output_dtypes=[dr_in.dtype],
dask_gufunc_kwargs={
'output_sizes': {
temp_horiz_dims[0]: self.shape_out[0],
Expand Down Expand Up @@ -554,14 +554,16 @@ def regrid_dataset(self, ds_in, keep_attrs=False, skipna=False, na_thres=1.0):
]
ds_in = ds_in.drop_vars(non_regriddable)

ds_dtypes = [d.dtype for d in ds_in.data_vars.values()]

ds_out = xr.apply_ufunc(
self._regrid_array,
ds_in,
kwargs=kwargs,
input_core_dims=[input_horiz_dims],
output_core_dims=[temp_horiz_dims],
dask='parallelized',
output_dtypes=[float],
output_dtypes=ds_dtypes,
dask_gufunc_kwargs={
'output_sizes': {
temp_horiz_dims[0]: self.shape_out[0],
Expand Down
2 changes: 2 additions & 0 deletions xesmf/smm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def apply_weights(weights, indata, shape_in, shape_out):
indata_flat = indata.reshape(-1, shape_in[0] * shape_in[1])
outdata_flat = weights.dot(indata_flat.T).T

outdata_flat = outdata_flat.astype(indata_flat.dtype)

# unflattened output array
outdata = outdata_flat.reshape([*extra_shape, shape_out[0], shape_out[1]])
return outdata
Expand Down
29 changes: 28 additions & 1 deletion xesmf/tests/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ds_in.coords['time'] = np.arange(7) + 1
ds_in.coords['lev'] = np.arange(11) + 1
ds_in['data4D'] = ds_in['time'] * ds_in['lev'] * ds_in['data']
ds_in['data4D_f4'] = ds_in['data4D'].astype('f4')
ds_out['data4D_ref'] = ds_in['time'] * ds_in['lev'] * ds_out['data_ref']

# use non-divisible chunk size to catch edge cases
Expand All @@ -50,7 +51,7 @@
'lat_b': (('lat_b',), [0, 1, 2]),
'lon_b': (('lon_b',), [0, 1, 2, 3]),
},
data_vars={'abc': (('lon', 'lat'), [[1, 2], [3, 4], [2, 4]])},
data_vars={'abc': (('lon', 'lat'), [[1.0, 2.0], [3.0, 4.0], [2.0, 4.0]])},
)
polys = [
Polygon([[0.5, 0.5], [0.5, 1.5], [1.5, 0.5]]), # Simple triangle polygon
Expand Down Expand Up @@ -276,6 +277,29 @@ def ds_2d_to_1d(ds):
return ds_1d


@pytest.mark.parametrize('dtype', ['float32', 'float64'])
@pytest.mark.parametrize(
'data_in',
[
pytest.param(np.array(ds_in['data']), id='np.ndarray'),
pytest.param(xr.DataArray(ds_in['data']), id='xr.DataArray input'),
pytest.param(xr.Dataset(ds_in[['data']]), id='xr.Dataset input'),
pytest.param(ds_in['data'].chunk(), id='da.Array input'),
],
)
def test_regridded_respects_input_dtype(dtype, data_in):
"""Tests regridded output has same dtype as input"""
data_in = data_in.astype(dtype)
regridder = xe.Regridder(ds_in, ds_out, 'bilinear') # Make this a fixture?
out = regridder(data_in)

if 'data' in data_in:
# When data_in is xr.Dataset, a mapping...
assert out['data'].dtype == data_in['data'].dtype
else:
assert out.dtype == data_in.dtype


def test_regrid_with_1d_grid():
ds_in_1d = ds_2d_to_1d(ds_in)
ds_out_1d = ds_2d_to_1d(ds_out)
Expand Down Expand Up @@ -528,6 +552,9 @@ def test_regrid_dataset():
decimal=10,
)

assert ds_result['data4D'].dtype == np.dtype('f8')
assert ds_result['data4D_f4'].dtype == np.dtype('f4')

# check metadata
xr.testing.assert_identical(ds_result['time'], ds_in['time'])
xr.testing.assert_identical(ds_result['lev'], ds_in['lev'])
Expand Down
2 changes: 1 addition & 1 deletion xesmf/tests/test_smm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def test_add_nans_to_weights():
""" testing adding Nans to empty rows in sparse matrix """
"""testing adding Nans to empty rows in sparse matrix"""
# create input sparse matrix with one empty row (j=2)
row = np.array([0, 3, 1, 0])
col = np.array([0, 3, 1, 2])
Expand Down

0 comments on commit a85025c

Please sign in to comment.