diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index d546900f83724f..4939480674f47c 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -247,7 +247,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // result: the value to store Stmt *result = irpass::analysis::get_store_data( block->statements[last_def_position].get()); - bool is_tensor_involved = var->ret_type->is(); + bool is_tensor_involved = var->ret_type.ptr_removed()->is(); if (!(var->is() && !is_tensor_involved)) { // In between the store stmt and current stmt, // if there's a third-stmt that "may" have stored a "different value" to @@ -355,7 +355,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // Check for aliased address // There's a store to the same dest_addr before this stmt - bool is_tensor_involved = var->ret_type->is(); + bool is_tensor_involved = var->ret_type.ptr_removed()->is(); if (!(var->is() && !is_tensor_involved)) { // In between the store stmt and current stmt, // if there's a third-stmt that "may" have stored a "different value" to @@ -443,7 +443,8 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access, continue; // special case of alloca (initialized to 0) - auto zero = Stmt::make(TypedConstant(result->ret_type, 0)); + auto zero = Stmt::make( + TypedConstant(result->ret_type.ptr_removed(), 0)); replace_with(i, std::move(zero), true); } else { if (result->ret_type.ptr_removed()->is() && diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index fb4a2721679c48..cb41c9ab5ad561 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -37,7 +37,8 @@ FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs) : lhs(lhs), rhs(rhs) { TI_ASSERT(lhs->is_lvalue()); if (lhs.is() && lhs->ret_type == PrimitiveType::unknown) { - lhs.expr->ret_type = rhs.get_rvalue_type(); + lhs.expr->ret_type = + TypeFactory::get_instance().get_pointer_type(rhs.get_rvalue_type()); } } @@ -127,7 +128,8 @@ void FrontendForStmt::init_loop_vars(const ExprGroup &loop_vars) { void FrontendForStmt::add_loop_var(const Expr &loop_var) { loop_var_ids.push_back(loop_var.cast()->id); - loop_var.expr->ret_type = PrimitiveType::i32; + loop_var.expr->ret_type = + TypeFactory::get_instance().get_pointer_type(PrimitiveType::i32); } FrontendFuncDefStmt::FrontendFuncDefStmt(const FrontendFuncDefStmt &o) @@ -883,7 +885,7 @@ void IndexExpression::type_check(const CompileConfig *) { "Invalid IndexExpression: the source is not among field, ndarray or " "local tensor"); } - + ret_type = TypeFactory::get_instance().get_pointer_type(ret_type); for (auto &indices : indices_group) { for (int i = 0; i < indices.exprs.size(); i++) { auto &expr = indices.exprs[i]; @@ -960,7 +962,7 @@ void LoopUniqueExpression::flatten(FlattenContext *ctx) { void IdExpression::flatten(FlattenContext *ctx) { stmt = ctx->current_block->lookup_var(id); - if (!ret_type->is_primitive(PrimitiveTypeID::unknown)) { + if (stmt->ret_type->is_primitive(PrimitiveTypeID::unknown)) { stmt->ret_type = ret_type; } } @@ -1514,7 +1516,7 @@ Expr ASTBuilder::expr_subscript(const Expr &expr, std::string tb) { TI_ASSERT(expr.is() || expr.is() || expr.is() || - is_tensor(expr.expr->ret_type)); + is_tensor(expr.expr->ret_type.ptr_removed())); // IndexExpression without ret_shape is used for matrix indexing, // where each entry of ExprGroup is interpreted as indexing into a specific @@ -1677,7 +1679,7 @@ std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { elem.expr->ret_type = struct_type->get_element_type(indices); expanded_exprs.push_back(elem); } - } else if (!expr->ret_type->is()) { + } else if (!expr->ret_type.ptr_removed()->is()) { expanded_exprs.push_back(expr); } else { // Expand TensorType expr @@ -1695,7 +1697,7 @@ std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { return {ind0, ind1, ind2, ind3} */ - auto tensor_type = expr->ret_type->cast(); + auto tensor_type = expr->ret_type.ptr_removed()->cast(); Expr id_expr; if (expr.is()) { @@ -1708,7 +1710,7 @@ std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { for (int i = 0; i < shape[0]; i++) { auto ind = Expr(std::make_shared( id_expr, ExprGroup(Expr(i)), expr->tb)); - ind.expr->ret_type = tensor_type->get_element_type(); + ind->type_check(nullptr); expanded_exprs.push_back(ind); } } else { @@ -1717,7 +1719,7 @@ std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { for (int j = 0; j < shape[1]; j++) { auto ind = Expr(std::make_shared( id_expr, ExprGroup(Expr(i), Expr(j)), expr->tb)); - ind.expr->ret_type = tensor_type->get_element_type(); + ind->type_check(nullptr); expanded_exprs.push_back(ind); } } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index dc9f6fad7d7da0..35ecf42241e6d0 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -86,7 +86,8 @@ class FrontendAllocaStmt : public Stmt { DataType element, bool is_shared = false) : ident(lhs), is_shared(is_shared) { - ret_type = DataType(TypeFactory::create_tensor_type(shape, element)); + ret_type = TypeFactory::get_instance().get_pointer_type( + DataType(TypeFactory::create_tensor_type(shape, element))); } bool is_shared; @@ -500,6 +501,7 @@ class ExternalTensorExpression : public Expression { void type_check(const CompileConfig *config) override { ret_type = dt; + ret_type.set_is_pointer(true); config_ = config; } @@ -585,7 +587,7 @@ class MatrixExpression : public Expression { std::vector shape, DataType element_type) : elements(elements) { - this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type)); + dt = TypeFactory::create_tensor_type(shape, element_type); } void type_check(const CompileConfig *config) override; diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 88ed18f0cea3b4..c01691abf4434b 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -86,6 +86,7 @@ MatrixOfMatrixPtrStmt::MatrixOfMatrixPtrStmt(const std::vector &stmts, DataType dt) : stmts(stmts) { ret_type = dt; + ret_type.set_is_pointer(true); TI_STMT_REG_FIELDS; } diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index b84122ae3c3476..935124d6ee0158 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -19,7 +19,11 @@ class Function; class AllocaStmt : public Stmt, public ir_traits::Store { public: explicit AllocaStmt(DataType type) : is_shared(false) { - ret_type = type; + if (type->is_primitive(PrimitiveTypeID::unknown)) { + ret_type = type; + } else { + ret_type = TypeFactory::get_instance().get_pointer_type(type); + } TI_STMT_REG_FIELDS; } @@ -27,7 +31,8 @@ class AllocaStmt : public Stmt, public ir_traits::Store { DataType type, bool is_shared = false) : is_shared(is_shared) { - ret_type = TypeFactory::create_tensor_type(shape, type); + ret_type = TypeFactory::get_instance().get_pointer_type( + TypeFactory::create_tensor_type(shape, type)); TI_STMT_REG_FIELDS; } diff --git a/taichi/transforms/frontend_type_check.cpp b/taichi/transforms/frontend_type_check.cpp index ae7c9af88664cf..4a75b6683ac8d9 100644 --- a/taichi/transforms/frontend_type_check.cpp +++ b/taichi/transforms/frontend_type_check.cpp @@ -7,11 +7,12 @@ namespace taichi::lang { class FrontendTypeCheck : public IRVisitor { void check_cond_type(const Expr &cond, std::string stmt_name) { - if (!cond->ret_type->is() || !is_integral(cond->ret_type)) + auto cond_type = cond.get_rvalue_type(); + if (!cond_type->is() || !is_integral(cond_type)) throw TaichiTypeError(fmt::format( "`{0}` conditions must be an integer; found {1}. Consider using " "`{0} x != 0` instead of `{0} x` for float values.", - stmt_name, cond->ret_type->to_string())); + stmt_name, cond_type->to_string())); } public: @@ -49,8 +50,8 @@ class FrontendTypeCheck : public IRVisitor { } void visit(FrontendAssignStmt *stmt) override { - auto lhs_type = stmt->lhs->ret_type; - auto rhs_type = stmt->rhs->ret_type; + auto lhs_type = stmt->lhs->ret_type.ptr_removed(); + auto rhs_type = stmt->rhs->ret_type.ptr_removed(); auto error = [&]() { throw TaichiTypeError(fmt::format("{}cannot assign '{}' to '{}'", @@ -85,7 +86,7 @@ class FrontendTypeCheck : public IRVisitor { Expr const &expr = std::get(content); TI_ASSERT(expr.expr != nullptr); - DataType data_type = expr->ret_type; + DataType data_type = expr.get_rvalue_type(); if (data_type->is()) { data_type = DataType(data_type->as()->get_element_type()); } diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 63d1227e4d223f..d3fab5c7663d80 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -68,15 +68,15 @@ class LowerAST : public IRVisitor { auto ident = stmt->ident; TI_ASSERT(block->local_var_to_stmt.find(ident) == block->local_var_to_stmt.end()); - if (stmt->ret_type->is()) { - auto tensor_type = stmt->ret_type->cast(); + auto alloca_type = stmt->ret_type.ptr_removed(); + if (auto tensor_type = alloca_type->cast()) { auto lowered = std::make_unique( tensor_type->get_shape(), tensor_type->get_element_type(), stmt->is_shared); block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get())); stmt->parent->replace_with(stmt, std::move(lowered)); } else { - auto lowered = std::make_unique(stmt->ret_type); + auto lowered = std::make_unique(alloca_type); block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get())); stmt->parent->replace_with(stmt, std::move(lowered)); } diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index d16030ca86c3e8..0adec11c50b870 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -530,16 +530,16 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { if (local_to_global_offset_.find(stmt) == local_to_global_offset_.end()) return; VecStatement replacement; - auto ret_type = stmt->ret_type; - local_to_global_vector_type_[stmt] = ret_type; + auto alloca_type = stmt->ret_type.ptr_removed(); + local_to_global_vector_type_[stmt] = alloca_type; auto ptr = replacement.push_back( - local_to_global_offset_.at(stmt), ret_type); + local_to_global_offset_.at(stmt), alloca_type); auto offloaded = stmt_to_offloaded_[stmt]; stmt_to_offloaded_[ptr] = offloaded; - TypedConstant zero(stmt->ret_type.get_element_type()); + TypedConstant zero(alloca_type.get_element_type()); auto const_zero_stmt = replacement.push_back(zero); - if (auto tensor_type = stmt->ret_type->cast()) { + if (auto tensor_type = alloca_type->cast()) { std::vector zero_values(tensor_type->get_num_elements(), const_zero_stmt); auto zero_matrix_init_stmt = diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 796a9939e491e2..1687b7a48bba3f 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -44,7 +44,7 @@ class Scalarize : public BasicStmtVisitor { template void scalarize_store_stmt(T *stmt) { auto dest_dtype = stmt->dest->ret_type.ptr_removed(); - auto val_dtype = stmt->val->ret_type; + auto val_dtype = stmt->val->ret_type.ptr_removed(); if (dest_dtype->template is() && val_dtype->template is()) { // Needs scalarize @@ -185,7 +185,8 @@ class Scalarize : public BasicStmtVisitor { stmt->replace_all_usages_with(tmp) */ void visit(UnaryOpStmt *stmt) override { - auto operand_dtype = stmt->operand->ret_type; + auto operand_dtype = stmt->operand->ret_type.ptr_removed(); + auto stmt_dtype = stmt->ret_type.ptr_removed(); if (operand_dtype->is()) { // Needs scalarize auto operand_tensor_type = operand_dtype->as(); @@ -198,7 +199,7 @@ class Scalarize : public BasicStmtVisitor { std::vector matrix_init_values; int num_elements = operand_tensor_type->get_num_elements(); - auto primitive_type = stmt->ret_type.get_element_type(); + auto primitive_type = stmt_dtype.get_element_type(); for (size_t i = 0; i < num_elements; i++) { auto unary_stmt = std::make_unique( stmt->op_type, operand_matrix_init_stmt->values[i]); @@ -246,8 +247,9 @@ class Scalarize : public BasicStmtVisitor { stmt->replace_all_usages_with(tmp) */ void visit(BinaryOpStmt *stmt) override { - auto lhs_dtype = stmt->lhs->ret_type; - auto rhs_dtype = stmt->rhs->ret_type; + auto lhs_dtype = stmt->lhs->ret_type.ptr_removed(); + auto rhs_dtype = stmt->rhs->ret_type.ptr_removed(); + auto stmt_dtype = stmt->ret_type.ptr_removed(); if (lhs_dtype->is() || rhs_dtype->is()) { // Make sure broadcasting has been correctly applied by // BinaryOpExpression::type_check(). @@ -270,7 +272,7 @@ class Scalarize : public BasicStmtVisitor { TI_ASSERT(rhs_vals.size() == lhs_vals.size()); size_t num_elements = lhs_vals.size(); - auto primitive_type = stmt->ret_type.get_element_type(); + auto primitive_type = stmt_dtype.get_element_type(); std::vector matrix_init_values; for (size_t i = 0; i < num_elements; i++) { auto binary_stmt = std::make_unique( @@ -581,9 +583,9 @@ class Scalarize : public BasicStmtVisitor { stmt->replace_all_usages_with(tmp) */ void visit(TernaryOpStmt *stmt) override { - auto cond_dtype = stmt->op1->ret_type; - auto op2_dtype = stmt->op2->ret_type; - auto op3_dtype = stmt->op3->ret_type; + auto cond_dtype = stmt->op1->ret_type.ptr_removed(); + auto op2_dtype = stmt->op2->ret_type.ptr_removed(); + auto op3_dtype = stmt->op3->ret_type.ptr_removed(); if (cond_dtype->is()) { // Make sure broadcasting has been correctly applied by // TernaryOpExpression::type_check(). @@ -1026,7 +1028,8 @@ class ScalarizePointers : public BasicStmtVisitor { for (size_t i = 0; i < tensor_type->get_num_elements(); i++) { auto scalarized_alloca_stmt = std::make_unique(primitive_type); - scalarized_alloca_stmt->ret_type = primitive_type; + scalarized_alloca_stmt->ret_type = + TypeFactory::get_instance().get_pointer_type(primitive_type); scalarized_local_tensor_map_[stmt].push_back( scalarized_alloca_stmt.get()); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 3a5bb1188c634e..859f5fb61432a9 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -23,13 +23,14 @@ class TypeCheck : public IRVisitor { Stmt *&val, const std::string &stmt_name) { auto dst_type = dst->ret_type.ptr_removed(); + auto val_type = val->ret_type.ptr_removed(); if (is_quant(dst_type)) { // We force the value type to be the compute_type of the bit pointer. // Casting from compute_type to physical_type is handled in codegen. dst_type = dst_type->get_compute_type(); } - if (dst_type != val->ret_type) { - auto promoted = promoted_type(dst_type, val->ret_type); + if (dst_type != val_type) { + auto promoted = promoted_type(dst_type, val_type); if (dst_type != promoted) { TI_WARN("[{}] {} may lose precision: {} <- {}\n{}", stmt->name(), stmt_name, dst_type->to_string(), val->ret_data_type_name(), @@ -88,13 +89,14 @@ class TypeCheck : public IRVisitor { TI_ASSERT(stmt->src->is() || stmt->src->is() || stmt->src->is()); if (auto ptr_offset_stmt = stmt->src->cast()) { - auto lookup = DataType(ptr_offset_stmt->origin->ret_type->as() + auto lookup = DataType(ptr_offset_stmt->origin->ret_type.ptr_removed() + ->as() ->get_element_type()) .ptr_removed(); stmt->ret_type = lookup; } else { auto lookup = stmt->src->ret_type; - stmt->ret_type = lookup; + stmt->ret_type = lookup.ptr_removed(); } } diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp index bce26256c0ac5f..dd21a5ec7b92ce 100644 --- a/tests/cpp/ir/frontend_type_inference_test.cpp +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -32,7 +32,9 @@ TEST(FrontendTypeInference, Id) { auto const_i32 = value(-(1 << 20)); const_i32->type_check(nullptr); auto id_i32 = kernel->context->builder().make_var(const_i32, const_i32->tb); - EXPECT_EQ(id_i32->ret_type, PrimitiveType::i32); + EXPECT_EQ(id_i32->ret_type, + DataType(TypeFactory::get_instance().get_pointer_type( + PrimitiveType::i32))); } TEST(FrontendTypeInference, BinaryOp) { @@ -139,7 +141,9 @@ TEST(FrontendTypeInference, GlobalPtr_Field) { index->type_check(nullptr); auto global_ptr = ast_builder->expr_subscript(global_var, ExprGroup(index)); global_ptr->type_check(nullptr); - EXPECT_EQ(global_ptr->ret_type, PrimitiveType::u8); + EXPECT_EQ(global_ptr->ret_type, + DataType(TypeFactory::get_instance().get_pointer_type( + PrimitiveType::u8))); } TEST(FrontendTypeInference, GlobalPtr_ExternalTensor) { @@ -172,7 +176,9 @@ TEST(FrontendTypeInference, TensorElement) { index->type_check(nullptr); auto tensor_element = Expr::make(var, ExprGroup(index)); tensor_element->type_check(nullptr); - EXPECT_EQ(tensor_element->ret_type, PrimitiveType::u32); + EXPECT_EQ(tensor_element->ret_type, + DataType(TypeFactory::get_instance().get_pointer_type( + PrimitiveType::u32))); } TEST(FrontendTypeInference, AtomicOp) { diff --git a/tests/python/test_unary_ops.py b/tests/python/test_unary_ops.py index 2dc7f4c68310ef..6fe014bee872b4 100644 --- a/tests/python/test_unary_ops.py +++ b/tests/python/test_unary_ops.py @@ -78,7 +78,7 @@ def test(x: ti.f32) -> ti.i32: @test_utils.test(arch=[ti.cuda, ti.vulkan, ti.opengl, ti.metal]) -def _test_frexp(): # Fails in this PR, but will be fixed in the last PR of this series +def test_frexp(): @ti.kernel def get_frac(x: ti.f32) -> ti.f32: a, b = ti.frexp(x)