From 8fa6d15217c3115abcec43afb53fd29d7ee7cb01 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Thu, 11 Nov 2021 15:02:32 +0800 Subject: [PATCH 1/5] Add type check for func calls --- taichi/ir/frontend_ir.cpp | 47 +++++++++++++++++++ taichi/ir/frontend_ir.h | 6 +++ tests/cpp/ir/frontend_type_inference_test.cpp | 6 +++ 3 files changed, 59 insertions(+) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 137c194e678a4..7f3c86e624089 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -230,6 +230,19 @@ void TernaryOpExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +void InternalFuncCallExpression::type_check() { + for (auto &arg : args) { + // TODO: assert no unknowns after type_check for all expressions are + // implemented + if (arg->ret_type == PrimitiveType::unknown) + return; + // There are no specifications for internal func calls for now, + // so arg types are not checked. + } + // Internal func calls will have default i32 return type. + ret_type = PrimitiveType::i32; +} + void InternalFuncCallExpression::flatten(FlattenContext *ctx) { std::vector args_stmts(args.size()); for (int i = 0; i < (int)args.size(); ++i) { @@ -240,6 +253,26 @@ void InternalFuncCallExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +void ExternalFuncCallExpression::type_check() { + for (auto &arg : args) { + // TODO: assert no unknowns after type_check for all expressions are + // implemented + if (arg->ret_type == PrimitiveType::unknown) + return; + // There are no specifications for external func calls for now, + // so arg types are not checked. + } + for (auto &output : outputs) { + // TODO: assert no unknowns after type_check for all expressions are + // implemented + if (output->ret_type == PrimitiveType::unknown) + return; + // There are no specifications for external func calls for now, + // so output types are not checked. + } + // External func calls have no return type. +} + void ExternalFuncCallExpression::flatten(FlattenContext *ctx) { TI_ASSERT((int)(so_func != nullptr) + (int)(!asm_source.empty()) + (int)(!bc_filename.empty()) == @@ -685,6 +718,20 @@ void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +void FuncCallExpression::type_check() { + for (auto &arg : args.exprs) { + // TODO: assert no unknowns after type_check for all expressions are + // implemented + if (arg->ret_type == PrimitiveType::unknown) + return; + // There are no specifications for external func calls for now, + // so arg types are not checked. + } + if (func->rets.size() == 1) { + ret_type = func->rets[0].dt; + } +} + void FuncCallExpression::flatten(FlattenContext *ctx) { std::vector stmt_args; for (auto &arg : args.exprs) { diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 553e8072327f0..a4ee41f961a10 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -357,6 +357,8 @@ class InternalFuncCallExpression : public Expression { } } + void type_check() override; + void serialize(std::ostream &ss) override { ss << "internal call " << func_name << '('; std::string args_str; @@ -395,6 +397,8 @@ class ExternalFuncCallExpression : public Expression { outputs(outputs) { } + void type_check() override; + void serialize(std::ostream &ss) override { if (so_func != nullptr) { ss << fmt::format("so {:x} (", (uint64)so_func); @@ -750,6 +754,8 @@ class FuncCallExpression : public Expression { Function *func; ExprGroup args; + void type_check() override; + void serialize(std::ostream &ss) override; FuncCallExpression(Function *func, const ExprGroup &args) diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp index 36c1d19466137..f2a0d4e481aeb 100644 --- a/tests/cpp/ir/frontend_type_inference_test.cpp +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -173,5 +173,11 @@ TEST(FrontendTypeInference, LoopUnique) { EXPECT_EQ(loop_unique->ret_type, PrimitiveType::i64); } +TEST(FrontendTypeInference, InternalFuncCall) { + auto internal_func_call = Expr::make("do_nothing", std::vector{}); + internal_func_call->type_check(); + EXPECT_EQ(internal_func_call->ret_type, PrimitiveType::i32); +} + } // namespace lang } // namespace taichi From 154485aaacbbc9d9106fc2752cf596ae26dc3b47 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Thu, 11 Nov 2021 16:13:14 +0800 Subject: [PATCH 2/5] Enforce type check for all expressions --- taichi/ir/expr.cpp | 6 --- taichi/ir/frontend_ir.cpp | 87 ++++++++++++++------------------------- 2 files changed, 31 insertions(+), 62 deletions(-) diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index 18c116a92a00b..d7c624a0c4037 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -72,12 +72,6 @@ Expr &Expr::operator=(const Expr &o) { } else if (expr->is_lvalue()) { current_ast_builder().insert( std::make_unique(*this, load_if_ptr(o))); - if (this->is()) { - expr->ret_type = current_ast_builder() - .get_last_stmt() - ->cast() - ->rhs->ret_type; - } } else { TI_ERROR("Cannot assign to non-lvalue: {}", serialize()); } diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 7f3c86e624089..8a8710d6e2b22 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -21,6 +21,9 @@ FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type, FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs) : lhs(lhs), rhs(rhs) { TI_ASSERT(lhs->is_lvalue()); + if (lhs.is() && lhs->ret_type == PrimitiveType::unknown) { + lhs.expr->ret_type = rhs->ret_type; + } } IRNode *FrontendContext::root() { @@ -128,10 +131,11 @@ void UnaryOpExpression::serialize(std::ostream &ss) { } void UnaryOpExpression::type_check() { - // TODO: assert no unknowns after type_check for all expressions are - // implemented - if (operand->ret_type == PrimitiveType::unknown) - return; + TI_ASSERT_INFO(operand->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", operand.serialize()); + if (!operand->ret_type->is()) + throw std::runtime_error(fmt::format( + "TypeError: unsupported operand type(s) for '{}': '{}'", + unary_op_type_name(type), operand->ret_type->to_string())); if ((type == UnaryOpType::floor || type == UnaryOpType::ceil || is_trigonometric(type)) && !is_real(operand->ret_type)) @@ -159,10 +163,8 @@ void UnaryOpExpression::flatten(FlattenContext *ctx) { void BinaryOpExpression::type_check() { auto lhs_type = lhs->ret_type; auto rhs_type = rhs->ret_type; - // TODO: assert no unknowns after type_check for all expressions are - // implemented - if (lhs_type == PrimitiveType::unknown || rhs_type == PrimitiveType::unknown) - return; + TI_ASSERT_INFO(lhs_type != PrimitiveType::unknown, "[{}] was not type-checked", lhs.serialize()); + TI_ASSERT_INFO(rhs_type != PrimitiveType::unknown, "[{}] was not type-checked", rhs.serialize()); auto error = [&]() { throw std::runtime_error(fmt::format( "TypeError: unsupported operand type(s) for '{}': '{}' and '{}'", @@ -204,9 +206,9 @@ void TernaryOpExpression::type_check() { auto op1_type = op1->ret_type; auto op2_type = op2->ret_type; auto op3_type = op3->ret_type; - if (op1_type == PrimitiveType::unknown || - op2_type == PrimitiveType::unknown || op3_type == PrimitiveType::unknown) - return; + TI_ASSERT_INFO(op1_type != PrimitiveType::unknown, "[{}] was not type-checked", op1.serialize()); + TI_ASSERT_INFO(op2_type != PrimitiveType::unknown, "[{}] was not type-checked", op2.serialize()); + TI_ASSERT_INFO(op3_type != PrimitiveType::unknown, "[{}] was not type-checked", op3.serialize()); auto error = [&]() { throw std::runtime_error(fmt::format( "TypeError: unsupported operand type(s) for '{}': '{}', '{}' and '{}'", @@ -232,14 +234,10 @@ void TernaryOpExpression::flatten(FlattenContext *ctx) { void InternalFuncCallExpression::type_check() { for (auto &arg : args) { - // TODO: assert no unknowns after type_check for all expressions are - // implemented - if (arg->ret_type == PrimitiveType::unknown) - return; - // There are no specifications for internal func calls for now, - // so arg types are not checked. - } - // Internal func calls will have default i32 return type. + TI_ASSERT_INFO(arg->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", arg.serialize()); + // no arg type compatibility check for now due to lack of specification + } + // internal func calls have default return type ret_type = PrimitiveType::i32; } @@ -255,22 +253,14 @@ void InternalFuncCallExpression::flatten(FlattenContext *ctx) { void ExternalFuncCallExpression::type_check() { for (auto &arg : args) { - // TODO: assert no unknowns after type_check for all expressions are - // implemented - if (arg->ret_type == PrimitiveType::unknown) - return; - // There are no specifications for external func calls for now, - // so arg types are not checked. + TI_ASSERT_INFO(arg->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", arg.serialize()); + // no arg type compatibility check for now due to lack of specification } for (auto &output : outputs) { - // TODO: assert no unknowns after type_check for all expressions are - // implemented - if (output->ret_type == PrimitiveType::unknown) - return; - // There are no specifications for external func calls for now, - // so output types are not checked. + TI_ASSERT_INFO(output->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", output.serialize()); + // no output type compatibility check for now due to lack of specification } - // External func calls have no return type. + // external func calls have no return type for now } void ExternalFuncCallExpression::flatten(FlattenContext *ctx) { @@ -328,10 +318,7 @@ void GlobalPtrExpression::type_check() { } else if (var.is()) { for (int i = 0; i < indices.exprs.size(); i++) { auto &expr = indices.exprs[i]; - // TODO: assert no unknowns after type_check for all expressions are - // implemented - if (expr->ret_type == PrimitiveType::unknown) - return; + TI_ASSERT_INFO(expr->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", expr.serialize()); if (!is_integral(expr->ret_type)) throw std::runtime_error( fmt::format("TypeError: indices must be integers, however '{}' is " @@ -474,11 +461,8 @@ void TensorElementExpression::flatten(FlattenContext *ctx) { } void RangeAssumptionExpression::type_check() { - // TODO: assert no unknowns after type_check for all expressions are - // implemented - if (input->ret_type == PrimitiveType::unknown || - base->ret_type == PrimitiveType::unknown) - return; + TI_ASSERT_INFO(input->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", input.serialize()); + TI_ASSERT_INFO(base->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", base.serialize()); if (!input->ret_type->is() || !base->ret_type->is() || input->ret_type != base->ret_type) throw std::runtime_error( @@ -497,10 +481,7 @@ void RangeAssumptionExpression::flatten(FlattenContext *ctx) { } void LoopUniqueExpression::type_check() { - // TODO: assert no unknowns after type_check for all expressions are - // implemented - if (input->ret_type == PrimitiveType::unknown) - return; + TI_ASSERT_INFO(input->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", input.serialize()); if (!input->ret_type->is()) throw std::runtime_error(fmt::format( "TypeError: unsupported operand type(s) for 'loop_unique': '{}'", @@ -549,11 +530,8 @@ void IdExpression::flatten(FlattenContext *ctx) { } void AtomicOpExpression::type_check() { - // TODO: assert no unknowns after type_check for all expressions are - // implemented - if (dest->ret_type == PrimitiveType::unknown || - val->ret_type == PrimitiveType::unknown) - return; + TI_ASSERT_INFO(dest->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", dest.serialize()); + TI_ASSERT_INFO(val->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", val.serialize()); auto error = [&]() { throw std::runtime_error(fmt::format( "TypeError: unsupported operand type(s) for 'atomic_{}': '{}' and '{}'", @@ -720,13 +698,10 @@ void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) { void FuncCallExpression::type_check() { for (auto &arg : args.exprs) { - // TODO: assert no unknowns after type_check for all expressions are - // implemented - if (arg->ret_type == PrimitiveType::unknown) - return; - // There are no specifications for external func calls for now, - // so arg types are not checked. + TI_ASSERT_INFO(arg->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", arg.serialize()); + // no arg type compatibility check for now due to lack of specification } + TI_ASSERT_INFO(func->rets.size() <= 1, "Too many (> 1) return values for FuncCallExpression"); if (func->rets.size() == 1) { ret_type = func->rets[0].dt; } From a0c9addf6163e828a83392b57c62ef020846aca9 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Thu, 11 Nov 2021 08:18:45 +0000 Subject: [PATCH 3/5] Auto Format --- taichi/ir/frontend_ir.cpp | 59 ++++++++++++------- tests/cpp/ir/frontend_type_inference_test.cpp | 3 +- 2 files changed, 40 insertions(+), 22 deletions(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 8a8710d6e2b22..397a5f2d30265 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -22,7 +22,7 @@ FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs) : lhs(lhs), rhs(rhs) { TI_ASSERT(lhs->is_lvalue()); if (lhs.is() && lhs->ret_type == PrimitiveType::unknown) { - lhs.expr->ret_type = rhs->ret_type; + lhs.expr->ret_type = rhs->ret_type; } } @@ -131,11 +131,12 @@ void UnaryOpExpression::serialize(std::ostream &ss) { } void UnaryOpExpression::type_check() { - TI_ASSERT_INFO(operand->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", operand.serialize()); + TI_ASSERT_INFO(operand->ret_type != PrimitiveType::unknown, + "[{}] was not type-checked", operand.serialize()); if (!operand->ret_type->is()) - throw std::runtime_error(fmt::format( - "TypeError: unsupported operand type(s) for '{}': '{}'", - unary_op_type_name(type), operand->ret_type->to_string())); + throw std::runtime_error( + fmt::format("TypeError: unsupported operand type(s) for '{}': '{}'", + unary_op_type_name(type), operand->ret_type->to_string())); if ((type == UnaryOpType::floor || type == UnaryOpType::ceil || is_trigonometric(type)) && !is_real(operand->ret_type)) @@ -163,8 +164,10 @@ void UnaryOpExpression::flatten(FlattenContext *ctx) { void BinaryOpExpression::type_check() { auto lhs_type = lhs->ret_type; auto rhs_type = rhs->ret_type; - TI_ASSERT_INFO(lhs_type != PrimitiveType::unknown, "[{}] was not type-checked", lhs.serialize()); - TI_ASSERT_INFO(rhs_type != PrimitiveType::unknown, "[{}] was not type-checked", rhs.serialize()); + TI_ASSERT_INFO(lhs_type != PrimitiveType::unknown, + "[{}] was not type-checked", lhs.serialize()); + TI_ASSERT_INFO(rhs_type != PrimitiveType::unknown, + "[{}] was not type-checked", rhs.serialize()); auto error = [&]() { throw std::runtime_error(fmt::format( "TypeError: unsupported operand type(s) for '{}': '{}' and '{}'", @@ -206,9 +209,12 @@ void TernaryOpExpression::type_check() { auto op1_type = op1->ret_type; auto op2_type = op2->ret_type; auto op3_type = op3->ret_type; - TI_ASSERT_INFO(op1_type != PrimitiveType::unknown, "[{}] was not type-checked", op1.serialize()); - TI_ASSERT_INFO(op2_type != PrimitiveType::unknown, "[{}] was not type-checked", op2.serialize()); - TI_ASSERT_INFO(op3_type != PrimitiveType::unknown, "[{}] was not type-checked", op3.serialize()); + TI_ASSERT_INFO(op1_type != PrimitiveType::unknown, + "[{}] was not type-checked", op1.serialize()); + TI_ASSERT_INFO(op2_type != PrimitiveType::unknown, + "[{}] was not type-checked", op2.serialize()); + TI_ASSERT_INFO(op3_type != PrimitiveType::unknown, + "[{}] was not type-checked", op3.serialize()); auto error = [&]() { throw std::runtime_error(fmt::format( "TypeError: unsupported operand type(s) for '{}': '{}', '{}' and '{}'", @@ -234,7 +240,8 @@ void TernaryOpExpression::flatten(FlattenContext *ctx) { void InternalFuncCallExpression::type_check() { for (auto &arg : args) { - TI_ASSERT_INFO(arg->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", arg.serialize()); + TI_ASSERT_INFO(arg->ret_type != PrimitiveType::unknown, + "[{}] was not type-checked", arg.serialize()); // no arg type compatibility check for now due to lack of specification } // internal func calls have default return type @@ -253,11 +260,13 @@ void InternalFuncCallExpression::flatten(FlattenContext *ctx) { void ExternalFuncCallExpression::type_check() { for (auto &arg : args) { - TI_ASSERT_INFO(arg->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", arg.serialize()); + TI_ASSERT_INFO(arg->ret_type != PrimitiveType::unknown, + "[{}] was not type-checked", arg.serialize()); // no arg type compatibility check for now due to lack of specification } for (auto &output : outputs) { - TI_ASSERT_INFO(output->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", output.serialize()); + TI_ASSERT_INFO(output->ret_type != PrimitiveType::unknown, + "[{}] was not type-checked", output.serialize()); // no output type compatibility check for now due to lack of specification } // external func calls have no return type for now @@ -318,7 +327,8 @@ void GlobalPtrExpression::type_check() { } else if (var.is()) { for (int i = 0; i < indices.exprs.size(); i++) { auto &expr = indices.exprs[i]; - TI_ASSERT_INFO(expr->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", expr.serialize()); + TI_ASSERT_INFO(expr->ret_type != PrimitiveType::unknown, + "[{}] was not type-checked", expr.serialize()); if (!is_integral(expr->ret_type)) throw std::runtime_error( fmt::format("TypeError: indices must be integers, however '{}' is " @@ -461,8 +471,10 @@ void TensorElementExpression::flatten(FlattenContext *ctx) { } void RangeAssumptionExpression::type_check() { - TI_ASSERT_INFO(input->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", input.serialize()); - TI_ASSERT_INFO(base->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", base.serialize()); + TI_ASSERT_INFO(input->ret_type != PrimitiveType::unknown, + "[{}] was not type-checked", input.serialize()); + TI_ASSERT_INFO(base->ret_type != PrimitiveType::unknown, + "[{}] was not type-checked", base.serialize()); if (!input->ret_type->is() || !base->ret_type->is() || input->ret_type != base->ret_type) throw std::runtime_error( @@ -481,7 +493,8 @@ void RangeAssumptionExpression::flatten(FlattenContext *ctx) { } void LoopUniqueExpression::type_check() { - TI_ASSERT_INFO(input->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", input.serialize()); + TI_ASSERT_INFO(input->ret_type != PrimitiveType::unknown, + "[{}] was not type-checked", input.serialize()); if (!input->ret_type->is()) throw std::runtime_error(fmt::format( "TypeError: unsupported operand type(s) for 'loop_unique': '{}'", @@ -530,8 +543,10 @@ void IdExpression::flatten(FlattenContext *ctx) { } void AtomicOpExpression::type_check() { - TI_ASSERT_INFO(dest->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", dest.serialize()); - TI_ASSERT_INFO(val->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", val.serialize()); + TI_ASSERT_INFO(dest->ret_type != PrimitiveType::unknown, + "[{}] was not type-checked", dest.serialize()); + TI_ASSERT_INFO(val->ret_type != PrimitiveType::unknown, + "[{}] was not type-checked", val.serialize()); auto error = [&]() { throw std::runtime_error(fmt::format( "TypeError: unsupported operand type(s) for 'atomic_{}': '{}' and '{}'", @@ -698,10 +713,12 @@ void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) { void FuncCallExpression::type_check() { for (auto &arg : args.exprs) { - TI_ASSERT_INFO(arg->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", arg.serialize()); + TI_ASSERT_INFO(arg->ret_type != PrimitiveType::unknown, + "[{}] was not type-checked", arg.serialize()); // no arg type compatibility check for now due to lack of specification } - TI_ASSERT_INFO(func->rets.size() <= 1, "Too many (> 1) return values for FuncCallExpression"); + TI_ASSERT_INFO(func->rets.size() <= 1, + "Too many (> 1) return values for FuncCallExpression"); if (func->rets.size() == 1) { ret_type = func->rets[0].dt; } diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp index f2a0d4e481aeb..4f73c8bb3be01 100644 --- a/tests/cpp/ir/frontend_type_inference_test.cpp +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -174,7 +174,8 @@ TEST(FrontendTypeInference, LoopUnique) { } TEST(FrontendTypeInference, InternalFuncCall) { - auto internal_func_call = Expr::make("do_nothing", std::vector{}); + auto internal_func_call = + Expr::make("do_nothing", std::vector{}); internal_func_call->type_check(); EXPECT_EQ(internal_func_call->ret_type, PrimitiveType::i32); } From 02fd7edf956b4cf99cc08f36c05ddec82b251f27 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Sat, 13 Nov 2021 18:19:34 +0800 Subject: [PATCH 4/5] Add micro for type-check assertions --- taichi/ir/frontend_ir.cpp | 50 ++++++++++++++------------------------- 1 file changed, 18 insertions(+), 32 deletions(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 397a5f2d30265..31aead6678673 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -5,6 +5,8 @@ TLANG_NAMESPACE_BEGIN +#define TI_ASSERT_TYPE_CHECKED(x) TI_ASSERT_INFO(x->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", x.serialize()) + FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type, SNode *snode, const ExprGroup &indices, @@ -131,8 +133,7 @@ void UnaryOpExpression::serialize(std::ostream &ss) { } void UnaryOpExpression::type_check() { - TI_ASSERT_INFO(operand->ret_type != PrimitiveType::unknown, - "[{}] was not type-checked", operand.serialize()); + TI_ASSERT_TYPE_CHECKED(operand); if (!operand->ret_type->is()) throw std::runtime_error( fmt::format("TypeError: unsupported operand type(s) for '{}': '{}'", @@ -162,12 +163,10 @@ void UnaryOpExpression::flatten(FlattenContext *ctx) { } void BinaryOpExpression::type_check() { + TI_ASSERT_TYPE_CHECKED(lhs); + TI_ASSERT_TYPE_CHECKED(rhs); auto lhs_type = lhs->ret_type; auto rhs_type = rhs->ret_type; - TI_ASSERT_INFO(lhs_type != PrimitiveType::unknown, - "[{}] was not type-checked", lhs.serialize()); - TI_ASSERT_INFO(rhs_type != PrimitiveType::unknown, - "[{}] was not type-checked", rhs.serialize()); auto error = [&]() { throw std::runtime_error(fmt::format( "TypeError: unsupported operand type(s) for '{}': '{}' and '{}'", @@ -206,15 +205,12 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) { } void TernaryOpExpression::type_check() { + TI_ASSERT_TYPE_CHECKED(op1); + TI_ASSERT_TYPE_CHECKED(op2); + TI_ASSERT_TYPE_CHECKED(op3); auto op1_type = op1->ret_type; auto op2_type = op2->ret_type; auto op3_type = op3->ret_type; - TI_ASSERT_INFO(op1_type != PrimitiveType::unknown, - "[{}] was not type-checked", op1.serialize()); - TI_ASSERT_INFO(op2_type != PrimitiveType::unknown, - "[{}] was not type-checked", op2.serialize()); - TI_ASSERT_INFO(op3_type != PrimitiveType::unknown, - "[{}] was not type-checked", op3.serialize()); auto error = [&]() { throw std::runtime_error(fmt::format( "TypeError: unsupported operand type(s) for '{}': '{}', '{}' and '{}'", @@ -240,8 +236,7 @@ void TernaryOpExpression::flatten(FlattenContext *ctx) { void InternalFuncCallExpression::type_check() { for (auto &arg : args) { - TI_ASSERT_INFO(arg->ret_type != PrimitiveType::unknown, - "[{}] was not type-checked", arg.serialize()); + TI_ASSERT_TYPE_CHECKED(arg); // no arg type compatibility check for now due to lack of specification } // internal func calls have default return type @@ -260,13 +255,11 @@ void InternalFuncCallExpression::flatten(FlattenContext *ctx) { void ExternalFuncCallExpression::type_check() { for (auto &arg : args) { - TI_ASSERT_INFO(arg->ret_type != PrimitiveType::unknown, - "[{}] was not type-checked", arg.serialize()); + TI_ASSERT_TYPE_CHECKED(arg); // no arg type compatibility check for now due to lack of specification } for (auto &output : outputs) { - TI_ASSERT_INFO(output->ret_type != PrimitiveType::unknown, - "[{}] was not type-checked", output.serialize()); + TI_ASSERT_TYPE_CHECKED(output); // no output type compatibility check for now due to lack of specification } // external func calls have no return type for now @@ -327,8 +320,7 @@ void GlobalPtrExpression::type_check() { } else if (var.is()) { for (int i = 0; i < indices.exprs.size(); i++) { auto &expr = indices.exprs[i]; - TI_ASSERT_INFO(expr->ret_type != PrimitiveType::unknown, - "[{}] was not type-checked", expr.serialize()); + TI_ASSERT_TYPE_CHECKED(expr); if (!is_integral(expr->ret_type)) throw std::runtime_error( fmt::format("TypeError: indices must be integers, however '{}' is " @@ -471,10 +463,8 @@ void TensorElementExpression::flatten(FlattenContext *ctx) { } void RangeAssumptionExpression::type_check() { - TI_ASSERT_INFO(input->ret_type != PrimitiveType::unknown, - "[{}] was not type-checked", input.serialize()); - TI_ASSERT_INFO(base->ret_type != PrimitiveType::unknown, - "[{}] was not type-checked", base.serialize()); + TI_ASSERT_TYPE_CHECKED(input); + TI_ASSERT_TYPE_CHECKED(base); if (!input->ret_type->is() || !base->ret_type->is() || input->ret_type != base->ret_type) throw std::runtime_error( @@ -493,8 +483,7 @@ void RangeAssumptionExpression::flatten(FlattenContext *ctx) { } void LoopUniqueExpression::type_check() { - TI_ASSERT_INFO(input->ret_type != PrimitiveType::unknown, - "[{}] was not type-checked", input.serialize()); + TI_ASSERT_TYPE_CHECKED(input); if (!input->ret_type->is()) throw std::runtime_error(fmt::format( "TypeError: unsupported operand type(s) for 'loop_unique': '{}'", @@ -543,10 +532,8 @@ void IdExpression::flatten(FlattenContext *ctx) { } void AtomicOpExpression::type_check() { - TI_ASSERT_INFO(dest->ret_type != PrimitiveType::unknown, - "[{}] was not type-checked", dest.serialize()); - TI_ASSERT_INFO(val->ret_type != PrimitiveType::unknown, - "[{}] was not type-checked", val.serialize()); + TI_ASSERT_TYPE_CHECKED(dest); + TI_ASSERT_TYPE_CHECKED(val); auto error = [&]() { throw std::runtime_error(fmt::format( "TypeError: unsupported operand type(s) for 'atomic_{}': '{}' and '{}'", @@ -713,8 +700,7 @@ void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) { void FuncCallExpression::type_check() { for (auto &arg : args.exprs) { - TI_ASSERT_INFO(arg->ret_type != PrimitiveType::unknown, - "[{}] was not type-checked", arg.serialize()); + TI_ASSERT_TYPE_CHECKED(arg); // no arg type compatibility check for now due to lack of specification } TI_ASSERT_INFO(func->rets.size() <= 1, From 4b114c147a55d71598b51c7b5208527042054d4c Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Sat, 13 Nov 2021 10:21:54 +0000 Subject: [PATCH 5/5] Auto Format --- taichi/ir/frontend_ir.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 31aead6678673..da47ab903f97e 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -5,7 +5,9 @@ TLANG_NAMESPACE_BEGIN -#define TI_ASSERT_TYPE_CHECKED(x) TI_ASSERT_INFO(x->ret_type != PrimitiveType::unknown, "[{}] was not type-checked", x.serialize()) +#define TI_ASSERT_TYPE_CHECKED(x) \ + TI_ASSERT_INFO(x->ret_type != PrimitiveType::unknown, \ + "[{}] was not type-checked", x.serialize()) FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type, SNode *snode,