diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 9ffe22702526a..667168ca343a4 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -454,6 +454,15 @@ def fused_string(entries): taichi_lang_core.create_print(contentries) +@taichi_scope +def ti_assert(cond, msg, extra_args): + # Mostly a wrapper to help us convert from ti.Expr (defined in Python) to + # taichi_lang_core.Expr (defined in C++) + import taichi as ti + taichi_lang_core.create_assert_stmt( + ti.Expr(cond).ptr, msg, [ti.Expr(x).ptr for x in extra_args]) + + @taichi_scope def ti_int(var): _taichi_skip_traceback = 1 diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index 8db97e0f0eb7f..884645b0b1f64 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -805,12 +805,35 @@ def make_node(a, b, token): return new_node + def _is_string_mod_args(self, msg): + # 1. str % (a, b, c, ...) + # 2. str % single_item + # Note that |msg.right| may not be a tuple. + return isinstance(msg, ast.BinOp) and isinstance( + msg.left, ast.Str) and isinstance(msg.op, ast.Mod) + + def _handle_string_mod_args(self, msg): + assert self._is_string_mod_args(msg) + s = msg.left.s + t = None + if isinstance(msg.right, ast.Tuple): + t = msg.right + else: + # assuming the format is `str % single_item` + t = ast.Tuple(elts=[msg.right], ctx=ast.Load()) + self.generic_visit(t) + return s, t + def visit_Assert(self, node): + is_str_mod = False if node.msg is not None: if isinstance(node.msg, ast.Constant): msg = node.msg.value elif isinstance(node.msg, ast.Str): msg = node.msg.s + elif self._is_string_mod_args(node.msg): + # Delay the handling until we call generic_visit() on |node|. + is_str_mod = True else: raise ValueError( f"assert info must be constant, not {ast.dump(node.msg)}") @@ -818,10 +841,16 @@ def visit_Assert(self, node): import astor msg = astor.to_source(node.test) self.generic_visit(node) - new_node = self.parse_stmt( - 'ti.core.create_assert_stmt(ti.Expr(0).ptr, 0)') - new_node.value.args[0].value.args[0] = node.test + + extra_args = ast.List(elts=[], ctx=ast.Load()) + if is_str_mod: + msg, extra_args = self._handle_string_mod_args(node.msg) + + new_node = self.parse_stmt('ti.ti_assert(0, 0, [])') + new_node.value.args[0] = node.test new_node.value.args[1] = self.parse_expr("'{}'".format(msg.strip())) + new_node.value.args[2] = extra_args + new_node = ast.copy_location(new_node, node) return new_node def visit_Return(self, node): diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index a95959fe393e6..d79304ea43a59 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -41,11 +41,21 @@ class FrontendAssertStmt : public Stmt { public: std::string text; Expr val; + std::vector args; FrontendAssertStmt(const std::string &text, const Expr &val) : text(text), val(val) { } + FrontendAssertStmt(const std::string &text, + const Expr &val, + const std::vector &args_) + : text(text), val(val) { + for (auto &a : args_) { + args.push_back(load_if_ptr(a)); + } + } + TI_DEFINE_ACCEPT }; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 6eb9965745f0f..cfb1d9ffa4bb1 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -288,8 +288,9 @@ void export_lang(py::module &m) { return Length(snode, indices); }); - m.def("create_assert_stmt", [&](const Expr &cond, const std::string &msg) { - auto stmt_unique = std::make_unique(msg, cond); + m.def("create_assert_stmt", [&](const Expr &cond, const std::string &msg, + const std::vector &args) { + auto stmt_unique = std::make_unique(msg, cond, args); current_ast_builder().insert(std::move(stmt_unique)); }); diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index d128d7ae8feb9..93474c30e9d5f 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -404,7 +404,14 @@ class LowerAST : public IRVisitor { expr->flatten(&fctx); val_stmt = expr->stmt; } - fctx.push_back(stmt->text, val_stmt); + + auto &fargs = stmt->args; // frontend stmt args + std::vector args_stmts(fargs.size()); + for (int i = 0; i < (int)fargs.size(); ++i) { + fargs[i]->flatten(&fctx); + args_stmts[i] = fargs[i]->stmt; + } + fctx.push_back(val_stmt, stmt->text, args_stmts); stmt->parent->replace_with(stmt, std::move(fctx.stmts)); throw IRModified(); } diff --git a/tests/python/test_assert.py b/tests/python/test_assert.py index 670b596a383e0..09aba7a3e66a8 100644 --- a/tests/python/test_assert.py +++ b/tests/python/test_assert.py @@ -11,8 +11,14 @@ def test_assert_minimal(): def func(): assert 0 + @ti.kernel + def func2(): + assert False + with pytest.raises(RuntimeError): func() + with pytest.raises(RuntimeError): + func2() @ti.require(ti.extension.assertion) @@ -39,6 +45,32 @@ def func(): func() +@ti.require(ti.extension.assertion) +@ti.all_archs_with(debug=True, gdb_trigger=False) +def test_assert_message_formatted(): + x = ti.field(dtype=int, shape=16) + x[10] = 42 + + @ti.kernel + def assert_formatted(): + for i in x: + assert x[i] == 0, 'x[%d] expect=%d got=%d' % (i, 0, x[i]) + + @ti.kernel + def assert_float(): + y = 0.5 + assert y < 0, 'y = %f' % y + + with pytest.raises(RuntimeError, match=r'x\[10\] expect=0 got=42'): + assert_formatted() + with pytest.raises(RuntimeError, match=r'y = 0.5'): + assert_float() + + # success case + x[10] = 0 + assert_formatted() + + @ti.require(ti.extension.assertion) @ti.all_archs_with(debug=True, gdb_trigger=False) def test_assert_ok():