diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index 18c116a92a00b..d7c624a0c4037 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -72,12 +72,6 @@ Expr &Expr::operator=(const Expr &o) { } else if (expr->is_lvalue()) { current_ast_builder().insert( std::make_unique(*this, load_if_ptr(o))); - if (this->is()) { - expr->ret_type = current_ast_builder() - .get_last_stmt() - ->cast() - ->rhs->ret_type; - } } else { TI_ERROR("Cannot assign to non-lvalue: {}", serialize()); } diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 137c194e678a4..da47ab903f97e 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -5,6 +5,10 @@ TLANG_NAMESPACE_BEGIN +#define TI_ASSERT_TYPE_CHECKED(x) \ + TI_ASSERT_INFO(x->ret_type != PrimitiveType::unknown, \ + "[{}] was not type-checked", x.serialize()) + FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type, SNode *snode, const ExprGroup &indices, @@ -21,6 +25,9 @@ FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type, FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs) : lhs(lhs), rhs(rhs) { TI_ASSERT(lhs->is_lvalue()); + if (lhs.is() && lhs->ret_type == PrimitiveType::unknown) { + lhs.expr->ret_type = rhs->ret_type; + } } IRNode *FrontendContext::root() { @@ -128,10 +135,11 @@ void UnaryOpExpression::serialize(std::ostream &ss) { } void UnaryOpExpression::type_check() { - // TODO: assert no unknowns after type_check for all expressions are - // implemented - if (operand->ret_type == PrimitiveType::unknown) - return; + TI_ASSERT_TYPE_CHECKED(operand); + if (!operand->ret_type->is()) + throw std::runtime_error( + fmt::format("TypeError: unsupported operand type(s) for '{}': '{}'", + unary_op_type_name(type), operand->ret_type->to_string())); if ((type == UnaryOpType::floor || type == UnaryOpType::ceil || is_trigonometric(type)) && !is_real(operand->ret_type)) @@ -157,12 +165,10 @@ void UnaryOpExpression::flatten(FlattenContext *ctx) { } void BinaryOpExpression::type_check() { + TI_ASSERT_TYPE_CHECKED(lhs); + TI_ASSERT_TYPE_CHECKED(rhs); auto lhs_type = lhs->ret_type; auto rhs_type = rhs->ret_type; - // TODO: assert no unknowns after type_check for all expressions are - // implemented - if (lhs_type == PrimitiveType::unknown || rhs_type == PrimitiveType::unknown) - return; auto error = [&]() { throw std::runtime_error(fmt::format( "TypeError: unsupported operand type(s) for '{}': '{}' and '{}'", @@ -201,12 +207,12 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) { } void TernaryOpExpression::type_check() { + TI_ASSERT_TYPE_CHECKED(op1); + TI_ASSERT_TYPE_CHECKED(op2); + TI_ASSERT_TYPE_CHECKED(op3); auto op1_type = op1->ret_type; auto op2_type = op2->ret_type; auto op3_type = op3->ret_type; - if (op1_type == PrimitiveType::unknown || - op2_type == PrimitiveType::unknown || op3_type == PrimitiveType::unknown) - return; auto error = [&]() { throw std::runtime_error(fmt::format( "TypeError: unsupported operand type(s) for '{}': '{}', '{}' and '{}'", @@ -230,6 +236,15 @@ void TernaryOpExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +void InternalFuncCallExpression::type_check() { + for (auto &arg : args) { + TI_ASSERT_TYPE_CHECKED(arg); + // no arg type compatibility check for now due to lack of specification + } + // internal func calls have default return type + ret_type = PrimitiveType::i32; +} + void InternalFuncCallExpression::flatten(FlattenContext *ctx) { std::vector args_stmts(args.size()); for (int i = 0; i < (int)args.size(); ++i) { @@ -240,6 +255,18 @@ void InternalFuncCallExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +void ExternalFuncCallExpression::type_check() { + for (auto &arg : args) { + TI_ASSERT_TYPE_CHECKED(arg); + // no arg type compatibility check for now due to lack of specification + } + for (auto &output : outputs) { + TI_ASSERT_TYPE_CHECKED(output); + // no output type compatibility check for now due to lack of specification + } + // external func calls have no return type for now +} + void ExternalFuncCallExpression::flatten(FlattenContext *ctx) { TI_ASSERT((int)(so_func != nullptr) + (int)(!asm_source.empty()) + (int)(!bc_filename.empty()) == @@ -295,10 +322,7 @@ void GlobalPtrExpression::type_check() { } else if (var.is()) { for (int i = 0; i < indices.exprs.size(); i++) { auto &expr = indices.exprs[i]; - // TODO: assert no unknowns after type_check for all expressions are - // implemented - if (expr->ret_type == PrimitiveType::unknown) - return; + TI_ASSERT_TYPE_CHECKED(expr); if (!is_integral(expr->ret_type)) throw std::runtime_error( fmt::format("TypeError: indices must be integers, however '{}' is " @@ -441,11 +465,8 @@ void TensorElementExpression::flatten(FlattenContext *ctx) { } void RangeAssumptionExpression::type_check() { - // TODO: assert no unknowns after type_check for all expressions are - // implemented - if (input->ret_type == PrimitiveType::unknown || - base->ret_type == PrimitiveType::unknown) - return; + TI_ASSERT_TYPE_CHECKED(input); + TI_ASSERT_TYPE_CHECKED(base); if (!input->ret_type->is() || !base->ret_type->is() || input->ret_type != base->ret_type) throw std::runtime_error( @@ -464,10 +485,7 @@ void RangeAssumptionExpression::flatten(FlattenContext *ctx) { } void LoopUniqueExpression::type_check() { - // TODO: assert no unknowns after type_check for all expressions are - // implemented - if (input->ret_type == PrimitiveType::unknown) - return; + TI_ASSERT_TYPE_CHECKED(input); if (!input->ret_type->is()) throw std::runtime_error(fmt::format( "TypeError: unsupported operand type(s) for 'loop_unique': '{}'", @@ -516,11 +534,8 @@ void IdExpression::flatten(FlattenContext *ctx) { } void AtomicOpExpression::type_check() { - // TODO: assert no unknowns after type_check for all expressions are - // implemented - if (dest->ret_type == PrimitiveType::unknown || - val->ret_type == PrimitiveType::unknown) - return; + TI_ASSERT_TYPE_CHECKED(dest); + TI_ASSERT_TYPE_CHECKED(val); auto error = [&]() { throw std::runtime_error(fmt::format( "TypeError: unsupported operand type(s) for 'atomic_{}': '{}' and '{}'", @@ -685,6 +700,18 @@ void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +void FuncCallExpression::type_check() { + for (auto &arg : args.exprs) { + TI_ASSERT_TYPE_CHECKED(arg); + // no arg type compatibility check for now due to lack of specification + } + TI_ASSERT_INFO(func->rets.size() <= 1, + "Too many (> 1) return values for FuncCallExpression"); + if (func->rets.size() == 1) { + ret_type = func->rets[0].dt; + } +} + void FuncCallExpression::flatten(FlattenContext *ctx) { std::vector stmt_args; for (auto &arg : args.exprs) { diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 553e8072327f0..a4ee41f961a10 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -357,6 +357,8 @@ class InternalFuncCallExpression : public Expression { } } + void type_check() override; + void serialize(std::ostream &ss) override { ss << "internal call " << func_name << '('; std::string args_str; @@ -395,6 +397,8 @@ class ExternalFuncCallExpression : public Expression { outputs(outputs) { } + void type_check() override; + void serialize(std::ostream &ss) override { if (so_func != nullptr) { ss << fmt::format("so {:x} (", (uint64)so_func); @@ -750,6 +754,8 @@ class FuncCallExpression : public Expression { Function *func; ExprGroup args; + void type_check() override; + void serialize(std::ostream &ss) override; FuncCallExpression(Function *func, const ExprGroup &args) diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp index 36c1d19466137..4f73c8bb3be01 100644 --- a/tests/cpp/ir/frontend_type_inference_test.cpp +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -173,5 +173,12 @@ TEST(FrontendTypeInference, LoopUnique) { EXPECT_EQ(loop_unique->ret_type, PrimitiveType::i64); } +TEST(FrontendTypeInference, InternalFuncCall) { + auto internal_func_call = + Expr::make("do_nothing", std::vector{}); + internal_func_call->type_check(); + EXPECT_EQ(internal_func_call->ret_type, PrimitiveType::i32); +} + } // namespace lang } // namespace taichi