Skip to content

Commit

Permalink
Make UADA work with __array_ufunc__
Browse files Browse the repository at this point in the history
Some of xarray and numpy work with __array_ufunc__ nowadays.  Make sure
this is supported by UADA to keep up to date.
  • Loading branch information
gerritholl committed Mar 26, 2018
1 parent abd2367 commit 11d65f9
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions typhon/physics/units/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class UnitsAwareDataArray(xarray.DataArray):
"""Like xarray.DataArray, but transfers units
"""

# need to keep both __array_wrap__ and __array_ufunc__. Although the
# former supersedes the latter, xarrays methods explicitly call the
# former sometimes.
def __array_wrap__(self, obj, context=None):
new_var = super().__array_wrap__(obj, context)
if self.attrs.get("units"):
Expand Down Expand Up @@ -56,6 +59,33 @@ def _apply_rbinary_op_to_units(self, func, other, x):
ureg.Quantity(1, self.attrs["units"]),).u)
return x

def __array_ufunc__(self, ufunc, method, *args, **kwargs):
new_var = super().__array_ufunc__(ufunc, method, *args, **kwargs)
# make sure we're still UADA
new_var = self.__class__(new_var)
if self.attrs.get("units"):
if method == "__call__":
q = ufunc(ureg.Quantity(1, self.attrs.get("units")))
try:
u = q.u
except AttributeError:
if (ureg(self.attrs["units"]).dimensionless or
new_var.dtype.kind == "b"):
# expected, see https://github.com/hgrecco/pint/issues/482
u = ureg.dimensionless
else:
raise
# for exp and log, values are not set correctly. I'm
# not sure why. Perhaps related to
# https://github.com/hgrecco/pint/issues/493
new_var.values = ufunc(ureg.Quantity(self.values, self.units))
new_var.attrs["units"] = str(u)
else: # unary operators? always retain units?
raise NotImplementedError("Not implented")
new_var.attrs["units"] = str(self.attrs.get("units"))

return new_var

# pow is different because resulting unit depends on argument, not on
# unit of argument
def __pow__(self, other):
Expand Down

0 comments on commit 11d65f9

Please sign in to comment.