diff --git a/CHANGES b/CHANGES index 8a2ba2d60..0395b7fc0 100644 --- a/CHANGES +++ b/CHANGES @@ -4,6 +4,13 @@ Pint Changelog 0.10 (unreleased) ----------------- +- Improvements to wraps and check: + - fail upon decoration (not execution) by checking wrapped function signature against + wraps/check arguments. + (might BREAK test code) + - wraps only accepts strings and Units (not quantities) to avoid confusion with magnitude. + (might BREAK code not conforming to documentation) + - when strict=True, strings that can be parsed to quantities are accepted as arguments. - Add revolutions per second (rps) - Improved compatbility for upcast types like xarray's DataArray or Dataset, to which Pint Quantities now fully defer for arithmetic and NumPy operations. A collection of diff --git a/pint/registry_helpers.py b/pint/registry_helpers.py index d51daba9e..970f2efa0 100644 --- a/pint/registry_helpers.py +++ b/pint/registry_helpers.py @@ -139,13 +139,20 @@ def _converter(ureg, values, strict): ) else: if strict: - raise ValueError( - "A wrapped function using strict=True requires " - "quantity for all arguments with not None units. " - "(error found for {}, {})".format( - args_as_uc[ndx][0], new_values[ndx] + if isinstance(values[ndx], str): + # if the value is a string, we try to parse it + tmp_value = ureg.parse_expression(values[ndx]) + new_values[ndx] = ureg._convert( + tmp_value._magnitude, tmp_value._units, args_as_uc[ndx][0] + ) + else: + raise ValueError( + "A wrapped function using strict=True requires " + "quantity or a string for all arguments with not None units. " + "(error found for {}, {})".format( + args_as_uc[ndx][0], new_values[ndx] + ) ) - ) return new_values, values_by_name @@ -179,41 +186,68 @@ def wraps(ureg, ret, args, strict=True): The value returned by the wrapped function will be converted to the units specified in `ret`. - Use None to skip argument conversion. - Set strict to False, to accept also numerical values. - Parameters ---------- - ureg : + ureg : UnitRegistry a UnitRegistry instance. - ret : - output units. - args : - iterable of input units. + ret : iterable of str or iterable of Unit + Units of each of the return values. Use `None` to skip argument conversion. + args : iterable of str or iterable of Unit + Units of each of the input arguments. Use `None` to skip argument conversion. strict : bool Indicates that only quantities are accepted. (Default value = True) Returns ------- callable - the wrapped function. + the wrapper function. + + Raises + ------ + TypeError + if the number of given arguments does not match the number of function parameters. + if the any of the provided arguments is not a unit a string or Quantity """ if not isinstance(args, (list, tuple)): args = (args,) + for arg in args: + if arg is not None and not isinstance(arg, (ureg.Unit, str)): + raise TypeError( + "wraps arguments must by of type str or Unit, not %s (%s)" + % (type(arg), arg) + ) + converter = _parse_wrap_args(args) - if isinstance(ret, (list, tuple)): - container, ret = ( - True, - ret.__class__([_to_units_container(arg, ureg) for arg in ret]), - ) + is_ret_container = isinstance(ret, (list, tuple)) + if is_ret_container: + for arg in ret: + if arg is not None and not isinstance(arg, (ureg.Unit, str)): + raise TypeError( + "wraps 'ret' argument must by of type str or Unit, not %s (%s)" + % (type(arg), arg) + ) + ret = ret.__class__([_to_units_container(arg, ureg) for arg in ret]) else: - container, ret = False, _to_units_container(ret, ureg) + if ret is not None and not isinstance(ret, (ureg.Unit, str)): + raise TypeError( + "wraps 'ret' argument must by of type str or Unit, not %s (%s)" + % (type(ret), ret) + ) + ret = _to_units_container(ret, ureg) def decorator(func): + + count_params = len(signature(func).parameters) + if len(args) != count_params: + raise TypeError( + "%s takes %i parameters, but %i units were passed" + % (func.__name__, count_params, len(args)) + ) + assigned = tuple( attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr) ) @@ -232,7 +266,7 @@ def wrapper(*values, **kw): result = func(*new_values, **kw) - if container: + if is_ret_container: out_units = ( _replace_units(r, values_by_name) if is_ref else r for (r, is_ref) in ret @@ -258,28 +292,27 @@ def check(ureg, *args): """Decorator to for quantity type checking for function inputs. Use it to ensure that the decorated function input parameters match - the expected type of pint quantity. + the expected dimension of pint quantity. - Use None to skip argument checking. + The wrapper function raises: + - `pint.DimensionalityError` if an argument doesn't match the required dimensions. - Parameters - ---------- - ureg : + ureg : UnitRegistry a UnitRegistry instance. - args : - iterable of input units. - *args : - + *args : iterable of str or iterable of UnitContainer + Dimensions of each of the input arguments. Use `None` to skip argument conversion. Returns ------- - type + callable the wrapped function. Raises ------ - pint.DimensionalityError - if the parameters don't match dimensions + TypeError + if the number of given dimensions does not match the number of function parameters. + ValueError + if the any of the provided dimensions cannot be parsed as a dimension. """ dimensions = [ @@ -287,6 +320,14 @@ def check(ureg, *args): ] def decorator(func): + + count_params = len(signature(func).parameters) + if len(dimensions) != count_params: + raise TypeError( + "%s takes %i parameters, but %i dimensions were passed" + % (func.__name__, count_params, len(dimensions)) + ) + assigned = tuple( attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr) ) @@ -297,11 +338,7 @@ def decorator(func): @functools.wraps(func, assigned=assigned, updated=updated) def wrapper(*args, **kwargs): list_args, empty = _apply_defaults(func, args, kwargs) - if len(dimensions) > len(list_args): - raise TypeError( - "%s takes %i parameters, but %i dimensions were passed" - % (func.__name__, len(list_args), len(dimensions)) - ) + for dim, value in zip(dimensions, list_args): if dim is None: diff --git a/pint/testsuite/test_unit.py b/pint/testsuite/test_unit.py index 7f64473cc..0b6c086a8 100644 --- a/pint/testsuite/test_unit.py +++ b/pint/testsuite/test_unit.py @@ -433,6 +433,9 @@ def func(x): ureg = self.ureg + self.assertRaises(TypeError, ureg.wraps, (3 * ureg.meter, [None])) + self.assertRaises(TypeError, ureg.wraps, (None, [3 * ureg.meter])) + f0 = ureg.wraps(None, [None])(func) self.assertEqual(f0(3.0), 3.0) @@ -451,6 +454,16 @@ def func(x): self.assertEqual(f1b(3.0 * ureg.meter), 3.0) self.assertRaises(DimensionalityError, f1b, 3 * ureg.second) + f1c = ureg.wraps("meter", [ureg.meter])(func) + self.assertEqual(f1c(3.0 * ureg.centimeter), 0.03 * ureg.meter) + self.assertEqual(f1c(3.0 * ureg.meter), 3.0 * ureg.meter) + self.assertRaises(DimensionalityError, f1c, 3 * ureg.second) + + f1d = ureg.wraps(ureg.meter, [ureg.meter])(func) + self.assertEqual(f1d(3.0 * ureg.centimeter), 0.03 * ureg.meter) + self.assertEqual(f1d(3.0 * ureg.meter), 3.0 * ureg.meter) + self.assertRaises(DimensionalityError, f1d, 3 * ureg.second) + f1 = ureg.wraps(None, "meter")(func) self.assertRaises(ValueError, f1, 3.0) self.assertEqual(f1(3.0 * ureg.centimeter), 0.03) @@ -565,17 +578,8 @@ def gfunc(x, y): 1 * ureg.meter / ureg.second ** 2, ) - g2 = ureg.check("[speed]")(gfunc) - self.assertRaises(DimensionalityError, g2, 3.0, 1) - self.assertRaises(TypeError, g2, 2 * ureg.parsec) - self.assertRaises(DimensionalityError, g2, 2 * ureg.parsec, 1.0) - self.assertEqual(g2(2.0 * ureg.km / ureg.hour, 2), 1 * ureg.km / ureg.hour) - - g3 = ureg.check("[speed]", "[time]", "[mass]")(gfunc) - self.assertRaises(TypeError, g3, 1 * ureg.parsec, 1 * ureg.angstrom) - self.assertRaises( - TypeError, g3, 1 * ureg.parsec, 1 * ureg.angstrom, 1 * ureg.kilogram - ) + self.assertRaises(TypeError, ureg.check("[speed]"), gfunc) + self.assertRaises(TypeError, ureg.check("[speed]", "[time]", "[mass]"), gfunc) def test_to_ref_vs_to(self): self.ureg.autoconvert_offset_to_baseunit = True