diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index d546900f83724f..4939480674f47c 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -247,7 +247,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // result: the value to store Stmt *result = irpass::analysis::get_store_data( block->statements[last_def_position].get()); - bool is_tensor_involved = var->ret_type->is(); + bool is_tensor_involved = var->ret_type.ptr_removed()->is(); if (!(var->is() && !is_tensor_involved)) { // In between the store stmt and current stmt, // if there's a third-stmt that "may" have stored a "different value" to @@ -355,7 +355,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // Check for aliased address // There's a store to the same dest_addr before this stmt - bool is_tensor_involved = var->ret_type->is(); + bool is_tensor_involved = var->ret_type.ptr_removed()->is(); if (!(var->is() && !is_tensor_involved)) { // In between the store stmt and current stmt, // if there's a third-stmt that "may" have stored a "different value" to @@ -443,7 +443,8 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access, continue; // special case of alloca (initialized to 0) - auto zero = Stmt::make(TypedConstant(result->ret_type, 0)); + auto zero = Stmt::make( + TypedConstant(result->ret_type.ptr_removed(), 0)); replace_with(i, std::move(zero), true); } else { if (result->ret_type.ptr_removed()->is() && diff --git a/taichi/transforms/frontend_type_check.cpp b/taichi/transforms/frontend_type_check.cpp index ae7c9af88664cf..4a75b6683ac8d9 100644 --- a/taichi/transforms/frontend_type_check.cpp +++ b/taichi/transforms/frontend_type_check.cpp @@ -7,11 +7,12 @@ namespace taichi::lang { class FrontendTypeCheck : public IRVisitor { void check_cond_type(const Expr &cond, std::string stmt_name) { - if (!cond->ret_type->is() || !is_integral(cond->ret_type)) + auto cond_type = cond.get_rvalue_type(); + if (!cond_type->is() || !is_integral(cond_type)) throw TaichiTypeError(fmt::format( "`{0}` conditions must be an integer; found {1}. Consider using " "`{0} x != 0` instead of `{0} x` for float values.", - stmt_name, cond->ret_type->to_string())); + stmt_name, cond_type->to_string())); } public: @@ -49,8 +50,8 @@ class FrontendTypeCheck : public IRVisitor { } void visit(FrontendAssignStmt *stmt) override { - auto lhs_type = stmt->lhs->ret_type; - auto rhs_type = stmt->rhs->ret_type; + auto lhs_type = stmt->lhs->ret_type.ptr_removed(); + auto rhs_type = stmt->rhs->ret_type.ptr_removed(); auto error = [&]() { throw TaichiTypeError(fmt::format("{}cannot assign '{}' to '{}'", @@ -85,7 +86,7 @@ class FrontendTypeCheck : public IRVisitor { Expr const &expr = std::get(content); TI_ASSERT(expr.expr != nullptr); - DataType data_type = expr->ret_type; + DataType data_type = expr.get_rvalue_type(); if (data_type->is()) { data_type = DataType(data_type->as()->get_element_type()); } diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 63d1227e4d223f..d3fab5c7663d80 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -68,15 +68,15 @@ class LowerAST : public IRVisitor { auto ident = stmt->ident; TI_ASSERT(block->local_var_to_stmt.find(ident) == block->local_var_to_stmt.end()); - if (stmt->ret_type->is()) { - auto tensor_type = stmt->ret_type->cast(); + auto alloca_type = stmt->ret_type.ptr_removed(); + if (auto tensor_type = alloca_type->cast()) { auto lowered = std::make_unique( tensor_type->get_shape(), tensor_type->get_element_type(), stmt->is_shared); block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get())); stmt->parent->replace_with(stmt, std::move(lowered)); } else { - auto lowered = std::make_unique(stmt->ret_type); + auto lowered = std::make_unique(alloca_type); block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get())); stmt->parent->replace_with(stmt, std::move(lowered)); } diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index d16030ca86c3e8..0adec11c50b870 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -530,16 +530,16 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { if (local_to_global_offset_.find(stmt) == local_to_global_offset_.end()) return; VecStatement replacement; - auto ret_type = stmt->ret_type; - local_to_global_vector_type_[stmt] = ret_type; + auto alloca_type = stmt->ret_type.ptr_removed(); + local_to_global_vector_type_[stmt] = alloca_type; auto ptr = replacement.push_back( - local_to_global_offset_.at(stmt), ret_type); + local_to_global_offset_.at(stmt), alloca_type); auto offloaded = stmt_to_offloaded_[stmt]; stmt_to_offloaded_[ptr] = offloaded; - TypedConstant zero(stmt->ret_type.get_element_type()); + TypedConstant zero(alloca_type.get_element_type()); auto const_zero_stmt = replacement.push_back(zero); - if (auto tensor_type = stmt->ret_type->cast()) { + if (auto tensor_type = alloca_type->cast()) { std::vector zero_values(tensor_type->get_num_elements(), const_zero_stmt); auto zero_matrix_init_stmt = diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 796a9939e491e2..f14859af99ede3 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -44,7 +44,7 @@ class Scalarize : public BasicStmtVisitor { template void scalarize_store_stmt(T *stmt) { auto dest_dtype = stmt->dest->ret_type.ptr_removed(); - auto val_dtype = stmt->val->ret_type; + auto val_dtype = stmt->val->ret_type.ptr_removed(); if (dest_dtype->template is() && val_dtype->template is()) { // Needs scalarize @@ -185,7 +185,8 @@ class Scalarize : public BasicStmtVisitor { stmt->replace_all_usages_with(tmp) */ void visit(UnaryOpStmt *stmt) override { - auto operand_dtype = stmt->operand->ret_type; + auto operand_dtype = stmt->operand->ret_type.ptr_removed(); + auto stmt_dtype = stmt->ret_type.ptr_removed(); if (operand_dtype->is()) { // Needs scalarize auto operand_tensor_type = operand_dtype->as(); @@ -198,7 +199,7 @@ class Scalarize : public BasicStmtVisitor { std::vector matrix_init_values; int num_elements = operand_tensor_type->get_num_elements(); - auto primitive_type = stmt->ret_type.get_element_type(); + auto primitive_type = stmt_dtype.get_element_type(); for (size_t i = 0; i < num_elements; i++) { auto unary_stmt = std::make_unique( stmt->op_type, operand_matrix_init_stmt->values[i]); @@ -246,8 +247,9 @@ class Scalarize : public BasicStmtVisitor { stmt->replace_all_usages_with(tmp) */ void visit(BinaryOpStmt *stmt) override { - auto lhs_dtype = stmt->lhs->ret_type; - auto rhs_dtype = stmt->rhs->ret_type; + auto lhs_dtype = stmt->lhs->ret_type.ptr_removed(); + auto rhs_dtype = stmt->rhs->ret_type.ptr_removed(); + auto stmt_dtype = stmt->ret_type.ptr_removed(); if (lhs_dtype->is() || rhs_dtype->is()) { // Make sure broadcasting has been correctly applied by // BinaryOpExpression::type_check(). @@ -270,7 +272,7 @@ class Scalarize : public BasicStmtVisitor { TI_ASSERT(rhs_vals.size() == lhs_vals.size()); size_t num_elements = lhs_vals.size(); - auto primitive_type = stmt->ret_type.get_element_type(); + auto primitive_type = stmt_dtype.get_element_type(); std::vector matrix_init_values; for (size_t i = 0; i < num_elements; i++) { auto binary_stmt = std::make_unique( @@ -581,9 +583,9 @@ class Scalarize : public BasicStmtVisitor { stmt->replace_all_usages_with(tmp) */ void visit(TernaryOpStmt *stmt) override { - auto cond_dtype = stmt->op1->ret_type; - auto op2_dtype = stmt->op2->ret_type; - auto op3_dtype = stmt->op3->ret_type; + auto cond_dtype = stmt->op1->ret_type.ptr_removed(); + auto op2_dtype = stmt->op2->ret_type.ptr_removed(); + auto op3_dtype = stmt->op3->ret_type.ptr_removed(); if (cond_dtype->is()) { // Make sure broadcasting has been correctly applied by // TernaryOpExpression::type_check(). diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 3a5bb1188c634e..859f5fb61432a9 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -23,13 +23,14 @@ class TypeCheck : public IRVisitor { Stmt *&val, const std::string &stmt_name) { auto dst_type = dst->ret_type.ptr_removed(); + auto val_type = val->ret_type.ptr_removed(); if (is_quant(dst_type)) { // We force the value type to be the compute_type of the bit pointer. // Casting from compute_type to physical_type is handled in codegen. dst_type = dst_type->get_compute_type(); } - if (dst_type != val->ret_type) { - auto promoted = promoted_type(dst_type, val->ret_type); + if (dst_type != val_type) { + auto promoted = promoted_type(dst_type, val_type); if (dst_type != promoted) { TI_WARN("[{}] {} may lose precision: {} <- {}\n{}", stmt->name(), stmt_name, dst_type->to_string(), val->ret_data_type_name(), @@ -88,13 +89,14 @@ class TypeCheck : public IRVisitor { TI_ASSERT(stmt->src->is() || stmt->src->is() || stmt->src->is()); if (auto ptr_offset_stmt = stmt->src->cast()) { - auto lookup = DataType(ptr_offset_stmt->origin->ret_type->as() + auto lookup = DataType(ptr_offset_stmt->origin->ret_type.ptr_removed() + ->as() ->get_element_type()) .ptr_removed(); stmt->ret_type = lookup; } else { auto lookup = stmt->src->ret_type; - stmt->ret_type = lookup; + stmt->ret_type = lookup.ptr_removed(); } }