Skip to content

Commit

Permalink
[autodiff] Fix AdStackAllocaStmt not correctly backup (#5692)
Browse files Browse the repository at this point in the history
* [autodiff] Fix AdStackAllocaStmt not correctly backup

* remove redundant replace

* add comments

* add polar decompose test

* erase outdated AdStackAllocaStmt
  • Loading branch information
erizmr authored Aug 10, 2022
1 parent ca5bf7d commit eb0ab1c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
17 changes: 17 additions & 0 deletions taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,23 @@ class BackupSSA : public BasicStmtVisitor {
if (op->is<AdStackLoadTopStmt>()) {
// Just create another AdStackLoadTopStmt
stmt->set_operand(i, stmt->insert_before_me(op->clone()));
} else if (op->is<AdStackAllocaStmt>()) {
// Backup AdStackAllocaStmt because it should not be local stored and
// local loaded
auto stack_alloca = op->as<AdStackAllocaStmt>();
if (backup_alloca.find(op) == backup_alloca.end()) {
auto backup_stack_alloca = Stmt::make<AdStackAllocaStmt>(
stack_alloca->dt, stack_alloca->max_size);
auto backup_stack_alloca_ptr = backup_stack_alloca.get();
independent_block->insert(std::move(backup_stack_alloca), 0);
backup_alloca[op] = backup_stack_alloca_ptr;
// Replace usages of all blocks i.e., the entry point for the
// replace is the top level block
irpass::replace_all_usages_with(leaf_to_root.back(), op,
backup_stack_alloca_ptr);
// Erase the outdated AdStackAllocaStmt
op->parent->erase(op);
}
} else {
auto alloca = load(op);
TI_ASSERT(op->width() == 1);
Expand Down
20 changes: 20 additions & 0 deletions tests/python/test_ad_math_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import taichi as ti
from tests import test_utils


@test_utils.test(require=ti.extension.adstack)
def test_polar_decompose_2D():
# `polar_decompose3d` in current Taichi version (v1.1) does not support autodiff,
# becasue it mixed usage of for-loops and statements without looping.
dim = 2
F_1 = ti.Matrix.field(dim, dim, dtype=ti.f32, shape=(), needs_grad=True)
F = ti.Matrix.field(dim, dim, dtype=ti.f32, shape=(), needs_grad=True)
loss = ti.field(dtype=ti.f32, shape=(), needs_grad=True)

@ti.kernel
def polar_decompose_2D():
r, s = ti.polar_decompose(F[None])
F_1[None] += r

with ti.ad.Tape(loss=loss):
polar_decompose_2D()

0 comments on commit eb0ab1c

Please sign in to comment.