From 746129b225d18918311736722e02cffbd47a5ec1 Mon Sep 17 00:00:00 2001 From: squarefk Date: Thu, 26 Aug 2021 10:46:54 +0800 Subject: [PATCH] [Lang] Fix tensor based grouped ndrange for (#2800) * fix tensor based grouped ndrange for * Auto Format Co-authored-by: Taichi Gardener --- python/taichi/lang/stmt_builder.py | 1 + tests/python/test_ndrange.py | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/python/taichi/lang/stmt_builder.py b/python/taichi/lang/stmt_builder.py index cb0d891c4f5f4..39b4725445480 100644 --- a/python/taichi/lang/stmt_builder.py +++ b/python/taichi/lang/stmt_builder.py @@ -409,6 +409,7 @@ def build_grouped_ndrange_for(ctx, node): ti.core.end_frontend_range_for() '''.format(target, target) t = ast.parse(template).body[0] + node.iter.args[0].args = build_exprs(ctx, node.iter.args[0].args) t.body[0].value = node.iter.args[0] cut = len(t.body) - 1 t.body = t.body[:cut] + node.body + t.body[cut:] diff --git a/tests/python/test_ndrange.py b/tests/python/test_ndrange.py index 6b623aa74bf10..5688ee76ff91a 100644 --- a/tests/python/test_ndrange.py +++ b/tests/python/test_ndrange.py @@ -61,6 +61,32 @@ def func(): assert x[i, j, k] == 0 +@ti.test() +def test_tensor_based_3d(): + x = ti.field(ti.i32, shape=(6, 6, 6)) + y = ti.field(ti.i32, shape=(6, 6, 6)) + + @ti.kernel + def func(): + lower = ti.Vector([0, 1, 2]) + upper = ti.Vector([3, 4, 5]) + for I in ti.grouped( + ti.ndrange((lower[0], upper[0]), (lower[1], upper[1]), + (lower[2], upper[2]))): + x[I] = I[0] + I[1] + I[2] + for i in range(0, 3): + for j in range(1, 4): + for k in range(2, 5): + y[i, j, k] = i + j + k + + func() + + for i in range(6): + for j in range(6): + for k in range(6): + assert x[i, j, k] == y[i, j, k] + + @ti.test() def test_static_grouped(): x = ti.field(ti.f32, shape=(16, 32, 64))