From a8f649a174331c5c770577c16e73101e69d04dcf Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Fri, 2 Jun 2023 11:15:50 +0800 Subject: [PATCH] [ir] [refactor] Let the type of Alloca be pointer ghstack-source-id: 663f7b3c9e383c290b47ddd4183b64d540b35c41 Pull Request resolved: https://github.com/taichi-dev/taichi/pull/8007 --- taichi/codegen/cuda/codegen_cuda.cpp | 4 +- taichi/codegen/llvm/codegen_llvm.cpp | 14 ++--- taichi/codegen/llvm/llvm_codegen_utils.h | 2 - taichi/codegen/spirv/spirv_codegen.cpp | 6 +- taichi/ir/control_flow_graph.cpp | 7 ++- taichi/ir/frontend_ir.cpp | 20 +++--- taichi/ir/frontend_ir.h | 6 +- taichi/ir/statements.cpp | 1 + taichi/ir/statements.h | 9 ++- taichi/ir/type_factory.cpp | 2 +- taichi/transforms/auto_diff.cpp | 61 ++++++++++--------- taichi/transforms/frontend_type_check.cpp | 11 ++-- taichi/transforms/lower_ast.cpp | 6 +- taichi/transforms/offload.cpp | 10 +-- taichi/transforms/scalarize.cpp | 23 ++++--- taichi/transforms/type_check.cpp | 10 +-- tests/cpp/ir/frontend_type_inference_test.cpp | 12 +++- tests/python/test_unary_ops.py | 2 +- 18 files changed, 116 insertions(+), 90 deletions(-) diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index 9ffb507b3905a1..9b87b5a91f0638 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -172,8 +172,8 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { void visit(AllocaStmt *stmt) override { // Override shared memory codegen logic for large shared memory - if (stmt->ret_type->is() && stmt->is_shared) { - auto tensor_type = stmt->ret_type->cast(); + auto tensor_type = stmt->ret_type.ptr_removed()->cast(); + if (tensor_type && stmt->is_shared) { size_t shared_array_bytes = tensor_type->get_num_elements() * data_type_size(tensor_type->get_element_type()); diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 03fa92e111edd4..782a3daa401291 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -126,8 +126,10 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) { } void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { - if (stmt->ret_type->is()) { - auto tensor_type = stmt->ret_type->cast(); + TI_ASSERT(stmt->ret_type.is_pointer()); + auto alloca_type = stmt->ret_type.ptr_removed(); + if (alloca_type->is()) { + auto tensor_type = alloca_type->cast(); auto type = tlctx->get_data_type(tensor_type); if (stmt->is_shared) { auto base = new llvm::GlobalVariable( @@ -141,12 +143,10 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { llvm_val[stmt] = create_entry_block_alloca(type); } } else { - llvm_val[stmt] = - create_entry_block_alloca(stmt->ret_type, stmt->ret_type.is_pointer()); + llvm_val[stmt] = create_entry_block_alloca(alloca_type); // initialize as zero if element is not a pointer - if (!stmt->ret_type.is_pointer()) - builder->CreateStore(tlctx->get_constant(stmt->ret_type, 0), - llvm_val[stmt]); + if (!alloca_type->is()) + builder->CreateStore(tlctx->get_constant(alloca_type, 0), llvm_val[stmt]); } } diff --git a/taichi/codegen/llvm/llvm_codegen_utils.h b/taichi/codegen/llvm/llvm_codegen_utils.h index 51b08efc52a49e..b9915edb487f37 100644 --- a/taichi/codegen/llvm/llvm_codegen_utils.h +++ b/taichi/codegen/llvm/llvm_codegen_utils.h @@ -83,8 +83,6 @@ class LLVMModuleBuilder { llvm::Value *create_entry_block_alloca(DataType dt, bool is_pointer = false) { auto type = tlctx->get_data_type(dt); - if (is_pointer) - type = llvm::PointerType::get(type, 0); return create_entry_block_alloca(type); } diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 700434684d7714..5326c7df897db3 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -274,8 +274,8 @@ class TaskCodegen : public IRVisitor { void visit(AllocaStmt *alloca) override { spirv::Value ptr_val; - if (alloca->ret_type->is()) { - auto tensor_type = alloca->ret_type->cast(); + auto alloca_type = alloca->ret_type.ptr_removed(); + if (auto tensor_type = alloca_type->cast()) { auto elem_num = tensor_type->get_num_elements(); spirv::SType elem_type = ir_->get_primitive_type(tensor_type->get_element_type()); @@ -288,7 +288,7 @@ class TaskCodegen : public IRVisitor { } } else { // Alloca for a single variable - spirv::SType src_type = ir_->get_primitive_type(alloca->element_type()); + spirv::SType src_type = ir_->get_primitive_type(alloca_type); ptr_val = ir_->alloca_variable(src_type); ir_->store_variable(ptr_val, ir_->get_zero(src_type)); } diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index d546900f83724f..4939480674f47c 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -247,7 +247,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // result: the value to store Stmt *result = irpass::analysis::get_store_data( block->statements[last_def_position].get()); - bool is_tensor_involved = var->ret_type->is(); + bool is_tensor_involved = var->ret_type.ptr_removed()->is(); if (!(var->is() && !is_tensor_involved)) { // In between the store stmt and current stmt, // if there's a third-stmt that "may" have stored a "different value" to @@ -355,7 +355,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // Check for aliased address // There's a store to the same dest_addr before this stmt - bool is_tensor_involved = var->ret_type->is(); + bool is_tensor_involved = var->ret_type.ptr_removed()->is(); if (!(var->is() && !is_tensor_involved)) { // In between the store stmt and current stmt, // if there's a third-stmt that "may" have stored a "different value" to @@ -443,7 +443,8 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access, continue; // special case of alloca (initialized to 0) - auto zero = Stmt::make(TypedConstant(result->ret_type, 0)); + auto zero = Stmt::make( + TypedConstant(result->ret_type.ptr_removed(), 0)); replace_with(i, std::move(zero), true); } else { if (result->ret_type.ptr_removed()->is() && diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index fb4a2721679c48..cb41c9ab5ad561 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -37,7 +37,8 @@ FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs) : lhs(lhs), rhs(rhs) { TI_ASSERT(lhs->is_lvalue()); if (lhs.is() && lhs->ret_type == PrimitiveType::unknown) { - lhs.expr->ret_type = rhs.get_rvalue_type(); + lhs.expr->ret_type = + TypeFactory::get_instance().get_pointer_type(rhs.get_rvalue_type()); } } @@ -127,7 +128,8 @@ void FrontendForStmt::init_loop_vars(const ExprGroup &loop_vars) { void FrontendForStmt::add_loop_var(const Expr &loop_var) { loop_var_ids.push_back(loop_var.cast()->id); - loop_var.expr->ret_type = PrimitiveType::i32; + loop_var.expr->ret_type = + TypeFactory::get_instance().get_pointer_type(PrimitiveType::i32); } FrontendFuncDefStmt::FrontendFuncDefStmt(const FrontendFuncDefStmt &o) @@ -883,7 +885,7 @@ void IndexExpression::type_check(const CompileConfig *) { "Invalid IndexExpression: the source is not among field, ndarray or " "local tensor"); } - + ret_type = TypeFactory::get_instance().get_pointer_type(ret_type); for (auto &indices : indices_group) { for (int i = 0; i < indices.exprs.size(); i++) { auto &expr = indices.exprs[i]; @@ -960,7 +962,7 @@ void LoopUniqueExpression::flatten(FlattenContext *ctx) { void IdExpression::flatten(FlattenContext *ctx) { stmt = ctx->current_block->lookup_var(id); - if (!ret_type->is_primitive(PrimitiveTypeID::unknown)) { + if (stmt->ret_type->is_primitive(PrimitiveTypeID::unknown)) { stmt->ret_type = ret_type; } } @@ -1514,7 +1516,7 @@ Expr ASTBuilder::expr_subscript(const Expr &expr, std::string tb) { TI_ASSERT(expr.is() || expr.is() || expr.is() || - is_tensor(expr.expr->ret_type)); + is_tensor(expr.expr->ret_type.ptr_removed())); // IndexExpression without ret_shape is used for matrix indexing, // where each entry of ExprGroup is interpreted as indexing into a specific @@ -1677,7 +1679,7 @@ std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { elem.expr->ret_type = struct_type->get_element_type(indices); expanded_exprs.push_back(elem); } - } else if (!expr->ret_type->is()) { + } else if (!expr->ret_type.ptr_removed()->is()) { expanded_exprs.push_back(expr); } else { // Expand TensorType expr @@ -1695,7 +1697,7 @@ std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { return {ind0, ind1, ind2, ind3} */ - auto tensor_type = expr->ret_type->cast(); + auto tensor_type = expr->ret_type.ptr_removed()->cast(); Expr id_expr; if (expr.is()) { @@ -1708,7 +1710,7 @@ std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { for (int i = 0; i < shape[0]; i++) { auto ind = Expr(std::make_shared( id_expr, ExprGroup(Expr(i)), expr->tb)); - ind.expr->ret_type = tensor_type->get_element_type(); + ind->type_check(nullptr); expanded_exprs.push_back(ind); } } else { @@ -1717,7 +1719,7 @@ std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { for (int j = 0; j < shape[1]; j++) { auto ind = Expr(std::make_shared( id_expr, ExprGroup(Expr(i), Expr(j)), expr->tb)); - ind.expr->ret_type = tensor_type->get_element_type(); + ind->type_check(nullptr); expanded_exprs.push_back(ind); } } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index dc9f6fad7d7da0..35ecf42241e6d0 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -86,7 +86,8 @@ class FrontendAllocaStmt : public Stmt { DataType element, bool is_shared = false) : ident(lhs), is_shared(is_shared) { - ret_type = DataType(TypeFactory::create_tensor_type(shape, element)); + ret_type = TypeFactory::get_instance().get_pointer_type( + DataType(TypeFactory::create_tensor_type(shape, element))); } bool is_shared; @@ -500,6 +501,7 @@ class ExternalTensorExpression : public Expression { void type_check(const CompileConfig *config) override { ret_type = dt; + ret_type.set_is_pointer(true); config_ = config; } @@ -585,7 +587,7 @@ class MatrixExpression : public Expression { std::vector shape, DataType element_type) : elements(elements) { - this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type)); + dt = TypeFactory::create_tensor_type(shape, element_type); } void type_check(const CompileConfig *config) override; diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 88ed18f0cea3b4..c01691abf4434b 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -86,6 +86,7 @@ MatrixOfMatrixPtrStmt::MatrixOfMatrixPtrStmt(const std::vector &stmts, DataType dt) : stmts(stmts) { ret_type = dt; + ret_type.set_is_pointer(true); TI_STMT_REG_FIELDS; } diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index b84122ae3c3476..935124d6ee0158 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -19,7 +19,11 @@ class Function; class AllocaStmt : public Stmt, public ir_traits::Store { public: explicit AllocaStmt(DataType type) : is_shared(false) { - ret_type = type; + if (type->is_primitive(PrimitiveTypeID::unknown)) { + ret_type = type; + } else { + ret_type = TypeFactory::get_instance().get_pointer_type(type); + } TI_STMT_REG_FIELDS; } @@ -27,7 +31,8 @@ class AllocaStmt : public Stmt, public ir_traits::Store { DataType type, bool is_shared = false) : is_shared(is_shared) { - ret_type = TypeFactory::create_tensor_type(shape, type); + ret_type = TypeFactory::get_instance().get_pointer_type( + TypeFactory::create_tensor_type(shape, type)); TI_STMT_REG_FIELDS; } diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 08a7fde586942d..c40d588dc248f2 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -222,7 +222,7 @@ static bool compare_types(DataType x, DataType y) { static DataType to_primitive_type(DataType d) { if (d->is()) { d = d->as()->get_pointee_type(); - TI_WARN("promoted_type got a pointer input."); + TI_ERROR("promoted_type got a pointer input."); } if (d->is()) { diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 4cda1ebf5eca22..fd9166c09f34ec 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -15,23 +15,24 @@ Stmt *insert_const(const DataType &dtype, Stmt *stmt, const T &value, bool insert_before_me = false) { + auto type = dtype.ptr_removed(); Stmt *zero = nullptr; if (insert_before_me) zero = stmt->insert_before_me( - Stmt::make(TypedConstant(dtype.get_element_type(), value))); + Stmt::make(TypedConstant(type.get_element_type(), value))); else zero = stmt->insert_after_me( - Stmt::make(TypedConstant(dtype.get_element_type(), value))); + Stmt::make(TypedConstant(type.get_element_type(), value))); - if (dtype->is()) { - auto t_dtype = dtype->as(); + if (type->is()) { + auto t_dtype = type->as(); std::vector values(t_dtype->get_num_elements(), zero); if (insert_before_me) { zero = zero->insert_before_me(Stmt::make(values)); } else { zero = zero->insert_after_me(Stmt::make(values)); } - zero->ret_type = dtype; + zero->ret_type = type; } return zero; } @@ -344,7 +345,8 @@ class PromoteSSA2LocalVar : public BasicStmtVisitor { if (stmt->is()) { // Create a new alloc at the top of an ib to replace the old alloca - auto alloc = Stmt::make(stmt->ret_type); + auto dtype = stmt->ret_type.ptr_removed(); + auto alloc = Stmt::make(dtype); auto alloc_ptr = alloc.get(); TI_ASSERT(alloca_block_); alloca_block_->insert(std::move(alloc), 0); @@ -354,7 +356,6 @@ class PromoteSSA2LocalVar : public BasicStmtVisitor { // Replace the old alloca with a local store // and it will be replaced by a AdStackPushStmt in the following // ReplaceLocalVarWithStacks pass - auto dtype = stmt->ret_type; auto zero = insert_const(dtype, stmt, 0); zero->insert_after_me(Stmt::make(alloc_ptr, zero)); @@ -362,7 +363,7 @@ class PromoteSSA2LocalVar : public BasicStmtVisitor { stmt->parent->erase(stmt); } else { // Create a alloc - auto alloc = Stmt::make(stmt->ret_type); + auto alloc = Stmt::make(stmt->ret_type.ptr_removed()); auto alloc_ptr = alloc.get(); TI_ASSERT(alloca_block_); alloca_block_->insert(std::move(alloc), 0); @@ -696,7 +697,7 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor { void visit(AllocaStmt *alloc) override { bool is_stack_needed = AdStackAllocaJudger::run(alloc); if (is_stack_needed) { - auto dtype = alloc->ret_type; + auto dtype = alloc->ret_type.ptr_removed(); auto stack_alloca = Stmt::make(dtype, ad_stack_size); auto stack_alloca_ptr = stack_alloca.get(); @@ -937,6 +938,7 @@ class ReverseOuterLoops : public BasicStmtVisitor { class ADTransform : public IRVisitor { protected: Stmt *constant(float32 x, DataType dtype = PrimitiveType::unknown) { + dtype.set_is_pointer(false); if (!dtype->is()) return insert(TypedConstant(x)); @@ -1024,12 +1026,13 @@ class ADTransform : public IRVisitor { template Stmt *insert_const_for_grad(const DataType &dtype, Stmt *stmt, const T &val) { - auto zero = insert(TypedConstant(dtype.get_element_type(), val)); - if (dtype->is()) { - auto t_dtype = dtype->as(); + auto zero = insert( + TypedConstant(dtype.ptr_removed().get_element_type(), val)); + if (dtype.ptr_removed()->is()) { + auto t_dtype = dtype.ptr_removed()->as(); std::vector values(t_dtype->get_num_elements(), zero); zero = insert(values); - zero->ret_type = dtype; + zero->ret_type = dtype.ptr_removed(); } return zero; } @@ -1209,16 +1212,16 @@ class MakeAdjoint : public ADTransform { Stmt *adjoint(Stmt *stmt) { DataType adjoint_dtype = stmt->ret_type.ptr_removed(); - if (stmt->ret_type->is()) { + if (adjoint_dtype->is()) { DataType prim_dtype = PrimitiveType::f32; - if (is_real(stmt->ret_type.ptr_removed().get_element_type())) { - prim_dtype = stmt->ret_type.ptr_removed().get_element_type(); + if (is_real(adjoint_dtype.get_element_type())) { + prim_dtype = adjoint_dtype.get_element_type(); } adjoint_dtype = TypeFactory::get_instance().get_tensor_type( - stmt->ret_type->as()->get_shape(), prim_dtype); + adjoint_dtype->as()->get_shape(), prim_dtype); } else if (stmt->is()) { // pass - } else if (!is_real(stmt->ret_type) || stmt->is()) { + } else if (!is_real(adjoint_dtype) || stmt->is()) { return constant(0); } @@ -1482,8 +1485,9 @@ class MakeAdjoint : public ADTransform { // iteration should be cleared after this iteration has been done // 2. If the alloca serves as the dest of multiple LocalStoreStmt, only the // last LocalStoreStmt should be taken account of - if (is_real(stmt->dest->ret_type.get_element_type())) { - auto dtype = stmt->dest->ret_type; + auto dest_type = stmt->dest->ret_type.ptr_removed(); + if (is_real(dest_type.get_element_type())) { + auto dtype = dest_type; auto zero = insert_const_for_grad(dtype, stmt, 0); insert(adjoint(stmt->dest), zero); } @@ -1748,7 +1752,7 @@ class MakeAdjoint : public ADTransform { */ int offset = stmt->offset->as()->val.val_int32(); - auto tensor_type = stmt->origin->ret_type->as(); + auto tensor_type = stmt->origin->ret_type.ptr_removed()->as(); int num_elements = tensor_type->get_num_elements(); auto zero = insert_const_for_grad(prim_dtype, stmt, 0); @@ -1783,7 +1787,7 @@ class MakeAdjoint : public ADTransform { accumulate($0_adj, $7) */ - auto tensor_type = stmt->origin->ret_type->as(); + auto tensor_type = stmt->origin->ret_type.ptr_removed()->as(); auto tensor_shape = tensor_type->get_shape(); int num_elements = tensor_type->get_num_elements(); @@ -1894,7 +1898,8 @@ class MakeDual : public ADTransform { } Stmt *dual(Stmt *stmt) { - if (!is_real(stmt->ret_type.get_element_type()) || stmt->is()) { + auto dual_type = stmt->ret_type.ptr_removed(); + if (!is_real(dual_type.get_element_type()) || stmt->is()) { return constant(0); } if (dual_stmt.find(stmt) == dual_stmt.end()) { @@ -1904,7 +1909,7 @@ class MakeDual : public ADTransform { // auto alloca = // Stmt::make(get_current_program().config.gradient_dt); // maybe it's better to use the statement data type than the default type - auto alloca = Stmt::make(stmt->ret_type); + auto alloca = Stmt::make(dual_type); dual_stmt[stmt] = alloca.get(); // TODO: check whether there are any edge cases for the alloca_block @@ -2074,8 +2079,8 @@ class MakeDual : public ADTransform { // If the alloca serves as the dest of multiple LocalStoreStmt, only the // last LocalStoreStmt should be taken account of, i.e, its history should // be cleared - if (is_real(stmt->dest->ret_type.get_element_type())) { - auto dtype = stmt->dest->ret_type; + auto dtype = stmt->dest->ret_type.ptr_removed(); + if (is_real(dtype.get_element_type())) { auto zero = insert_const_for_grad(dtype, stmt, 0); insert(dual(stmt->dest), zero); } @@ -2200,7 +2205,7 @@ class BackupSSA : public BasicStmtVisitor { Stmt *load(Stmt *stmt) { if (backup_alloca.find(stmt) == backup_alloca.end()) { - auto alloca = Stmt::make(stmt->ret_type); + auto alloca = Stmt::make(stmt->ret_type.ptr_removed()); auto alloca_ptr = alloca.get(); independent_block->insert(std::move(alloca), 0); auto local_store = Stmt::make(alloca_ptr, stmt); @@ -2419,7 +2424,7 @@ class GloablDataAccessRuleChecker : public BasicStmtVisitor { snode = snode->get_adjoint_checkbit(); auto global_ptr = stmt->insert_after_me(Stmt::make(snode, src->indices)); - auto dtype = global_ptr->ret_type; + auto dtype = global_ptr->ret_type.ptr_removed(); auto one = global_ptr->insert_after_me( Stmt::make(TypedConstant(dtype, 1))); one->insert_after_me(Stmt::make(global_ptr, one)); diff --git a/taichi/transforms/frontend_type_check.cpp b/taichi/transforms/frontend_type_check.cpp index ae7c9af88664cf..4a75b6683ac8d9 100644 --- a/taichi/transforms/frontend_type_check.cpp +++ b/taichi/transforms/frontend_type_check.cpp @@ -7,11 +7,12 @@ namespace taichi::lang { class FrontendTypeCheck : public IRVisitor { void check_cond_type(const Expr &cond, std::string stmt_name) { - if (!cond->ret_type->is() || !is_integral(cond->ret_type)) + auto cond_type = cond.get_rvalue_type(); + if (!cond_type->is() || !is_integral(cond_type)) throw TaichiTypeError(fmt::format( "`{0}` conditions must be an integer; found {1}. Consider using " "`{0} x != 0` instead of `{0} x` for float values.", - stmt_name, cond->ret_type->to_string())); + stmt_name, cond_type->to_string())); } public: @@ -49,8 +50,8 @@ class FrontendTypeCheck : public IRVisitor { } void visit(FrontendAssignStmt *stmt) override { - auto lhs_type = stmt->lhs->ret_type; - auto rhs_type = stmt->rhs->ret_type; + auto lhs_type = stmt->lhs->ret_type.ptr_removed(); + auto rhs_type = stmt->rhs->ret_type.ptr_removed(); auto error = [&]() { throw TaichiTypeError(fmt::format("{}cannot assign '{}' to '{}'", @@ -85,7 +86,7 @@ class FrontendTypeCheck : public IRVisitor { Expr const &expr = std::get(content); TI_ASSERT(expr.expr != nullptr); - DataType data_type = expr->ret_type; + DataType data_type = expr.get_rvalue_type(); if (data_type->is()) { data_type = DataType(data_type->as()->get_element_type()); } diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 63d1227e4d223f..d3fab5c7663d80 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -68,15 +68,15 @@ class LowerAST : public IRVisitor { auto ident = stmt->ident; TI_ASSERT(block->local_var_to_stmt.find(ident) == block->local_var_to_stmt.end()); - if (stmt->ret_type->is()) { - auto tensor_type = stmt->ret_type->cast(); + auto alloca_type = stmt->ret_type.ptr_removed(); + if (auto tensor_type = alloca_type->cast()) { auto lowered = std::make_unique( tensor_type->get_shape(), tensor_type->get_element_type(), stmt->is_shared); block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get())); stmt->parent->replace_with(stmt, std::move(lowered)); } else { - auto lowered = std::make_unique(stmt->ret_type); + auto lowered = std::make_unique(alloca_type); block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get())); stmt->parent->replace_with(stmt, std::move(lowered)); } diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index d16030ca86c3e8..0adec11c50b870 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -530,16 +530,16 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { if (local_to_global_offset_.find(stmt) == local_to_global_offset_.end()) return; VecStatement replacement; - auto ret_type = stmt->ret_type; - local_to_global_vector_type_[stmt] = ret_type; + auto alloca_type = stmt->ret_type.ptr_removed(); + local_to_global_vector_type_[stmt] = alloca_type; auto ptr = replacement.push_back( - local_to_global_offset_.at(stmt), ret_type); + local_to_global_offset_.at(stmt), alloca_type); auto offloaded = stmt_to_offloaded_[stmt]; stmt_to_offloaded_[ptr] = offloaded; - TypedConstant zero(stmt->ret_type.get_element_type()); + TypedConstant zero(alloca_type.get_element_type()); auto const_zero_stmt = replacement.push_back(zero); - if (auto tensor_type = stmt->ret_type->cast()) { + if (auto tensor_type = alloca_type->cast()) { std::vector zero_values(tensor_type->get_num_elements(), const_zero_stmt); auto zero_matrix_init_stmt = diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 796a9939e491e2..1687b7a48bba3f 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -44,7 +44,7 @@ class Scalarize : public BasicStmtVisitor { template void scalarize_store_stmt(T *stmt) { auto dest_dtype = stmt->dest->ret_type.ptr_removed(); - auto val_dtype = stmt->val->ret_type; + auto val_dtype = stmt->val->ret_type.ptr_removed(); if (dest_dtype->template is() && val_dtype->template is()) { // Needs scalarize @@ -185,7 +185,8 @@ class Scalarize : public BasicStmtVisitor { stmt->replace_all_usages_with(tmp) */ void visit(UnaryOpStmt *stmt) override { - auto operand_dtype = stmt->operand->ret_type; + auto operand_dtype = stmt->operand->ret_type.ptr_removed(); + auto stmt_dtype = stmt->ret_type.ptr_removed(); if (operand_dtype->is()) { // Needs scalarize auto operand_tensor_type = operand_dtype->as(); @@ -198,7 +199,7 @@ class Scalarize : public BasicStmtVisitor { std::vector matrix_init_values; int num_elements = operand_tensor_type->get_num_elements(); - auto primitive_type = stmt->ret_type.get_element_type(); + auto primitive_type = stmt_dtype.get_element_type(); for (size_t i = 0; i < num_elements; i++) { auto unary_stmt = std::make_unique( stmt->op_type, operand_matrix_init_stmt->values[i]); @@ -246,8 +247,9 @@ class Scalarize : public BasicStmtVisitor { stmt->replace_all_usages_with(tmp) */ void visit(BinaryOpStmt *stmt) override { - auto lhs_dtype = stmt->lhs->ret_type; - auto rhs_dtype = stmt->rhs->ret_type; + auto lhs_dtype = stmt->lhs->ret_type.ptr_removed(); + auto rhs_dtype = stmt->rhs->ret_type.ptr_removed(); + auto stmt_dtype = stmt->ret_type.ptr_removed(); if (lhs_dtype->is() || rhs_dtype->is()) { // Make sure broadcasting has been correctly applied by // BinaryOpExpression::type_check(). @@ -270,7 +272,7 @@ class Scalarize : public BasicStmtVisitor { TI_ASSERT(rhs_vals.size() == lhs_vals.size()); size_t num_elements = lhs_vals.size(); - auto primitive_type = stmt->ret_type.get_element_type(); + auto primitive_type = stmt_dtype.get_element_type(); std::vector matrix_init_values; for (size_t i = 0; i < num_elements; i++) { auto binary_stmt = std::make_unique( @@ -581,9 +583,9 @@ class Scalarize : public BasicStmtVisitor { stmt->replace_all_usages_with(tmp) */ void visit(TernaryOpStmt *stmt) override { - auto cond_dtype = stmt->op1->ret_type; - auto op2_dtype = stmt->op2->ret_type; - auto op3_dtype = stmt->op3->ret_type; + auto cond_dtype = stmt->op1->ret_type.ptr_removed(); + auto op2_dtype = stmt->op2->ret_type.ptr_removed(); + auto op3_dtype = stmt->op3->ret_type.ptr_removed(); if (cond_dtype->is()) { // Make sure broadcasting has been correctly applied by // TernaryOpExpression::type_check(). @@ -1026,7 +1028,8 @@ class ScalarizePointers : public BasicStmtVisitor { for (size_t i = 0; i < tensor_type->get_num_elements(); i++) { auto scalarized_alloca_stmt = std::make_unique(primitive_type); - scalarized_alloca_stmt->ret_type = primitive_type; + scalarized_alloca_stmt->ret_type = + TypeFactory::get_instance().get_pointer_type(primitive_type); scalarized_local_tensor_map_[stmt].push_back( scalarized_alloca_stmt.get()); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 3a5bb1188c634e..859f5fb61432a9 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -23,13 +23,14 @@ class TypeCheck : public IRVisitor { Stmt *&val, const std::string &stmt_name) { auto dst_type = dst->ret_type.ptr_removed(); + auto val_type = val->ret_type.ptr_removed(); if (is_quant(dst_type)) { // We force the value type to be the compute_type of the bit pointer. // Casting from compute_type to physical_type is handled in codegen. dst_type = dst_type->get_compute_type(); } - if (dst_type != val->ret_type) { - auto promoted = promoted_type(dst_type, val->ret_type); + if (dst_type != val_type) { + auto promoted = promoted_type(dst_type, val_type); if (dst_type != promoted) { TI_WARN("[{}] {} may lose precision: {} <- {}\n{}", stmt->name(), stmt_name, dst_type->to_string(), val->ret_data_type_name(), @@ -88,13 +89,14 @@ class TypeCheck : public IRVisitor { TI_ASSERT(stmt->src->is() || stmt->src->is() || stmt->src->is()); if (auto ptr_offset_stmt = stmt->src->cast()) { - auto lookup = DataType(ptr_offset_stmt->origin->ret_type->as() + auto lookup = DataType(ptr_offset_stmt->origin->ret_type.ptr_removed() + ->as() ->get_element_type()) .ptr_removed(); stmt->ret_type = lookup; } else { auto lookup = stmt->src->ret_type; - stmt->ret_type = lookup; + stmt->ret_type = lookup.ptr_removed(); } } diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp index bce26256c0ac5f..dd21a5ec7b92ce 100644 --- a/tests/cpp/ir/frontend_type_inference_test.cpp +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -32,7 +32,9 @@ TEST(FrontendTypeInference, Id) { auto const_i32 = value(-(1 << 20)); const_i32->type_check(nullptr); auto id_i32 = kernel->context->builder().make_var(const_i32, const_i32->tb); - EXPECT_EQ(id_i32->ret_type, PrimitiveType::i32); + EXPECT_EQ(id_i32->ret_type, + DataType(TypeFactory::get_instance().get_pointer_type( + PrimitiveType::i32))); } TEST(FrontendTypeInference, BinaryOp) { @@ -139,7 +141,9 @@ TEST(FrontendTypeInference, GlobalPtr_Field) { index->type_check(nullptr); auto global_ptr = ast_builder->expr_subscript(global_var, ExprGroup(index)); global_ptr->type_check(nullptr); - EXPECT_EQ(global_ptr->ret_type, PrimitiveType::u8); + EXPECT_EQ(global_ptr->ret_type, + DataType(TypeFactory::get_instance().get_pointer_type( + PrimitiveType::u8))); } TEST(FrontendTypeInference, GlobalPtr_ExternalTensor) { @@ -172,7 +176,9 @@ TEST(FrontendTypeInference, TensorElement) { index->type_check(nullptr); auto tensor_element = Expr::make(var, ExprGroup(index)); tensor_element->type_check(nullptr); - EXPECT_EQ(tensor_element->ret_type, PrimitiveType::u32); + EXPECT_EQ(tensor_element->ret_type, + DataType(TypeFactory::get_instance().get_pointer_type( + PrimitiveType::u32))); } TEST(FrontendTypeInference, AtomicOp) { diff --git a/tests/python/test_unary_ops.py b/tests/python/test_unary_ops.py index 2dc7f4c68310ef..6fe014bee872b4 100644 --- a/tests/python/test_unary_ops.py +++ b/tests/python/test_unary_ops.py @@ -78,7 +78,7 @@ def test(x: ti.f32) -> ti.i32: @test_utils.test(arch=[ti.cuda, ti.vulkan, ti.opengl, ti.metal]) -def _test_frexp(): # Fails in this PR, but will be fixed in the last PR of this series +def test_frexp(): @ti.kernel def get_frac(x: ti.f32) -> ti.f32: a, b = ti.frexp(x)