Skip to content

Commit

Permalink
[bug] Matrix refactor bug fix: Fix cross scope matrix operations (#6822)
Browse files Browse the repository at this point in the history
Issue: #5819

### Brief Summary
  • Loading branch information
jim19930609 authored Dec 8, 2022
1 parent c179aca commit b28ca23
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 1 deletion.
4 changes: 4 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ class Matrix(TaichiOperations):
0
"""
_is_taichi_class = True
_is_matrix_class = True
__array_priority__ = 1000

def __init__(self, arr, dt=None, is_ref=False, ndim=None):
Expand Down Expand Up @@ -665,6 +666,9 @@ def _subscript(self, *indices):
is_global_mat = isinstance(self, _MatrixFieldElement)
return self._impl._subscript(is_global_mat, *indices)

def _make_matrix(self):
return make_matrix(self._impl.entries)

def to_list(self):
"""Return this matrix as a 1D `list`.
Expand Down
27 changes: 26 additions & 1 deletion python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,26 @@
from taichi.lang import expr, impl
from taichi.lang.exception import TaichiSyntaxError
from taichi.lang.field import Field
from taichi.lang.util import cook_dtype, is_taichi_class, taichi_scope
from taichi.lang.util import (cook_dtype, is_matrix_class, is_taichi_class,
taichi_scope)


def uniform_matrix_inputs(*args):
has_real_matrix = False
for arg in args:
if is_taichi_expr(arg) and arg.ptr.is_tensor():
has_real_matrix = True
break

results = []
for arg in args:
if has_real_matrix and is_matrix_class(arg):
results.append(arg._make_matrix())
else:
results.append(arg)

return results


unary_ops = []

Expand Down Expand Up @@ -52,6 +71,8 @@ def rev_foo(x, y):

@functools.wraps(foo)
def wrapped(a, b):
a, b = uniform_matrix_inputs(a, b)

if isinstance(a, Field) or isinstance(b, Field):
return NotImplemented
if is_taichi_class(a):
Expand Down Expand Up @@ -82,6 +103,8 @@ def cab_foo(c, a, b):

@functools.wraps(foo)
def wrapped(a, b, c):
a, b, c = uniform_matrix_inputs(a, b, c)

if isinstance(a, Field) or isinstance(b, Field) or isinstance(
c, Field):
return NotImplemented
Expand All @@ -107,6 +130,8 @@ def imp_foo(x, y):

@functools.wraps(foo)
def wrapped(a, b):
a, b = uniform_matrix_inputs(a, b)

if isinstance(a, Field) or isinstance(b, Field):
return NotImplemented
if is_taichi_class(a):
Expand Down
10 changes: 10 additions & 0 deletions python/taichi/lang/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ def get_clangpp():
return _clangpp_presence


def is_matrix_class(rhs):
matrix_class = False
try:
if rhs._is_matrix_class:
matrix_class = True
except:
pass
return matrix_class


def is_taichi_class(rhs):
taichi_class = False
try:
Expand Down
64 changes: 64 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,70 @@ def foo():
foo()


@test_utils.test(arch=[ti.cuda, ti.cpu],
real_matrix=True,
real_matrix_scalarize=True,
debug=True)
def test_cross_scope_matrix_binary_ops():
n = 128
x = ti.Vector.field(3, dtype=int, shape=(n, n))
spring_offsets = [ti.Vector([1, 2]), ti.Vector([2, 3])]

@ti.kernel
def test():
vec = ti.Vector([4, 5])
ind0 = vec + ti.static(spring_offsets)[0]
ind1 = ti.lang.ops.add(vec, ti.static(spring_offsets)[1])

x[ind0] = [100, 10, 1]
x[ind1] = [1, 10, 100]

test()

assert (x[5, 7] == [100, 10, 1]).all()
assert (x[6, 8] == [1, 10, 100]).all()


@test_utils.test(arch=[ti.cuda, ti.cpu],
real_matrix=True,
real_matrix_scalarize=True,
debug=True)
def test_cross_scope_matrix_ternary_ops():
n = 128
x = ti.Vector.field(3, dtype=int, shape=(n, n))
spring_offsets = [ti.Vector([1, 2]), ti.Vector([2, 3])]

@ti.kernel
def test():
vec = ti.Vector([0, 1])
ind0 = ti.select(vec, vec, ti.static(spring_offsets)[0])
x[ind0] = [100, 10, 1]

test()

assert (x[1, 1] == [100, 10, 1]).all()


@test_utils.test(arch=[ti.cuda, ti.cpu],
real_matrix=True,
real_matrix_scalarize=True,
debug=True)
def test_cross_scope_matrix_atomic_ops():
n = 128
x = ti.Vector.field(3, dtype=int, shape=(n, n))
spring_offsets = [ti.Vector([1, 2]), ti.Vector([2, 3])]

@ti.kernel
def test():
vec = ti.Vector([0, 1])
vec += ti.static(spring_offsets)[0]
x[vec] = [100, 10, 1]

test()

assert (x[1, 3] == [100, 10, 1]).all()


@test_utils.test(require=ti.extension.dynamic_index,
dynamic_index=True,
debug=True)
Expand Down

0 comments on commit b28ca23

Please sign in to comment.