diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 04cea51486769..d5bb2bfee93a1 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -14,12 +14,14 @@ class IndependentBlocksJudger : public BasicStmtVisitor { void visit(LocalLoadStmt *stmt) override { for (auto &lane : stmt->src.data) { - touched_allocas_.insert(lane.var->as()); + TI_ASSERT(lane.var->is() || lane.var->is()); + touched_allocas_.insert(lane.var); } } void visit(LocalStoreStmt *stmt) override { - touched_allocas_.insert(stmt->dest->as()); + TI_ASSERT(stmt->dest->is() || stmt->dest->is()); + touched_allocas_.insert(stmt->dest); } void visit(AtomicOpStmt *stmt) override { @@ -75,7 +77,7 @@ class IndependentBlocksJudger : public BasicStmtVisitor { } private: - std::set touched_allocas_; + std::set touched_allocas_; bool qualified_atomics_ = true; bool inner_most_loop_ = true; bool is_inside_loop_ = false; @@ -578,6 +580,10 @@ class ADTransform : public IRVisitor { // do nothing. } + void visit(PtrOffsetStmt *stmt) override { + // do nothing. + } + void visit(PrintStmt *print_stmt) override { // do nothing } @@ -994,7 +1000,16 @@ class MakeAdjoint : public ADTransform { "Importing data from external array (such as numpy array) not " "supported in AutoDiff for now") } - GlobalPtrStmt *src = stmt->src->as(); + + GlobalPtrStmt *src = nullptr; + bool is_ptr_offset = false; + if (stmt->src->is()) { + is_ptr_offset = true; + src = stmt->src->as()->origin->as(); + } else { + src = stmt->src->as(); + } + TI_ASSERT(src->width() == 1); auto snodes = src->snodes; if (!snodes[0]->has_adjoint()) { @@ -1008,6 +1023,10 @@ class MakeAdjoint : public ADTransform { TI_ASSERT(snodes[0]->get_adjoint() != nullptr); snodes[0] = snodes[0]->get_adjoint(); auto adj_ptr = insert(snodes, src->indices); + if (is_ptr_offset) { + adj_ptr = insert(adj_ptr, + stmt->src->as()->offset); + } insert(AtomicOpType::add, adj_ptr, load(adjoint(stmt))); } @@ -1018,7 +1037,16 @@ class MakeAdjoint : public ADTransform { "Exporting data to external array (such as numpy array) not " "supported in AutoDiff for now") } - GlobalPtrStmt *dest = stmt->dest->as(); + + GlobalPtrStmt *dest = nullptr; + bool is_ptr_offset = false; + if (stmt->dest->is()) { + is_ptr_offset = true; + dest = stmt->dest->as()->origin->as(); + } else { + dest = stmt->dest->as(); + } + TI_ASSERT(dest->width() == 1); auto snodes = dest->snodes; if (!snodes[0]->has_adjoint()) { @@ -1028,24 +1056,40 @@ class MakeAdjoint : public ADTransform { TI_ASSERT(snodes[0]->get_adjoint() != nullptr); snodes[0] = snodes[0]->get_adjoint(); auto adjoint_ptr = insert(snodes, dest->indices); - auto load = insert(adjoint_ptr); - accumulate(stmt->val, load); + if (is_ptr_offset) { + adjoint_ptr = insert( + adjoint_ptr, stmt->dest->as()->offset); + } + accumulate(stmt->val, insert(adjoint_ptr)); stmt->parent->erase(stmt); } void visit(AtomicOpStmt *stmt) override { // erase and replace with global load adjoint - GlobalPtrStmt *dest = stmt->dest->as(); + GlobalPtrStmt *dest = nullptr; + bool is_ptr_offset = false; + if (stmt->dest->is()) { + is_ptr_offset = true; + dest = stmt->dest->as()->origin->as(); + } else { + dest = stmt->dest->as(); + } + TI_ASSERT(dest->width() == 1); auto snodes = dest->snodes; - if (snodes[0]->has_adjoint()) { - TI_ASSERT(snodes[0]->get_adjoint() != nullptr); - snodes[0] = snodes[0]->get_adjoint(); - auto adjoint_ptr = insert(snodes, dest->indices); - accumulate(stmt->val, insert(adjoint_ptr)); - } else { + if (!snodes[0]->has_adjoint()) { // no gradient (likely integer types) + return; + } + + TI_ASSERT(snodes[0]->get_adjoint() != nullptr); + snodes[0] = snodes[0]->get_adjoint(); + auto adjoint_ptr = insert(snodes, dest->indices); + if (is_ptr_offset) { + adjoint_ptr = insert( + adjoint_ptr, stmt->dest->as()->offset); } + accumulate(stmt->val, insert(adjoint_ptr)); stmt->parent->erase(stmt); } }; @@ -1288,7 +1332,14 @@ class MakeDual : public ADTransform { void visit(GlobalLoadStmt *stmt) override { // issue global store to dual - GlobalPtrStmt *src = stmt->src->as(); + GlobalPtrStmt *src = nullptr; + bool is_ptr_offset = false; + if (stmt->src->is()) { + is_ptr_offset = true; + src = stmt->src->as()->origin->as(); + } else { + src = stmt->src->as(); + } TI_ASSERT(src->width() == 1); auto snodes = src->snodes; if (!snodes[0]->has_dual()) { @@ -1302,11 +1353,22 @@ class MakeDual : public ADTransform { TI_ASSERT(snodes[0]->get_dual() != nullptr); snodes[0] = snodes[0]->get_dual(); auto dual_ptr = insert(snodes, src->indices); + if (is_ptr_offset) { + dual_ptr = insert(dual_ptr, + stmt->src->as()->offset); + } accumulate(stmt, insert(dual_ptr)); } void visit(GlobalStoreStmt *stmt) override { - GlobalPtrStmt *dest = stmt->dest->as(); + GlobalPtrStmt *dest = nullptr; + bool is_ptr_offset = false; + if (stmt->dest->is()) { + is_ptr_offset = true; + dest = stmt->dest->as()->origin->as(); + } else { + dest = stmt->dest->as(); + } TI_ASSERT(dest->width() == 1); auto snodes = dest->snodes; if (!snodes[0]->has_dual()) { @@ -1316,11 +1378,22 @@ class MakeDual : public ADTransform { TI_ASSERT(snodes[0]->get_dual() != nullptr); snodes[0] = snodes[0]->get_dual(); auto dual_ptr = insert(snodes, dest->indices); + if (is_ptr_offset) { + dual_ptr = insert(dual_ptr, + stmt->dest->as()->offset); + } insert(AtomicOpType::add, dual_ptr, load(dual(stmt->val))); } void visit(AtomicOpStmt *stmt) override { - GlobalPtrStmt *dest = stmt->dest->as(); + GlobalPtrStmt *dest = nullptr; + bool is_ptr_offset = false; + if (stmt->dest->is()) { + is_ptr_offset = true; + dest = stmt->dest->as()->origin->as(); + } else { + dest = stmt->dest->as(); + } TI_ASSERT(dest->width() == 1); auto snodes = dest->snodes; if (!snodes[0]->has_dual()) { @@ -1330,6 +1403,10 @@ class MakeDual : public ADTransform { TI_ASSERT(snodes[0]->get_dual() != nullptr); snodes[0] = snodes[0]->get_dual(); auto dual_ptr = insert(snodes, dest->indices); + if (is_ptr_offset) { + dual_ptr = insert(dual_ptr, + stmt->dest->as()->offset); + } insert(AtomicOpType::add, dual_ptr, load(dual(stmt->val))); } }; diff --git a/tests/python/test_ad_dynamic_index.py b/tests/python/test_ad_dynamic_index.py new file mode 100644 index 0000000000000..fc3e504a9978f --- /dev/null +++ b/tests/python/test_ad_dynamic_index.py @@ -0,0 +1,28 @@ +import taichi as ti +from tests import test_utils + + +@test_utils.test(require=ti.extension.dynamic_index, + dynamic_index=True, + debug=True) +def test_matrix_non_constant_index(): + m = ti.Matrix.field(2, 2, ti.f32, 5, needs_grad=True) + n = ti.Matrix.field(2, 2, ti.f32, 5, needs_grad=True) + loss = ti.field(ti.f32, (), needs_grad=True) + + n.fill(0) + + @ti.kernel + def func1(): + for i in range(5): + for j, k in ti.ndrange(2, 2): + m[i][j, k] = (j + 1) * (k + 1) * n[i][j, k] + loss[None] += m[i][j, k] + + loss.grad[None] = 1.0 + func1.grad() + + for i in range(5): + for j in range(2): + for k in range(2): + assert n.grad[i][j, k] == (j + 1) * (k + 1) diff --git a/tests/python/test_ad_math_func.py b/tests/python/test_ad_math_func.py index 34192d5d9beb9..63be38b560f19 100644 --- a/tests/python/test_ad_math_func.py +++ b/tests/python/test_ad_math_func.py @@ -2,7 +2,7 @@ from tests import test_utils -@test_utils.test(require=ti.extension.adstack, dynamic_index=False) +@test_utils.test(require=ti.extension.adstack, dynamic_index=True) def test_polar_decompose_2D(): # `polar_decompose3d` in current Taichi version (v1.1) does not support autodiff, # becasue it mixed usage of for-loops and statements without looping.