diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index d841b9387e12b..9f3215ef1e742 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -14,6 +14,28 @@ TLANG_NAMESPACE_BEGIN namespace irpass { namespace { +class SquashPtrOffset : public IRVisitor { + public: + SquashPtrOffset() { + allow_undefined_visitor = true; + invoke_default_visitor = true; + } + void visit(Stmt *stmt) override { + top_level_ptr = stmt; + } + void visit(PtrOffsetStmt *stmt) override { + stmt->origin->accept(this); + } + static Stmt *run(Stmt *root) { + SquashPtrOffset v; + root->accept(&v); + return v.top_level_ptr; + } + + private: + Stmt *top_level_ptr = nullptr; +}; + // Offloaded local variables to its offset in the global tmps memory. using StmtToOffsetMap = std::unordered_map; @@ -305,13 +327,11 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor { // Directly insert copies of ConstStmts later if (stmt->is()) return; + auto top_level_ptr = SquashPtrOffset::run(stmt); // We don't support storing a pointer for now. - if (stmt->is()) + if (top_level_ptr->is()) return; // Not yet allocated - auto top_level_ptr = stmt; - while (top_level_ptr->is()) - top_level_ptr = top_level_ptr->cast()->origin; if (local_to_global.find(top_level_ptr) == local_to_global.end()) { local_to_global[top_level_ptr] = allocate_global(top_level_ptr->ret_type); } @@ -445,12 +465,13 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { local_to_global_vector_type[stmt] = ret_type; auto ptr = replacement.push_back( local_to_global_offset[stmt], ret_type); - stmt_to_offloaded[ptr] = stmt_to_offloaded[stmt]; + auto offloaded = stmt_to_offloaded[stmt]; + stmt_to_offloaded[ptr] = offloaded; if (auto tensor_type = stmt->ret_type->cast()) { LaneAttribute zero(std::vector( 1, TypedConstant(tensor_type->get_element_type()))); auto const_zero_stmt = replacement.push_back(zero); - stmt_to_offloaded[const_zero_stmt] = stmt_to_offloaded[stmt]; + stmt_to_offloaded[const_zero_stmt] = offloaded; for (int i = 0; i < tensor_type->get_num_elements(); ++i) { LaneAttribute offset(std::vector( 1, TypedConstant(i * @@ -460,9 +481,9 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { replacement.push_back(ptr, const_offset_stmt); auto global_store_stmt = replacement.push_back( ptr_offset_stmt, const_zero_stmt); - stmt_to_offloaded[const_offset_stmt] = stmt_to_offloaded[stmt]; - stmt_to_offloaded[ptr_offset_stmt] = stmt_to_offloaded[stmt]; - stmt_to_offloaded[global_store_stmt] = stmt_to_offloaded[stmt]; + stmt_to_offloaded[const_offset_stmt] = offloaded; + stmt_to_offloaded[ptr_offset_stmt] = offloaded; + stmt_to_offloaded[global_store_stmt] = offloaded; } } else { LaneAttribute zeros(std::vector( @@ -470,7 +491,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { auto const_zeros = replacement.push_back(zeros); auto global_store_stmt = replacement.push_back(ptr, const_zeros); - stmt_to_offloaded[global_store_stmt] = stmt_to_offloaded[stmt]; + stmt_to_offloaded[global_store_stmt] = offloaded; } stmt->parent->replace_with(stmt, std::move(replacement), false); @@ -484,9 +505,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { generic_visit(stmt); TI_ASSERT(stmt->width() == 1) auto ptr = stmt->src[0].var; - auto top_level_ptr = ptr; - while (top_level_ptr->is()) - top_level_ptr = top_level_ptr->cast()->origin; + auto top_level_ptr = SquashPtrOffset::run(ptr); if (top_level_ptr->is()) { VecStatement replacement; auto global_load = replacement.push_back(ptr); @@ -499,9 +518,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { void visit(LocalStoreStmt *stmt) override { generic_visit(stmt); auto ptr = stmt->dest; - auto top_level_ptr = ptr; - while (top_level_ptr->is()) - top_level_ptr = top_level_ptr->cast()->origin; + auto top_level_ptr = SquashPtrOffset::run(ptr); if (top_level_ptr->is()) { VecStatement replacement; auto global_store = @@ -520,25 +537,31 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { return false; if (stmt_to_offloaded[stmt] == stmt_to_offloaded[op]) // same OffloadedStmt return false; + + auto offloaded = stmt_to_offloaded[stmt]; + if (op->is()) { auto copy = op->clone(); copy->as()->activate = false; - stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt]; + stmt_to_offloaded[copy.get()] = offloaded; stmt->set_operand(index, copy.get()); stmt->insert_before_me(std::move(copy)); return true; } if (local_to_global_offset.find(op) == local_to_global_offset.end()) { + TI_ASSERT_INFO(op->is() || op->is() || + op->is(), + "{} is not allowed here.", op->type()); // For cases like ConstStmt auto copy = op->clone(); - stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt]; + stmt_to_offloaded[copy.get()] = offloaded; stmt->set_operand(index, copy.get()); stmt->insert_before_me(std::move(copy)); } else { auto global_temporary = Stmt::make( local_to_global_offset[op], op->ret_type); - stmt_to_offloaded[global_temporary.get()] = stmt_to_offloaded[stmt]; + stmt_to_offloaded[global_temporary.get()] = offloaded; stmt->set_operand(index, global_temporary.get()); if (op->is() || op->ret_type.is_pointer()) { // For cases like Alloca both TensorType and Scalar which will be @@ -547,7 +570,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { } else { // For other cases like ArgLoadStmt UnaryOpStmt which needs to load. auto load = Stmt::make(global_temporary.get()); - stmt_to_offloaded[load.get()] = stmt_to_offloaded[stmt]; + stmt_to_offloaded[load.get()] = offloaded; stmt->set_operand(index, load.get()); stmt->insert_before_me(std::move(global_temporary)); stmt->insert_before_me(std::move(load));