Skip to content

Commit

Permalink
[autodiff] Support control flow for forward mode
Browse files Browse the repository at this point in the history
ghstack-source-id: 8ec583584dfc9a33c480361776076dd91c6fe491
Pull Request resolved: #5231
  • Loading branch information
erizmr authored and Ailing Zhang committed Jun 28, 2022
1 parent 39e67cc commit b8b099e
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
27 changes: 27 additions & 0 deletions taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1137,12 +1137,39 @@ class MakeDual : public ADTransform {
// d (x * y) = y * dx + x * dy
accumulate(bin, mul(bin->lhs, dual(bin->rhs)));
accumulate(bin, mul(bin->rhs, dual(bin->lhs)));
} else if (is_comparison(bin->op_type) || is_bit_op(bin->op_type)) {
// do nothing
} else {
TI_WARN("gradient of binary op {}", binary_op_type_name(bin->op_type));
TI_NOT_IMPLEMENTED
}
}

void visit(IfStmt *if_stmt) override {
if (if_stmt->true_statements) {
std::vector<Stmt *> true_statements;
for (auto &stmt : if_stmt->true_statements->statements) {
true_statements.push_back(stmt.get());
}

for (auto stmt : true_statements) {
current_stmt = stmt;
stmt->accept(this);
}
}
if (if_stmt->false_statements) {
std::vector<Stmt *> false_statements;
for (auto &stmt : if_stmt->false_statements->statements) {
false_statements.push_back(stmt.get());
}

for (auto stmt : false_statements) {
current_stmt = stmt;
stmt->accept(this);
}
}
}

void visit(RangeForStmt *for_stmt) override {
std::vector<Stmt *> statements;
// always make a copy since the list can be modified.
Expand Down
83 changes: 83 additions & 0 deletions tests/python/test_ad_if_fwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from taichi.lang import impl
from taichi.lang.misc import get_host_arch_list

import taichi as ti
from tests import test_utils


@test_utils.test()
def test_ad_if_simple_fwd():
x = ti.field(ti.f32, shape=())
y = ti.field(ti.f32, shape=())
ti.root.lazy_dual()

@ti.kernel
def func():
if x[None] > 0.:
y[None] = x[None]

x[None] = 1
with ti.ad.FwdMode(loss=y, parameters=x, seed=[1.0]):
func()

assert y.dual[None] == 1


@test_utils.test()
def test_ad_if():
x = ti.field(ti.f32, shape=2)
y = ti.field(ti.f32, shape=2)
ti.root.lazy_dual()

@ti.kernel
def func(i: ti.i32):
if x[i] > 0:
y[i] = x[i]
else:
y[i] = 2 * x[i]

x[0] = 0
x[1] = 1
with ti.ad.FwdMode(loss=y, parameters=x, seed=[1.0, 1.0]):
func(0)
func(1)
assert y.dual[0] == 2
assert y.dual[1] == 1


@test_utils.test()
def test_ad_if_nested():
n = 20
x = ti.field(ti.f32, shape=n)
y = ti.field(ti.f32, shape=n)
z = ti.field(ti.f32, shape=n)
ti.root.lazy_dual()

@ti.kernel
def func():
for i in x:
if x[i] < 2:
if x[i] == 0:
y[i] = 0
else:
y[i] = z[i] * 1
else:
if x[i] == 2:
y[i] = z[i] * 2
else:
y[i] = z[i] * 3

z.fill(1)

for i in range(n):
x[i] = i % 4

func()
for i in range(n):
assert y[i] == i % 4

with ti.ad.FwdMode(loss=y, parameters=z, seed=[1.0 for _ in range(n)]):
func()

for i in range(n):
assert y.dual[i] == i % 4

0 comments on commit b8b099e

Please sign in to comment.