diff --git a/odl/operator/default_ops.py b/odl/operator/default_ops.py index 8f9be6a9d54..fe2d2faa275 100644 --- a/odl/operator/default_ops.py +++ b/odl/operator/default_ops.py @@ -30,45 +30,6 @@ 'ComplexModulus', 'ComplexModulusSquared') -def _scale_op(operator, scalar): - """Scale an operator, optimizing for ``scalar=0`` and ``scalar=1``.""" - if scalar == 0: - return ZeroOperator(operator.domain, operator.range) - elif scalar == 1: - return operator - else: - return scalar * operator - - -def _lico_ops(a, op1, b, op2): - """Linear combination of operators, optimizing trivial cases.""" - if op1.domain != op2.domain or op1.range != op2.range: - raise ValueError('domain/range mismatch between {!r} and {!r}' - .format(op1, op2)) - dom, ran = op1.domain, op1.range - if a == 0: - if b == 0: - return ZeroOperator(dom, ran) - elif b == 1: - return op2 - else: - return b * op2 - elif a == 1: - if b == 0: - return op1 - elif b == 1: - return op1 + op2 - else: - return op1 + b * op2 - else: - if b == 0: - return a * op1 - elif b == 1: - return a * op1 + op2 - else: - return a * op1 + b * op2 - - class ScalingOperator(Operator): """Operator of multiplication with a scalar. @@ -1538,24 +1499,33 @@ def inverse(self): return ZeroOperator(self.range, self.domain) if self.domain.is_real: - # Real domain - # Optimizations for simple cases. + # Real domain, with optimizations for simple cases if self.scalar.real == self.scalar: - return _scale_op(RealPart(self.range, self.domain), - 1 / self.scalar.real) + # embedding x -> (a + i*0) * x + op = RealPart(self.range, self.domain) + if self.scalar.real != 1: + op = (1 / self.scalar.real) * op + return op elif 1j * self.scalar.imag == self.scalar: - return _scale_op(ImagPart(self.range, self.domain), - 1 / self.scalar.imag) + # embedding x -> (0 + i*b) * x + op = ImagPart(self.range, self.domain) + if self.scalar.imag != 1: + op = (1 / self.scalar.imag) * op + return op else: - # General case - inv_scalar = (1 / self.scalar).conjugate() - return _lico_ops( - inv_scalar.real, RealPart(self.range, self.domain), - inv_scalar.imag, ImagPart(self.range, self.domain)) + # embedding x -> (a + i*b) * x + inv_scalar = 1 / self.scalar + re_op = RealPart(self.range, self.domain) + if inv_scalar.real != 1: + re_op = inv_scalar.real * re_op + im_op = ImagPart(self.range, self.domain) + if inv_scalar.imag != 1: + im_op = inv_scalar.imag * im_op + + return re_op + im_op else: # Complex domain - return ComplexEmbedding(self.range, self.domain, - self.scalar.conjugate()) + return ComplexEmbedding(self.range, self.domain, 1 / self.scalar) @property def adjoint(self): @@ -1602,19 +1572,29 @@ def adjoint(self): return ZeroOperator(self.range, self.domain) if self.domain.is_real: - # Real domain - # Optimizations for simple cases. + # Real domain, with optimizations for simple cases if self.scalar.real == self.scalar: - return _scale_op(self.scalar.real, - ComplexEmbedding(self.range, self.domain)) + # embedding x -> (a + i*0) * x + op = RealPart(self.range, self.domain) + if self.scalar.real != 1: + op = self.scalar.real * op + return op elif 1j * self.scalar.imag == self.scalar: - return _scale_op(self.scalar.imag, - ImagPart(self.range, self.domain)) + # embedding x -> (0 + i*b) * x + op = ImagPart(self.range, self.domain) + if self.scalar.imag != 1: + op = self.scalar.imag * op + return op else: - # General case - return _lico_ops( - self.scalar.real, RealPart(self.range, self.domain), - self.scalar.imag, ImagPart(self.range, self.domain)) + # embedding x -> (a + i*b) * x + re_op = RealPart(self.range, self.domain) + if self.scalar.real != 1: + re_op = self.scalar.real * re_op + im_op = ImagPart(self.range, self.domain) + if self.scalar.imag != 1: + im_op = self.scalar.imag * im_op + + return re_op + im_op else: # Complex domain return ComplexEmbedding(self.range, self.domain,