diff --git a/odl/operator/operator.py b/odl/operator/operator.py index dff197ef2ef..0b114fddacc 100644 --- a/odl/operator/operator.py +++ b/odl/operator/operator.py @@ -79,46 +79,39 @@ def _default_call_in_place(op, x, out, **kwargs): out.assign(op.range.element(op._call_out_of_place(x, **kwargs))) -def _signature_from_spec(func): - """Return the signature of a python function as a string. +def _get_signature(func): + """Return the signature of a callable as a string. Parameters ---------- - func : `function` - Function whose signature to compile + func : callable + Function whose signature to extract. Returns ------- sig : string - Signature of the function + Signature of the function. """ - py3 = (sys.version_info.major > 2) - if py3: - spec = inspect.getfullargspec(func) - else: - spec = inspect.getargspec(func) + if sys.version_info.major > 2: + # Python 3 already implements this functionality + return func.__name__ + str(inspect.signature(func)) + # In Python 2 we have to do it manually, unfortunately + spec = inspect.getargspec(func) posargs = spec.args defaults = spec.defaults if spec.defaults is not None else [] varargs = spec.varargs - kwargs = spec.varkw if py3 else spec.keywords + kwargs = spec.keywords deflen = 0 if defaults is None else len(defaults) nodeflen = 0 if posargs is None else len(posargs) - deflen args = ['{}'.format(arg) for arg in posargs[:nodeflen]] - args += ['{}={}'.format(arg, dval) - for arg, dval in zip(posargs[nodeflen:], defaults)] + args.extend('{}={}'.format(arg, dval) + for arg, dval in zip(posargs[nodeflen:], defaults)) if varargs: - args += ['*{}'.format(varargs)] - if py3: - kw_only = spec.kwonlyargs - kw_only_defaults = spec.kwonlydefaults - if kw_only and not varargs: - args += ['*'] - args += ['{}={}'.format(arg, kw_only_defaults[arg]) - for arg in kw_only] + args.append('*{}'.format(varargs)) if kwargs: - args += ['**{}'.format(kwargs)] + args.append('**{}'.format(kwargs)) argstr = ', '.join(args) @@ -236,7 +229,7 @@ def _dispatch_call_args(cls=None, bound_call=None, unbound_call=None, kw_only = () kw_only_defaults = {} - signature = _signature_from_spec(call) + signature = _get_signature(call) pos_args = spec.args if unbound_call is not None: diff --git a/odl/space/fspace.py b/odl/space/fspace.py index 85f50e0d74d..94203152481 100644 --- a/odl/space/fspace.py +++ b/odl/space/fspace.py @@ -14,10 +14,10 @@ standard_library.install_aliases() from builtins import super -from inspect import isfunction +import inspect import numpy as np +import sys -from odl.operator.operator import Operator, _dispatch_call_args from odl.set import RealNumbers, ComplexNumbers, Set, LinearSpace from odl.set.space import LinearSpaceElement from odl.util import ( @@ -31,6 +31,57 @@ __all__ = ('FunctionSpace',) +def _check_out_arg(func): + """Check if of ``func`` has an (optional) ``out`` argument. + + Also verify that the signature of ``func`` has no ``*args`` since + they make argument propagation a hassle. + + Parameters + ---------- + func : callable + Object that should be inspected. + + Returns + ------- + has_out : bool + ``True`` if the signature has an ``out`` argument, ``False`` + otherwise. + out_is_optional : bool + ``True`` if ``out`` is present and optional in the signature, + ``False`` otherwise. + + Raises + ------ + TypeError + If ``func``'s signature has ``*args``. + """ + if sys.version_info.major > 2: + spec = inspect.getfullargspec(func) + kw_only = spec.kwonlyargs + else: + spec = inspect.getargspec(func) + kw_only = () + + if spec.varargs is not None: + raise TypeError('*args not allowed in function signature') + + pos_args = spec.args + pos_defaults = () if spec.defaults is None else spec.defaults + + has_out = 'out' in pos_args or 'out' in kw_only + if 'out' in pos_args: + has_out = True + out_is_optional = ( + pos_args.index('out') >= len(pos_args) - len(pos_defaults)) + elif 'out' in kw_only: + has_out = out_is_optional = True + else: + has_out = out_is_optional = False + + return has_out, out_is_optional + + def _default_in_place(func, x, out, **kwargs): """Default in-place evaluation method.""" out[:] = func(x, **kwargs) @@ -681,7 +732,7 @@ def __str__(self): return repr(self) -class FunctionSpaceElement(LinearSpaceElement, Operator): +class FunctionSpaceElement(LinearSpaceElement): """Representation of a `FunctionSpace` element.""" @@ -697,14 +748,14 @@ def __init__(self, fspace, fcall): It must return a `FunctionSpace.range` element or a `numpy.ndarray` of such (vectorized call). """ - LinearSpaceElement.__init__(self, fspace) - Operator.__init__(self, self.space.domain, self.space.range, - linear=False) + super().__init__(fspace) + self.__domain = self.space.domain + self.__range = self.space.range # Determine which type of implementation fcall is if isinstance(fcall, FunctionSpaceElement): - call_has_out, call_out_optional, _ = _dispatch_call_args( - bound_call=fcall._call) + call_has_out = fcall._call_has_out + call_out_optional = fcall._call_out_optional # Numpy Ufuncs and similar objects (e.g. Numba DUfuncs) elif hasattr(fcall, 'nin') and hasattr(fcall, 'nout'): @@ -717,12 +768,10 @@ def __init__(self, fspace, fcall): 'expected at most 1' ''.format(fcall.__name__, fcall.nout)) call_has_out = call_out_optional = (fcall.nout == 1) - elif isfunction(fcall): - call_has_out, call_out_optional, _ = _dispatch_call_args( - unbound_call=fcall) + elif inspect.isfunction(fcall): + call_has_out, call_out_optional = _check_out_arg(fcall) elif callable(fcall): - call_has_out, call_out_optional, _ = _dispatch_call_args( - bound_call=fcall.__call__) + call_has_out, call_out_optional = _check_out_arg(fcall.__call__) else: raise TypeError('type {!r} not callable') @@ -743,6 +792,16 @@ def __init__(self, fspace, fcall): self._call_out_of_place = preload_first_arg(self, 'out-of-place')( _default_out_of_place) + @property + def domain(self): + """Set of objects on which this function can be evaluated.""" + return self.__domain + + @property + def range(self): + """Set in which the result of an evaluation of this function lies.""" + return self.__range + @property def out_dtype(self): """Output data type of this function. diff --git a/odl/test/operator/operator_test.py b/odl/test/operator/operator_test.py index aaa518a87dd..92417dd7b23 100644 --- a/odl/test/operator/operator_test.py +++ b/odl/test/operator/operator_test.py @@ -18,7 +18,7 @@ FunctionalLeftVectorMult, OperatorRightVectorMult, MatrixOperator, OperatorLeftVectorMult, OpTypeError, OpDomainError, OpRangeError) -from odl.operator.operator import _signature_from_spec, _dispatch_call_args +from odl.operator.operator import _get_signature, _dispatch_call_args from odl.util.testutils import almost_equal, all_almost_equal, noise_element @@ -857,10 +857,10 @@ def func(request): return request.param -def test_signature_from_spec(func): +def test_get_signature(func): true_sig = func.__doc__.splitlines()[0].strip() - sig = _signature_from_spec(func) + sig = _get_signature(func) assert true_sig == sig diff --git a/odl/test/space/fspace_test.py b/odl/test/space/fspace_test.py index 5b2da902a70..d15cf340cb8 100644 --- a/odl/test/space/fspace_test.py +++ b/odl/test/space/fspace_test.py @@ -316,6 +316,53 @@ def test_fspace_vector_eval_complex(): assert all_equal(out_mg, true_mg) +def test_fspace_vector_with_params(): + rect, points, mg = _standard_setup_2d() + + def f(x, c): + return sum(x) + c + + def f_out1(x, out, c): + out[:] = sum(x) + c + + def f_out2(x, c, out): + out[:] = sum(x) + c + + fspace = FunctionSpace(rect) + true_result_arr = f(points, c=2) + true_result_mg = f(mg, c=2) + + f_elem = fspace.element(f) + assert all_equal(f_elem(points, c=2), true_result_arr) + out_arr = np.empty((5,)) + f_elem(points, c=2, out=out_arr) + assert all_equal(out_arr, true_result_arr) + assert all_equal(f_elem(mg, c=2), true_result_mg) + out_mg = np.empty((2, 3)) + f_elem(mg, c=2, out=out_mg) + assert all_equal(out_mg, true_result_mg) + + f_out1_elem = fspace.element(f_out1) + assert all_equal(f_out1_elem(points, c=2), true_result_arr) + out_arr = np.empty((5,)) + f_out1_elem(points, c=2, out=out_arr) + assert all_equal(out_arr, true_result_arr) + assert all_equal(f_out1_elem(mg, c=2), true_result_mg) + out_mg = np.empty((2, 3)) + f_out1_elem(mg, c=2, out=out_mg) + assert all_equal(out_mg, true_result_mg) + + f_out2_elem = fspace.element(f_out2) + assert all_equal(f_out2_elem(points, c=2), true_result_arr) + out_arr = np.empty((5,)) + f_out2_elem(points, c=2, out=out_arr) + assert all_equal(out_arr, true_result_arr) + assert all_equal(f_out2_elem(mg, c=2), true_result_mg) + out_mg = np.empty((2, 3)) + f_out2_elem(mg, c=2, out=out_mg) + assert all_equal(out_mg, true_result_mg) + + def test_fspace_vector_ufunc(): intv = odl.IntervalProd(0, 1) points = _points(intv, num=5)