diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 85a3bfd9eeeb8..2a2c29a44e251 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -1,3 +1,6 @@ +import sys +import traceback + import numpy as np from taichi.core.util import ti_core as _ti_core from taichi.lang import impl @@ -35,6 +38,20 @@ def __init__(self, *args, tb=None): assert False if self.tb: self.ptr.set_tb(self.tb) + try: + self.ptr.type_check() + except RuntimeError as e: + if str(e).startswith('TypeError: '): + s = traceback.extract_stack() + for i, l in enumerate(s): + if 'taichi_ast_generator' in l: + s = s[i + 1:] + break + print('[Taichi] Compilation failed', file=sys.stderr) + print(traceback.format_list(s[:1])[0], end='', file=sys.stderr) + print(f'TaichiTypeError: {str(e)[11:]}', file=sys.stderr) + sys.exit(1) + raise e def __hash__(self): return self.ptr.get_raw_address() diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index 72d758d3ba428..ba06aadee11a6 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -33,6 +33,10 @@ DataType Expr::get_ret_type() const { return expr->ret_type; } +void Expr::type_check() { + expr->type_check(); +} + Expr select(const Expr &cond, const Expr &true_val, const Expr &false_val) { return Expr::make(TernaryOpType::select, cond, true_val, false_val); diff --git a/taichi/ir/expr.h b/taichi/ir/expr.h index 2afe5a540eaaf..b0ad4635ebf9b 100644 --- a/taichi/ir/expr.h +++ b/taichi/ir/expr.h @@ -109,6 +109,8 @@ class Expr { std::string get_attribute(const std::string &key) const; DataType get_ret_type() const; + + void type_check(); }; Expr select(const Expr &cond, const Expr &true_val, const Expr &false_val); diff --git a/taichi/ir/expression.h b/taichi/ir/expression.h index d8e900bc75690..ba78fbe9fa11a 100644 --- a/taichi/ir/expression.h +++ b/taichi/ir/expression.h @@ -38,6 +38,11 @@ class Expression { stmt = nullptr; } + virtual void type_check() { + // TODO: make it pure virtual after type_check for all expressions are + // implemented + } + virtual void serialize(std::ostream &ss) = 0; virtual void flatten(FlattenContext *ctx) { diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index d9c645da7648d..fd58ea49f3e01 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -88,12 +88,24 @@ FrontendForStmt::FrontendForStmt(const Expr &loop_var, loop_var_id[0] = loop_var.cast()->id; } +void ArgLoadExpression::type_check() { + TI_ASSERT_INFO(dt->is() && dt != PrimitiveType::unknown, + "Invalid dt [{}] for ArgLoadExpression", dt->to_string()); + ret_type = dt; +} + void ArgLoadExpression::flatten(FlattenContext *ctx) { auto arg_load = std::make_unique(arg_id, dt); ctx->push_back(std::move(arg_load)); stmt = ctx->back_stmt(); } +void RandExpression::type_check() { + TI_ASSERT_INFO(dt->is() && dt != PrimitiveType::unknown, + "Invalid dt [{}] for RandExpression", dt->to_string()); + ret_type = dt; +} + void RandExpression::flatten(FlattenContext *ctx) { auto ran = std::make_unique(dt); ctx->push_back(std::move(ran)); @@ -128,22 +140,24 @@ void UnaryOpExpression::flatten(FlattenContext *ctx) { ctx->push_back(std::move(unary)); } -BinaryOpExpression::BinaryOpExpression(const BinaryOpType &type, - const Expr &lhs, - const Expr &rhs) - : type(type) { - this->lhs.set(load_if_ptr(lhs)); - this->rhs.set(load_if_ptr(rhs)); - auto lhs_type = this->lhs->ret_type; - auto rhs_type = this->rhs->ret_type; - // TODO: report error messages for unsuccessful inference - if (!lhs_type->is() || !rhs_type->is()) - return; +void BinaryOpExpression::type_check() { + auto lhs_type = lhs->ret_type; + auto rhs_type = rhs->ret_type; + // TODO: assert no unknowns after type_check for all expressions are + // implemented if (lhs_type == PrimitiveType::unknown || rhs_type == PrimitiveType::unknown) return; + auto error = [&]() { + throw std::runtime_error(fmt::format( + "TypeError: unsupported operand type(s) for {}: '{}' and '{}'", + binary_op_type_symbol(type), lhs->ret_type->to_string(), + rhs->ret_type->to_string())); + }; + if (!lhs_type->is() || !rhs_type->is()) + error(); if (binary_is_bitwise(type) && (!is_integral(lhs_type) || !is_integral(rhs_type))) - return; + error(); if (is_comparison(type)) { ret_type = PrimitiveType::i32; return; @@ -487,6 +501,13 @@ void GlobalLoadExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +void ConstExpression::type_check() { + TI_ASSERT_INFO( + val.dt->is() && val.dt != PrimitiveType::unknown, + "Invalid dt [{}] for ConstExpression", val.dt->to_string()); + ret_type = val.dt; +} + void ConstExpression::flatten(FlattenContext *ctx) { ctx->push_back(Stmt::make(val)); stmt = ctx->back_stmt(); diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 79575883df19e..382cb257f1acb 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -240,9 +240,10 @@ class ArgLoadExpression : public Expression { DataType dt; ArgLoadExpression(int arg_id, DataType dt) : arg_id(arg_id), dt(dt) { - ret_type = dt; } + void type_check() override; + void serialize(std::ostream &ss) override { ss << fmt::format("arg[{}] (dt={})", arg_id, data_type_name(dt)); } @@ -255,9 +256,10 @@ class RandExpression : public Expression { DataType dt; RandExpression(DataType dt) : dt(dt) { - ret_type = dt; } + void type_check() override; + void serialize(std::ostream &ss) override { ss << fmt::format("rand<{}>()", data_type_name(dt)); } @@ -288,9 +290,11 @@ class BinaryOpExpression : public Expression { BinaryOpType type; Expr lhs, rhs; - BinaryOpExpression(const BinaryOpType &type, - const Expr &lhs, - const Expr &rhs); + BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs) + : type(type), lhs(load_if_ptr(lhs)), rhs(load_if_ptr(rhs)) { + } + + void type_check() override; void serialize(std::ostream &ss) override { ss << '('; @@ -590,6 +594,9 @@ class IdExpression : public Expression { IdExpression(const Identifier &id) : id(id) { } + void type_check() override { + } + void serialize(std::ostream &ss) override { ss << id.name(); } @@ -680,6 +687,8 @@ class ConstExpression : public Expression { ret_type = val.dt; } + void type_check() override; + void serialize(std::ostream &ss) override { ss << val.stringify(); } diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 45a37f802348e..29992c48f0293 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -468,6 +468,7 @@ void export_lang(py::module &m) { .def("set_grad", &Expr::set_grad) .def("set_attribute", &Expr::set_attribute) .def("get_ret_type", &Expr::get_ret_type) + .def("type_check", &Expr::type_check) .def("get_expr_name", [](Expr *expr) { return expr->cast()->name; diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp index a50baffa1b17f..971e7d57e8e7c 100644 --- a/tests/cpp/ir/frontend_type_inference_test.cpp +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -8,16 +8,19 @@ namespace lang { TEST(FrontendTypeInference, Const) { auto const_i64 = Expr::make(1LL << 63); + const_i64->type_check(); EXPECT_EQ(const_i64->ret_type, PrimitiveType::i64); } TEST(FrontendTypeInference, ArgLoad) { auto arg_load_u64 = Expr::make(2, PrimitiveType::u64); + arg_load_u64->type_check(); EXPECT_EQ(arg_load_u64->ret_type, PrimitiveType::u64); } TEST(FrontendTypeInference, Rand) { auto rand_f16 = Expr::make(PrimitiveType::f16); + rand_f16->type_check(); EXPECT_EQ(rand_f16->ret_type, PrimitiveType::f16); } @@ -27,6 +30,7 @@ TEST(FrontendTypeInference, Id) { auto kernel = std::make_unique(*prog, func, "fake_kernel"); Callable::CurrentCallableGuard _(kernel->program, kernel.get()); auto const_i32 = Expr::make(-(1 << 20)); + const_i32->type_check(); auto id_i32 = Var(const_i32); EXPECT_EQ(id_i32->ret_type, PrimitiveType::i32); } @@ -35,8 +39,11 @@ TEST(FrontendTypeInference, BinaryOp) { auto prog = std::make_unique(Arch::x64); prog->config.default_fp = PrimitiveType::f64; auto const_i32 = Expr::make(-(1 << 20)); + const_i32->type_check(); auto const_f32 = Expr::make(5.0); + const_f32->type_check(); auto truediv_f64 = expr_truediv(const_i32, const_f32); + truediv_f64->type_check(); EXPECT_EQ(truediv_f64->ret_type, PrimitiveType::f64); } diff --git a/tests/python/test_type_check.py b/tests/python/test_type_check.py new file mode 100644 index 0000000000000..ecada35b8429f --- /dev/null +++ b/tests/python/test_type_check.py @@ -0,0 +1,15 @@ +import pytest + +import taichi as ti + + +@ti.test(arch=ti.cpu) +def test_binary_op(): + @ti.kernel + def bitwise_float(): + a = 1 + b = 3.1 + c = a & b + + with pytest.raises(SystemExit): + bitwise_float()