Skip to content

Commit

Permalink
[Bug] [ir] Fix compilation crash when there's a cross-offload global …
Browse files Browse the repository at this point in the history
…atomic operation (#1392)

* [Bug] [ir] Fix compilation crash when there's a cross-offload global atomic operation

* add a test

* OFT

Co-authored-by: archibate <[email protected]>
  • Loading branch information
xumingkuan and archibate authored Jul 4, 2020
1 parent 1c1dffc commit fe17ca1
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
24 changes: 12 additions & 12 deletions taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,10 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
}

void visit(AtomicOpStmt *stmt) override {
if (!stmt->dest->is<AllocaStmt>()) {
generic_visit(stmt);
return;
}
if (visit_operand(stmt, stmt->locate_operand(&stmt->val)))
throw IRModified();
TI_ASSERT(stmt->width() == 1);
Expand Down Expand Up @@ -442,8 +446,8 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
}
}
if (op->is<GlobalPtrStmt>()) {
TI_ASSERT(!op->has_global_side_effect());
auto copy = op->clone();
copy->as<GlobalPtrStmt>()->activate = false;
stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt];
stmt->set_operand(index, copy.get());
stmt->insert_before_me(std::move(copy));
Expand All @@ -462,9 +466,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
return true;
}

// Generic visitor
void visit(Stmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
void generic_visit(Stmt *stmt) {
int n_op = stmt->num_operands();
bool modified = false;
for (int i = 0; i < n_op; i++) {
Expand All @@ -475,15 +477,13 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
throw IRModified();
}

void visit(Stmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
generic_visit(stmt);
}

void preprocess_container_stmt(Stmt *stmt) override {
int n_op = stmt->num_operands();
bool modified = false;
for (int i = 0; i < n_op; i++) {
if (visit_operand(stmt, i))
modified = true;
}
if (modified)
throw IRModified();
generic_visit(stmt);
}

public:
Expand Down
20 changes: 19 additions & 1 deletion tests/python/test_offload_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def ker():
assert ret[None] == 10


@ti.archs_excluding(ti.opengl) # OpenGL doesn't support dynamic range for now
@ti.all_archs
def test_offload_with_flexible_bounds():
s = ti.var(ti.i32, shape=())
lower = ti.var(ti.i32, shape=())
Expand All @@ -90,3 +90,21 @@ def ker():
ker()

assert s[None] == 29 * 10 // 2


@ti.all_archs
def test_offload_with_cross_block_globals():
ret = ti.var(ti.f32)

ti.root.place(ret)

@ti.kernel
def ker():
ret[None] = 0
for i in range(10):
ret[None] += i
ret[None] += 1

ker()

assert ret[None] == 46

0 comments on commit fe17ca1

Please sign in to comment.