From fe17ca1f71216dd88e6da11b68b97af964fa4061 Mon Sep 17 00:00:00 2001 From: xumingkuan <xumingkuan0721@126.com> Date: Sat, 4 Jul 2020 01:03:33 -0400 Subject: [PATCH] [Bug] [ir] Fix compilation crash when there's a cross-offload global atomic operation (#1392) * [Bug] [ir] Fix compilation crash when there's a cross-offload global atomic operation * add a test * OFT Co-authored-by: archibate <1931127624@qq.com> --- taichi/transforms/offload.cpp | 24 ++++++++++++------------ tests/python/test_offload_cross.py | 20 +++++++++++++++++++- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index 2b8ec3b0a4b36..df236786d9a90 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -406,6 +406,10 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { } void visit(AtomicOpStmt *stmt) override { + if (!stmt->dest->is<AllocaStmt>()) { + generic_visit(stmt); + return; + } if (visit_operand(stmt, stmt->locate_operand(&stmt->val))) throw IRModified(); TI_ASSERT(stmt->width() == 1); @@ -442,8 +446,8 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { } } if (op->is<GlobalPtrStmt>()) { - TI_ASSERT(!op->has_global_side_effect()); auto copy = op->clone(); + copy->as<GlobalPtrStmt>()->activate = false; stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt]; stmt->set_operand(index, copy.get()); stmt->insert_before_me(std::move(copy)); @@ -462,9 +466,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { return true; } - // Generic visitor - void visit(Stmt *stmt) override { - TI_ASSERT(stmt->width() == 1); + void generic_visit(Stmt *stmt) { int n_op = stmt->num_operands(); bool modified = false; for (int i = 0; i < n_op; i++) { @@ -475,15 +477,13 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { throw IRModified(); } + void visit(Stmt *stmt) override { + TI_ASSERT(stmt->width() == 1); + generic_visit(stmt); + } + void preprocess_container_stmt(Stmt *stmt) override { - int n_op = stmt->num_operands(); - bool modified = false; - for (int i = 0; i < n_op; i++) { - if (visit_operand(stmt, i)) - modified = true; - } - if (modified) - throw IRModified(); + generic_visit(stmt); } public: diff --git a/tests/python/test_offload_cross.py b/tests/python/test_offload_cross.py index bdf3508655293..3f6bacc99c784 100644 --- a/tests/python/test_offload_cross.py +++ b/tests/python/test_offload_cross.py @@ -74,7 +74,7 @@ def ker(): assert ret[None] == 10 -@ti.archs_excluding(ti.opengl) # OpenGL doesn't support dynamic range for now +@ti.all_archs def test_offload_with_flexible_bounds(): s = ti.var(ti.i32, shape=()) lower = ti.var(ti.i32, shape=()) @@ -90,3 +90,21 @@ def ker(): ker() assert s[None] == 29 * 10 // 2 + + +@ti.all_archs +def test_offload_with_cross_block_globals(): + ret = ti.var(ti.f32) + + ti.root.place(ret) + + @ti.kernel + def ker(): + ret[None] = 0 + for i in range(10): + ret[None] += i + ret[None] += 1 + + ker() + + assert ret[None] == 46