Skip to content

Commit

Permalink
[Bug] [lang] [std] Fix all potential matrix SSA violation by using el…
Browse files Browse the repository at this point in the history
…ement_wise_write_binary (#1424)

* [Bug] Fix all potential matrix SSA violation by using element_wise_write_binary

* really fix

* codecov 50%

* [skip ci] enforce code format

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
archibate and taichi-gardener authored Jul 8, 2020
1 parent 8f5c84d commit 33515be
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 5 deletions.
2 changes: 1 addition & 1 deletion codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ coverage:
lang:
paths:
- python/taichi/lang
target: 60%
target: 50%
project:
default: false
lang:
Expand Down
21 changes: 18 additions & 3 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import numbers
import numpy as np
from .util import taichi_scope, python_scope, deprecated, to_numpy_type, to_pytorch_type, in_python_scope
from .util import taichi_scope, python_scope, deprecated, to_numpy_type, to_pytorch_type, in_python_scope, is_taichi_class
from .common_ops import TaichiOperations
from .exception import TaichiSyntaxError
from collections.abc import Iterable
Expand Down Expand Up @@ -155,6 +155,21 @@ def element_wise_binary(self, foo, other):
ret = self.empty_copy()
if isinstance(other, (list, tuple)):
other = Matrix(other)
if isinstance(other, Matrix):
assert self.m == other.m and self.n == other.n, f"Dimension mismatch between shapes ({self.n}, {self.m}), ({other.n}, {other.m})"
for i in range(self.n * self.m):
ret.entries[i] = foo(self.entries[i], other.entries[i])
else: # assumed to be scalar
for i in range(self.n * self.m):
ret.entries[i] = foo(self.entries[i], other)
return ret

def element_wise_writeback_binary(self, foo, other):
ret = self.empty_copy()
if isinstance(other, (list, tuple)):
other = Matrix(other)
if is_taichi_class(other):
other = other.variable()
if foo.__name__ == 'assign' and not isinstance(other, Matrix):
raise TaichiSyntaxError(
'cannot assign scalar expr to '
Expand Down Expand Up @@ -205,7 +220,7 @@ def linearize_entry_id(self, *args):
' for i in ti.static(range(3)):\n'
' print(i, "-th component is", vec[i])\n'
'See https://taichi.readthedocs.io/en/stable/meta.html#when-to-use-for-loops-with-ti-static for more details.'
)
)
return args[0] * self.m + args[1]

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -609,7 +624,7 @@ def assign_renamed(x, y):
import taichi as ti
return ti.assign(x, y)

return self.element_wise_binary(assign_renamed, val)
return self.element_wise_writeback_binary(assign_renamed, val)

if isinstance(val, numbers.Number):
val = tuple(
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def imp_foo(x, y):
@functools.wraps(foo)
def wrapped(a, b):
if ti.is_taichi_class(a):
return a.element_wise_binary(imp_foo, b)
return a.element_wise_writeback_binary(imp_foo, b)
elif ti.is_taichi_class(b):
raise SyntaxError(
f'cannot augassign taichi class {type(b)} to scalar expr')
Expand Down
4 changes: 4 additions & 0 deletions tests/python/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,23 @@ def init():
def test_matrix_ssa():
a = ti.Vector(2, ti.f32, ())
b = ti.Matrix(2, 2, ti.f32, ())
c = ti.Vector(2, ti.f32, ())

@ti.kernel
def func():
a[None] = a[None].normalized()
b[None] = b[None].transpose()
c[None] = ti.Vector([c[None][1], c[None][0]])

inv_sqrt2 = 1 / math.sqrt(2)

a[None] = [1, 1]
b[None] = [[1, 2], [3, 4]]
c[None] = [2, 3]
func()
assert a[None].value == ti.Vector([inv_sqrt2, inv_sqrt2])
assert b[None].value == ti.Matrix([[1, 3], [2, 4]])
assert c[None].value == ti.Vector([3, 2])


@ti.all_archs
Expand Down

0 comments on commit 33515be

Please sign in to comment.