Skip to content

Commit

Permalink
Merge pull request odlgroup#1225 from kohr-h/issue-1181__op_aliasing_bug
Browse files Browse the repository at this point in the history
BUG: fix in-out aliasing issue in OperatorSum, closes odlgroup#1181
  • Loading branch information
Holger Kohr authored Nov 11, 2017
2 parents 27f32e6 + 9b5e1b9 commit 72b7fc8
Show file tree
Hide file tree
Showing 3 changed files with 388 additions and 331 deletions.
12 changes: 8 additions & 4 deletions odl/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,8 +1154,10 @@ def _call(self, x, out=None):
else:
tmp = (self.__tmp_ran if self.__tmp_ran is not None
else self.range.element())
self.left(x, out=out)
self.right(x, out=tmp)
# Write to `tmp` first, otherwise aliased `x` and `out` lead
# to wrong result
self.left(x, out=tmp)
self.right(x, out=out)
out += tmp

def derivative(self, x):
Expand Down Expand Up @@ -1492,8 +1494,10 @@ def _call(self, x, out=None):
return self.left(x) * self.right(x)
else:
tmp = self.right.range.element()
self.left(x, out=out)
self.right(x, out=tmp)
# Write to `tmp` first, otherwise aliased `x` and `out` lead
# to wrong result
self.left(x, out=tmp)
self.right(x, out=out)
out *= tmp

def derivative(self, x):
Expand Down
12 changes: 10 additions & 2 deletions odl/operator/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,8 @@ def inverse(self):

def _call(self, x, out=None):
"""Return ``self(x[, out])``."""
from pkg_resources import parse_version

if out is None:
return self.range.element(self.matrix.dot(x))
else:
Expand All @@ -866,8 +868,14 @@ def _call(self, x, out=None):
# sparse matrices
out[:] = self.matrix.dot(x)
else:
with writable_array(out) as out_arr:
self.matrix.dot(x, out=out_arr)
if (parse_version(np.__version__) < parse_version('1.13.0') and
x is out):
# Workaround for bug in Numpy < 1.13 with aliased in and
# out in np.dot
out[:] = self.matrix.dot(x)
else:
with writable_array(out) as out_arr:
self.matrix.dot(x, out=out_arr)

def __repr__(self):
"""Return ``repr(self)``."""
Expand Down
Loading

0 comments on commit 72b7fc8

Please sign in to comment.