Skip to content

Commit

Permalink
[Bug] Avoid overwriting global tmp with dynamic_index=True (taichi-de…
Browse files Browse the repository at this point in the history
…v#6820)

Issue: fix taichi-dev#6663

### Brief Summary

In `MatrixPtrStmt`, when `origin` is `GlobalTemporaryStmt`, the
semantics of `offset` has changed from the number of bytes to the number
of elements. This PR fixes the outdated usage which may overwrite the
global tmp buffer.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent abe623b commit a3e0094
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
5 changes: 2 additions & 3 deletions taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,8 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
auto const_zero_stmt = replacement.push_back<ConstStmt>(zero);
stmt_to_offloaded_[const_zero_stmt] = offloaded;
for (int i = 0; i < tensor_type->get_num_elements(); ++i) {
TypedConstant offset(i *
data_type_size(tensor_type->get_element_type()));
auto const_offset_stmt = replacement.push_back<ConstStmt>(offset);
auto const_offset_stmt =
replacement.push_back<ConstStmt>(TypedConstant(i));
auto ptr_offset_stmt =
replacement.push_back<MatrixPtrStmt>(ptr, const_offset_stmt);
auto global_store_stmt = replacement.push_back<GlobalStoreStmt>(
Expand Down
22 changes: 22 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,3 +1209,25 @@ def foo():
r"`transpose\(\)` cannot apply to a vector. If you want something like `a @ b.transpose\(\)`, write `a.outer_product\(b\)` instead."
):
foo()


@test_utils.test(require=ti.extension.dynamic_index,
dynamic_index=True,
debug=True)
def test_global_tmp_overwrite():
# https://github.com/taichi-dev/taichi/issues/6663
@ti.kernel
def foo() -> ti.i32:
p = ti.Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
loop = 1
sig = ti.Vector([0, 0, 0, 0])
assert p[0, 0] == 1
while loop == 1:
assert p[0, 0] == 1
loop = 0
p[0, 0] = -1
for i in range(1):
sig[i] = 2
return sig.sum() + p.sum()

assert foo() == 4

0 comments on commit a3e0094

Please sign in to comment.