Skip to content

Commit

Permalink
ENH: remove Operator parent class from FunctionSpaceElement, closes #949
Browse files Browse the repository at this point in the history
  • Loading branch information
Holger Kohr committed Sep 2, 2017
1 parent 417101a commit 40e98a8
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 39 deletions.
39 changes: 16 additions & 23 deletions odl/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,46 +79,39 @@ def _default_call_in_place(op, x, out, **kwargs):
out.assign(op.range.element(op._call_out_of_place(x, **kwargs)))


def _signature_from_spec(func):
"""Return the signature of a python function as a string.
def _get_signature(func):
"""Return the signature of a callable as a string.
Parameters
----------
func : `function`
Function whose signature to compile
func : callable
Function whose signature to extract.
Returns
-------
sig : string
Signature of the function
Signature of the function.
"""
py3 = (sys.version_info.major > 2)
if py3:
spec = inspect.getfullargspec(func)
else:
spec = inspect.getargspec(func)
if sys.version_info.major > 2:
# Python 3 already implements this functionality
return func.__name__ + str(inspect.signature(func))

# In Python 2 we have to do it manually, unfortunately
spec = inspect.getargspec(func)
posargs = spec.args
defaults = spec.defaults if spec.defaults is not None else []
varargs = spec.varargs
kwargs = spec.varkw if py3 else spec.keywords
kwargs = spec.keywords
deflen = 0 if defaults is None else len(defaults)
nodeflen = 0 if posargs is None else len(posargs) - deflen

args = ['{}'.format(arg) for arg in posargs[:nodeflen]]
args += ['{}={}'.format(arg, dval)
for arg, dval in zip(posargs[nodeflen:], defaults)]
args.extend('{}={}'.format(arg, dval)
for arg, dval in zip(posargs[nodeflen:], defaults))
if varargs:
args += ['*{}'.format(varargs)]
if py3:
kw_only = spec.kwonlyargs
kw_only_defaults = spec.kwonlydefaults
if kw_only and not varargs:
args += ['*']
args += ['{}={}'.format(arg, kw_only_defaults[arg])
for arg in kw_only]
args.append('*{}'.format(varargs))
if kwargs:
args += ['**{}'.format(kwargs)]
args.append('**{}'.format(kwargs))

argstr = ', '.join(args)

Expand Down Expand Up @@ -236,7 +229,7 @@ def _dispatch_call_args(cls=None, bound_call=None, unbound_call=None,
kw_only = ()
kw_only_defaults = {}

signature = _signature_from_spec(call)
signature = _get_signature(call)

pos_args = spec.args
if unbound_call is not None:
Expand Down
85 changes: 72 additions & 13 deletions odl/space/fspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
standard_library.install_aliases()
from builtins import super

from inspect import isfunction
import inspect
import numpy as np
import sys

from odl.operator.operator import Operator, _dispatch_call_args
from odl.set import RealNumbers, ComplexNumbers, Set, LinearSpace
from odl.set.space import LinearSpaceElement
from odl.util import (
Expand All @@ -31,6 +31,57 @@
__all__ = ('FunctionSpace',)


def _check_out_arg(func):
"""Check if of ``func`` has an (optional) ``out`` argument.
Also verify that the signature of ``func`` has no ``*args`` since
they make argument propagation a hassle.
Parameters
----------
func : callable
Object that should be inspected.
Returns
-------
has_out : bool
``True`` if the signature has an ``out`` argument, ``False``
otherwise.
out_is_optional : bool
``True`` if ``out`` is present and optional in the signature,
``False`` otherwise.
Raises
------
TypeError
If ``func``'s signature has ``*args``.
"""
if sys.version_info.major > 2:
spec = inspect.getfullargspec(func)
kw_only = spec.kwonlyargs
else:
spec = inspect.getargspec(func)
kw_only = ()

if spec.varargs is not None:
raise TypeError('*args not allowed in function signature')

pos_args = spec.args
pos_defaults = () if spec.defaults is None else spec.defaults

has_out = 'out' in pos_args or 'out' in kw_only
if 'out' in pos_args:
has_out = True
out_is_optional = (
pos_args.index('out') >= len(pos_args) - len(pos_defaults))
elif 'out' in kw_only:
has_out = out_is_optional = True
else:
has_out = out_is_optional = False

return has_out, out_is_optional


def _default_in_place(func, x, out, **kwargs):
"""Default in-place evaluation method."""
out[:] = func(x, **kwargs)
Expand Down Expand Up @@ -681,7 +732,7 @@ def __str__(self):
return repr(self)


class FunctionSpaceElement(LinearSpaceElement, Operator):
class FunctionSpaceElement(LinearSpaceElement):

"""Representation of a `FunctionSpace` element."""

Expand All @@ -697,14 +748,14 @@ def __init__(self, fspace, fcall):
It must return a `FunctionSpace.range` element or a
`numpy.ndarray` of such (vectorized call).
"""
LinearSpaceElement.__init__(self, fspace)
Operator.__init__(self, self.space.domain, self.space.range,
linear=False)
super().__init__(fspace)
self.__domain = self.space.domain
self.__range = self.space.range

# Determine which type of implementation fcall is
if isinstance(fcall, FunctionSpaceElement):
call_has_out, call_out_optional, _ = _dispatch_call_args(
bound_call=fcall._call)
call_has_out = fcall._call_has_out
call_out_optional = fcall._call_out_optional

# Numpy Ufuncs and similar objects (e.g. Numba DUfuncs)
elif hasattr(fcall, 'nin') and hasattr(fcall, 'nout'):
Expand All @@ -717,12 +768,10 @@ def __init__(self, fspace, fcall):
'expected at most 1'
''.format(fcall.__name__, fcall.nout))
call_has_out = call_out_optional = (fcall.nout == 1)
elif isfunction(fcall):
call_has_out, call_out_optional, _ = _dispatch_call_args(
unbound_call=fcall)
elif inspect.isfunction(fcall):
call_has_out, call_out_optional = _check_out_arg(fcall)
elif callable(fcall):
call_has_out, call_out_optional, _ = _dispatch_call_args(
bound_call=fcall.__call__)
call_has_out, call_out_optional = _check_out_arg(fcall.__call__)
else:
raise TypeError('type {!r} not callable')

Expand All @@ -743,6 +792,16 @@ def __init__(self, fspace, fcall):
self._call_out_of_place = preload_first_arg(self, 'out-of-place')(
_default_out_of_place)

@property
def domain(self):
"""Set of objects on which this function can be evaluated."""
return self.__domain

@property
def range(self):
"""Set in which the result of an evaluation of this function lies."""
return self.__range

@property
def out_dtype(self):
"""Output data type of this function.
Expand Down
6 changes: 3 additions & 3 deletions odl/test/operator/operator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
FunctionalLeftVectorMult, OperatorRightVectorMult,
MatrixOperator, OperatorLeftVectorMult,
OpTypeError, OpDomainError, OpRangeError)
from odl.operator.operator import _signature_from_spec, _dispatch_call_args
from odl.operator.operator import _get_signature, _dispatch_call_args
from odl.util.testutils import almost_equal, all_almost_equal, noise_element


Expand Down Expand Up @@ -857,10 +857,10 @@ def func(request):
return request.param


def test_signature_from_spec(func):
def test_get_signature(func):

true_sig = func.__doc__.splitlines()[0].strip()
sig = _signature_from_spec(func)
sig = _get_signature(func)
assert true_sig == sig


Expand Down
47 changes: 47 additions & 0 deletions odl/test/space/fspace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,53 @@ def test_fspace_vector_eval_complex():
assert all_equal(out_mg, true_mg)


def test_fspace_vector_with_params():
rect, points, mg = _standard_setup_2d()

def f(x, c):
return sum(x) + c

def f_out1(x, out, c):
out[:] = sum(x) + c

def f_out2(x, c, out):
out[:] = sum(x) + c

fspace = FunctionSpace(rect)
true_result_arr = f(points, c=2)
true_result_mg = f(mg, c=2)

f_elem = fspace.element(f)
assert all_equal(f_elem(points, c=2), true_result_arr)
out_arr = np.empty((5,))
f_elem(points, c=2, out=out_arr)
assert all_equal(out_arr, true_result_arr)
assert all_equal(f_elem(mg, c=2), true_result_mg)
out_mg = np.empty((2, 3))
f_elem(mg, c=2, out=out_mg)
assert all_equal(out_mg, true_result_mg)

f_out1_elem = fspace.element(f_out1)
assert all_equal(f_out1_elem(points, c=2), true_result_arr)
out_arr = np.empty((5,))
f_out1_elem(points, c=2, out=out_arr)
assert all_equal(out_arr, true_result_arr)
assert all_equal(f_out1_elem(mg, c=2), true_result_mg)
out_mg = np.empty((2, 3))
f_out1_elem(mg, c=2, out=out_mg)
assert all_equal(out_mg, true_result_mg)

f_out2_elem = fspace.element(f_out2)
assert all_equal(f_out2_elem(points, c=2), true_result_arr)
out_arr = np.empty((5,))
f_out2_elem(points, c=2, out=out_arr)
assert all_equal(out_arr, true_result_arr)
assert all_equal(f_out2_elem(mg, c=2), true_result_mg)
out_mg = np.empty((2, 3))
f_out2_elem(mg, c=2, out=out_mg)
assert all_equal(out_mg, true_result_mg)


def test_fspace_vector_ufunc():
intv = odl.IntervalProd(0, 1)
points = _points(intv, num=5)
Expand Down

0 comments on commit 40e98a8

Please sign in to comment.