Skip to content

Commit

Permalink
Merge pull request #73 from odlgroup/issue-49__matrix_representation_…
Browse files Browse the repository at this point in the history
…of_linear_operator

Closes #49
  • Loading branch information
kohr-h committed Nov 26, 2015
2 parents 69b729b + 38ba1a6 commit 2f4073c
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 34 deletions.
118 changes: 118 additions & 0 deletions odl/operator/oputils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright 2014, 2015 The ODL development group
#
# This file is part of ODL.
#
# ODL is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ODL is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ODL. If not, see <http://www.gnu.org/licenses/>.

"""Usefull utility functions on discrete spaces (i.e., either Rn/Cn or
discretized function spaces), for example obtaining a matrix representation of
an operator. """

# Imports for common Python 2/3 codebase
from __future__ import print_function, division, absolute_import
from future import standard_library
standard_library.install_aliases()

# External
import numpy as np

# Internal
from odl.space.base_ntuples import FnBase
from odl.set.pspace import ProductSpace


def matrix_representation(op):
"""Returns a matrix representation of a linear operator.
Parameters
----------
op : :class:`~odl.Operator`
The linear operator of which one wants a matrix representation.
Returns
----------
matrix : `numpy.ndarray`
The matrix representation of the operator.
Notes
----------
The algorithm works by letting the operator act on all unit vectors, and
stacking the output as a matrix.
"""

if not op.is_linear:
raise ValueError('The operator is not linear')

if not (isinstance(op.domain, FnBase) or
(isinstance(op.domain, ProductSpace) and
all(isinstance(spc, FnBase) for spc in op.domain))):
raise TypeError('Operator domain {} is not FnBase, nor ProductSpace '
'with only FnBase components'.format(op.domain))

if not (isinstance(op.range, FnBase) or
(isinstance(op.range, ProductSpace) and
all(isinstance(spc, FnBase) for spc in op.range))):
raise TypeError('Operator range {} is not FnBase, nor ProductSpace '
'with only FnBase components'.format(op.range))

# Get the size of the range, and handle ProductSpace
# Store for reuse in loop
op_ran_is_prod_space = isinstance(op.range, ProductSpace)
if op_ran_is_prod_space:
num_ran = op.range.size
n = [ran.size for ran in op.range]
else:
num_ran = 1
n = [op.range.size]

# Get the size of the domain, and handle ProductSpace
# Store for reuse in loop
op_dom_is_prod_space = isinstance(op.domain, ProductSpace)
if op_dom_is_prod_space:
num_dom = op.domain.size
m = [dom.size for dom in op.domain]
else:
num_dom = 1
m = [op.domain.size]

# Generate the matrix
matrix = np.zeros([np.sum(n), np.sum(m)])
tmp_ran = op.range.element() # Store for reuse in loop
tmp_dom = op.domain.zero() # Store for reuse in loop
index = 0
last_i = last_j = 0

for i in range(num_dom):
for j in range(m[i]):
if op_dom_is_prod_space:
tmp_dom[last_i][last_j] = 0.0
tmp_dom[i][j] = 1.0
else:
tmp_dom[last_j] = 0.0
tmp_dom[j] = 1.0
op(tmp_dom, out=tmp_ran)
if op_ran_is_prod_space:
tmp_idx = 0
for k in range(num_ran):
matrix[tmp_idx: tmp_idx + op.range[k].size, index] = (
tmp_ran[k])
tmp_idx += op.range[k].size
else:
matrix[:, index] = tmp_ran.asarray()
index += 1
last_j = j
last_i = i

return matrix
47 changes: 13 additions & 34 deletions test/operator/operator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import odl
from odl import (Operator, OperatorSum, OperatorComp,
OperatorLeftScalarMult, OperatorRightScalarMult,
FunctionalLeftVectorMult,
OperatorRightVectorMult)
FunctionalLeftVectorMult, OperatorRightVectorMult,
MatVecOperator)
from odl.util.testutils import almost_equal, all_almost_equal


Expand Down Expand Up @@ -184,36 +184,14 @@ def test_nonlinear_composition():
C = OperatorComp(Bop, Aop)


class MultiplyOp(Operator):

"""Multiply with matrix.
"""

def __init__(self, matrix, domain=None, range=None):
domain = (odl.Rn(matrix.shape[1])
if domain is None else domain)
range = (odl.Rn(matrix.shape[0])
if range is None else range)
self.matrix = matrix

super().__init__(domain, range, linear=True)

def _apply(self, rhs, out):
np.dot(self.matrix, rhs.data, out=out.data)

@property
def adjoint(self):
return MultiplyOp(self.matrix.T, self.range, self.domain)


def test_linear_Op():
# Verify that the multiply op does indeed work as expected

A = np.random.rand(3, 3)
x = np.random.rand(3)
out = np.random.rand(3)

Aop = MultiplyOp(A)
Aop = MatVecOperator(A)
xvec = Aop.domain.element(x)
outvec = Aop.range.element()

Expand All @@ -232,7 +210,8 @@ def test_linear_op_nonsquare():
x = np.random.rand(3)
out = np.random.rand(4)

Aop = MultiplyOp(A)
Aop = MatVecOperator(A)

xvec = Aop.domain.element(x)
outvec = Aop.range.element()

Expand All @@ -250,7 +229,7 @@ def test_linear_adjoint():
x = np.random.rand(4)
out = np.random.rand(3)

Aop = MultiplyOp(A)
Aop = MatVecOperator(A)
xvec = Aop.range.element(x)
outvec = Aop.domain.element()

Expand All @@ -269,8 +248,8 @@ def test_linear_addition():
x = np.random.rand(3)
y = np.random.rand(4)

Aop = MultiplyOp(A)
Bop = MultiplyOp(B)
Aop = MatVecOperator(A)
Bop = MatVecOperator(B)
xvec = Aop.domain.element(x)
yvec = Aop.range.element(y)

Expand All @@ -295,7 +274,7 @@ def test_linear_scale():
x = np.random.rand(3)
y = np.random.rand(4)

Aop = MultiplyOp(A)
Aop = MatVecOperator(A)
xvec = Aop.domain.element(x)
yvec = Aop.range.element(y)

Expand Down Expand Up @@ -325,7 +304,7 @@ def test_linear_scale():
def test_linear_right_vector_mult():
A = np.random.rand(4, 3)

Aop = MultiplyOp(A)
Aop = MatVecOperator(A)
vec = Aop.domain.element([1, 2, 3])
x = Aop.domain.element([4, 5, 6])
y = Aop.range.element([5, 6, 7, 8])
Expand Down Expand Up @@ -354,8 +333,8 @@ def test_linear_composition():
x = np.random.rand(3)
y = np.random.rand(5)

Aop = MultiplyOp(A)
Bop = MultiplyOp(B)
Aop = MatVecOperator(A)
Bop = MatVecOperator(B)
xvec = Bop.domain.element(x)
yvec = Aop.range.element(y)

Expand All @@ -372,7 +351,7 @@ def test_type_errors():
r3 = odl.Rn(3)
r4 = odl.Rn(4)

Aop = MultiplyOp(np.random.rand(3, 3))
Aop = MatVecOperator(np.random.rand(3, 3))
r3Vec1 = r3.zero()
r3Vec2 = r3.zero()
r4Vec1 = r4.zero()
Expand Down
Loading

0 comments on commit 2f4073c

Please sign in to comment.