Skip to content

Commit

Permalink
[autodiff] Enforce loss seed only set once in the tape
Browse files Browse the repository at this point in the history
  • Loading branch information
erizmr committed Apr 27, 2023
1 parent 748abdb commit fdbd83a
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 10 deletions.
17 changes: 7 additions & 10 deletions python/taichi/ad/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def __enter__(self):
"Gradients of loss are not allocated, please set needs_grad=True for all ndarrays that are required by autodiff."
)
self.loss.fill(0.0)
self.loss.grad.fill(1.0)
else:
import torch # pylint: disable=C0415

Expand All @@ -224,6 +225,12 @@ def __enter__(self):
with torch.no_grad():
self.loss.fill_(0.0)

if self.loss.grad is None:
self.loss.grad = torch.ones_like(self.loss)
else:
with torch.no_grad():
self.loss.grad.fill_(1.0)

# Attach the context manager to runtime
self.runtime.target_tape = self
return self
Expand Down Expand Up @@ -256,16 +263,6 @@ def grad(self):
# since we insert write_int and write_float kernels to self.calls
# e.g. x[None] = 0.0, this func has no grad attribute
if hasattr(func, "grad"):
if isinstance(self.loss, (Field, Ndarray)):
self.loss.grad.fill(1.0)
else:
import torch # pylint: disable=C0415

if self.loss.grad is None:
self.loss.grad = torch.ones_like(self.loss)
else:
with torch.no_grad():
self.loss.grad.fill_(1.0)
func.grad(*args)

self.gradient_evaluated = True
Expand Down
31 changes: 31 additions & 0 deletions tests/python/test_ad_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,3 +529,34 @@ def func():
assert b.grad[None] == 0.0
assert c.grad[None] == 0.0
assert d.grad[None] == 0.0


@test_utils.test()
def test_ad_set_loss_grad():
x = ti.field(dtype=ti.f32, shape=(), needs_grad=True)
loss = ti.field(dtype=ti.f32, shape=(), needs_grad=True)

@ti.kernel
def eval_x(x: ti.template()):
x[None] = 1.0

@ti.kernel
def compute_1(x: ti.template(), loss: ti.template()):
loss[None] = x[None]

@ti.kernel
def compute_2(x: ti.template(), loss: ti.template()):
loss[None] = 2 * x[None]

@ti.kernel
def compute_3(x: ti.template(), loss: ti.template()):
loss[None] = 4 * x[None]

eval_x(x)
with ti.ad.Tape(loss=loss):
compute_1(x, loss)
compute_2(x, loss)
compute_3(x, loss)

assert loss[None] == 4
assert x.grad[None] == 4
31 changes: 31 additions & 0 deletions tests/python/test_ad_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,37 @@ def compute_sum(a: ti.types.ndarray(), p: ti.types.ndarray()):
assert a.grad[i][1] == 3


@test_utils.test(arch=archs_support_ndarray_ad)
def test_ad_set_loss_grad():
x = ti.ndarray(dtype=ti.f32, shape=(), needs_grad=True)
loss = ti.ndarray(dtype=ti.f32, shape=(), needs_grad=True)

@ti.kernel
def eval_x(x: ti.types.ndarray()):
x[None] = 1.0

@ti.kernel
def compute_1(x: ti.types.ndarray(), loss: ti.types.ndarray()):
loss[None] = x[None]

@ti.kernel
def compute_2(x: ti.types.ndarray(), loss: ti.types.ndarray()):
loss[None] = 2 * x[None]

@ti.kernel
def compute_3(x: ti.types.ndarray(), loss: ti.types.ndarray()):
loss[None] = 4 * x[None]

eval_x(x)
with ti.ad.Tape(loss=loss):
compute_1(x, loss)
compute_2(x, loss)
compute_3(x, loss)

assert loss[None] == 4
assert x.grad[None] == 4


@pytest.mark.skipif(not has_pytorch(), reason="Pytorch not installed.")
@test_utils.test(arch=archs_support_ndarray_ad)
def test_ad_mixed_with_torch():
Expand Down

0 comments on commit fdbd83a

Please sign in to comment.