Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Support formatted string with args in assert #1806

Merged
merged 4 commits into from
Aug 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,15 @@ def fused_string(entries):
taichi_lang_core.create_print(contentries)


@taichi_scope
def ti_assert(cond, msg, extra_args):
# Mostly a wrapper to help us convert from ti.Expr (defined in Python) to
# taichi_lang_core.Expr (defined in C++)
import taichi as ti
taichi_lang_core.create_assert_stmt(
ti.Expr(cond).ptr, msg, [ti.Expr(x).ptr for x in extra_args])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if extra_args is Vector? Let's make use of Vector.__ti_repr__ like currently ti_print does.

Copy link
Member Author

@k-ye k-ye Aug 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Unfortunately, assert format specifier doesn't support such a fancy data structure yet, see

if (dtype == 'd') {
error_message_formatted += fmt::format(
"{}", taichi_union_cast_with_different_sizes<int32>(argument));
} else if (dtype == 'f') {
error_message_formatted += fmt::format(
"{}",
taichi_union_cast_with_different_sizes<float32>(argument));
} else {
TI_ERROR("Data type identifier %{} is not supported", dtype);
}

It only has %d or %f..

I think this comes from the fact that, right now print and assert handle args differently. assert only handles string with format specifiers, which limits what kind of runtime data we can put into the message.



@taichi_scope
def ti_int(var):
_taichi_skip_traceback = 1
Expand Down
35 changes: 32 additions & 3 deletions python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,23 +805,52 @@ def make_node(a, b, token):

return new_node

def _is_string_mod_args(self, msg):
# 1. str % (a, b, c, ...)
# 2. str % single_item
# Note that |msg.right| may not be a tuple.
return isinstance(msg, ast.BinOp) and isinstance(
msg.left, ast.Str) and isinstance(msg.op, ast.Mod)

def _handle_string_mod_args(self, msg):
assert self._is_string_mod_args(msg)
s = msg.left.s
t = None
if isinstance(msg.right, ast.Tuple):
t = msg.right
else:
# assuming the format is `str % single_item`
t = ast.Tuple(elts=[msg.right], ctx=ast.Load())
self.generic_visit(t)
return s, t

def visit_Assert(self, node):
is_str_mod = False
if node.msg is not None:
if isinstance(node.msg, ast.Constant):
msg = node.msg.value
elif isinstance(node.msg, ast.Str):
msg = node.msg.s
elif self._is_string_mod_args(node.msg):
# Delay the handling until we call generic_visit() on |node|.
is_str_mod = True
else:
raise ValueError(
f"assert info must be constant, not {ast.dump(node.msg)}")
else:
import astor
msg = astor.to_source(node.test)
self.generic_visit(node)
new_node = self.parse_stmt(
'ti.core.create_assert_stmt(ti.Expr(0).ptr, 0)')
new_node.value.args[0].value.args[0] = node.test

extra_args = ast.List(elts=[], ctx=ast.Load())
if is_str_mod:
msg, extra_args = self._handle_string_mod_args(node.msg)

new_node = self.parse_stmt('ti.ti_assert(0, 0, [])')
new_node.value.args[0] = node.test
new_node.value.args[1] = self.parse_expr("'{}'".format(msg.strip()))
new_node.value.args[2] = extra_args
new_node = ast.copy_location(new_node, node)
return new_node

def visit_Return(self, node):
Expand Down
10 changes: 10 additions & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,21 @@ class FrontendAssertStmt : public Stmt {
public:
std::string text;
Expr val;
std::vector<Expr> args;

FrontendAssertStmt(const std::string &text, const Expr &val)
: text(text), val(val) {
}

FrontendAssertStmt(const std::string &text,
const Expr &val,
const std::vector<Expr> &args_)
: text(text), val(val) {
for (auto &a : args_) {
args.push_back(load_if_ptr(a));
}
}

TI_DEFINE_ACCEPT
};

Expand Down
5 changes: 3 additions & 2 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,9 @@ void export_lang(py::module &m) {
return Length(snode, indices);
});

m.def("create_assert_stmt", [&](const Expr &cond, const std::string &msg) {
auto stmt_unique = std::make_unique<FrontendAssertStmt>(msg, cond);
m.def("create_assert_stmt", [&](const Expr &cond, const std::string &msg,
const std::vector<Expr> &args) {
auto stmt_unique = std::make_unique<FrontendAssertStmt>(msg, cond, args);
current_ast_builder().insert(std::move(stmt_unique));
});

Expand Down
9 changes: 8 additions & 1 deletion taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,14 @@ class LowerAST : public IRVisitor {
expr->flatten(&fctx);
val_stmt = expr->stmt;
}
fctx.push_back<AssertStmt>(stmt->text, val_stmt);

auto &fargs = stmt->args; // frontend stmt args
std::vector<Stmt *> args_stmts(fargs.size());
for (int i = 0; i < (int)fargs.size(); ++i) {
fargs[i]->flatten(&fctx);
args_stmts[i] = fargs[i]->stmt;
}
fctx.push_back<AssertStmt>(val_stmt, stmt->text, args_stmts);
stmt->parent->replace_with(stmt, std::move(fctx.stmts));
throw IRModified();
}
Expand Down
32 changes: 32 additions & 0 deletions tests/python/test_assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@ def test_assert_minimal():
def func():
assert 0

@ti.kernel
def func2():
assert False

with pytest.raises(RuntimeError):
func()
with pytest.raises(RuntimeError):
func2()
k-ye marked this conversation as resolved.
Show resolved Hide resolved


@ti.require(ti.extension.assertion)
Expand All @@ -39,6 +45,32 @@ def func():
func()


@ti.require(ti.extension.assertion)
@ti.all_archs_with(debug=True, gdb_trigger=False)
def test_assert_message_formatted():
x = ti.field(dtype=int, shape=16)
x[10] = 42

@ti.kernel
def assert_formatted():
for i in x:
assert x[i] == 0, 'x[%d] expect=%d got=%d' % (i, 0, x[i])

@ti.kernel
def assert_float():
y = 0.5
assert y < 0, 'y = %f' % y

with pytest.raises(RuntimeError, match=r'x\[10\] expect=0 got=42'):
assert_formatted()
with pytest.raises(RuntimeError, match=r'y = 0.5'):
assert_float()
yuanming-hu marked this conversation as resolved.
Show resolved Hide resolved

# success case
x[10] = 0
assert_formatted()


@ti.require(ti.extension.assertion)
@ti.all_archs_with(debug=True, gdb_trigger=False)
def test_assert_ok():
Expand Down