Skip to content

Commit

Permalink
[Lang] Support formatted string with args in assert (#1806)
Browse files Browse the repository at this point in the history
* [Lang] Support formatted string with args in assert

* Update tests/python/test_assert.py

Co-authored-by: 彭于斌 <[email protected]>

* add success test

Co-authored-by: 彭于斌 <[email protected]>
  • Loading branch information
k-ye and archibate authored Aug 30, 2020
1 parent c3f1a5d commit 58363bb
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 6 deletions.
9 changes: 9 additions & 0 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,15 @@ def fused_string(entries):
taichi_lang_core.create_print(contentries)


@taichi_scope
def ti_assert(cond, msg, extra_args):
# Mostly a wrapper to help us convert from ti.Expr (defined in Python) to
# taichi_lang_core.Expr (defined in C++)
import taichi as ti
taichi_lang_core.create_assert_stmt(
ti.Expr(cond).ptr, msg, [ti.Expr(x).ptr for x in extra_args])


@taichi_scope
def ti_int(var):
_taichi_skip_traceback = 1
Expand Down
35 changes: 32 additions & 3 deletions python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,23 +805,52 @@ def make_node(a, b, token):

return new_node

def _is_string_mod_args(self, msg):
# 1. str % (a, b, c, ...)
# 2. str % single_item
# Note that |msg.right| may not be a tuple.
return isinstance(msg, ast.BinOp) and isinstance(
msg.left, ast.Str) and isinstance(msg.op, ast.Mod)

def _handle_string_mod_args(self, msg):
assert self._is_string_mod_args(msg)
s = msg.left.s
t = None
if isinstance(msg.right, ast.Tuple):
t = msg.right
else:
# assuming the format is `str % single_item`
t = ast.Tuple(elts=[msg.right], ctx=ast.Load())
self.generic_visit(t)
return s, t

def visit_Assert(self, node):
is_str_mod = False
if node.msg is not None:
if isinstance(node.msg, ast.Constant):
msg = node.msg.value
elif isinstance(node.msg, ast.Str):
msg = node.msg.s
elif self._is_string_mod_args(node.msg):
# Delay the handling until we call generic_visit() on |node|.
is_str_mod = True
else:
raise ValueError(
f"assert info must be constant, not {ast.dump(node.msg)}")
else:
import astor
msg = astor.to_source(node.test)
self.generic_visit(node)
new_node = self.parse_stmt(
'ti.core.create_assert_stmt(ti.Expr(0).ptr, 0)')
new_node.value.args[0].value.args[0] = node.test

extra_args = ast.List(elts=[], ctx=ast.Load())
if is_str_mod:
msg, extra_args = self._handle_string_mod_args(node.msg)

new_node = self.parse_stmt('ti.ti_assert(0, 0, [])')
new_node.value.args[0] = node.test
new_node.value.args[1] = self.parse_expr("'{}'".format(msg.strip()))
new_node.value.args[2] = extra_args
new_node = ast.copy_location(new_node, node)
return new_node

def visit_Return(self, node):
Expand Down
10 changes: 10 additions & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,21 @@ class FrontendAssertStmt : public Stmt {
public:
std::string text;
Expr val;
std::vector<Expr> args;

FrontendAssertStmt(const std::string &text, const Expr &val)
: text(text), val(val) {
}

FrontendAssertStmt(const std::string &text,
const Expr &val,
const std::vector<Expr> &args_)
: text(text), val(val) {
for (auto &a : args_) {
args.push_back(load_if_ptr(a));
}
}

TI_DEFINE_ACCEPT
};

Expand Down
5 changes: 3 additions & 2 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,9 @@ void export_lang(py::module &m) {
return Length(snode, indices);
});

m.def("create_assert_stmt", [&](const Expr &cond, const std::string &msg) {
auto stmt_unique = std::make_unique<FrontendAssertStmt>(msg, cond);
m.def("create_assert_stmt", [&](const Expr &cond, const std::string &msg,
const std::vector<Expr> &args) {
auto stmt_unique = std::make_unique<FrontendAssertStmt>(msg, cond, args);
current_ast_builder().insert(std::move(stmt_unique));
});

Expand Down
9 changes: 8 additions & 1 deletion taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,14 @@ class LowerAST : public IRVisitor {
expr->flatten(&fctx);
val_stmt = expr->stmt;
}
fctx.push_back<AssertStmt>(stmt->text, val_stmt);

auto &fargs = stmt->args; // frontend stmt args
std::vector<Stmt *> args_stmts(fargs.size());
for (int i = 0; i < (int)fargs.size(); ++i) {
fargs[i]->flatten(&fctx);
args_stmts[i] = fargs[i]->stmt;
}
fctx.push_back<AssertStmt>(val_stmt, stmt->text, args_stmts);
stmt->parent->replace_with(stmt, std::move(fctx.stmts));
throw IRModified();
}
Expand Down
32 changes: 32 additions & 0 deletions tests/python/test_assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@ def test_assert_minimal():
def func():
assert 0

@ti.kernel
def func2():
assert False

with pytest.raises(RuntimeError):
func()
with pytest.raises(RuntimeError):
func2()


@ti.require(ti.extension.assertion)
Expand All @@ -39,6 +45,32 @@ def func():
func()


@ti.require(ti.extension.assertion)
@ti.all_archs_with(debug=True, gdb_trigger=False)
def test_assert_message_formatted():
x = ti.field(dtype=int, shape=16)
x[10] = 42

@ti.kernel
def assert_formatted():
for i in x:
assert x[i] == 0, 'x[%d] expect=%d got=%d' % (i, 0, x[i])

@ti.kernel
def assert_float():
y = 0.5
assert y < 0, 'y = %f' % y

with pytest.raises(RuntimeError, match=r'x\[10\] expect=0 got=42'):
assert_formatted()
with pytest.raises(RuntimeError, match=r'y = 0.5'):
assert_float()

# success case
x[10] = 0
assert_formatted()


@ti.require(ti.extension.assertion)
@ti.all_archs_with(debug=True, gdb_trigger=False)
def test_assert_ok():
Expand Down

0 comments on commit 58363bb

Please sign in to comment.