diff --git a/src/metpy/xarray.py b/src/metpy/xarray.py index ee2c0218f76..ba8889cbe9e 100644 --- a/src/metpy/xarray.py +++ b/src/metpy/xarray.py @@ -16,6 +16,7 @@ See Also: :doc:`xarray with MetPy Tutorial `. """ import functools +from inspect import signature import logging import re import warnings @@ -1051,18 +1052,112 @@ def _build_y_x(da, tolerance): 'correpsond to your CRS coordinate.') -def preprocess_xarray(func): - """Decorate a function to convert all DataArray arguments to pint.Quantities. +def preprocess_and_wrap(broadcast=None, wrap_like=None, match_unit=False, to_magnitude=False): + """Return decorator to wrap array calculations for type flexibility. - This uses the metpy xarray accessors to do the actual conversion. + Assuming you have a calculation that works internally with `pint.Quantity` or + `numpy.ndarray`, this will wrap the function to be able to handle `xarray.DataArray` and + `pint.Quantity` as well (assuming appropriate match to one of the input arguments). + + Parameters + ---------- + broadcast : iterable of str or None + Iterable of string labels for arguments to broadcast against each other using xarray, + assuming they are supplied as `xarray.DataArray`. No automatic broadcasting will occur + with default of None. + wrap_like : str or array-like or tuple of str or tuple of array-like or None + Wrap the calculation output following a particular input argument (if str) or data + data object (if array-like). If tuple, will assume output is in the form of a tuple, + and wrap iteratively according to the str or array-like contained within. If None, + will not wrap output. + match_unit : bool + If true, force the unit of the final output to be that of wrapping object (as + determined by wrap_like), no matter the original calculation output. Defaults to + False. + to_magnitude : bool + If true, downcast xarray and Pint arguments to their magnitude. If false, downcast + xarray arguments to Quantity, and do not change other array-like arguments. """ - @functools.wraps(func) - def wrapper(*args, **kwargs): - args = tuple(a.metpy.unit_array if isinstance(a, xr.DataArray) else a for a in args) - kwargs = {name: (v.metpy.unit_array if isinstance(v, xr.DataArray) else v) - for name, v in kwargs.items()} - return func(*args, **kwargs) - return wrapper + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + bound_args = signature(func).bind(*args, **kwargs) + + # Obtain proper match if referencing an input + match = list(wrap_like) + if isinstance(wrap_like, str): + match = bound_args.arguments[wrap_like] + elif isinstance(wrap_like, tuple): + for i in len(wrap_like): + if isinstance(wrap_like[i], str): + match[i] = bound_args.arguments[wrap_like[i]] + + # Auto-broadcast select xarray arguments, and update bound_args + if broadcast is not None: + arg_names_to_broadcast = ( + arg_name for arg_name in broadcast + if arg_name in bound_args.arguments + and isinstance(bound_args.arguments[arg_name], xr.DataArray) + ) + broadcasted_args = xr.broadcast( + *(bound_args.arguments[arg_name] for arg_name in arg_names_to_broadcast) + ) + for i, arg_name in enumerate(arg_names_to_broadcast): + bound_args.arguments[arg_name] = broadcasted_args[i] + + # Cast all DataArrays to Pint Quantities + for arg_name in bound_args.arguments: + if isinstance(bound_args.arguments[arg_name], xr.DataArray): + bound_args.arguments[arg_name] = ( + bound_args.arguments[arg_name].metpy.unit_array + ) + + # Optionally cast all Quantities to their magnitudes + if to_magnitude: + for arg_name in bound_args.arguments: + if isinstance(bound_args.arguments[arg_name], units.Quantity): + bound_args.arguments[arg_name] = bound_args.arguments[arg_name].m + + # Evaluate inner calculation + result = func(*bound_args.args, **bound_args.kwargs) + + # Wrap output based on match and match_unit + if match is None: + return result + else: + if match_unit: + wrapping = _wrap_output_like_matching_units + else: + wrapping = _wrap_output_like_not_matching_units + + if isinstance(match, list): + return tuple(wrapping(*pair) for pair in zip(result, match)) + else: + return wrapping(result, match) + return wrapper + return decorator + + +def _wrap_output_like_matching_units(result, match): + """Convert result to be like match with matching units for output wrapper.""" + output_xarray = isinstance(match, xr.DataArray) + match_units = str(match.metpy.units if output_xarray else getattr(match, 'units', '')) + + if isinstance(result, xr.DataArray): + result = result.metpy.convert_units(match_units) + return result if output_xarray else result.metpy.unit_array + else: + result = result.to(match_units) if isinstance(result, units.Quantity) else result + return match.copy(data=result) if output_xarray else result + + +def _wrap_output_like_not_matching_units(result, match): + """Convert result to be like match without matching units for output wrapper.""" + output_xarray = isinstance(match, xr.DataArray) + if isinstance(result, xr.DataArray): + return result if output_xarray else result.metpy.unit_array + else: + return match.copy(data=result) if output_xarray else result def check_matching_coordinates(func):