Skip to content

Commit

Permalink
[bug] MatrixType bug fix: Fix error with static-grouped-ndrange (#6839)
Browse files Browse the repository at this point in the history
Issue: #5819

### Brief Summary

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Yi Xu <[email protected]>
  • Loading branch information
3 people authored Dec 11, 2022
1 parent bd54211 commit 22d3fee
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 19 deletions.
10 changes: 3 additions & 7 deletions python/taichi/lang/_ndrange.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import collections.abc

import numpy as np
from taichi.lang import impl, ops
from taichi.lang import ops
from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError
from taichi.lang.expr import Expr
from taichi.lang.matrix import _IntermediateMatrix, make_matrix
from taichi.types import primitive_types
from taichi.lang.matrix import _IntermediateMatrix
from taichi.types.utils import is_integral


Expand Down Expand Up @@ -145,10 +144,7 @@ def __init__(self, r):

def __iter__(self):
for ind in self.r:
if impl.current_cfg().real_matrix:
yield make_matrix(list(ind), dt=primitive_types.i32)
else:
yield _IntermediateMatrix(len(ind), 1, list(ind), ndim=1)
yield _IntermediateMatrix(len(ind), 1, list(ind), ndim=1)


__all__ = ['ndrange']
5 changes: 3 additions & 2 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,9 +1071,10 @@ def static(x, *xs):
return [static(x)] + [static(x) for x in xs]

if isinstance(x,
(bool, int, float, range, list, tuple, enumerate, _Ndrange,
GroupedNDRange, zip, filter, map)) or x is None:
(bool, int, float, range, list, tuple, enumerate,
GroupedNDRange, _Ndrange, zip, filter, map)) or x is None:
return x

if isinstance(x, AnyArray):
return x
if isinstance(x, Field):
Expand Down
43 changes: 37 additions & 6 deletions tests/python/test_ast_refactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,7 @@ def foo(a: ti.template()):
assert a[i] == 0


@test_utils.test(print_preprocessed_ir=True)
def test_static_grouped_for_break():
def _test_static_grouped_for_break():
n = 4

@ti.kernel
Expand All @@ -518,6 +517,18 @@ def foo(a: ti.template()):
assert a[i, j] == 0


@test_utils.test(print_preprocessed_ir=True)
def test_static_grouped_for_break():
_test_static_grouped_for_break()


@test_utils.test(print_preprocessed_ir=True,
real_matrix=True,
real_matrix_scalarize=True)
def test_static_grouped_for_break_matrix_scalarize():
_test_static_grouped_for_break()


@test_utils.test(print_preprocessed_ir=True)
def test_static_for_continue():
n = 10
Expand All @@ -540,8 +551,7 @@ def foo(a: ti.template()):
assert a[i] == 3


@test_utils.test(print_preprocessed_ir=True)
def test_static_grouped_for_continue():
def _test_static_grouped_for_continue():
n = 4

@ti.kernel
Expand All @@ -563,6 +573,18 @@ def foo(a: ti.template()):
assert a[i, j] == 3


@test_utils.test(print_preprocessed_ir=True)
def test_static_grouped_for_continue():
_test_static_grouped_for_continue()


@test_utils.test(print_preprocessed_ir=True,
real_matrix=True,
real_matrix_scalarize=True)
def test_static_grouped_for_continue_matrix_scalarize():
_test_static_grouped_for_continue()


@test_utils.test(print_preprocessed_ir=True)
def test_for_break():
n = 4
Expand Down Expand Up @@ -1039,8 +1061,7 @@ def foo() -> ti.i32:
assert foo() == 123


@test_utils.test()
def test_grouped_static_for_cast():
def _test_grouped_static_for_cast():
@ti.kernel
def foo() -> ti.f32:
ret = 0.
Expand All @@ -1050,3 +1071,13 @@ def foo() -> ti.f32:
return ret

assert foo() == test_utils.approx(10)


@test_utils.test()
def test_grouped_static_for_cast():
_test_grouped_static_for_cast()


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_grouped_static_for_cast_matrix_scalarize():
_test_grouped_static_for_cast()
13 changes: 11 additions & 2 deletions tests/python/test_grouped.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,7 @@ def test():
assert val[None] == 42


@test_utils.test()
def test_static_grouped_func():
def _test_static_grouped_func():

K = 3
dim = 2
Expand All @@ -207,3 +206,13 @@ def p2g():
for j in range(K):
for k in range(K):
assert v[i, j][k] == i + j * 3 + k * 10


@test_utils.test()
def test_static_grouped_func():
_test_static_grouped_func()


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_static_grouped_func_matrix_scalarize():
_test_static_grouped_func()
13 changes: 11 additions & 2 deletions tests/python/test_ndrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ def func():
assert x[i, j, k] == 0


@test_utils.test()
def test_static_grouped_static():
def _test_static_grouped_static():
x = ti.Matrix.field(2, 3, dtype=ti.f32, shape=(16, 4))

@ti.kernel
Expand All @@ -126,6 +125,16 @@ def func():
assert x[i, j][k, l] == k + l * 10 + i + j * 4


@test_utils.test()
def test_static_grouped_static():
_test_static_grouped_static()


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_static_grouped_static_matrix_scalarize():
_test_static_grouped_static()


@test_utils.test()
def test_field_init_eye():
# https://github.com/taichi-dev/taichi/issues/1824
Expand Down

0 comments on commit 22d3fee

Please sign in to comment.