diff --git a/odl/solvers/nonsmooth/proximal_operators.py b/odl/solvers/nonsmooth/proximal_operators.py index da3849939f8..0e1c6258950 100644 --- a/odl/solvers/nonsmooth/proximal_operators.py +++ b/odl/solvers/nonsmooth/proximal_operators.py @@ -1035,7 +1035,11 @@ def _call(self, x, out): if g is not None: diff = x - self.sigma * g else: - diff = x + if x is out: + # Handle aliased data properly + diff = x.copy() + else: + diff = x if isotropic: # Calculate |x| = pointwise 2-norm of x diff --git a/odl/test/solvers/nonsmooth/proximal_operator_test.py b/odl/test/solvers/nonsmooth/proximal_operator_test.py index 4ca988bee7a..3feb0fdca05 100644 --- a/odl/test/solvers/nonsmooth/proximal_operator_test.py +++ b/odl/test/solvers/nonsmooth/proximal_operator_test.py @@ -295,14 +295,23 @@ def test_proximal_convconj_l1_simple_space_without_data(): assert isinstance(prox, odl.Operator) # Apply the proximal operator returning its optimal point - x_opt = space.element() - prox(x, x_opt) - # Explicit computation: x / max(lam, |x|) denom = np.maximum(lam, np.sqrt(x_arr ** 2)) - x_verify = lam * x_arr / denom + x_exact = lam * x_arr / denom - assert all_almost_equal(x_opt, x_verify, HIGH_ACC) + # Using out + x_opt = space.element() + x_result = prox(x, x_opt) + assert x_result is x_opt + assert all_almost_equal(x_opt, x_exact, HIGH_ACC) + + # Without out + x_result = prox(x) + assert all_almost_equal(x_result, x_exact, HIGH_ACC) + + # With aliased out + x_result = prox(x, x) + assert all_almost_equal(x_result, x_exact, HIGH_ACC) def test_proximal_convconj_l1_simple_space_with_data():