Skip to content

Commit

Permalink
MAINT: remove operator scaling and lico
Browse files Browse the repository at this point in the history
  • Loading branch information
kohr-h committed Aug 29, 2018
1 parent c80607c commit 958e0a0
Showing 1 changed file with 42 additions and 62 deletions.
104 changes: 42 additions & 62 deletions odl/operator/default_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,45 +30,6 @@
'ComplexModulus', 'ComplexModulusSquared')


def _scale_op(operator, scalar):
"""Scale an operator, optimizing for ``scalar=0`` and ``scalar=1``."""
if scalar == 0:
return ZeroOperator(operator.domain, operator.range)
elif scalar == 1:
return operator
else:
return scalar * operator


def _lico_ops(a, op1, b, op2):
"""Linear combination of operators, optimizing trivial cases."""
if op1.domain != op2.domain or op1.range != op2.range:
raise ValueError('domain/range mismatch between {!r} and {!r}'
.format(op1, op2))
dom, ran = op1.domain, op1.range
if a == 0:
if b == 0:
return ZeroOperator(dom, ran)
elif b == 1:
return op2
else:
return b * op2
elif a == 1:
if b == 0:
return op1
elif b == 1:
return op1 + op2
else:
return op1 + b * op2
else:
if b == 0:
return a * op1
elif b == 1:
return a * op1 + op2
else:
return a * op1 + b * op2


class ScalingOperator(Operator):

"""Operator of multiplication with a scalar.
Expand Down Expand Up @@ -1538,24 +1499,33 @@ def inverse(self):
return ZeroOperator(self.range, self.domain)

if self.domain.is_real:
# Real domain
# Optimizations for simple cases.
# Real domain, with optimizations for simple cases
if self.scalar.real == self.scalar:
return _scale_op(RealPart(self.range, self.domain),
1 / self.scalar.real)
# embedding x -> (a + i*0) * x
op = RealPart(self.range, self.domain)
if self.scalar.real != 1:
op = (1 / self.scalar.real) * op
return op
elif 1j * self.scalar.imag == self.scalar:
return _scale_op(ImagPart(self.range, self.domain),
1 / self.scalar.imag)
# embedding x -> (0 + i*b) * x
op = ImagPart(self.range, self.domain)
if self.scalar.imag != 1:
op = (1 / self.scalar.imag) * op
return op
else:
# General case
inv_scalar = (1 / self.scalar).conjugate()
return _lico_ops(
inv_scalar.real, RealPart(self.range, self.domain),
inv_scalar.imag, ImagPart(self.range, self.domain))
# embedding x -> (a + i*b) * x
inv_scalar = 1 / self.scalar
re_op = RealPart(self.range, self.domain)
if inv_scalar.real != 1:
re_op = inv_scalar.real * re_op
im_op = ImagPart(self.range, self.domain)
if inv_scalar.imag != 1:
im_op = inv_scalar.imag * im_op

return re_op + im_op
else:
# Complex domain
return ComplexEmbedding(self.range, self.domain,
self.scalar.conjugate())
return ComplexEmbedding(self.range, self.domain, 1 / self.scalar)

@property
def adjoint(self):
Expand Down Expand Up @@ -1602,19 +1572,29 @@ def adjoint(self):
return ZeroOperator(self.range, self.domain)

if self.domain.is_real:
# Real domain
# Optimizations for simple cases.
# Real domain, with optimizations for simple cases
if self.scalar.real == self.scalar:
return _scale_op(self.scalar.real,
ComplexEmbedding(self.range, self.domain))
# embedding x -> (a + i*0) * x
op = RealPart(self.range, self.domain)
if self.scalar.real != 1:
op = self.scalar.real * op
return op
elif 1j * self.scalar.imag == self.scalar:
return _scale_op(self.scalar.imag,
ImagPart(self.range, self.domain))
# embedding x -> (0 + i*b) * x
op = ImagPart(self.range, self.domain)
if self.scalar.imag != 1:
op = self.scalar.imag * op
return op
else:
# General case
return _lico_ops(
self.scalar.real, RealPart(self.range, self.domain),
self.scalar.imag, ImagPart(self.range, self.domain))
# embedding x -> (a + i*b) * x
re_op = RealPart(self.range, self.domain)
if self.scalar.real != 1:
re_op = self.scalar.real * re_op
im_op = ImagPart(self.range, self.domain)
if self.scalar.imag != 1:
im_op = self.scalar.imag * im_op

return re_op + im_op
else:
# Complex domain
return ComplexEmbedding(self.range, self.domain,
Expand Down

0 comments on commit 958e0a0

Please sign in to comment.