Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] Raise an error when struct-for indices number mismatch #1357

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ def expr_init_func(rhs): # temporary solution to allow passing in tensors as
return expr_init(rhs)


def begin_frontend_struct_for(group, loop_range):
if group.size() != len(loop_range.shape):
raise IndexError(
'Size mismatch between struct-for indices and loop range: '
f'{group.size()} != {len(loop_range.shape)}. Maybe you want to '
' use ti.grouped(x) to group all indices into a single vector?')
archibate marked this conversation as resolved.
Show resolved Hide resolved
taichi_lang_core.begin_frontend_struct_for(group, loop_range.ptr)


def wrap_scalar(x):
if type(x) in [int, float]:
return Expr(x)
Expand Down
4 changes: 2 additions & 2 deletions python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def visit_struct_for(self, node, is_grouped):
___loop_var = 0
{} = ti.make_var_vector(size=len(___loop_var.loop_range().shape))
___expr_group = ti.make_expr_group({})
ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr)
ti.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range())
ti.core.end_frontend_range_for()
'''.format(vars, vars)
t = ast.parse(template).body[0]
Expand All @@ -450,7 +450,7 @@ def visit_struct_for(self, node, is_grouped):
{}
___loop_var = 0
___expr_group = ti.make_expr_group({})
ti.core.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range().ptr)
ti.begin_frontend_struct_for(___expr_group, ___loop_var.loop_range())
ti.core.end_frontend_range_for()
'''.format(var_decl, vars)
t = ast.parse(template).body[0]
Expand Down
4 changes: 3 additions & 1 deletion taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,9 @@ class Simplify : public IRVisitor {
}

void visit(StructForStmt *for_stmt) override {
TI_ASSERT(current_struct_for == nullptr);
TI_ASSERT_INFO(current_struct_for == nullptr,
"nesting struct-for is not supported for now, "
"please try use a range-for instead.");
archibate marked this conversation as resolved.
Show resolved Hide resolved
current_struct_for = for_stmt;
for_stmt->body->accept(this);
current_struct_for = nullptr;
Expand Down
96 changes: 96 additions & 0 deletions tests/python/test_for_group_mismatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import taichi as ti


@ti.must_throw(IndexError)
@ti.host_arch_only
def test_struct_for_mismatch():
x = ti.var(ti.f32, (3, 4))

@ti.kernel
def func():
for i in x:
print(i)

func()


@ti.must_throw(IndexError)
@ti.host_arch_only
def test_struct_for_mismatch2():
x = ti.var(ti.f32, (3, 4))

@ti.kernel
def func():
for i, j, k in x:
print(i, j, k)

func()


@ti.must_throw(IndexError)
@ti.host_arch_only
def _test_grouped_struct_for_mismatch():
# doesn't work for now
# need grouped refactor
# for now, it just throw a unfriendly message:
# AssertionError: __getitem__ cannot be called in Python-scope
x = ti.var(ti.f32, (3, 4))

@ti.kernel
def func():
for i, j in ti.grouped(x):
print(i, j)

func()


@ti.must_throw(IndexError)
@ti.host_arch_only
def _test_ndrange_for_mismatch():
# doesn't work for now
# need ndrange refactor
@ti.kernel
def func():
for i in ti.ndrange(3, 4):
print(i)

func()


@ti.must_throw(IndexError)
@ti.host_arch_only
def _test_ndrange_for_mismatch2():
# doesn't work for now
# need ndrange and grouped refactor
@ti.kernel
def func():
for i, j, k in ti.ndrange(3, 4):
print(i, j, k)

func()


@ti.must_throw(IndexError)
@ti.host_arch_only
def _test_grouped_ndrange_for_mismatch():
# doesn't work for now
# need ndrange and grouped refactor
@ti.kernel
def func():
for i in ti.grouped(ti.ndrange(3, 4)):
print(i)

func()


@ti.must_throw(IndexError)
@ti.host_arch_only
def _test_static_ndrange_for_mismatch():
# doesn't work for now
# need ndrange and static refactor
@ti.kernel
def func():
for i in ti.static(ti.ndrange(3, 4)):
print(i)

func()