Skip to content

Commit

Permalink
Combine preprocess_xarray and wrap_output_like from Unidata#1304
Browse files Browse the repository at this point in the history
  • Loading branch information
jthielen committed Sep 2, 2020
1 parent 361f271 commit 0766511
Showing 1 changed file with 105 additions and 10 deletions.
115 changes: 105 additions & 10 deletions src/metpy/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
See Also: :doc:`xarray with MetPy Tutorial </tutorials/xarray_tutorial>`.
"""
import functools
from inspect import signature
import logging
import re
import warnings
Expand Down Expand Up @@ -1054,18 +1055,112 @@ def _build_y_x(da, tolerance):
'correpsond to your CRS coordinate.')


def preprocess_xarray(func):
"""Decorate a function to convert all DataArray arguments to pint.Quantities.
def preprocess_and_wrap(broadcast=None, wrap_like=None, match_unit=False, to_magnitude=False):
"""Return decorator to wrap array calculations for type flexibility.
This uses the metpy xarray accessors to do the actual conversion.
Assuming you have a calculation that works internally with `pint.Quantity` or
`numpy.ndarray`, this will wrap the function to be able to handle `xarray.DataArray` and
`pint.Quantity` as well (assuming appropriate match to one of the input arguments).
Parameters
----------
broadcast : iterable of str or None
Iterable of string labels for arguments to broadcast against each other using xarray,
assuming they are supplied as `xarray.DataArray`. No automatic broadcasting will occur
with default of None.
wrap_like : str or array-like or tuple of str or tuple of array-like or None
Wrap the calculation output following a particular input argument (if str) or data
data object (if array-like). If tuple, will assume output is in the form of a tuple,
and wrap iteratively according to the str or array-like contained within. If None,
will not wrap output.
match_unit : bool
If true, force the unit of the final output to be that of wrapping object (as
determined by wrap_like), no matter the original calculation output. Defaults to
False.
to_magnitude : bool
If true, downcast xarray and Pint arguments to their magnitude. If false, downcast
xarray arguments to Quantity, and do not change other array-like arguments.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
args = tuple(a.metpy.unit_array if isinstance(a, xr.DataArray) else a for a in args)
kwargs = {name: (v.metpy.unit_array if isinstance(v, xr.DataArray) else v)
for name, v in kwargs.items()}
return func(*args, **kwargs)
return wrapper
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
bound_args = signature(func).bind(*args, **kwargs)

# Obtain proper match if referencing an input
match = list(wrap_like)
if isinstance(wrap_like, str):
match = bound_args.arguments[wrap_like]
elif isinstance(wrap_like, tuple):
for i in len(wrap_like):
if isinstance(wrap_like[i], str):
match[i] = bound_args.arguments[wrap_like[i]]

# Auto-broadcast select xarray arguments, and update bound_args
if broadcast is not None:
arg_names_to_broadcast = (
arg_name for arg_name in broadcast
if arg_name in bound_args.arguments
and isinstance(bound_args.arguments[arg_name], xr.DataArray)
)
broadcasted_args = xr.broadcast(
*(bound_args.arguments[arg_name] for arg_name in arg_names_to_broadcast)
)
for i, arg_name in enumerate(arg_names_to_broadcast):
bound_args.arguments[arg_name] = broadcasted_args[i]

# Cast all DataArrays to Pint Quantities
for arg_name in bound_args.arguments:
if isinstance(bound_args.arguments[arg_name], xr.DataArray):
bound_args.arguments[arg_name] = (
bound_args.arguments[arg_name].metpy.unit_array
)

# Optionally cast all Quantities to their magnitudes
if to_magnitude:
for arg_name in bound_args.arguments:
if isinstance(bound_args.arguments[arg_name], units.Quantity):
bound_args.arguments[arg_name] = bound_args.arguments[arg_name].m

# Evaluate inner calculation
result = func(*bound_args.args, **bound_args.kwargs)

# Wrap output based on match and match_unit
if match is None:
return result
else:
if match_unit:
wrapping = _wrap_output_like_matching_units
else:
wrapping = _wrap_output_like_not_matching_units

if isinstance(match, list):
return tuple(wrapping(*pair) for pair in zip(result, match))
else:
return wrapping(result, match)
return wrapper
return decorator


def _wrap_output_like_matching_units(result, match):
"""Convert result to be like match with matching units for output wrapper."""
output_xarray = isinstance(match, xr.DataArray)
match_units = str(match.metpy.units if output_xarray else getattr(match, 'units', ''))

if isinstance(result, xr.DataArray):
result = result.metpy.convert_units(match_units)
return result if output_xarray else result.metpy.unit_array
else:
result = result.to(match_units) if isinstance(result, units.Quantity) else result
return match.copy(data=result) if output_xarray else result


def _wrap_output_like_not_matching_units(result, match):
"""Convert result to be like match without matching units for output wrapper."""
output_xarray = isinstance(match, xr.DataArray)
if isinstance(result, xr.DataArray):
return result if output_xarray else result.metpy.unit_array
else:
return match.copy(data=result) if output_xarray else result


def check_matching_coordinates(func):
Expand Down

0 comments on commit 0766511

Please sign in to comment.