Skip to content

Commit

Permalink
Enable lazy computation of wind vector rotation (#4972)
Browse files Browse the repository at this point in the history
* Lazy rotate_winds and associated tests

* Data ownership bug fix and extra tests

* Fix input data test

* Code formatting

* Add whats new entry

* Add missing author name entry

---------

Co-authored-by: Bill Little <[email protected]>
  • Loading branch information
tinyendian and bjlittle authored Mar 31, 2023
1 parent b8f5eea commit ba3ac6d
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 14 deletions.
8 changes: 7 additions & 1 deletion docs/src/whatsnew/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ This document explains the changes made to Iris for this release
#. `@lbdreyer`_ and `@trexfeathers`_ (reviewer) added :func:`iris.plot.hist`
and :func:`iris.quickplot.hist`. (:pull:`5189`)

#. `@tinyendian`_ edited :func:`~iris.analysis.cartography.rotate_winds` to
enable lazy computation of rotated wind vector components (:issue:`4934`,
:pull:`4972`)


🐛 Bugs Fixed
=============
Expand Down Expand Up @@ -156,9 +160,11 @@ This document explains the changes made to Iris for this release
.. _@ed-hawkins: https://github.com/ed-hawkins
.. _@scottrobinson02: https://github.com/scottrobinson02
.. _@agriyakhetarpal: https://github.com/agriyakhetarpal
.. _@tinyendian: https://github.com/tinyendian


.. comment
Whatsnew resources in alphabetical order:
.. _#ShowYourStripes: https://showyourstripes.info/s/globe/
.. _README.md: https://github.com/SciTools/iris#-----
.. _README.md: https://github.com/SciTools/iris#-----
45 changes: 32 additions & 13 deletions lib/iris/analysis/cartography.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import cartopy.crs as ccrs
import cartopy.img_transform
import cf_units
import dask.array as da
import numpy as np
import numpy.ma as ma

Expand Down Expand Up @@ -1206,9 +1207,15 @@ def rotate_winds(u_cube, v_cube, target_cs):
x = x.transpose()
y = y.transpose()

# Create resulting cubes.
ut_cube = u_cube.copy()
vt_cube = v_cube.copy()
# Create resulting cubes - produce lazy output data if at least
# one input cube has lazy data
lazy_output = u_cube.has_lazy_data() or v_cube.has_lazy_data()
if lazy_output:
ut_cube = u_cube.copy(data=da.empty_like(u_cube.lazy_data()))
vt_cube = v_cube.copy(data=da.empty_like(v_cube.lazy_data()))
else:
ut_cube = u_cube.copy()
vt_cube = v_cube.copy()
ut_cube.rename("transformed_{}".format(u_cube.name()))
vt_cube.rename("transformed_{}".format(v_cube.name()))

Expand Down Expand Up @@ -1236,8 +1243,16 @@ def rotate_winds(u_cube, v_cube, target_cs):
apply_mask = mask.any()
if apply_mask:
# Make masked arrays to accept masking.
ut_cube.data = ma.asanyarray(ut_cube.data)
vt_cube.data = ma.asanyarray(vt_cube.data)
if lazy_output:
ut_cube = ut_cube.copy(
data=da.ma.masked_array(ut_cube.core_data())
)
vt_cube = vt_cube.copy(
data=da.ma.masked_array(vt_cube.core_data())
)
else:
ut_cube.data = ma.asanyarray(ut_cube.data)
vt_cube.data = ma.asanyarray(vt_cube.data)

# Project vectors with u, v components one horiz slice at a time and
# insert into the resulting cubes.
Expand All @@ -1250,16 +1265,20 @@ def rotate_winds(u_cube, v_cube, target_cs):
for dim in dims:
index[dim] = slice(None, None)
index = tuple(index)
u = u_cube.data[index]
v = v_cube.data[index]
u = u_cube.core_data()[index]
v = v_cube.core_data()[index]
ut, vt = _transform_distance_vectors(u, v, ds, dx2, dy2)
if apply_mask:
ut = ma.asanyarray(ut)
ut[mask] = ma.masked
vt = ma.asanyarray(vt)
vt[mask] = ma.masked
ut_cube.data[index] = ut
vt_cube.data[index] = vt
if lazy_output:
ut = da.ma.masked_array(ut, mask=mask)
vt = da.ma.masked_array(vt, mask=mask)
else:
ut = ma.asanyarray(ut)
ut[mask] = ma.masked
vt = ma.asanyarray(vt)
vt[mask] = ma.masked
ut_cube.core_data()[index] = ut
vt_cube.core_data()[index] = vt

# Calculate new coords of locations in target coordinate system.
xyz_tran = target_crs.transform_points(src_crs, x, y)
Expand Down
57 changes: 57 additions & 0 deletions lib/iris/tests/unit/analysis/cartography/test_rotate_winds.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,5 +511,62 @@ def test_non_earth_semimajor_axis(self):
rotate_winds(u, v, other_cs)


class TestLazyRotateWinds(tests.IrisTest):
def _compare_lazy_rotate_winds(self, masked):
# Compute wind rotation with lazy data and compare results

# Choose target coord system that will (not) lead to masked results
if masked:
coord_sys = iris.coord_systems.OSGB()
else:
coord_sys = iris.coord_systems.GeogCS(6371229)

u, v = uv_cubes()

# Create deep copy of the cubes with rechunked lazy data to check if
# input data is modified, and if Dask metadata is preserved
u_lazy = u.copy(data=u.copy().lazy_data().rechunk([2, 1]))
v_lazy = v.copy(data=v.copy().lazy_data().rechunk([1, 2]))

ut_ref, vt_ref = rotate_winds(u, v, coord_sys)
self.assertFalse(ut_ref.has_lazy_data())
self.assertFalse(vt_ref.has_lazy_data())
# Ensure that choice of target coordinates leads to (no) masking
self.assertTrue(ma.isMaskedArray(ut_ref.data) == masked)

# Results are lazy if at least one component is lazy
ut, vt = rotate_winds(u_lazy, v, coord_sys)
self.assertTrue(ut.has_lazy_data())
self.assertTrue(vt.has_lazy_data())
self.assertTrue(ut.core_data().chunksize == (2, 1))
self.assertArrayAllClose(ut.data, ut_ref.data, rtol=1e-5)
self.assertArrayAllClose(vt.data, vt_ref.data, rtol=1e-5)

ut, vt = rotate_winds(u, v_lazy, coord_sys)
self.assertTrue(ut.has_lazy_data())
self.assertTrue(vt.has_lazy_data())
self.assertTrue(vt.core_data().chunksize == (1, 2))
self.assertArrayAllClose(ut.data, ut_ref.data, rtol=1e-5)
self.assertArrayAllClose(vt.data, vt_ref.data, rtol=1e-5)

ut, vt = rotate_winds(u_lazy, v_lazy, coord_sys)
self.assertTrue(ut.has_lazy_data())
self.assertTrue(vt.has_lazy_data())
self.assertTrue(ut.core_data().chunksize == (2, 1))
self.assertTrue(vt.core_data().chunksize == (1, 2))
self.assertArrayAllClose(ut.data, ut_ref.data, rtol=1e-5)
self.assertArrayAllClose(vt.data, vt_ref.data, rtol=1e-5)

# Ensure that input data has not been modified
self.assertArrayAllClose(u.data, u_lazy.data, rtol=1e-5)
self.assertArrayAllClose(v.data, v_lazy.data, rtol=1e-5)

def test_lazy_rotate_winds_masked(self):
self._compare_lazy_rotate_winds(True)

def test_lazy_rotate_winds_notmasked(self):
self._compare_lazy_rotate_winds(False)


if __name__ == "__main__":
tests.main()

0 comments on commit ba3ac6d

Please sign in to comment.