Skip to content

Commit

Permalink
[Lang] Better type error messages (#3345)
Browse files Browse the repository at this point in the history
* Move type_check out of Expression constructors; Add error messages

* Add type_check to Expr

* Better typecheck error messages

* Fix tests

* Auto Format

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
strongoier and taichi-gardener authored Nov 2, 2021
1 parent 4a4b3fc commit ecc345e
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 17 deletions.
17 changes: 17 additions & 0 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import sys
import traceback

import numpy as np
from taichi.core.util import ti_core as _ti_core
from taichi.lang import impl
Expand Down Expand Up @@ -35,6 +38,20 @@ def __init__(self, *args, tb=None):
assert False
if self.tb:
self.ptr.set_tb(self.tb)
try:
self.ptr.type_check()
except RuntimeError as e:
if str(e).startswith('TypeError: '):
s = traceback.extract_stack()
for i, l in enumerate(s):
if 'taichi_ast_generator' in l:
s = s[i + 1:]
break
print('[Taichi] Compilation failed', file=sys.stderr)
print(traceback.format_list(s[:1])[0], end='', file=sys.stderr)
print(f'TaichiTypeError: {str(e)[11:]}', file=sys.stderr)
sys.exit(1)
raise e

def __hash__(self):
return self.ptr.get_raw_address()
Expand Down
4 changes: 4 additions & 0 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ DataType Expr::get_ret_type() const {
return expr->ret_type;
}

void Expr::type_check() {
expr->type_check();
}

Expr select(const Expr &cond, const Expr &true_val, const Expr &false_val) {
return Expr::make<TernaryOpExpression>(TernaryOpType::select, cond, true_val,
false_val);
Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class Expr {
std::string get_attribute(const std::string &key) const;

DataType get_ret_type() const;

void type_check();
};

Expr select(const Expr &cond, const Expr &true_val, const Expr &false_val);
Expand Down
5 changes: 5 additions & 0 deletions taichi/ir/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ class Expression {
stmt = nullptr;
}

virtual void type_check() {
// TODO: make it pure virtual after type_check for all expressions are
// implemented
}

virtual void serialize(std::ostream &ss) = 0;

virtual void flatten(FlattenContext *ctx) {
Expand Down
45 changes: 33 additions & 12 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,24 @@ FrontendForStmt::FrontendForStmt(const Expr &loop_var,
loop_var_id[0] = loop_var.cast<IdExpression>()->id;
}

void ArgLoadExpression::type_check() {
TI_ASSERT_INFO(dt->is<PrimitiveType>() && dt != PrimitiveType::unknown,
"Invalid dt [{}] for ArgLoadExpression", dt->to_string());
ret_type = dt;
}

void ArgLoadExpression::flatten(FlattenContext *ctx) {
auto arg_load = std::make_unique<ArgLoadStmt>(arg_id, dt);
ctx->push_back(std::move(arg_load));
stmt = ctx->back_stmt();
}

void RandExpression::type_check() {
TI_ASSERT_INFO(dt->is<PrimitiveType>() && dt != PrimitiveType::unknown,
"Invalid dt [{}] for RandExpression", dt->to_string());
ret_type = dt;
}

void RandExpression::flatten(FlattenContext *ctx) {
auto ran = std::make_unique<RandStmt>(dt);
ctx->push_back(std::move(ran));
Expand Down Expand Up @@ -128,22 +140,24 @@ void UnaryOpExpression::flatten(FlattenContext *ctx) {
ctx->push_back(std::move(unary));
}

BinaryOpExpression::BinaryOpExpression(const BinaryOpType &type,
const Expr &lhs,
const Expr &rhs)
: type(type) {
this->lhs.set(load_if_ptr(lhs));
this->rhs.set(load_if_ptr(rhs));
auto lhs_type = this->lhs->ret_type;
auto rhs_type = this->rhs->ret_type;
// TODO: report error messages for unsuccessful inference
if (!lhs_type->is<PrimitiveType>() || !rhs_type->is<PrimitiveType>())
return;
void BinaryOpExpression::type_check() {
auto lhs_type = lhs->ret_type;
auto rhs_type = rhs->ret_type;
// TODO: assert no unknowns after type_check for all expressions are
// implemented
if (lhs_type == PrimitiveType::unknown || rhs_type == PrimitiveType::unknown)
return;
auto error = [&]() {
throw std::runtime_error(fmt::format(
"TypeError: unsupported operand type(s) for {}: '{}' and '{}'",
binary_op_type_symbol(type), lhs->ret_type->to_string(),
rhs->ret_type->to_string()));
};
if (!lhs_type->is<PrimitiveType>() || !rhs_type->is<PrimitiveType>())
error();
if (binary_is_bitwise(type) &&
(!is_integral(lhs_type) || !is_integral(rhs_type)))
return;
error();
if (is_comparison(type)) {
ret_type = PrimitiveType::i32;
return;
Expand Down Expand Up @@ -487,6 +501,13 @@ void GlobalLoadExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void ConstExpression::type_check() {
TI_ASSERT_INFO(
val.dt->is<PrimitiveType>() && val.dt != PrimitiveType::unknown,
"Invalid dt [{}] for ConstExpression", val.dt->to_string());
ret_type = val.dt;
}

void ConstExpression::flatten(FlattenContext *ctx) {
ctx->push_back(Stmt::make<ConstStmt>(val));
stmt = ctx->back_stmt();
Expand Down
19 changes: 14 additions & 5 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,10 @@ class ArgLoadExpression : public Expression {
DataType dt;

ArgLoadExpression(int arg_id, DataType dt) : arg_id(arg_id), dt(dt) {
ret_type = dt;
}

void type_check() override;

void serialize(std::ostream &ss) override {
ss << fmt::format("arg[{}] (dt={})", arg_id, data_type_name(dt));
}
Expand All @@ -255,9 +256,10 @@ class RandExpression : public Expression {
DataType dt;

RandExpression(DataType dt) : dt(dt) {
ret_type = dt;
}

void type_check() override;

void serialize(std::ostream &ss) override {
ss << fmt::format("rand<{}>()", data_type_name(dt));
}
Expand Down Expand Up @@ -288,9 +290,11 @@ class BinaryOpExpression : public Expression {
BinaryOpType type;
Expr lhs, rhs;

BinaryOpExpression(const BinaryOpType &type,
const Expr &lhs,
const Expr &rhs);
BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs)
: type(type), lhs(load_if_ptr(lhs)), rhs(load_if_ptr(rhs)) {
}

void type_check() override;

void serialize(std::ostream &ss) override {
ss << '(';
Expand Down Expand Up @@ -590,6 +594,9 @@ class IdExpression : public Expression {
IdExpression(const Identifier &id) : id(id) {
}

void type_check() override {
}

void serialize(std::ostream &ss) override {
ss << id.name();
}
Expand Down Expand Up @@ -680,6 +687,8 @@ class ConstExpression : public Expression {
ret_type = val.dt;
}

void type_check() override;

void serialize(std::ostream &ss) override {
ss << val.stringify();
}
Expand Down
1 change: 1 addition & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ void export_lang(py::module &m) {
.def("set_grad", &Expr::set_grad)
.def("set_attribute", &Expr::set_attribute)
.def("get_ret_type", &Expr::get_ret_type)
.def("type_check", &Expr::type_check)
.def("get_expr_name",
[](Expr *expr) {
return expr->cast<GlobalVariableExpression>()->name;
Expand Down
7 changes: 7 additions & 0 deletions tests/cpp/ir/frontend_type_inference_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@ namespace lang {

TEST(FrontendTypeInference, Const) {
auto const_i64 = Expr::make<ConstExpression, int64>(1LL << 63);
const_i64->type_check();
EXPECT_EQ(const_i64->ret_type, PrimitiveType::i64);
}

TEST(FrontendTypeInference, ArgLoad) {
auto arg_load_u64 = Expr::make<ArgLoadExpression>(2, PrimitiveType::u64);
arg_load_u64->type_check();
EXPECT_EQ(arg_load_u64->ret_type, PrimitiveType::u64);
}

TEST(FrontendTypeInference, Rand) {
auto rand_f16 = Expr::make<RandExpression>(PrimitiveType::f16);
rand_f16->type_check();
EXPECT_EQ(rand_f16->ret_type, PrimitiveType::f16);
}

Expand All @@ -27,6 +30,7 @@ TEST(FrontendTypeInference, Id) {
auto kernel = std::make_unique<Kernel>(*prog, func, "fake_kernel");
Callable::CurrentCallableGuard _(kernel->program, kernel.get());
auto const_i32 = Expr::make<ConstExpression, int32>(-(1 << 20));
const_i32->type_check();
auto id_i32 = Var(const_i32);
EXPECT_EQ(id_i32->ret_type, PrimitiveType::i32);
}
Expand All @@ -35,8 +39,11 @@ TEST(FrontendTypeInference, BinaryOp) {
auto prog = std::make_unique<Program>(Arch::x64);
prog->config.default_fp = PrimitiveType::f64;
auto const_i32 = Expr::make<ConstExpression, int32>(-(1 << 20));
const_i32->type_check();
auto const_f32 = Expr::make<ConstExpression, float32>(5.0);
const_f32->type_check();
auto truediv_f64 = expr_truediv(const_i32, const_f32);
truediv_f64->type_check();
EXPECT_EQ(truediv_f64->ret_type, PrimitiveType::f64);
}

Expand Down
15 changes: 15 additions & 0 deletions tests/python/test_type_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest

import taichi as ti


@ti.test(arch=ti.cpu)
def test_binary_op():
@ti.kernel
def bitwise_float():
a = 1
b = 3.1
c = a & b

with pytest.raises(SystemExit):
bitwise_float()

0 comments on commit ecc345e

Please sign in to comment.