Skip to content

Commit

Permalink
Merge branch 'master' of github.com:hgrecco/pint
Browse files Browse the repository at this point in the history
  • Loading branch information
hgrecco committed Dec 3, 2023
2 parents 236b00c + cf86f71 commit 57764be
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- python-version: 3.9
numpy: "numpy"
uncertainties: "uncertainties"
extras: "sparse xarray netCDF4 dask[complete]==2023.4.0 graphviz babel==2.8"
extras: "sparse xarray netCDF4 dask[complete]==2023.4.0 graphviz babel==2.8 mip>=1.13"
runs-on: ubuntu-latest

env:
Expand Down
1 change: 1 addition & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.. include:: ../CHANGES
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Pint: makes units easy
Advanced topics <advanced/index>
ecosystem
API Reference <api/index>
changes

.. toctree::
:maxdepth: 1
Expand Down
12 changes: 6 additions & 6 deletions pint/facets/numpy/numpy_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,23 +741,23 @@ def _base_unit_if_needed(a):


@implements("trapz", "function")
def _trapz(a, x=None, dx=1.0, **kwargs):
a = _base_unit_if_needed(a)
units = a.units
def _trapz(y, x=None, dx=1.0, **kwargs):
y = _base_unit_if_needed(y)
units = y.units
if x is not None:
if hasattr(x, "units"):
x = _base_unit_if_needed(x)
units *= x.units
x = x._magnitude
ret = np.trapz(a._magnitude, x, **kwargs)
ret = np.trapz(y._magnitude, x, **kwargs)
else:
if hasattr(dx, "units"):
dx = _base_unit_if_needed(dx)
units *= dx.units
dx = dx._magnitude
ret = np.trapz(a._magnitude, dx=dx, **kwargs)
ret = np.trapz(y._magnitude, dx=dx, **kwargs)

return a.units._REGISTRY.Quantity(ret, units)
return y.units._REGISTRY.Quantity(ret, units)


def implement_mul_func(func):
Expand Down
2 changes: 1 addition & 1 deletion pint/facets/numpy/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def __setitem__(self, key, value):
isinstance(self._magnitude, np.ma.MaskedArray)
and np.ma.is_masked(value)
and getattr(value, "size", 0) == 1
) or math.isnan(value):
) or (getattr(value, "ndim", 0) == 0 and math.isnan(value)):
self._magnitude[key] = value
return
except TypeError:
Expand Down
54 changes: 36 additions & 18 deletions pint/registry_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from __future__ import annotations

import functools
from inspect import signature
from inspect import signature, Parameter
from itertools import zip_longest
from typing import TYPE_CHECKING, Callable, TypeVar, Any, Union, Optional
from collections.abc import Iterable
Expand Down Expand Up @@ -119,22 +119,27 @@ def _parse_wrap_args(args, registry=None):
"Not all variable referenced in %s are defined using !" % args[ndx]
)

def _converter(ureg, values, strict):
new_values = list(value for value in values)
def _converter(ureg, sig, values, kw, strict):
len_initial_values = len(values)

# pack kwargs
for i, param_name in enumerate(sig.parameters):
if i >= len_initial_values:
values.append(kw[param_name])

values_by_name = {}

# first pass: Grab named values
for ndx in defs_args_ndx:
value = values[ndx]
values_by_name[args_as_uc[ndx][0]] = value
new_values[ndx] = getattr(value, "_magnitude", value)
values[ndx] = getattr(value, "_magnitude", value)

# second pass: calculate derived values based on named values
for ndx in dependent_args_ndx:
value = values[ndx]
assert _replace_units(args_as_uc[ndx][0], values_by_name) is not None
new_values[ndx] = ureg._convert(
values[ndx] = ureg._convert(
getattr(value, "_magnitude", value),
getattr(value, "_units", UnitsContainer({})),
_replace_units(args_as_uc[ndx][0], values_by_name),
Expand All @@ -143,27 +148,32 @@ def _converter(ureg, values, strict):
# third pass: convert other arguments
for ndx in unit_args_ndx:
if isinstance(values[ndx], ureg.Quantity):
new_values[ndx] = ureg._convert(
values[ndx] = ureg._convert(
values[ndx]._magnitude, values[ndx]._units, args_as_uc[ndx][0]
)
else:
if strict:
if isinstance(values[ndx], str):
# if the value is a string, we try to parse it
tmp_value = ureg.parse_expression(values[ndx])
new_values[ndx] = ureg._convert(
values[ndx] = ureg._convert(
tmp_value._magnitude, tmp_value._units, args_as_uc[ndx][0]
)
else:
raise ValueError(
"A wrapped function using strict=True requires "
"quantity or a string for all arguments with not None units. "
"(error found for {}, {})".format(
args_as_uc[ndx][0], new_values[ndx]
args_as_uc[ndx][0], values[ndx]
)
)

return new_values, values_by_name
# unpack kwargs
for i, param_name in enumerate(sig.parameters):
if i >= len_initial_values:
kw[param_name] = values[i]

return values[:len_initial_values], kw, values_by_name

return _converter

Expand All @@ -175,12 +185,14 @@ def _apply_defaults(sig, args, kwargs):
values so that every argument is defined.
"""

bound_arguments = sig.bind(*args, **kwargs)
for param in sig.parameters.values():
if param.name not in bound_arguments.arguments:
bound_arguments.arguments[param.name] = param.default
args = [bound_arguments.arguments[key] for key in sig.parameters.keys()]
return args, {}
for i, param in enumerate(sig.parameters.values()):
if (
i >= len(args)
and param.default != Parameter.empty
and param.name not in kwargs
):
kwargs[param.name] = param.default
return list(args), kwargs


def wraps(
Expand Down Expand Up @@ -274,9 +286,11 @@ def wrapper(*values, **kw) -> Quantity:

# In principle, the values are used as is
# When then extract the magnitudes when needed.
new_values, values_by_name = converter(ureg, values, strict)
new_values, new_kw, values_by_name = converter(
ureg, sig, values, kw, strict
)

result = func(*new_values, **kw)
result = func(*new_values, **new_kw)

if is_ret_container:
out_units = (
Expand Down Expand Up @@ -352,7 +366,11 @@ def decorator(func):

@functools.wraps(func, assigned=assigned, updated=updated)
def wrapper(*args, **kwargs):
list_args, empty = _apply_defaults(sig, args, kwargs)
list_args, kw = _apply_defaults(sig, args, kwargs)

for i, param_name in enumerate(sig.parameters):
if i >= len(args):
list_args.append(kw[param_name])

for dim, value in zip(dimensions, list_args):
if dim is None:
Expand Down
17 changes: 9 additions & 8 deletions pint/testsuite/test_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,8 @@ def test_convert(self):

@helpers.requires_mip
def test_to_preferred(self):
ureg = UnitRegistry()
Q_ = ureg.Quantity
ureg = self.ureg
Q_ = self.Q_

ureg.define("pound_force_per_square_foot = 47.8803 pascals = psf")
ureg.define("pound_mass = 0.45359237 kg = lbm")
Expand Down Expand Up @@ -412,9 +412,9 @@ def test_to_preferred(self):

@helpers.requires_mip
def test_to_preferred_registry(self):
ureg = UnitRegistry()
Q_ = ureg.Quantity
ureg.preferred_units = [
ureg = self.ureg
Q_ = self.Q_
ureg.default_preferred_units = [
ureg.m, # distance L
ureg.kg, # mass M
ureg.s, # duration T
Expand All @@ -427,9 +427,10 @@ def test_to_preferred_registry(self):

@helpers.requires_mip
def test_autoconvert_to_preferred(self):
ureg = UnitRegistry()
Q_ = ureg.Quantity
ureg.preferred_units = [
ureg = self.ureg
Q_ = self.Q_
ureg.autoconvert_to_preferred = True
ureg.default_preferred_units = [
ureg.m, # distance L
ureg.kg, # mass M
ureg.s, # duration T
Expand Down
22 changes: 20 additions & 2 deletions pint/testsuite/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,11 @@ def test_define(self):
assert len(dir(ureg)) > 0

def test_load(self):
import pkg_resources
from importlib.resources import files

from .. import compat

data = pkg_resources.resource_filename(compat.__name__, "default_en.txt")
data = files(compat.__package__).joinpath("default_en.txt")
ureg1 = UnitRegistry()
ureg2 = UnitRegistry(data)
assert dir(ureg1) == dir(ureg2)
Expand Down Expand Up @@ -595,6 +595,23 @@ def hfunc(x, y):
h3 = ureg.wraps((None,), (None, None))(hfunc)
assert h3(3, 1) == (3, 1)

def kfunc(a, /, b, c=5, *, d=6):
return a, b, c, d

k1 = ureg.wraps((None,), (None, None, None, None))(kfunc)
assert k1(1, 2, 3, d=4) == (1, 2, 3, 4)
assert k1(1, 2, c=3, d=4) == (1, 2, 3, 4)
assert k1(1, b=2, c=3, d=4) == (1, 2, 3, 4)
assert k1(1, d=4, b=2, c=3) == (1, 2, 3, 4)
assert k1(1, 2, c=3) == (1, 2, 3, 6)
assert k1(1, 2, d=4) == (1, 2, 5, 4)
assert k1(1, 2) == (1, 2, 5, 6)

k2 = ureg.wraps((None,), ("meter", "centimeter", "meter", "centimeter"))(kfunc)
assert k2(
1 * ureg.meter, 2 * ureg.centimeter, 3 * ureg.meter, d=4 * ureg.centimeter
) == (1, 2, 3, 4)

def test_wrap_referencing(self):
ureg = self.ureg

Expand Down Expand Up @@ -643,6 +660,7 @@ def func(x):
assert f0(3.0 * ureg.centimeter) == 0.03 * ureg.meter
with pytest.raises(DimensionalityError):
f0(3.0 * ureg.kilogram)
assert f0(x=3.0 * ureg.centimeter) == 0.03 * ureg.meter

f0b = ureg.check(ureg.meter)(func)
with pytest.raises(DimensionalityError):
Expand Down

0 comments on commit 57764be

Please sign in to comment.