From 34abdb666bcfe923cdd77d11f0dc85dde101338b Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Tue, 30 May 2023 16:40:31 +0800 Subject: [PATCH] [ir] [refactor] Let the type of Alloca be pointer ghstack-source-id: 790315984a78fee0410ed85c89229f2785027338 Pull Request resolved: https://github.com/taichi-dev/taichi/pull/8007 --- python/taichi/lang/_ndrange.py | 2 +- python/taichi/lang/ast/ast_transformer.py | 4 +- python/taichi/lang/matrix.py | 6 +- taichi/analysis/gen_offline_cache_key.cpp | 18 +- taichi/codegen/cuda/codegen_cuda.cpp | 4 +- taichi/codegen/llvm/codegen_llvm.cpp | 18 +- taichi/codegen/llvm/llvm_codegen_utils.h | 4 +- taichi/codegen/spirv/spirv_codegen.cpp | 6 +- taichi/ir/control_flow_graph.cpp | 7 +- taichi/ir/expression_printer.h | 17 +- taichi/ir/frontend_ir.cpp | 264 +++++++++++------- taichi/ir/frontend_ir.h | 46 ++- taichi/ir/statements.cpp | 1 + taichi/ir/statements.h | 9 +- taichi/ir/type_factory.cpp | 3 +- taichi/python/export_lang.cpp | 14 +- taichi/transforms/auto_diff.cpp | 62 ++-- taichi/transforms/compile_to_offloads.cpp | 1 + taichi/transforms/frontend_type_check.cpp | 11 +- taichi/transforms/lower_ast.cpp | 6 +- taichi/transforms/offload.cpp | 10 +- taichi/transforms/scalarize.cpp | 26 +- taichi/transforms/simplify.cpp | 31 ++ taichi/transforms/type_check.cpp | 15 +- tests/cpp/ir/frontend_type_inference_test.cpp | 15 +- 25 files changed, 387 insertions(+), 213 deletions(-) diff --git a/python/taichi/lang/_ndrange.py b/python/taichi/lang/_ndrange.py index 5ea7c2af9a6cd5..a729b217ee7fb0 100644 --- a/python/taichi/lang/_ndrange.py +++ b/python/taichi/lang/_ndrange.py @@ -23,7 +23,7 @@ def __init__(self, *args): for arg in args: for bound in arg: if not isinstance(bound, (int, np.integer)) and not ( - isinstance(bound, Expr) and is_integral(bound.ptr.get_ret_type()) + isinstance(bound, Expr) and is_integral(bound.ptr.get_rvalue_type()) ): raise TaichiTypeError( "Every argument of ndrange should be an integer scalar or a tuple/list of (int, int)" diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 471676d3457b08..b0b1ee35da6b9a 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -709,7 +709,7 @@ def transform_as_kernel(): f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix, but got {type(data)}." ) - element_shape = data.ptr.get_ret_type().shape() + element_shape = data.ptr.get_rvalue_type().shape() if len(element_shape) != ctx.func.arguments[i].annotation.ndim: raise TaichiSyntaxError( f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with ndim {ctx.func.arguments[i].annotation.ndim}, but got {len(element_shape)}." @@ -1449,7 +1449,7 @@ def ti_format_list_to_assert_msg(raw): if isinstance(entry, str): msg += entry elif isinstance(entry, _ti_core.Expr): - ty = entry.get_ret_type() + ty = entry.get_rvalue_type() if ty in primitive_types.real_types: msg += "%f" elif ty in primitive_types.integer_types: diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 03daa799044f6d..e3c1ed6496a2fb 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -133,7 +133,7 @@ def _infer_entry_dt(entry): if isinstance(entry, (float, np.floating)): return impl.get_runtime().default_fp if isinstance(entry, expr.Expr): - dt = entry.ptr.get_ret_type() + dt = entry.ptr.get_rvalue_type() if dt == ti_python_core.DataType_unknown: raise TaichiTypeError("Element type of the matrix cannot be inferred. Please set dt instead for now.") return dt @@ -1412,7 +1412,7 @@ def __call__(self, *args): # Init from a real Matrix if isinstance(args[0], expr.Expr) and args[0].ptr.is_tensor(): arg = args[0] - shape = arg.ptr.get_ret_type().shape() + shape = arg.ptr.get_rvalue_type().shape() assert self.ndim == len(shape) assert self.n == shape[0] if self.ndim > 1: @@ -1554,7 +1554,7 @@ def __call__(self, *args): # Init from a real Matrix if isinstance(args[0], expr.Expr) and args[0].ptr.is_tensor(): arg = args[0] - shape = arg.ptr.get_ret_type().shape() + shape = arg.ptr.get_rvalue_type().shape() assert len(shape) == 1 assert self.n == shape[0] return expr.Expr(arg.ptr) diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index 13cb82809c1843..a420fb67b05502 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -12,23 +12,6 @@ namespace taichi::lang { namespace { -enum class ExprOpCode : std::uint8_t { - NIL, -#define PER_EXPRESSION(x) x, -#include "taichi/inc/expressions.inc.h" -#undef PER_EXPRESSION -}; - -enum class StmtOpCode : std::uint8_t { - NIL, - EnterBlock, - ExitBlock, - StopGrad, -#define PER_STATEMENT(x) x, -#include "taichi/inc/frontend_statements.inc.h" -#undef PER_STATEMENT -}; - enum class ForLoopType : std::uint8_t { StructForOnSNode, StructForOnExternalTensor, @@ -198,6 +181,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { void visit(IdExpression *expr) override { emit(ExprOpCode::IdExpression); emit(expr->id); + emit(expr->op); } void visit(AtomicOpExpression *expr) override { diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index 9ffb507b3905a1..9b87b5a91f0638 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -172,8 +172,8 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { void visit(AllocaStmt *stmt) override { // Override shared memory codegen logic for large shared memory - if (stmt->ret_type->is() && stmt->is_shared) { - auto tensor_type = stmt->ret_type->cast(); + auto tensor_type = stmt->ret_type.ptr_removed()->cast(); + if (tensor_type && stmt->is_shared) { size_t shared_array_bytes = tensor_type->get_num_elements() * data_type_size(tensor_type->get_element_type()); diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 4e717e24ba3694..c09e5c7989f0a5 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -126,8 +126,10 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) { } void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { - if (stmt->ret_type->is()) { - auto tensor_type = stmt->ret_type->cast(); + TI_ASSERT(stmt->ret_type.is_pointer()); + auto alloca_type = stmt->ret_type->as()->get_pointee_type(); + if (alloca_type->is()) { + auto tensor_type = alloca_type->cast(); auto type = tlctx->get_data_type(tensor_type); if (stmt->is_shared) { auto base = new llvm::GlobalVariable( @@ -141,12 +143,10 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { llvm_val[stmt] = create_entry_block_alloca(type); } } else { - llvm_val[stmt] = - create_entry_block_alloca(stmt->ret_type, stmt->ret_type.is_pointer()); + llvm_val[stmt] = create_entry_block_alloca(alloca_type); // initialize as zero if element is not a pointer - if (!stmt->ret_type.is_pointer()) - builder->CreateStore(tlctx->get_constant(stmt->ret_type, 0), - llvm_val[stmt]); + if (!alloca_type->is()) + builder->CreateStore(tlctx->get_constant(alloca_type, 0), llvm_val[stmt]); } } @@ -878,6 +878,7 @@ void TaskCodeGenLLVM::visit(IfStmt *if_stmt) { llvm::BasicBlock::Create(*llvm_context, "false_block", func); llvm::BasicBlock *after_if = llvm::BasicBlock::Create(*llvm_context, "after_if", func); + llvm_val[if_stmt->cond]->dump(); llvm::Value *cond = builder->CreateIsNotNull(llvm_val[if_stmt->cond]); builder->CreateCondBr(cond, true_block, false_block); builder->SetInsertPoint(true_block); @@ -1310,6 +1311,9 @@ void TaskCodeGenLLVM::visit(LocalLoadStmt *stmt) { } void TaskCodeGenLLVM::visit(LocalStoreStmt *stmt) { + // irpass::print(stmt); + // llvm_val[stmt->val]->dump(); + // llvm_val[stmt->dest]->dump(); builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]); } diff --git a/taichi/codegen/llvm/llvm_codegen_utils.h b/taichi/codegen/llvm/llvm_codegen_utils.h index 51b08efc52a49e..6cc36518a2a4ba 100644 --- a/taichi/codegen/llvm/llvm_codegen_utils.h +++ b/taichi/codegen/llvm/llvm_codegen_utils.h @@ -83,8 +83,8 @@ class LLVMModuleBuilder { llvm::Value *create_entry_block_alloca(DataType dt, bool is_pointer = false) { auto type = tlctx->get_data_type(dt); - if (is_pointer) - type = llvm::PointerType::get(type, 0); + // if (is_pointer) + // type = llvm::PointerType::get(type, 0); return create_entry_block_alloca(type); } diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 700434684d7714..e588d18746d45d 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -274,8 +274,8 @@ class TaskCodegen : public IRVisitor { void visit(AllocaStmt *alloca) override { spirv::Value ptr_val; - if (alloca->ret_type->is()) { - auto tensor_type = alloca->ret_type->cast(); + auto alloca_type = alloca->ret_type->as()->get_pointee_type(); + if (auto tensor_type = alloca_type->cast()) { auto elem_num = tensor_type->get_num_elements(); spirv::SType elem_type = ir_->get_primitive_type(tensor_type->get_element_type()); @@ -288,7 +288,7 @@ class TaskCodegen : public IRVisitor { } } else { // Alloca for a single variable - spirv::SType src_type = ir_->get_primitive_type(alloca->element_type()); + spirv::SType src_type = ir_->get_primitive_type(alloca_type); ptr_val = ir_->alloca_variable(src_type); ir_->store_variable(ptr_val, ir_->get_zero(src_type)); } diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index cc8e9472d4f705..f95bf8934cdc35 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -247,7 +247,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // result: the value to store Stmt *result = irpass::analysis::get_store_data( block->statements[last_def_position].get()); - bool is_tensor_involved = var->ret_type->is(); + bool is_tensor_involved = var->ret_type.ptr_removed()->is(); if (!(var->is() && !is_tensor_involved)) { // In between the store stmt and current stmt, // if there's a third-stmt that "may" have stored a "different value" to @@ -355,7 +355,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // Check for aliased address // There's a store to the same dest_addr before this stmt - bool is_tensor_involved = var->ret_type->is(); + bool is_tensor_involved = var->ret_type.ptr_removed()->is(); if (!(var->is() && !is_tensor_involved)) { // In between the store stmt and current stmt, // if there's a third-stmt that "may" have stored a "different value" to @@ -440,7 +440,8 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access, continue; // special case of alloca (initialized to 0) - auto zero = Stmt::make(TypedConstant(result->ret_type, 0)); + auto zero = Stmt::make( + TypedConstant(result->ret_type.ptr_removed(), 0)); replace_with(i, std::move(zero), true); } else { if (result->ret_type.ptr_removed()->is() && diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h index 7d4750ffe9e5e0..ece4411d96d45a 100644 --- a/taichi/ir/expression_printer.h +++ b/taichi/ir/expression_printer.h @@ -76,11 +76,11 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { } void visit(TernaryOpExpression *expr) override { - emit(ternary_type_name(expr->type), '('); + emit(ternary_type_name(expr->type), "(op1: "); expr->op1->accept(this); - emit(' '); + emit(", op2: "); expr->op2->accept(this); - emit(' '); + emit(", op3: "); expr->op3->accept(this); emit(')'); } @@ -125,6 +125,7 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { } void visit(IndexExpression *expr) override { + emit("<" + expr->ret_type->to_string() + ">"); expr->var->accept(this); emit('['); if (expr->ret_shape.empty()) { @@ -164,7 +165,10 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { } void visit(IdExpression *expr) override { + emit("<" + expr->ret_type->to_string() + ">"); emit(expr->id.name()); + emit(": "); + emit(to_string(expr->op)); } void visit(AtomicOpExpression *expr) override { @@ -251,6 +255,13 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { return oss.str(); } + static std::string expr_to_string(Expression *expr) { + std::ostringstream oss; + ExpressionHumanFriendlyPrinter printer(&oss); + expr->accept(&printer); + return oss.str(); + } + protected: template void emit(Args &&...args) { diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 1caf2977fef88d..0b9b95f6acd9a8 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -37,7 +37,15 @@ FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs) : lhs(lhs), rhs(rhs) { TI_ASSERT(lhs->is_lvalue()); if (lhs.is() && lhs->ret_type == PrimitiveType::unknown) { - lhs.expr->ret_type = rhs->ret_type; + TI_ASSERT(lhs.cast()->op == StmtOpCode::FrontendAllocaStmt); + // TI_INFO("lhs: {}, rhs: {}", + // ExpressionHumanFriendlyPrinter::expr_to_string(lhs.expr.get()), + // ExpressionHumanFriendlyPrinter::expr_to_string(rhs.expr.get())); + lhs.expr->ret_type = + TypeFactory::get_instance().get_pointer_type(get_rvalue_dtype(rhs)); + // TI_INFO("after: lhs: {}, rhs: {}", + // ExpressionHumanFriendlyPrinter::expr_to_string(lhs.expr.get()), + // ExpressionHumanFriendlyPrinter::expr_to_string(rhs.expr.get())); } } @@ -127,7 +135,8 @@ void FrontendForStmt::init_loop_vars(const ExprGroup &loop_vars) { void FrontendForStmt::add_loop_var(const Expr &loop_var) { loop_var_ids.push_back(loop_var.cast()->id); - loop_var.expr->ret_type = PrimitiveType::i32; + loop_var.expr->ret_type = + TypeFactory::get_instance().get_pointer_type(PrimitiveType::i32); } FrontendFuncDefStmt::FrontendFuncDefStmt(const FrontendFuncDefStmt &o) @@ -188,8 +197,8 @@ void UnaryOpExpression::type_check(const CompileConfig *config) { the same. Therefore we extract the primitive type to perform the type inference, and then reconstruct the TensorType once neccessary. */ - - auto operand_primitive_type = operand->ret_type.get_element_type(); + auto operand_type = get_rvalue_dtype(operand); + auto operand_primitive_type = operand_type.get_element_type(); auto ret_primitive_type = ret_type; if (!operand_primitive_type->is()) { @@ -245,11 +254,11 @@ void UnaryOpExpression::type_check(const CompileConfig *config) { unary_op_type_name(type), operand_primitive_type->to_string())); } - if (operand->ret_type->is()) { + if (operand_type->is()) { ret_type = taichi::lang::TypeFactory::get_instance().get_tensor_type( - operand->ret_type.get_shape(), ret_primitive_type); + operand_type.get_shape(), ret_primitive_type); } else { - TI_ASSERT(operand->ret_type->is()); + TI_ASSERT(operand_type->is()); ret_type = ret_primitive_type; } } @@ -271,13 +280,14 @@ void UnaryOpExpression::flatten(FlattenContext *ctx) { } Expr to_broadcast_tensor(const Expr &elt, const DataType &dt) { - if (!elt->ret_type->is() && !dt->is()) + auto elt_type = get_rvalue_dtype(elt); + if (!elt_type->is() && !dt->is()) return elt; - if (elt->ret_type->is() && dt->is()) { + if (elt_type->is() && dt->is()) { // Only tensor shape will be checked here, since the dtype will // be promoted later at irpass::type_check() - if (elt->ret_type.get_shape() != dt.get_shape()) { + if (elt_type.get_shape() != dt.get_shape()) { TI_ERROR("Cannot broadcast tensor to tensor"); } else { return elt; @@ -285,23 +295,24 @@ Expr to_broadcast_tensor(const Expr &elt, const DataType &dt) { } auto tensor_type = dt->as(); - auto elt_type = tensor_type->get_element_type(); - TI_ASSERT_INFO(elt_type->is(), + auto tensor_elt_type = tensor_type->get_element_type(); + TI_ASSERT_INFO(tensor_elt_type->is(), "Only primitive types are supported in Tensors, got {}", - elt_type->to_string()); + tensor_elt_type->to_string()); std::vector broadcast_values(tensor_type->get_num_elements(), elt); auto matrix_expr = Expr::make( - broadcast_values, tensor_type->get_shape(), elt->ret_type); + broadcast_values, tensor_type->get_shape(), elt_type); matrix_expr->type_check(nullptr); return matrix_expr; } std::tuple unify_binop_operands(const Expr &e1, const Expr &e2) { - if (e1->ret_type->is() && e2->ret_type->is()) { - return std::tuple(to_broadcast_tensor(e1, e2->ret_type), e2); - } else if (e1->ret_type->is() && - e2->ret_type->is()) { - return std::tuple(e1, to_broadcast_tensor(e2, e1->ret_type)); + auto e1_type = get_rvalue_dtype(e1); + auto e2_type = get_rvalue_dtype(e2); + if (e1_type->is() && e2_type->is()) { + return std::tuple(to_broadcast_tensor(e1, e2_type), e2); + } else if (e1_type->is() && e2_type->is()) { + return std::tuple(e1, to_broadcast_tensor(e2, e1_type)); } else { return std::tuple(e1, e2); } @@ -310,13 +321,16 @@ std::tuple unify_binop_operands(const Expr &e1, const Expr &e2) { void BinaryOpExpression::type_check(const CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(lhs); TI_ASSERT_TYPE_CHECKED(rhs); - auto lhs_type = lhs->ret_type; - auto rhs_type = rhs->ret_type; + + auto lhs_type = get_rvalue_dtype(lhs); + auto rhs_type = get_rvalue_dtype(rhs); + // TI_INFO("BinExpr: {}", + // ExpressionHumanFriendlyPrinter::expr_to_string(this)); auto error = [&]() { throw TaichiTypeError( fmt::format("unsupported operand type(s) for '{}': '{}' and '{}'", - binary_op_type_symbol(type), lhs->ret_type->to_string(), - rhs->ret_type->to_string())); + binary_op_type_symbol(type), lhs_type->to_string(), + rhs_type->to_string())); }; if (!is_primitive_or_tensor_type(lhs_type) || @@ -330,14 +344,14 @@ void BinaryOpExpression::type_check(const CompileConfig *config) { auto [unified_l, unified_r] = unify_binop_operands(lhs, rhs); lhs = unified_l; rhs = unified_r; - if (lhs->ret_type == PrimitiveType::unknown) + if (lhs_type == PrimitiveType::unknown) lhs.type_check(config); - if (rhs->ret_type == PrimitiveType::unknown) + if (rhs_type == PrimitiveType::unknown) rhs.type_check(config); - TI_ASSERT(lhs->ret_type->is()); - TI_ASSERT(rhs->ret_type->is()); - lhs_type = lhs->ret_type; - rhs_type = rhs->ret_type; + lhs_type = get_rvalue_dtype(lhs); + rhs_type = get_rvalue_dtype(rhs); + TI_ASSERT(lhs_type->is()); + TI_ASSERT(rhs_type->is()); } bool is_tensor_op = false; @@ -351,10 +365,10 @@ void BinaryOpExpression::type_check(const CompileConfig *config) { error(); } - auto make_dt = [&is_tensor_op, this](DataType dt) { + auto make_dt = [&is_tensor_op, lhs_type](DataType dt) { if (is_tensor_op) { return TypeFactory::create_tensor_type( - this->lhs->ret_type->cast()->get_shape(), dt); + lhs_type->cast()->get_shape(), dt); } else { return dt; } @@ -404,8 +418,10 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) { // return; auto lhs_stmt = flatten_rvalue(lhs, ctx); - if (binary_is_logical(type) && !is_tensor(lhs->ret_type) && - !is_tensor(rhs->ret_type)) { + auto lhs_type = get_rvalue_dtype(lhs); + auto rhs_type = get_rvalue_dtype(rhs); + + if (binary_is_logical(type) && !is_tensor(lhs_type) && !is_tensor(rhs_type)) { auto result = ctx->push_back(ret_type); ctx->push_back(result, lhs_stmt); auto cond = ctx->push_back(result); @@ -478,12 +494,15 @@ static std::tuple unify_ternaryop_operands(const Expr &e1, auto target_dtype = PrimitiveType::unknown; // Since we don't support broadcasting between two TensorTypes, // we can simply use the first TensorType's dtype as the target dtype. - if (e1->ret_type->is()) { - target_dtype = e1->ret_type; - } else if (e2->ret_type->is()) { - target_dtype = e2->ret_type; - } else if (e3->ret_type->is()) { - target_dtype = e3->ret_type; + auto e1_type = get_rvalue_dtype(e1); + auto e2_type = get_rvalue_dtype(e2); + auto e3_type = get_rvalue_dtype(e3); + if (e1_type->is()) { + target_dtype = e1_type; + } else if (e2_type->is()) { + target_dtype = e2_type; + } else if (e3_type->is()) { + target_dtype = e3_type; } if (target_dtype == PrimitiveType::unknown) { @@ -498,6 +517,8 @@ void TernaryOpExpression::type_check(const CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(op1); TI_ASSERT_TYPE_CHECKED(op2); TI_ASSERT_TYPE_CHECKED(op3); + // TI_INFO("Ternary op {}", + // ExpressionHumanFriendlyPrinter::expr_to_string(this)); bool is_valid = true; bool is_tensor = false; @@ -507,17 +528,17 @@ void TernaryOpExpression::type_check(const CompileConfig *config) { op1 = unified_cond; op2 = unified_l; op3 = unified_r; - auto op1_type = op1->ret_type; - auto op2_type = op2->ret_type; - auto op3_type = op3->ret_type; + auto op1_type = get_rvalue_dtype(op1); + auto op2_type = get_rvalue_dtype(op2); + auto op3_type = get_rvalue_dtype(op3); auto error = [&]() { throw TaichiTypeError( fmt::format("unsupported operand type(s) for '{}': '{}', '{}' and '{}'", - ternary_type_name(type), op1->ret_type->to_string(), - op2->ret_type->to_string(), op3->ret_type->to_string())); + ternary_type_name(type), op1_type->to_string(), + op2_type->to_string(), op3_type->to_string())); }; - + std::vector shape; if (op2_type->is() && op3_type->is()) { // valid is_tensor = true; @@ -534,6 +555,7 @@ void TernaryOpExpression::type_check(const CompileConfig *config) { if (op1_type->is()) { op1_type = op1_type->cast()->get_element_type(); } + shape = op2_type->cast()->get_shape(); op2_type = op2_type->cast()->get_element_type(); op3_type = op3_type->cast()->get_element_type(); @@ -556,7 +578,6 @@ void TernaryOpExpression::type_check(const CompileConfig *config) { if (is_tensor) { auto primitive_dtype = promoted_type(op2_type, op3_type); - auto shape = op2->ret_type->cast()->get_shape(); ret_type = TypeFactory::create_tensor_type(shape, primitive_dtype); } else { ret_type = promoted_type(op2_type, op3_type); @@ -584,7 +605,7 @@ void InternalFuncCallExpression::type_check(const CompileConfig *) { std::vector arg_types; for (auto &arg : args) { TI_ASSERT_TYPE_CHECKED(arg); - arg_types.push_back(arg.get_ret_type()); + arg_types.push_back(get_rvalue_dtype(arg)); } ret_type = op->type_check(arg_types); } @@ -710,11 +731,11 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, const std::string &tb) { auto var_stmt = flatten_lvalue(var, ctx); if (!var->is_lvalue()) { - auto alloca_stmt = ctx->push_back(var->ret_type); + auto alloca_stmt = ctx->push_back(var->ret_type.ptr_removed()); ctx->push_back(alloca_stmt, var_stmt); var_stmt = alloca_stmt; } - if (is_tensor(ret_type)) { + if (ret_type->as()->get_pointee_type()->is()) { std::vector stmts; for (auto &indices : indices_group) { stmts.push_back( @@ -727,12 +748,13 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, } void MatrixExpression::type_check(const CompileConfig *config) { - TI_ASSERT(dt->as()->get_num_elements() == elements.size()); + auto tensor_type = dt->as(); + TI_ASSERT(tensor_type->get_num_elements() == elements.size()); for (auto &arg : elements) { TI_ASSERT_TYPE_CHECKED(arg); - if (arg->ret_type != dt.get_element_type()) { - arg = cast(arg, dt.get_element_type()); + if (get_rvalue_dtype(arg)->get_type() != tensor_type->get_element_type()) { + arg = cast(arg, tensor_type->get_element_type()); arg->type_check(config); } } @@ -740,13 +762,13 @@ void MatrixExpression::type_check(const CompileConfig *config) { } void MatrixExpression::flatten(FlattenContext *ctx) { - TI_ASSERT(this->dt->is()); + TI_ASSERT(dt->is()); std::vector values; for (auto &elt : elements) { values.push_back(flatten_rvalue(elt, ctx)); } stmt = ctx->push_back(values); - stmt->ret_type = this->dt; + stmt->ret_type = dt; } IndexExpression::IndexExpression(const Expr &var, @@ -782,7 +804,7 @@ bool IndexExpression::is_ndarray() const { } bool IndexExpression::is_tensor() const { - return var->ret_type->is(); + return var->ret_type.ptr_removed()->is(); } bool IndexExpression::is_local() const { @@ -822,16 +844,15 @@ void IndexExpression::type_check(const CompileConfig *) { std::multiplies<>())); int index_dim = indices_group.empty() ? 0 : indices_group[0].size(); bool has_slice = !ret_shape.empty(); + auto var_type = get_rvalue_dtype(var); if (has_slice) { TI_ASSERT_INFO(is_tensor(), "Slice or swizzle can only apply on matrices"); - auto element_type = var->ret_type->as()->get_element_type(); + auto element_type = var_type->as()->get_element_type(); ret_type = TypeFactory::create_tensor_type(ret_shape, element_type); - } else if (is_field()) { // field auto field_expr = var.cast(); field_validation(field_expr.get(), index_dim); ret_type = field_expr->dt->get_compute_type(); - } else if (is_matrix_field()) { auto matrix_field_expr = var.cast(); @@ -862,27 +883,32 @@ void IndexExpression::type_check(const CompileConfig *) { ret_type = var.cast()->dt; } } else if (is_tensor()) { // local tensor - auto shape = var->ret_type->as()->get_shape(); + auto tensor_type = var_type->as(); + auto shape = tensor_type->get_shape(); if (indices_group[0].size() != shape.size()) { TI_ERROR("Expected {} indices, got {}.", shape.size(), indices_group[0].size()); } - ret_type = var->ret_type->cast()->get_element_type(); + ret_type = tensor_type->get_element_type(); } else { throw TaichiTypeError( "Invalid IndexExpression: the source is not among field, ndarray or " "local tensor"); } - + ret_type = TypeFactory::get_instance().get_pointer_type(ret_type); + // TI_INFO("IndexExpression {} type checked : {}.", + // ExpressionHumanFriendlyPrinter::expr_to_string(this), + // ret_type->to_string()); for (auto &indices : indices_group) { for (int i = 0; i < indices.exprs.size(); i++) { auto &expr = indices.exprs[i]; TI_ASSERT_TYPE_CHECKED(expr); - if (!is_integral(expr->ret_type)) + auto expr_type = get_rvalue_dtype(expr); + if (!is_integral(expr_type)) throw TaichiTypeError( fmt::format("indices must be integers, however '{}' is " "provided as index {}", - expr->ret_type->to_string(), i)); + expr_type->to_string(), i)); } } } @@ -897,9 +923,10 @@ void IndexExpression::flatten(FlattenContext *ctx) { } else if (is_ndarray()) { stmt = make_ndarray_access(ctx, var, indices_group[0]); } else if (is_tensor()) { - stmt = - make_tensor_access(ctx, var, indices_group, ret_type, - var->ret_type->cast()->get_shape(), tb); + // TI_INFO("{}", ExpressionHumanFriendlyPrinter::expr_to_string(var)); + stmt = make_tensor_access( + ctx, var, indices_group, ret_type, + var->ret_type.ptr_removed()->as()->get_shape(), tb); } else { throw TaichiTypeError( "Invalid IndexExpression: the source is not among field, ndarray or " @@ -911,13 +938,15 @@ void IndexExpression::flatten(FlattenContext *ctx) { void RangeAssumptionExpression::type_check(const CompileConfig *) { TI_ASSERT_TYPE_CHECKED(input); TI_ASSERT_TYPE_CHECKED(base); - if (!input->ret_type->is() || - !base->ret_type->is() || input->ret_type != base->ret_type) + auto input_type = get_rvalue_dtype(input); + auto base_type = get_rvalue_dtype(base); + if (!input_type->is() || !base_type->is() || + input_type != base_type) throw TaichiTypeError( fmt::format("unsupported operand type(s) for " "'range_assumption': '{}' and '{}'", - input->ret_type->to_string(), base->ret_type->to_string())); - ret_type = input->ret_type; + input_type->to_string(), base_type->to_string())); + ret_type = input_type; } void RangeAssumptionExpression::flatten(FlattenContext *ctx) { @@ -930,11 +959,13 @@ void RangeAssumptionExpression::flatten(FlattenContext *ctx) { void LoopUniqueExpression::type_check(const CompileConfig *) { TI_ASSERT_TYPE_CHECKED(input); - if (!input->ret_type->is()) + auto input_type = get_rvalue_dtype(input); + + if (!input_type->is()) throw TaichiTypeError( fmt::format("unsupported operand type(s) for 'loop_unique': '{}'", - input->ret_type->to_string())); - ret_type = input->ret_type; + input_type->to_string())); + ret_type = input_type; } void LoopUniqueExpression::flatten(FlattenContext *ctx) { @@ -945,7 +976,7 @@ void LoopUniqueExpression::flatten(FlattenContext *ctx) { void IdExpression::flatten(FlattenContext *ctx) { stmt = ctx->current_block->lookup_var(id); - if (!ret_type->is_primitive(PrimitiveTypeID::unknown)) { + if (stmt->ret_type->is_primitive(PrimitiveTypeID::unknown)) { stmt->ret_type = ret_type; } } @@ -1038,10 +1069,11 @@ void SNodeOpExpression::type_check(const CompileConfig *config) { for (int i = 0; i < values.size(); i++) { TI_ASSERT_TYPE_CHECKED(values[i]); auto &dst_type = snode->ch[i]->dt; - auto promoted = promoted_type(dst_type, values[i]->ret_type); + auto value_type = get_rvalue_dtype(values[i]); + auto promoted = promoted_type(dst_type, value_type); if (dst_type != promoted) { TI_WARN("Append may lose precision: {} <- {}\n{}", - dst_type->to_string(), values[i]->ret_type->to_string(), tb); + dst_type->to_string(), value_type->to_string(), tb); } values[i] = cast(values[i], dst_type); values[i]->type_check(config); @@ -1105,11 +1137,12 @@ void TextureOpExpression::type_check(const CompileConfig *config) { ptr->num_dims); for (int i = 0; i < ptr->num_dims; i++) { TI_ASSERT_TYPE_CHECKED(args[i]); - if (args[i].get_ret_type() != PrimitiveType::f32) { + auto arg_type = get_rvalue_dtype(args[i]); + if (arg_type != PrimitiveType::f32) { throw TaichiTypeError( fmt::format("Invalid type for texture sample_lod: '{}', all " "arguments must be f32", - args[i].get_ret_type()->to_string())); + arg_type->to_string())); } } } else if (op == TextureOpType::kFetchTexel) { @@ -1120,11 +1153,12 @@ void TextureOpExpression::type_check(const CompileConfig *config) { ptr->num_dims); for (int i = 0; i < ptr->num_dims; i++) { TI_ASSERT_TYPE_CHECKED(args[i]); - if (args[i].get_ret_type() != PrimitiveType::i32) { + auto arg_type = get_rvalue_dtype(args[i]); + if (arg_type != PrimitiveType::i32) { throw TaichiTypeError( fmt::format("Invalid type for texture fetch_texel: '{}', all " "arguments must be i32", - args[i].get_ret_type()->to_string())); + arg_type->to_string())); } } } else if (op == TextureOpType::kLoad) { @@ -1135,11 +1169,12 @@ void TextureOpExpression::type_check(const CompileConfig *config) { ptr->num_dims); for (int i = 0; i < ptr->num_dims; i++) { TI_ASSERT_TYPE_CHECKED(args[i]); - if (args[i].get_ret_type() != PrimitiveType::i32) { + auto arg_type = get_rvalue_dtype(args[i]); + if (arg_type != PrimitiveType::i32) { throw TaichiTypeError( fmt::format("Invalid type for texture load: '{}', all " "arguments must be i32", - args[i].get_ret_type()->to_string())); + arg_type->to_string())); } } } else if (op == TextureOpType::kStore) { @@ -1150,20 +1185,22 @@ void TextureOpExpression::type_check(const CompileConfig *config) { ptr->num_dims); for (int i = 0; i < ptr->num_dims; i++) { TI_ASSERT_TYPE_CHECKED(args[i]); - if (args[i].get_ret_type() != PrimitiveType::i32) { + auto arg_type = get_rvalue_dtype(args[i]); + if (arg_type != PrimitiveType::i32) { throw TaichiTypeError( fmt::format("Invalid type for texture load: '{}', index " "arguments must be i32", - args[i].get_ret_type()->to_string())); + arg_type->to_string())); } } for (int i = ptr->num_dims; i < ptr->num_dims + 4; i++) { TI_ASSERT_TYPE_CHECKED(args[i]); - if (args[i].get_ret_type() != PrimitiveType::f32) { + auto arg_type = get_rvalue_dtype(args[i]); + if (arg_type != PrimitiveType::f32) { throw TaichiTypeError( fmt::format("Invalid type for texture load: '{}', value " "arguments must be f32", - args[i].get_ret_type()->to_string())); + arg_type->to_string())); } } } else { @@ -1213,12 +1250,12 @@ void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) { void GetElementExpression::type_check(const CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(src); - TI_ASSERT_INFO(src->ret_type->is(), + auto src_type = src->ret_type; + TI_ASSERT_INFO(src_type->is(), "Invalid src [{}] for GetElementExpression", ExpressionHumanFriendlyPrinter::expr_to_string(src)); - ret_type = - src->ret_type.ptr_removed()->as()->get_element_type(index); + ret_type = src_type.ptr_removed()->as()->get_element_type(index); } void GetElementExpression::flatten(FlattenContext *ctx) { @@ -1334,7 +1371,8 @@ Expr ASTBuilder::make_id_expr(const std::string &name) { void ASTBuilder::insert_for(const Expr &s, const Expr &e, const std::function &func) { - auto i = Expr(std::make_shared(get_next_id())); + auto i = Expr(std::make_shared(get_next_id(), + StmtOpCode::FrontendForStmt)); auto stmt_unique = std::make_unique(i, s, e, this->arch_, for_loop_dec_.config); for_loop_dec_.reset(); @@ -1431,7 +1469,8 @@ void ASTBuilder::insert_external_func_call(std::size_t func_addr, } Expr ASTBuilder::expr_alloca() { - auto var = Expr(std::make_shared(get_next_id())); + auto var = Expr(std::make_shared( + get_next_id(), StmtOpCode::FrontendAllocaStmt)); this->insert(std::make_unique( std::static_pointer_cast(var.expr)->id, PrimitiveType::unknown)); @@ -1443,7 +1482,8 @@ std::optional ASTBuilder::insert_func_call(Function *func, ExprGroup expanded_args; expanded_args.exprs = this->expand_exprs(args.exprs); if (!func->rets.empty()) { - auto var = Expr(std::make_shared(get_next_id())); + auto var = Expr(std::make_shared( + get_next_id(), StmtOpCode::FrontendFuncCallStmt)); this->insert(std::make_unique( func, expanded_args, std::static_pointer_cast(var.expr)->id)); @@ -1473,7 +1513,8 @@ Expr ASTBuilder::make_matrix_expr(const std::vector &shape, Expr ASTBuilder::expr_alloca_shared_array(const std::vector &shape, const DataType &element_type) { - auto var = Expr(std::make_shared(get_next_id())); + auto var = Expr(std::make_shared( + get_next_id(), StmtOpCode::FrontendAllocaStmt)); this->insert(std::make_unique( std::static_pointer_cast(var.expr)->id, shape, element_type, true)); @@ -1493,7 +1534,7 @@ Expr ASTBuilder::expr_subscript(const Expr &expr, std::string tb) { TI_ASSERT(expr.is() || expr.is() || expr.is() || - is_tensor(expr.expr->ret_type)); + is_tensor(expr.expr->ret_type.ptr_removed())); // IndexExpression without ret_shape is used for matrix indexing, // where each entry of ExprGroup is interpreted as indexing into a specific @@ -1656,7 +1697,7 @@ std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { elem.expr->ret_type = struct_type->get_element_type(indices); expanded_exprs.push_back(elem); } - } else if (!expr->ret_type->is()) { + } else if (!expr->ret_type.ptr_removed()->is()) { expanded_exprs.push_back(expr); } else { // Expand TensorType expr @@ -1674,7 +1715,7 @@ std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { return {ind0, ind1, ind2, ind3} */ - auto tensor_type = expr->ret_type->cast(); + auto tensor_type = expr->ret_type.ptr_removed()->cast(); Expr id_expr; if (expr.is()) { @@ -1687,7 +1728,8 @@ std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { for (int i = 0; i < shape[0]; i++) { auto ind = Expr(std::make_shared( id_expr, ExprGroup(Expr(i)), expr->tb)); - ind.expr->ret_type = tensor_type->get_element_type(); + ind->type_check(nullptr); + // ind.expr->ret_type = tensor_type->get_element_type(); expanded_exprs.push_back(ind); } } else { @@ -1696,7 +1738,8 @@ std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { for (int j = 0; j < shape[1]; j++) { auto ind = Expr(std::make_shared( id_expr, ExprGroup(Expr(i), Expr(j)), expr->tb)); - ind.expr->ret_type = tensor_type->get_element_type(); + ind->type_check(nullptr); + // ind.expr->ret_type = tensor_type->get_element_type(); expanded_exprs.push_back(ind); } } @@ -1797,13 +1840,34 @@ Stmt *flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) { return ptr_stmt; } -DataType get_rvalue_dtype(Expr expr) { +DataType get_rvalue_dtype(const Expr &expr) { if (auto argload = expr.cast()) { if (argload->is_ptr) { - return argload->ret_type.ptr_removed(); + return argload->ret_type->as()->get_pointee_type(); } return argload->ret_type; } + if (auto id = expr.cast()) { + // if (id->op == StmtOpCode::FrontendAllocaStmt) { + return id->ret_type->as()->get_pointee_type(); + // } + // return id->ret_type; + } + if (auto index_expr = expr.cast()) { + return index_expr->ret_type->as()->get_pointee_type(); + } + if (auto unary = expr.cast()) { + if (unary->type == UnaryOpType::frexp) { + return unary->ret_type->as()->get_pointee_type(); + } + return unary->ret_type; + } + if (auto texture_op = expr.cast()) { + if (texture_op->op == TextureOpType::kStore) { + return texture_op->ret_type->as()->get_pointee_type(); + } + return texture_op->ret_type; + } return expr->ret_type; } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index f0d116399d5f74..0d50c18957abc5 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -15,7 +15,41 @@ namespace taichi::lang { class ASTBuilder; +enum class ExprOpCode : std::uint8_t { + NIL, +#define PER_EXPRESSION(x) x, +#include "taichi/inc/expressions.inc.h" +#undef PER_EXPRESSION +}; +enum class StmtOpCode : std::uint8_t { + NIL, + EnterBlock, + ExitBlock, + StopGrad, +#define PER_STATEMENT(x) x, +#include "taichi/inc/frontend_statements.inc.h" +#undef PER_STATEMENT +}; +inline std::string to_string(StmtOpCode op) { + switch (op) { +#define PER_STATEMENT(x) \ + case StmtOpCode::x: \ + return #x; +#include "taichi/inc/frontend_statements.inc.h" +#undef PER_STATEMENT + case StmtOpCode::NIL: + return "NIL"; + case StmtOpCode::EnterBlock: + return "EnterBlock"; + case StmtOpCode::ExitBlock: + return "ExitBlock"; + case StmtOpCode::StopGrad: + return "StopGrad"; + default: + TI_NOT_IMPLEMENTED + } +} struct ForLoopConfig { bool is_bit_vectorized{false}; int num_cpu_threads{0}; @@ -86,7 +120,8 @@ class FrontendAllocaStmt : public Stmt { DataType element, bool is_shared = false) : ident(lhs), is_shared(is_shared) { - ret_type = DataType(TypeFactory::create_tensor_type(shape, element)); + ret_type = TypeFactory::get_instance().get_pointer_type( + DataType(TypeFactory::create_tensor_type(shape, element))); } bool is_shared; @@ -500,6 +535,7 @@ class ExternalTensorExpression : public Expression { void type_check(const CompileConfig *config) override { ret_type = dt; + ret_type.set_is_pointer(true); config_ = config; } @@ -585,7 +621,7 @@ class MatrixExpression : public Expression { std::vector shape, DataType element_type) : elements(elements) { - this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type)); + dt = TypeFactory::create_tensor_type(shape, element_type); } void type_check(const CompileConfig *config) override; @@ -675,8 +711,10 @@ class LoopUniqueExpression : public Expression { class IdExpression : public Expression { public: Identifier id; + StmtOpCode op; - explicit IdExpression(const Identifier &id) : id(id) { + explicit IdExpression(const Identifier &id, StmtOpCode op = StmtOpCode::NIL) + : id(id), op(op) { } void type_check(const CompileConfig *config) override { @@ -1083,6 +1121,6 @@ Stmt *flatten_lvalue(Expr expr, Expression::FlattenContext *ctx); Stmt *flatten_rvalue(Expr expr, Expression::FlattenContext *ctx); -DataType get_rvalue_dtype(Expr expr); +DataType get_rvalue_dtype(const Expr &expr); } // namespace taichi::lang diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 88ed18f0cea3b4..c01691abf4434b 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -86,6 +86,7 @@ MatrixOfMatrixPtrStmt::MatrixOfMatrixPtrStmt(const std::vector &stmts, DataType dt) : stmts(stmts) { ret_type = dt; + ret_type.set_is_pointer(true); TI_STMT_REG_FIELDS; } diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index b84122ae3c3476..935124d6ee0158 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -19,7 +19,11 @@ class Function; class AllocaStmt : public Stmt, public ir_traits::Store { public: explicit AllocaStmt(DataType type) : is_shared(false) { - ret_type = type; + if (type->is_primitive(PrimitiveTypeID::unknown)) { + ret_type = type; + } else { + ret_type = TypeFactory::get_instance().get_pointer_type(type); + } TI_STMT_REG_FIELDS; } @@ -27,7 +31,8 @@ class AllocaStmt : public Stmt, public ir_traits::Store { DataType type, bool is_shared = false) : is_shared(is_shared) { - ret_type = TypeFactory::create_tensor_type(shape, type); + ret_type = TypeFactory::get_instance().get_pointer_type( + TypeFactory::create_tensor_type(shape, type)); TI_STMT_REG_FIELDS; } diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 08a7fde586942d..76ec7658c38d90 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -222,7 +222,7 @@ static bool compare_types(DataType x, DataType y) { static DataType to_primitive_type(DataType d) { if (d->is()) { d = d->as()->get_pointee_type(); - TI_WARN("promoted_type got a pointer input."); + TI_ERROR("promoted_type got a pointer input."); } if (d->is()) { @@ -256,6 +256,7 @@ DataType promoted_type(DataType a, DataType b) { return TypeFactory::create_tensor_type(tensor_ty_a->get_shape(), promoted_dt); } else { + // TI_INFO("a = {}, b = {}", a->to_string(), b->to_string()); return promoted_primitive_type(a, b); } }; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index a41316b44bad44..ada578f4eaf8b5 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -16,6 +16,7 @@ #include "taichi/ir/expression_ops.h" #include "taichi/ir/frontend_ir.h" #include "taichi/ir/statements.h" +#include "taichi/ir/expression_printer.h" #include "taichi/program/graph_builder.h" #include "taichi/program/extension.h" #include "taichi/program/ndarray.h" @@ -772,17 +773,22 @@ void export_lang(py::module &m) { }, py::return_value_policy::reference) .def("get_ret_type", &Expr::get_ret_type) + .def("get_rvalue_type", + [](Expr *expr) { return get_rvalue_dtype(*expr); }) .def("is_tensor", - [](Expr *expr) { return expr->expr->ret_type->is(); }) + [](Expr *expr) { + return expr->expr->ret_type.ptr_removed()->is(); + }) .def("is_struct", [](Expr *expr) { return expr->expr->ret_type.ptr_removed()->is(); }) .def("get_shape", [](Expr *expr) -> std::optional> { - if (expr->expr->ret_type->is()) { - return std::optional>( - expr->expr->ret_type->cast()->get_shape()); + auto tensor_type = + expr->expr->ret_type.ptr_removed()->cast(); + if (tensor_type) { + return std::optional>(tensor_type->get_shape()); } return std::nullopt; }) diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 4cda1ebf5eca22..4a16726ab3cb18 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -15,23 +15,24 @@ Stmt *insert_const(const DataType &dtype, Stmt *stmt, const T &value, bool insert_before_me = false) { + auto type = dtype.ptr_removed(); Stmt *zero = nullptr; if (insert_before_me) zero = stmt->insert_before_me( - Stmt::make(TypedConstant(dtype.get_element_type(), value))); + Stmt::make(TypedConstant(type.get_element_type(), value))); else zero = stmt->insert_after_me( - Stmt::make(TypedConstant(dtype.get_element_type(), value))); + Stmt::make(TypedConstant(type.get_element_type(), value))); - if (dtype->is()) { - auto t_dtype = dtype->as(); + if (type->is()) { + auto t_dtype = type->as(); std::vector values(t_dtype->get_num_elements(), zero); if (insert_before_me) { zero = zero->insert_before_me(Stmt::make(values)); } else { zero = zero->insert_after_me(Stmt::make(values)); } - zero->ret_type = dtype; + zero->ret_type = type; } return zero; } @@ -344,7 +345,8 @@ class PromoteSSA2LocalVar : public BasicStmtVisitor { if (stmt->is()) { // Create a new alloc at the top of an ib to replace the old alloca - auto alloc = Stmt::make(stmt->ret_type); + auto dtype = stmt->ret_type.ptr_removed(); + auto alloc = Stmt::make(dtype); auto alloc_ptr = alloc.get(); TI_ASSERT(alloca_block_); alloca_block_->insert(std::move(alloc), 0); @@ -354,7 +356,6 @@ class PromoteSSA2LocalVar : public BasicStmtVisitor { // Replace the old alloca with a local store // and it will be replaced by a AdStackPushStmt in the following // ReplaceLocalVarWithStacks pass - auto dtype = stmt->ret_type; auto zero = insert_const(dtype, stmt, 0); zero->insert_after_me(Stmt::make(alloc_ptr, zero)); @@ -362,7 +363,7 @@ class PromoteSSA2LocalVar : public BasicStmtVisitor { stmt->parent->erase(stmt); } else { // Create a alloc - auto alloc = Stmt::make(stmt->ret_type); + auto alloc = Stmt::make(stmt->ret_type.ptr_removed()); auto alloc_ptr = alloc.get(); TI_ASSERT(alloca_block_); alloca_block_->insert(std::move(alloc), 0); @@ -696,7 +697,7 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor { void visit(AllocaStmt *alloc) override { bool is_stack_needed = AdStackAllocaJudger::run(alloc); if (is_stack_needed) { - auto dtype = alloc->ret_type; + auto dtype = alloc->ret_type.ptr_removed(); auto stack_alloca = Stmt::make(dtype, ad_stack_size); auto stack_alloca_ptr = stack_alloca.get(); @@ -937,6 +938,7 @@ class ReverseOuterLoops : public BasicStmtVisitor { class ADTransform : public IRVisitor { protected: Stmt *constant(float32 x, DataType dtype = PrimitiveType::unknown) { + dtype.set_is_pointer(false); if (!dtype->is()) return insert(TypedConstant(x)); @@ -1024,12 +1026,13 @@ class ADTransform : public IRVisitor { template Stmt *insert_const_for_grad(const DataType &dtype, Stmt *stmt, const T &val) { - auto zero = insert(TypedConstant(dtype.get_element_type(), val)); - if (dtype->is()) { - auto t_dtype = dtype->as(); + auto zero = insert( + TypedConstant(dtype.ptr_removed().get_element_type(), val)); + if (dtype.ptr_removed()->is()) { + auto t_dtype = dtype.ptr_removed()->as(); std::vector values(t_dtype->get_num_elements(), zero); zero = insert(values); - zero->ret_type = dtype; + zero->ret_type = dtype.ptr_removed(); } return zero; } @@ -1209,16 +1212,16 @@ class MakeAdjoint : public ADTransform { Stmt *adjoint(Stmt *stmt) { DataType adjoint_dtype = stmt->ret_type.ptr_removed(); - if (stmt->ret_type->is()) { + if (adjoint_dtype->is()) { DataType prim_dtype = PrimitiveType::f32; - if (is_real(stmt->ret_type.ptr_removed().get_element_type())) { - prim_dtype = stmt->ret_type.ptr_removed().get_element_type(); + if (is_real(adjoint_dtype.get_element_type())) { + prim_dtype = adjoint_dtype.get_element_type(); } adjoint_dtype = TypeFactory::get_instance().get_tensor_type( - stmt->ret_type->as()->get_shape(), prim_dtype); + adjoint_dtype->as()->get_shape(), prim_dtype); } else if (stmt->is()) { // pass - } else if (!is_real(stmt->ret_type) || stmt->is()) { + } else if (!is_real(adjoint_dtype) || stmt->is()) { return constant(0); } @@ -1482,8 +1485,9 @@ class MakeAdjoint : public ADTransform { // iteration should be cleared after this iteration has been done // 2. If the alloca serves as the dest of multiple LocalStoreStmt, only the // last LocalStoreStmt should be taken account of - if (is_real(stmt->dest->ret_type.get_element_type())) { - auto dtype = stmt->dest->ret_type; + auto dest_type = stmt->dest->ret_type.ptr_removed(); + if (is_real(dest_type.get_element_type())) { + auto dtype = dest_type; auto zero = insert_const_for_grad(dtype, stmt, 0); insert(adjoint(stmt->dest), zero); } @@ -1748,7 +1752,7 @@ class MakeAdjoint : public ADTransform { */ int offset = stmt->offset->as()->val.val_int32(); - auto tensor_type = stmt->origin->ret_type->as(); + auto tensor_type = stmt->origin->ret_type.ptr_removed()->as(); int num_elements = tensor_type->get_num_elements(); auto zero = insert_const_for_grad(prim_dtype, stmt, 0); @@ -1783,7 +1787,7 @@ class MakeAdjoint : public ADTransform { accumulate($0_adj, $7) */ - auto tensor_type = stmt->origin->ret_type->as(); + auto tensor_type = stmt->origin->ret_type.ptr_removed()->as(); auto tensor_shape = tensor_type->get_shape(); int num_elements = tensor_type->get_num_elements(); @@ -1894,7 +1898,8 @@ class MakeDual : public ADTransform { } Stmt *dual(Stmt *stmt) { - if (!is_real(stmt->ret_type.get_element_type()) || stmt->is()) { + auto dual_type = stmt->ret_type.ptr_removed(); + if (!is_real(dual_type.get_element_type()) || stmt->is()) { return constant(0); } if (dual_stmt.find(stmt) == dual_stmt.end()) { @@ -1904,7 +1909,7 @@ class MakeDual : public ADTransform { // auto alloca = // Stmt::make(get_current_program().config.gradient_dt); // maybe it's better to use the statement data type than the default type - auto alloca = Stmt::make(stmt->ret_type); + auto alloca = Stmt::make(dual_type); dual_stmt[stmt] = alloca.get(); // TODO: check whether there are any edge cases for the alloca_block @@ -2074,8 +2079,8 @@ class MakeDual : public ADTransform { // If the alloca serves as the dest of multiple LocalStoreStmt, only the // last LocalStoreStmt should be taken account of, i.e, its history should // be cleared - if (is_real(stmt->dest->ret_type.get_element_type())) { - auto dtype = stmt->dest->ret_type; + auto dtype = stmt->dest->ret_type.ptr_removed(); + if (is_real(dtype.get_element_type())) { auto zero = insert_const_for_grad(dtype, stmt, 0); insert(dual(stmt->dest), zero); } @@ -2200,7 +2205,7 @@ class BackupSSA : public BasicStmtVisitor { Stmt *load(Stmt *stmt) { if (backup_alloca.find(stmt) == backup_alloca.end()) { - auto alloca = Stmt::make(stmt->ret_type); + auto alloca = Stmt::make(stmt->ret_type.ptr_removed()); auto alloca_ptr = alloca.get(); independent_block->insert(std::move(alloca), 0); auto local_store = Stmt::make(alloca_ptr, stmt); @@ -2378,6 +2383,7 @@ void auto_diff(IRNode *root, PromoteSSA2LocalVar::run(ib); ReplaceLocalVarWithStacks replace(config.ad_stack_size); ib->accept(&replace); + // irpass::print(ib); type_check(root, config); MakeAdjoint::run(ib); @@ -2419,7 +2425,7 @@ class GloablDataAccessRuleChecker : public BasicStmtVisitor { snode = snode->get_adjoint_checkbit(); auto global_ptr = stmt->insert_after_me(Stmt::make(snode, src->indices)); - auto dtype = global_ptr->ret_type; + auto dtype = global_ptr->ret_type.ptr_removed(); auto one = global_ptr->insert_after_me( Stmt::make(TypedConstant(dtype, 1))); one->insert_after_me(Stmt::make(global_ptr, one)); diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 83acbcd3aa9c36..ac4661f3688da7 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -57,6 +57,7 @@ void compile_to_offloads(IRNode *ir, if (start_from_ast) { irpass::frontend_type_check(ir); + print("Frontend Typechecked"); irpass::lower_ast(ir); print("Lowered"); } diff --git a/taichi/transforms/frontend_type_check.cpp b/taichi/transforms/frontend_type_check.cpp index ae7c9af88664cf..8f652c3d4a038f 100644 --- a/taichi/transforms/frontend_type_check.cpp +++ b/taichi/transforms/frontend_type_check.cpp @@ -7,11 +7,12 @@ namespace taichi::lang { class FrontendTypeCheck : public IRVisitor { void check_cond_type(const Expr &cond, std::string stmt_name) { - if (!cond->ret_type->is() || !is_integral(cond->ret_type)) + auto cond_type = get_rvalue_dtype(cond); + if (!cond_type->is() || !is_integral(cond_type)) throw TaichiTypeError(fmt::format( "`{0}` conditions must be an integer; found {1}. Consider using " "`{0} x != 0` instead of `{0} x` for float values.", - stmt_name, cond->ret_type->to_string())); + stmt_name, cond_type->to_string())); } public: @@ -49,8 +50,8 @@ class FrontendTypeCheck : public IRVisitor { } void visit(FrontendAssignStmt *stmt) override { - auto lhs_type = stmt->lhs->ret_type; - auto rhs_type = stmt->rhs->ret_type; + auto lhs_type = stmt->lhs->ret_type.ptr_removed(); + auto rhs_type = stmt->rhs->ret_type.ptr_removed(); auto error = [&]() { throw TaichiTypeError(fmt::format("{}cannot assign '{}' to '{}'", @@ -85,7 +86,7 @@ class FrontendTypeCheck : public IRVisitor { Expr const &expr = std::get(content); TI_ASSERT(expr.expr != nullptr); - DataType data_type = expr->ret_type; + DataType data_type = get_rvalue_dtype(expr); if (data_type->is()) { data_type = DataType(data_type->as()->get_element_type()); } diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 63d1227e4d223f..d3fab5c7663d80 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -68,15 +68,15 @@ class LowerAST : public IRVisitor { auto ident = stmt->ident; TI_ASSERT(block->local_var_to_stmt.find(ident) == block->local_var_to_stmt.end()); - if (stmt->ret_type->is()) { - auto tensor_type = stmt->ret_type->cast(); + auto alloca_type = stmt->ret_type.ptr_removed(); + if (auto tensor_type = alloca_type->cast()) { auto lowered = std::make_unique( tensor_type->get_shape(), tensor_type->get_element_type(), stmt->is_shared); block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get())); stmt->parent->replace_with(stmt, std::move(lowered)); } else { - auto lowered = std::make_unique(stmt->ret_type); + auto lowered = std::make_unique(alloca_type); block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get())); stmt->parent->replace_with(stmt, std::move(lowered)); } diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index d16030ca86c3e8..0adec11c50b870 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -530,16 +530,16 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { if (local_to_global_offset_.find(stmt) == local_to_global_offset_.end()) return; VecStatement replacement; - auto ret_type = stmt->ret_type; - local_to_global_vector_type_[stmt] = ret_type; + auto alloca_type = stmt->ret_type.ptr_removed(); + local_to_global_vector_type_[stmt] = alloca_type; auto ptr = replacement.push_back( - local_to_global_offset_.at(stmt), ret_type); + local_to_global_offset_.at(stmt), alloca_type); auto offloaded = stmt_to_offloaded_[stmt]; stmt_to_offloaded_[ptr] = offloaded; - TypedConstant zero(stmt->ret_type.get_element_type()); + TypedConstant zero(alloca_type.get_element_type()); auto const_zero_stmt = replacement.push_back(zero); - if (auto tensor_type = stmt->ret_type->cast()) { + if (auto tensor_type = alloca_type->cast()) { std::vector zero_values(tensor_type->get_num_elements(), const_zero_stmt); auto zero_matrix_init_stmt = diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 7e84c6bd401216..4a0daae6c67c12 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -41,7 +41,7 @@ class Scalarize : public BasicStmtVisitor { template void scalarize_store_stmt(T *stmt) { auto dest_dtype = stmt->dest->ret_type.ptr_removed(); - auto val_dtype = stmt->val->ret_type; + auto val_dtype = stmt->val->ret_type.ptr_removed(); if (dest_dtype->template is() && val_dtype->template is()) { // Needs scalarize @@ -50,6 +50,9 @@ class Scalarize : public BasicStmtVisitor { TI_ASSERT(dest_tensor_type->get_shape() == val_tensor_type->get_shape()); + // irpass::print(stmt); + // irpass::print(stmt->val); + TI_ASSERT(stmt->val->template is()); auto matrix_init_stmt = stmt->val->template as(); @@ -182,7 +185,8 @@ class Scalarize : public BasicStmtVisitor { stmt->replace_all_usages_with(tmp) */ void visit(UnaryOpStmt *stmt) override { - auto operand_dtype = stmt->operand->ret_type; + auto operand_dtype = stmt->operand->ret_type.ptr_removed(); + auto stmt_dtype = stmt->ret_type.ptr_removed(); if (operand_dtype->is()) { // Needs scalarize auto operand_tensor_type = operand_dtype->as(); @@ -195,7 +199,7 @@ class Scalarize : public BasicStmtVisitor { std::vector matrix_init_values; int num_elements = operand_tensor_type->get_num_elements(); - auto primitive_type = stmt->ret_type.get_element_type(); + auto primitive_type = stmt_dtype.get_element_type(); for (size_t i = 0; i < num_elements; i++) { auto unary_stmt = std::make_unique( stmt->op_type, operand_matrix_init_stmt->values[i]); @@ -243,8 +247,9 @@ class Scalarize : public BasicStmtVisitor { stmt->replace_all_usages_with(tmp) */ void visit(BinaryOpStmt *stmt) override { - auto lhs_dtype = stmt->lhs->ret_type; - auto rhs_dtype = stmt->rhs->ret_type; + auto lhs_dtype = stmt->lhs->ret_type.ptr_removed(); + auto rhs_dtype = stmt->rhs->ret_type.ptr_removed(); + auto stmt_dtype = stmt->ret_type.ptr_removed(); if (lhs_dtype->is() || rhs_dtype->is()) { // Make sure broadcasting has been correctly applied by // BinaryOpExpression::type_check(). @@ -267,7 +272,7 @@ class Scalarize : public BasicStmtVisitor { TI_ASSERT(rhs_vals.size() == lhs_vals.size()); size_t num_elements = lhs_vals.size(); - auto primitive_type = stmt->ret_type.get_element_type(); + auto primitive_type = stmt_dtype.get_element_type(); std::vector matrix_init_values; for (size_t i = 0; i < num_elements; i++) { auto binary_stmt = std::make_unique( @@ -495,9 +500,9 @@ class Scalarize : public BasicStmtVisitor { stmt->replace_all_usages_with(tmp) */ void visit(TernaryOpStmt *stmt) override { - auto cond_dtype = stmt->op1->ret_type; - auto op2_dtype = stmt->op2->ret_type; - auto op3_dtype = stmt->op3->ret_type; + auto cond_dtype = stmt->op1->ret_type.ptr_removed(); + auto op2_dtype = stmt->op2->ret_type.ptr_removed(); + auto op3_dtype = stmt->op3->ret_type.ptr_removed(); if (cond_dtype->is()) { // Make sure broadcasting has been correctly applied by // TernaryOpExpression::type_check(). @@ -940,7 +945,8 @@ class ScalarizePointers : public BasicStmtVisitor { for (size_t i = 0; i < tensor_type->get_num_elements(); i++) { auto scalarized_alloca_stmt = std::make_unique(primitive_type); - scalarized_alloca_stmt->ret_type = primitive_type; + scalarized_alloca_stmt->ret_type = + TypeFactory::get_instance().get_pointer_type(primitive_type); scalarized_local_tensor_map_[stmt].push_back( scalarized_alloca_stmt.get()); diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 0bf75e1866c55c..e67334eab4df89 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -509,9 +509,28 @@ bool simplify(IRNode *root, const CompileConfig &config) { return modified; } +namespace { + +std::function +make_pass_printer(bool verbose, const std::string &kernel_name, IRNode *ir) { + if (!verbose) { + return [](const std::string &) {}; + } + return [ir, kernel_name](const std::string &pass) { + TI_INFO("[{}] {}:", kernel_name, pass); + std::cout << std::flush; + irpass::re_id(ir); + irpass::print(ir); + std::cout << std::flush; + }; +} + +} // namespace + void full_simplify(IRNode *root, const CompileConfig &config, const FullSimplifyPass::Args &args) { + auto print = make_pass_printer(false && config.print_ir, "simplify", root); TI_AUTO_PROF; if (config.advanced_optimization) { bool first_iteration = true; @@ -519,32 +538,44 @@ void full_simplify(IRNode *root, bool modified = false; if (extract_constant(root, config)) modified = true; + print("extract_constant"); if (unreachable_code_elimination(root)) modified = true; + print("unreachable_code_elimination"); if (binary_op_simplify(root, config)) modified = true; + print("binary_op_simplify"); if (config.constant_folding && constant_fold(root)) modified = true; + print("constant_fold"); if (die(root)) modified = true; + print("die"); if (alg_simp(root, config)) modified = true; + print("alg_simp"); if (loop_invariant_code_motion(root, config)) modified = true; + print("loop_invariant_code_motion"); if (die(root)) modified = true; + print("die"); if (simplify(root, config)) modified = true; + print("simplify"); if (die(root)) modified = true; + print("die"); if (config.opt_level > 0 && whole_kernel_cse(root)) modified = true; + print("whole_kernel_cse"); // Don't do this time-consuming optimization pass again if the IR is // not modified. if (config.opt_level > 0 && first_iteration && config.cfg_optimization && cfg_optimization(root, args.after_lower_access, args.autodiff_enabled, !config.real_matrix_scalarize)) modified = true; + print("cfg_optimization"); first_iteration = false; if (!modified) break; diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 3a5bb1188c634e..d36dfd28042943 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -23,13 +23,14 @@ class TypeCheck : public IRVisitor { Stmt *&val, const std::string &stmt_name) { auto dst_type = dst->ret_type.ptr_removed(); + auto val_type = val->ret_type.ptr_removed(); if (is_quant(dst_type)) { // We force the value type to be the compute_type of the bit pointer. // Casting from compute_type to physical_type is handled in codegen. dst_type = dst_type->get_compute_type(); } - if (dst_type != val->ret_type) { - auto promoted = promoted_type(dst_type, val->ret_type); + if (dst_type != val_type) { + auto promoted = promoted_type(dst_type, val_type); if (dst_type != promoted) { TI_WARN("[{}] {} may lose precision: {} <- {}\n{}", stmt->name(), stmt_name, dst_type->to_string(), val->ret_data_type_name(), @@ -88,13 +89,15 @@ class TypeCheck : public IRVisitor { TI_ASSERT(stmt->src->is() || stmt->src->is() || stmt->src->is()); if (auto ptr_offset_stmt = stmt->src->cast()) { - auto lookup = DataType(ptr_offset_stmt->origin->ret_type->as() + auto lookup = DataType(ptr_offset_stmt->origin->ret_type.ptr_removed() + ->as() ->get_element_type()) .ptr_removed(); stmt->ret_type = lookup; } else { auto lookup = stmt->src->ret_type; - stmt->ret_type = lookup; + TI_ASSERT(lookup.is_pointer()); + stmt->ret_type = lookup->as()->get_pointee_type(); } } @@ -366,6 +369,10 @@ class TypeCheck : public IRVisitor { insert_shift_op_assertion_before(stmt, stmt->lhs, stmt->rhs); } } else { + // TI_INFO("promote type bin op"); + // irpass::print(stmt); + // irpass::print(stmt->lhs); + // irpass::print(stmt->rhs); ret_type = promoted_type(stmt->lhs->ret_type, stmt->rhs->ret_type); } diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp index bce26256c0ac5f..c27efbdde34cea 100644 --- a/tests/cpp/ir/frontend_type_inference_test.cpp +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -32,7 +32,9 @@ TEST(FrontendTypeInference, Id) { auto const_i32 = value(-(1 << 20)); const_i32->type_check(nullptr); auto id_i32 = kernel->context->builder().make_var(const_i32, const_i32->tb); - EXPECT_EQ(id_i32->ret_type, PrimitiveType::i32); + EXPECT_EQ(id_i32->ret_type, + DataType(TypeFactory::get_instance().get_pointer_type( + PrimitiveType::i32))); } TEST(FrontendTypeInference, BinaryOp) { @@ -139,7 +141,9 @@ TEST(FrontendTypeInference, GlobalPtr_Field) { index->type_check(nullptr); auto global_ptr = ast_builder->expr_subscript(global_var, ExprGroup(index)); global_ptr->type_check(nullptr); - EXPECT_EQ(global_ptr->ret_type, PrimitiveType::u8); + EXPECT_EQ(global_ptr->ret_type, + DataType(TypeFactory::get_instance().get_pointer_type( + PrimitiveType::u8))); } TEST(FrontendTypeInference, GlobalPtr_ExternalTensor) { @@ -163,7 +167,8 @@ TEST(FrontendTypeInference, TensorElement) { auto kernel = std::make_unique(*prog, func, "fake_kernel"); auto *ast_builder = &kernel->context->builder(); const std::vector shape{3}; - auto var = Expr(std::make_shared(ast_builder->get_next_id())); + auto var = Expr(std::make_shared( + ast_builder->get_next_id(), StmtOpCode::FrontendAllocaStmt)); ast_builder->insert(std::make_unique( std::static_pointer_cast(var.expr)->id, shape, PrimitiveType::u32)); @@ -172,7 +177,9 @@ TEST(FrontendTypeInference, TensorElement) { index->type_check(nullptr); auto tensor_element = Expr::make(var, ExprGroup(index)); tensor_element->type_check(nullptr); - EXPECT_EQ(tensor_element->ret_type, PrimitiveType::u32); + EXPECT_EQ(tensor_element->ret_type, + DataType(TypeFactory::get_instance().get_pointer_type( + PrimitiveType::u32))); } TEST(FrontendTypeInference, AtomicOp) {