Skip to content

Commit

Permalink
[autodiff] Support shift ptr in dynamic index (#5770)
Browse files Browse the repository at this point in the history
* [autodiff] Support shift ptr in dynamic index

* update the offset

* update offset

* add dynamic index test for ad

* forcely enable dynamic index for polar decompose
  • Loading branch information
erizmr authored Aug 18, 2022
1 parent 19c0272 commit ce86a1d
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 18 deletions.
111 changes: 94 additions & 17 deletions taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ class IndependentBlocksJudger : public BasicStmtVisitor {

void visit(LocalLoadStmt *stmt) override {
for (auto &lane : stmt->src.data) {
touched_allocas_.insert(lane.var->as<AllocaStmt>());
TI_ASSERT(lane.var->is<AllocaStmt>() || lane.var->is<PtrOffsetStmt>());
touched_allocas_.insert(lane.var);
}
}

void visit(LocalStoreStmt *stmt) override {
touched_allocas_.insert(stmt->dest->as<AllocaStmt>());
TI_ASSERT(stmt->dest->is<AllocaStmt>() || stmt->dest->is<PtrOffsetStmt>());
touched_allocas_.insert(stmt->dest);
}

void visit(AtomicOpStmt *stmt) override {
Expand Down Expand Up @@ -75,7 +77,7 @@ class IndependentBlocksJudger : public BasicStmtVisitor {
}

private:
std::set<AllocaStmt *> touched_allocas_;
std::set<Stmt *> touched_allocas_;
bool qualified_atomics_ = true;
bool inner_most_loop_ = true;
bool is_inside_loop_ = false;
Expand Down Expand Up @@ -578,6 +580,10 @@ class ADTransform : public IRVisitor {
// do nothing.
}

void visit(PtrOffsetStmt *stmt) override {
// do nothing.
}

void visit(PrintStmt *print_stmt) override {
// do nothing
}
Expand Down Expand Up @@ -994,7 +1000,16 @@ class MakeAdjoint : public ADTransform {
"Importing data from external array (such as numpy array) not "
"supported in AutoDiff for now")
}
GlobalPtrStmt *src = stmt->src->as<GlobalPtrStmt>();

GlobalPtrStmt *src = nullptr;
bool is_ptr_offset = false;
if (stmt->src->is<PtrOffsetStmt>()) {
is_ptr_offset = true;
src = stmt->src->as<PtrOffsetStmt>()->origin->as<GlobalPtrStmt>();
} else {
src = stmt->src->as<GlobalPtrStmt>();
}

TI_ASSERT(src->width() == 1);
auto snodes = src->snodes;
if (!snodes[0]->has_adjoint()) {
Expand All @@ -1008,6 +1023,10 @@ class MakeAdjoint : public ADTransform {
TI_ASSERT(snodes[0]->get_adjoint() != nullptr);
snodes[0] = snodes[0]->get_adjoint();
auto adj_ptr = insert<GlobalPtrStmt>(snodes, src->indices);
if (is_ptr_offset) {
adj_ptr = insert<PtrOffsetStmt>(adj_ptr,
stmt->src->as<PtrOffsetStmt>()->offset);
}
insert<AtomicOpStmt>(AtomicOpType::add, adj_ptr, load(adjoint(stmt)));
}

Expand All @@ -1018,7 +1037,16 @@ class MakeAdjoint : public ADTransform {
"Exporting data to external array (such as numpy array) not "
"supported in AutoDiff for now")
}
GlobalPtrStmt *dest = stmt->dest->as<GlobalPtrStmt>();

GlobalPtrStmt *dest = nullptr;
bool is_ptr_offset = false;
if (stmt->dest->is<PtrOffsetStmt>()) {
is_ptr_offset = true;
dest = stmt->dest->as<PtrOffsetStmt>()->origin->as<GlobalPtrStmt>();
} else {
dest = stmt->dest->as<GlobalPtrStmt>();
}

TI_ASSERT(dest->width() == 1);
auto snodes = dest->snodes;
if (!snodes[0]->has_adjoint()) {
Expand All @@ -1028,24 +1056,40 @@ class MakeAdjoint : public ADTransform {
TI_ASSERT(snodes[0]->get_adjoint() != nullptr);
snodes[0] = snodes[0]->get_adjoint();
auto adjoint_ptr = insert<GlobalPtrStmt>(snodes, dest->indices);
auto load = insert<GlobalLoadStmt>(adjoint_ptr);
accumulate(stmt->val, load);
if (is_ptr_offset) {
adjoint_ptr = insert<PtrOffsetStmt>(
adjoint_ptr, stmt->dest->as<PtrOffsetStmt>()->offset);
}
accumulate(stmt->val, insert<GlobalLoadStmt>(adjoint_ptr));
stmt->parent->erase(stmt);
}

void visit(AtomicOpStmt *stmt) override {
// erase and replace with global load adjoint
GlobalPtrStmt *dest = stmt->dest->as<GlobalPtrStmt>();
GlobalPtrStmt *dest = nullptr;
bool is_ptr_offset = false;
if (stmt->dest->is<PtrOffsetStmt>()) {
is_ptr_offset = true;
dest = stmt->dest->as<PtrOffsetStmt>()->origin->as<GlobalPtrStmt>();
} else {
dest = stmt->dest->as<GlobalPtrStmt>();
}

TI_ASSERT(dest->width() == 1);
auto snodes = dest->snodes;
if (snodes[0]->has_adjoint()) {
TI_ASSERT(snodes[0]->get_adjoint() != nullptr);
snodes[0] = snodes[0]->get_adjoint();
auto adjoint_ptr = insert<GlobalPtrStmt>(snodes, dest->indices);
accumulate(stmt->val, insert<GlobalLoadStmt>(adjoint_ptr));
} else {
if (!snodes[0]->has_adjoint()) {
// no gradient (likely integer types)
return;
}

TI_ASSERT(snodes[0]->get_adjoint() != nullptr);
snodes[0] = snodes[0]->get_adjoint();
auto adjoint_ptr = insert<GlobalPtrStmt>(snodes, dest->indices);
if (is_ptr_offset) {
adjoint_ptr = insert<PtrOffsetStmt>(
adjoint_ptr, stmt->dest->as<PtrOffsetStmt>()->offset);
}
accumulate(stmt->val, insert<GlobalLoadStmt>(adjoint_ptr));
stmt->parent->erase(stmt);
}
};
Expand Down Expand Up @@ -1288,7 +1332,14 @@ class MakeDual : public ADTransform {

void visit(GlobalLoadStmt *stmt) override {
// issue global store to dual
GlobalPtrStmt *src = stmt->src->as<GlobalPtrStmt>();
GlobalPtrStmt *src = nullptr;
bool is_ptr_offset = false;
if (stmt->src->is<PtrOffsetStmt>()) {
is_ptr_offset = true;
src = stmt->src->as<PtrOffsetStmt>()->origin->as<GlobalPtrStmt>();
} else {
src = stmt->src->as<GlobalPtrStmt>();
}
TI_ASSERT(src->width() == 1);
auto snodes = src->snodes;
if (!snodes[0]->has_dual()) {
Expand All @@ -1302,11 +1353,22 @@ class MakeDual : public ADTransform {
TI_ASSERT(snodes[0]->get_dual() != nullptr);
snodes[0] = snodes[0]->get_dual();
auto dual_ptr = insert<GlobalPtrStmt>(snodes, src->indices);
if (is_ptr_offset) {
dual_ptr = insert<PtrOffsetStmt>(dual_ptr,
stmt->src->as<PtrOffsetStmt>()->offset);
}
accumulate(stmt, insert<GlobalLoadStmt>(dual_ptr));
}

void visit(GlobalStoreStmt *stmt) override {
GlobalPtrStmt *dest = stmt->dest->as<GlobalPtrStmt>();
GlobalPtrStmt *dest = nullptr;
bool is_ptr_offset = false;
if (stmt->dest->is<PtrOffsetStmt>()) {
is_ptr_offset = true;
dest = stmt->dest->as<PtrOffsetStmt>()->origin->as<GlobalPtrStmt>();
} else {
dest = stmt->dest->as<GlobalPtrStmt>();
}
TI_ASSERT(dest->width() == 1);
auto snodes = dest->snodes;
if (!snodes[0]->has_dual()) {
Expand All @@ -1316,11 +1378,22 @@ class MakeDual : public ADTransform {
TI_ASSERT(snodes[0]->get_dual() != nullptr);
snodes[0] = snodes[0]->get_dual();
auto dual_ptr = insert<GlobalPtrStmt>(snodes, dest->indices);
if (is_ptr_offset) {
dual_ptr = insert<PtrOffsetStmt>(dual_ptr,
stmt->dest->as<PtrOffsetStmt>()->offset);
}
insert<AtomicOpStmt>(AtomicOpType::add, dual_ptr, load(dual(stmt->val)));
}

void visit(AtomicOpStmt *stmt) override {
GlobalPtrStmt *dest = stmt->dest->as<GlobalPtrStmt>();
GlobalPtrStmt *dest = nullptr;
bool is_ptr_offset = false;
if (stmt->dest->is<PtrOffsetStmt>()) {
is_ptr_offset = true;
dest = stmt->dest->as<PtrOffsetStmt>()->origin->as<GlobalPtrStmt>();
} else {
dest = stmt->dest->as<GlobalPtrStmt>();
}
TI_ASSERT(dest->width() == 1);
auto snodes = dest->snodes;
if (!snodes[0]->has_dual()) {
Expand All @@ -1330,6 +1403,10 @@ class MakeDual : public ADTransform {
TI_ASSERT(snodes[0]->get_dual() != nullptr);
snodes[0] = snodes[0]->get_dual();
auto dual_ptr = insert<GlobalPtrStmt>(snodes, dest->indices);
if (is_ptr_offset) {
dual_ptr = insert<PtrOffsetStmt>(dual_ptr,
stmt->dest->as<PtrOffsetStmt>()->offset);
}
insert<AtomicOpStmt>(AtomicOpType::add, dual_ptr, load(dual(stmt->val)));
}
};
Expand Down
28 changes: 28 additions & 0 deletions tests/python/test_ad_dynamic_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import taichi as ti
from tests import test_utils


@test_utils.test(require=ti.extension.dynamic_index,
dynamic_index=True,
debug=True)
def test_matrix_non_constant_index():
m = ti.Matrix.field(2, 2, ti.f32, 5, needs_grad=True)
n = ti.Matrix.field(2, 2, ti.f32, 5, needs_grad=True)
loss = ti.field(ti.f32, (), needs_grad=True)

n.fill(0)

@ti.kernel
def func1():
for i in range(5):
for j, k in ti.ndrange(2, 2):
m[i][j, k] = (j + 1) * (k + 1) * n[i][j, k]
loss[None] += m[i][j, k]

loss.grad[None] = 1.0
func1.grad()

for i in range(5):
for j in range(2):
for k in range(2):
assert n.grad[i][j, k] == (j + 1) * (k + 1)
2 changes: 1 addition & 1 deletion tests/python/test_ad_math_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from tests import test_utils


@test_utils.test(require=ti.extension.adstack, dynamic_index=False)
@test_utils.test(require=ti.extension.adstack, dynamic_index=True)
def test_polar_decompose_2D():
# `polar_decompose3d` in current Taichi version (v1.1) does not support autodiff,
# becasue it mixed usage of for-loops and statements without looping.
Expand Down

0 comments on commit ce86a1d

Please sign in to comment.