diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 3e6e79ea9dc9a..fbdb923f1f3c0 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -278,6 +278,15 @@ bool Stmt::have_operand(Stmt *stmt) const { return false; } +int Stmt::locate_operand(Stmt **stmt) { + for (int i = 0; i < num_operands(); i++) { + if (operands[i] == stmt) { + return i; + } + } + return -1; +} + std::string Expression::get_attribute(const std::string &key) const { if (auto it = attributes.find(key); it == attributes.end()) { TI_ERROR("Attribute {} not found.", key); diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index a88d6df17fe23..cce7c0cbb7119 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -660,6 +660,7 @@ class Stmt : public IRNode { void set_operand(int i, Stmt *stmt); void register_operand(Stmt *&stmt); + int locate_operand(Stmt **stmt); void mark_fields_registered(); virtual void rebuild_operands() { diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index 7b43bf4eff244..c8a80b216071b 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -1,5 +1,6 @@ #include #include +#include #include "taichi/ir/ir.h" @@ -345,7 +346,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { const StmtToOffsetMap &local_to_global_offset, std::unordered_map stmt_to_offloaded) : local_to_global_offset(local_to_global_offset), - stmt_to_offloaded(stmt_to_offloaded) { + stmt_to_offloaded(std::move(stmt_to_offloaded)) { allow_undefined_visitor = true; invoke_default_visitor = true; } @@ -399,6 +400,8 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { } void visit(LocalStoreStmt *stmt) override { + if (visit_operand(stmt, stmt->locate_operand(&stmt->data))) + throw IRModified(); TI_ASSERT(stmt->width() == 1); auto alloca = stmt->ptr; if (local_to_global_offset.find(alloca) == local_to_global_offset.end()) @@ -416,6 +419,8 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { } void visit(AtomicOpStmt *stmt) override { + if (visit_operand(stmt, stmt->locate_operand(&stmt->val))) + throw IRModified(); TI_ASSERT(stmt->width() == 1); auto alloca = stmt->dest; if (local_to_global_offset.find(alloca) == local_to_global_offset.end()) @@ -432,38 +437,45 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { throw IRModified(); } + bool visit_operand(Stmt *stmt, int index) { + // return true if modified + TI_ASSERT(index >= 0 && index < stmt->num_operands()); + auto op = stmt->operand(index); + if (op == nullptr) + return false; + if (stmt_to_offloaded[stmt] == + stmt_to_offloaded[op]) // same OffloadedStmt + return false; + if (advanced_optimization) { + if (op->is()) { + auto copy = op->as()->copy(); + stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt]; + stmt->set_operand(index, copy.get()); + stmt->insert_before_me(std::move(copy)); + return true; + } + } + if (local_to_global_offset.find(op) == local_to_global_offset.end()) + return false; + + auto global = Stmt::make(local_to_global_offset[op], + op->ret_type); + auto load = Stmt::make(global.get()); + stmt_to_offloaded[load.get()] = stmt_to_offloaded[stmt]; + stmt->set_operand(index, load.get()); + stmt->insert_before_me(std::move(global)); + stmt->insert_before_me(std::move(load)); + return true; + } + // Generic visitor void visit(Stmt *stmt) override { TI_ASSERT(stmt->width() == 1); int n_op = stmt->num_operands(); bool modified = false; for (int i = 0; i < n_op; i++) { - auto op = stmt->operand(i); - if (op == nullptr) - continue; - if (stmt_to_offloaded[stmt] == - stmt_to_offloaded[op]) // same OffloadedStmt - continue; - if (advanced_optimization) { - if (op->is()) { - auto copy = op->as()->copy(); - stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt]; - stmt->set_operand(i, copy.get()); - stmt->insert_before_me(std::move(copy)); - modified = true; - continue; - } - } - if (local_to_global_offset.find(op) == local_to_global_offset.end()) - continue; - - auto global = Stmt::make(local_to_global_offset[op], - op->ret_type); - auto load = Stmt::make(global.get()); - stmt->set_operand(i, load.get()); - stmt->insert_before_me(std::move(global)); - stmt->insert_before_me(std::move(load)); - modified = true; + if (visit_operand(stmt, i)) + modified = true; } if (modified) throw IRModified();