From 71e81fe52db13324b4331640745d7275e189445b Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Sat, 29 Aug 2020 22:37:41 +0900 Subject: [PATCH 1/3] [Lang] Support formatted string with args in assert --- python/taichi/lang/impl.py | 9 ++++++++ python/taichi/lang/transformer.py | 35 ++++++++++++++++++++++++++++--- taichi/ir/frontend_ir.h | 10 +++++++++ taichi/python/export_lang.cpp | 5 +++-- taichi/transforms/lower_ast.cpp | 9 +++++++- tests/python/test_assert.py | 27 ++++++++++++++++++++++++ 6 files changed, 89 insertions(+), 6 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 9ffe22702526a..667168ca343a4 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -454,6 +454,15 @@ def fused_string(entries): taichi_lang_core.create_print(contentries) +@taichi_scope +def ti_assert(cond, msg, extra_args): + # Mostly a wrapper to help us convert from ti.Expr (defined in Python) to + # taichi_lang_core.Expr (defined in C++) + import taichi as ti + taichi_lang_core.create_assert_stmt( + ti.Expr(cond).ptr, msg, [ti.Expr(x).ptr for x in extra_args]) + + @taichi_scope def ti_int(var): _taichi_skip_traceback = 1 diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index 8db97e0f0eb7f..884645b0b1f64 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -805,12 +805,35 @@ def make_node(a, b, token): return new_node + def _is_string_mod_args(self, msg): + # 1. str % (a, b, c, ...) + # 2. str % single_item + # Note that |msg.right| may not be a tuple. + return isinstance(msg, ast.BinOp) and isinstance( + msg.left, ast.Str) and isinstance(msg.op, ast.Mod) + + def _handle_string_mod_args(self, msg): + assert self._is_string_mod_args(msg) + s = msg.left.s + t = None + if isinstance(msg.right, ast.Tuple): + t = msg.right + else: + # assuming the format is `str % single_item` + t = ast.Tuple(elts=[msg.right], ctx=ast.Load()) + self.generic_visit(t) + return s, t + def visit_Assert(self, node): + is_str_mod = False if node.msg is not None: if isinstance(node.msg, ast.Constant): msg = node.msg.value elif isinstance(node.msg, ast.Str): msg = node.msg.s + elif self._is_string_mod_args(node.msg): + # Delay the handling until we call generic_visit() on |node|. + is_str_mod = True else: raise ValueError( f"assert info must be constant, not {ast.dump(node.msg)}") @@ -818,10 +841,16 @@ def visit_Assert(self, node): import astor msg = astor.to_source(node.test) self.generic_visit(node) - new_node = self.parse_stmt( - 'ti.core.create_assert_stmt(ti.Expr(0).ptr, 0)') - new_node.value.args[0].value.args[0] = node.test + + extra_args = ast.List(elts=[], ctx=ast.Load()) + if is_str_mod: + msg, extra_args = self._handle_string_mod_args(node.msg) + + new_node = self.parse_stmt('ti.ti_assert(0, 0, [])') + new_node.value.args[0] = node.test new_node.value.args[1] = self.parse_expr("'{}'".format(msg.strip())) + new_node.value.args[2] = extra_args + new_node = ast.copy_location(new_node, node) return new_node def visit_Return(self, node): diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index a95959fe393e6..d79304ea43a59 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -41,11 +41,21 @@ class FrontendAssertStmt : public Stmt { public: std::string text; Expr val; + std::vector args; FrontendAssertStmt(const std::string &text, const Expr &val) : text(text), val(val) { } + FrontendAssertStmt(const std::string &text, + const Expr &val, + const std::vector &args_) + : text(text), val(val) { + for (auto &a : args_) { + args.push_back(load_if_ptr(a)); + } + } + TI_DEFINE_ACCEPT }; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 6eb9965745f0f..cfb1d9ffa4bb1 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -288,8 +288,9 @@ void export_lang(py::module &m) { return Length(snode, indices); }); - m.def("create_assert_stmt", [&](const Expr &cond, const std::string &msg) { - auto stmt_unique = std::make_unique(msg, cond); + m.def("create_assert_stmt", [&](const Expr &cond, const std::string &msg, + const std::vector &args) { + auto stmt_unique = std::make_unique(msg, cond, args); current_ast_builder().insert(std::move(stmt_unique)); }); diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index d128d7ae8feb9..93474c30e9d5f 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -404,7 +404,14 @@ class LowerAST : public IRVisitor { expr->flatten(&fctx); val_stmt = expr->stmt; } - fctx.push_back(stmt->text, val_stmt); + + auto &fargs = stmt->args; // frontend stmt args + std::vector args_stmts(fargs.size()); + for (int i = 0; i < (int)fargs.size(); ++i) { + fargs[i]->flatten(&fctx); + args_stmts[i] = fargs[i]->stmt; + } + fctx.push_back(val_stmt, stmt->text, args_stmts); stmt->parent->replace_with(stmt, std::move(fctx.stmts)); throw IRModified(); } diff --git a/tests/python/test_assert.py b/tests/python/test_assert.py index 670b596a383e0..d7bc31318a264 100644 --- a/tests/python/test_assert.py +++ b/tests/python/test_assert.py @@ -11,8 +11,13 @@ def test_assert_minimal(): def func(): assert 0 + @ti.kernel + def func2(): + assert False + with pytest.raises(RuntimeError): func() + func2() @ti.require(ti.extension.assertion) @@ -39,6 +44,28 @@ def func(): func() +@ti.require(ti.extension.assertion) +@ti.all_archs_with(debug=True, gdb_trigger=False) +def test_assert_message_formatted(): + x = ti.field(dtype=int, shape=16) + x[10] = 42 + + @ti.kernel + def assert_formatted(): + for i in x: + assert x[i] == 0, 'x[%d] expect=%d got=%d' % (i, 0, x[i]) + + @ti.kernel + def assert_float(): + y = 0.5 + assert y < 0, 'y = %f' % y + + with pytest.raises(RuntimeError, match=r'x\[10\] expect=0 got=42'): + assert_formatted() + with pytest.raises(RuntimeError, match=r'y = 0.5'): + assert_float() + + @ti.require(ti.extension.assertion) @ti.all_archs_with(debug=True, gdb_trigger=False) def test_assert_ok(): From 9121487cb241152817cfe8dfd3d174d58a2522c2 Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Sun, 30 Aug 2020 10:56:20 +0900 Subject: [PATCH 2/3] Update tests/python/test_assert.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 彭于斌 <1931127624@qq.com> --- tests/python/test_assert.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/test_assert.py b/tests/python/test_assert.py index d7bc31318a264..a9b3b511aee9f 100644 --- a/tests/python/test_assert.py +++ b/tests/python/test_assert.py @@ -17,6 +17,7 @@ def func2(): with pytest.raises(RuntimeError): func() + with pytest.raises(RuntimeError): func2() From 68659df1f60b5f49f8e61177de6c2ef29dde989d Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Sun, 30 Aug 2020 10:58:28 +0900 Subject: [PATCH 3/3] add success test --- tests/python/test_assert.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/python/test_assert.py b/tests/python/test_assert.py index a9b3b511aee9f..09aba7a3e66a8 100644 --- a/tests/python/test_assert.py +++ b/tests/python/test_assert.py @@ -66,6 +66,10 @@ def assert_float(): with pytest.raises(RuntimeError, match=r'y = 0.5'): assert_float() + # success case + x[10] = 0 + assert_formatted() + @ti.require(ti.extension.assertion) @ti.all_archs_with(debug=True, gdb_trigger=False)