Skip to content

Commit

Permalink
API: make MatVecOperator infer dom and ran by default
Browse files Browse the repository at this point in the history
  • Loading branch information
kohr-h committed Nov 26, 2015
1 parent 6600854 commit e98cdc0
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 42 deletions.
61 changes: 43 additions & 18 deletions odl/space/ntuples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,38 +1314,61 @@ class MatVecOperator(Operator):

"""Matrix multiply operator :math:`\mathbb{F}^n -> \mathbb{F}^m`."""

def __init__(self, dom, ran, matrix):
def __init__(self, matrix, dom=None, ran=None):
"""Initialize a new instance.
Parameters
----------
dom : `Fn`
Space on whose elements the matrix acts. Its dtype must be
castable to the range dtype.
ran : `Fn`
Space to which the matrix maps
matrix : array-like or ``scipy.sparse.spmatrix``
Matrix representing the linear operator. Its shape must be
``(m, n)``, where ``n`` is the size of ``dom`` and ``m`` the size
of ``ran``. Its dtype must be castable to the range dtype.
``(m, n)``, where ``n`` is the size of ``dom`` and ``m`` the
size of ``ran``. Its dtype must be castable to the range
``dtype``.
dom : `Fn`, optional
Space on whose elements the matrix acts. If not provided,
the domain is inferred from the matrix ``dtype`` and
``shape``. If provided, its dtype must be castable to the
range dtype.
ran : `Fn`, optional
Space to which the matrix maps. If not provided,
the domain is inferred from the matrix ``dtype`` and
``shape``.
"""
super().__init__(dom, ran, linear=True)
if not isinstance(dom, Fn):
if isspmatrix(matrix):
self._matrix = matrix
else:
self._matrix = np.asarray(matrix)

if self._matrix.ndim != 2:
raise ValueError('matrix {} has {} axes instead of 2.'
''.format(matrix, self._matrix.ndim))

# Infer domain and range from matrix if necessary
if is_real_floating_dtype(self._matrix):
spc_type = Rn
elif is_complex_floating_dtype(self._matrix):
spc_type = Cn
else:
spc_type = Fn

if dom is None:
dom = spc_type(self._matrix.shape[1], dtype=self._matrix.dtype)
elif not isinstance(dom, Fn):
raise TypeError('domain {!r} is not an `Fn` instance.'
''.format(dom))
if not isinstance(ran, Fn):

if ran is None:
ran = spc_type(self._matrix.shape[0], dtype=self._matrix.dtype)
elif not isinstance(ran, Fn):
raise TypeError('range {!r} is not an `Fn` instance.'
''.format(ran))

# Check compatibility of matrix with domain and range
if not np.can_cast(dom.dtype, ran.dtype):
raise TypeError('domain data type {} cannot be safely cast to '
'range data type {}.'
''.format(dom.dtype, ran.dtype))

if isspmatrix(matrix):
self._matrix = matrix
else:
self._matrix = np.asarray(matrix)

if self._matrix.shape != (ran.size, dom.size):
raise ValueError('matrix shape {} does not match the required '
'shape {} of a matrix {} --> {}.'
Expand All @@ -1357,6 +1380,8 @@ def __init__(self, dom, ran, matrix):
'range data type {}.'
''.format(matrix.dtype, ran.dtype))

super().__init__(dom, ran, linear=True)

@property
def matrix(self):
"""Matrix representing this operator."""
Expand All @@ -1375,8 +1400,8 @@ def adjoint(self):
'of domain and range differ ({} != {}).'
''.format(self.domain.field,
self.range.field))
return MatVecOperator(self.range, self.domain,
self.matrix.conj().T)
return MatVecOperator(self.matrix.conj().T,
dom=self.range, ran=self.domain)

def _call(self, x):
"""Raw call method on input, producing a new output."""
Expand Down
4 changes: 2 additions & 2 deletions test/solvers/vector/newton_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _apply(self, x, out):
def derivative(self, x):
matrix = np.array([[2 - 400 * x[1] + 1200 * x[0] ** 2, -400 * x[0]],
[-400 * x[0], 200]])
return odl.MatVecOperator(self.domain, self.range, matrix)
return odl.MatVecOperator(matrix, self.domain, self.range)


def test_newton_solver_quadratic():
Expand All @@ -92,7 +92,7 @@ def test_newton_solver_quadratic():
x_opt = np.linalg.solve(H, -c)

# Create derivative operator operator
Aop = odl.MatVecOperator(rn, rn, H)
Aop = odl.MatVecOperator(H, rn, rn)
deriv_op = ResidualOp(Aop, -c)

# Create line search object
Expand Down
63 changes: 41 additions & 22 deletions test/space/ntuples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,40 +667,59 @@ def test_matvec_init(fn):
sparse_mat = _sparse_matrix(fn)
dense_mat = _dense_matrix(fn)

MatVecOperator(fn, fn, sparse_mat)
MatVecOperator(fn, fn, dense_mat)
MatVecOperator(sparse_mat, fn, fn)
MatVecOperator(dense_mat, fn, fn)

# Test defaults
op_float = MatVecOperator([[1.0, 2],
[-1, 0.5]])

assert isinstance(op_float.domain, Rn)
assert isinstance(op_float.range, Rn)

op_complex = MatVecOperator([[1.0, 2 + 1j],
[-1 - 1j, 0.5]])

assert isinstance(op_complex.domain, Cn)
assert isinstance(op_complex.range, Cn)

op_int = MatVecOperator([[1, 2],
[-1, 0]])

assert isinstance(op_int.domain, Fn)
assert isinstance(op_int.range, Fn)

# Rectangular
rect_mat = 2 * np.eye(2, 3)
r2 = Rn(2)
r3 = Rn(3)

MatVecOperator(r3, r2, rect_mat)
MatVecOperator(rect_mat, r3, r2)

with pytest.raises(ValueError):
MatVecOperator(r2, r2, rect_mat)
MatVecOperator(rect_mat, r2, r2)

with pytest.raises(ValueError):
MatVecOperator(r3, r3, rect_mat)
MatVecOperator(rect_mat, r3, r3)

with pytest.raises(ValueError):
MatVecOperator(r2, r3, rect_mat)
MatVecOperator(rect_mat, r2, r3)

# Rn to Cn okay
MatVecOperator(r3, Cn(2), rect_mat)
MatVecOperator(rect_mat, r3, Cn(2))

# Cn to Rn not okay (no safe cast)
with pytest.raises(TypeError):
MatVecOperator(Cn(3), r2)
MatVecOperator(rect_mat, Cn(3), r2)

# Complex matrix between real spaces not okay
rect_complex_mat = rect_mat + 1j
with pytest.raises(TypeError):
MatVecOperator(r3, r2, rect_complex_mat)
MatVecOperator(rect_complex_mat, r3, r2)

# Init with array-like structure (including numpy.matrix)
MatVecOperator(r3, r2, rect_mat.tolist())
MatVecOperator(r3, r2, np.asmatrix(rect_mat))
MatVecOperator(rect_mat.tolist(), r3, r2)
MatVecOperator(np.asmatrix(rect_mat), r3, r2)


def test_matvec_simple_properties():
Expand All @@ -709,18 +728,18 @@ def test_matvec_simple_properties():
r2 = Rn(2)
r3 = Rn(3)

op = MatVecOperator(r3, r2, rect_mat)
op = MatVecOperator(rect_mat, r3, r2)
assert isinstance(op.matrix, np.ndarray)

op = MatVecOperator(r3, r2, np.asmatrix(rect_mat))
op = MatVecOperator(np.asmatrix(rect_mat), r3, r2)
assert isinstance(op.matrix, np.ndarray)

op = MatVecOperator(r3, r2, rect_mat.tolist())
op = MatVecOperator(rect_mat.tolist(), r3, r2)
assert isinstance(op.matrix, np.ndarray)
assert not op.matrix_issparse

sparse_mat = _sparse_matrix(Rn(5))
op = MatVecOperator(Rn(5), Rn(5), sparse_mat)
op = MatVecOperator(sparse_mat, Rn(5), Rn(5))
assert isinstance(op.matrix, sp.sparse.spmatrix)
assert op.matrix_issparse

Expand All @@ -730,8 +749,8 @@ def test_matvec_adjoint(fn):
sparse_mat = _sparse_matrix(fn)
dense_mat = _dense_matrix(fn)

op_sparse = MatVecOperator(fn, fn, sparse_mat)
op_dense = MatVecOperator(fn, fn, dense_mat)
op_sparse = MatVecOperator(sparse_mat, fn, fn)
op_dense = MatVecOperator(dense_mat, fn, fn)

# Just test if it runs, nothing interesting to test here
op_sparse.adjoint
Expand All @@ -742,15 +761,15 @@ def test_matvec_adjoint(fn):
r2, r3 = Rn(2), Rn(3)
c2 = Cn(2)

op = MatVecOperator(r3, r2, rect_mat)
op = MatVecOperator(rect_mat, r3, r2)
op_adj = op.adjoint
assert op_adj.domain == op.range
assert op_adj.range == op.domain
assert np.array_equal(op_adj.matrix, op.matrix.conj().T)
assert np.array_equal(op_adj.adjoint.matrix, op.matrix)

# The operator Rn -> Cn has no adjoint
op_noadj = MatVecOperator(r3, c2, rect_mat)
op_noadj = MatVecOperator(rect_mat, r3, c2)
with pytest.raises(NotImplementedError):
op_noadj.adjoint

Expand All @@ -761,8 +780,8 @@ def test_matvec_call(fn):
dense_mat = _dense_matrix(fn)
xarr, x = _vectors(fn)

op_sparse = MatVecOperator(fn, fn, sparse_mat)
op_dense = MatVecOperator(fn, fn, dense_mat)
op_sparse = MatVecOperator(sparse_mat, fn, fn)
op_dense = MatVecOperator(dense_mat, fn, fn)

yarr_sparse = sparse_mat.dot(xarr)
yarr_dense = dense_mat.dot(xarr)
Expand All @@ -787,7 +806,7 @@ def test_matvec_call(fn):
rect_mat = 2 * np.eye(2, 3)
r2, r3 = Rn(2), Rn(3)

op = MatVecOperator(r3, r2, rect_mat)
op = MatVecOperator(rect_mat, r3, r2)
xarr = np.arange(3, dtype=float)
x = r3.element(xarr)

Expand Down

0 comments on commit e98cdc0

Please sign in to comment.