Skip to content

Commit

Permalink
[Lang] Fix tensor based grouped ndrange for (#2800)
Browse files Browse the repository at this point in the history
* fix tensor based grouped ndrange for

* Auto Format

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
squarefk and taichi-gardener authored Aug 26, 2021
1 parent 33d82f9 commit 746129b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/taichi/lang/stmt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def build_grouped_ndrange_for(ctx, node):
ti.core.end_frontend_range_for()
'''.format(target, target)
t = ast.parse(template).body[0]
node.iter.args[0].args = build_exprs(ctx, node.iter.args[0].args)
t.body[0].value = node.iter.args[0]
cut = len(t.body) - 1
t.body = t.body[:cut] + node.body + t.body[cut:]
Expand Down
26 changes: 26 additions & 0 deletions tests/python/test_ndrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,32 @@ def func():
assert x[i, j, k] == 0


@ti.test()
def test_tensor_based_3d():
x = ti.field(ti.i32, shape=(6, 6, 6))
y = ti.field(ti.i32, shape=(6, 6, 6))

@ti.kernel
def func():
lower = ti.Vector([0, 1, 2])
upper = ti.Vector([3, 4, 5])
for I in ti.grouped(
ti.ndrange((lower[0], upper[0]), (lower[1], upper[1]),
(lower[2], upper[2]))):
x[I] = I[0] + I[1] + I[2]
for i in range(0, 3):
for j in range(1, 4):
for k in range(2, 5):
y[i, j, k] = i + j + k

func()

for i in range(6):
for j in range(6):
for k in range(6):
assert x[i, j, k] == y[i, j, k]


@ti.test()
def test_static_grouped():
x = ti.field(ti.f32, shape=(16, 32, 64))
Expand Down

0 comments on commit 746129b

Please sign in to comment.