Skip to content

Commit

Permalink
[autodiff] Make loss seed only set once in the tape (taichi-dev#7910)
Browse files Browse the repository at this point in the history
Issue: #

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at fdbd83a</samp>

This pull request enhances the reverse mode automatic differentiation
(AD) module in `ad._ad.py`. It simplifies the loss gradient
initialization and validation, and removes unnecessary code.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at fdbd83a</samp>

* Initialize the loss gradient with 1.0 for reverse mode AD in the tape
context
(`[link](https://github.com/taichi-dev/taichi/pull/7910/files?diff=unified&w=0#diff-b986921c47e4b8c903d6bfc906398260dfeb17e16f05e5cd5b52e401eddc0bd0R215)`,
`[link](https://github.com/taichi-dev/taichi/pull/7910/files?diff=unified&w=0#diff-b986921c47e4b8c903d6bfc906398260dfeb17e16f05e5cd5b52e401eddc0bd0R228-R233)`)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent 630922a commit ce694be
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 21 deletions.
8 changes: 0 additions & 8 deletions python/taichi/_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,6 @@ def clear_gradients(_vars: template()):
ScalarField(Expr(s))[I] = ops.cast(0, dtype=s.get_dt())


@kernel
def clear_loss(l: template()):
# Using SNode writers would result in a forced sync, therefore we wrap these
# writes into a kernel.
l[None] = 0
l.grad[None] = 1


@kernel
def field_fill_python_scope(F: template(), val: template()):
field_fill_taichi_scope(F, val)
Expand Down
26 changes: 13 additions & 13 deletions python/taichi/ad/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,7 @@ def __enter__(self):
if self.validation:
clear_all_gradients(gradient_type=SNodeGradType.ADJOINT_CHECKBIT)

from taichi._kernels import clear_loss # pylint: disable=C0415

clear_loss(self.loss)
self.loss.fill(0.0)
elif isinstance(self.loss, Ndarray):
if self.loss._get_nelement() != 1:
raise RuntimeError("The loss of `Tape` must be an ndarray with only one element")
Expand Down Expand Up @@ -251,21 +249,23 @@ def grad(self):
assert self.entered, "Before evaluating gradients tape must be entered."
assert not self.gradient_evaluated, "Gradients of grad can be evaluated only once."

# Set grad for loss
if isinstance(self.loss, (Field, Ndarray)):
self.loss.grad.fill(1.0)
else:
import torch # pylint: disable=C0415

if self.loss.grad is None:
self.loss.grad = torch.ones_like(self.loss)
else:
with torch.no_grad():
self.loss.grad.fill_(1.0)

for func, args in reversed(self.calls):
# we need to check whether "func" has "grad" attribute
# since we insert write_int and write_float kernels to self.calls
# e.g. x[None] = 0.0, this func has no grad attribute
if hasattr(func, "grad"):
if isinstance(self.loss, (Field, Ndarray)):
self.loss.grad.fill(1.0)
else:
import torch # pylint: disable=C0415

if self.loss.grad is None:
self.loss.grad = torch.ones_like(self.loss)
else:
with torch.no_grad():
self.loss.grad.fill_(1.0)
func.grad(*args)

self.gradient_evaluated = True
Expand Down
31 changes: 31 additions & 0 deletions tests/python/test_ad_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,3 +529,34 @@ def func():
assert b.grad[None] == 0.0
assert c.grad[None] == 0.0
assert d.grad[None] == 0.0


@test_utils.test()
def test_ad_set_loss_grad():
x = ti.field(dtype=ti.f32, shape=(), needs_grad=True)
loss = ti.field(dtype=ti.f32, shape=(), needs_grad=True)

@ti.kernel
def eval_x(x: ti.template()):
x[None] = 1.0

@ti.kernel
def compute_1(x: ti.template(), loss: ti.template()):
loss[None] = x[None]

@ti.kernel
def compute_2(x: ti.template(), loss: ti.template()):
loss[None] = 2 * x[None]

@ti.kernel
def compute_3(x: ti.template(), loss: ti.template()):
loss[None] = 4 * x[None]

eval_x(x)
with ti.ad.Tape(loss=loss):
compute_1(x, loss)
compute_2(x, loss)
compute_3(x, loss)

assert loss[None] == 4
assert x.grad[None] == 4
31 changes: 31 additions & 0 deletions tests/python/test_ad_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,37 @@ def compute_sum(a: ti.types.ndarray(), p: ti.types.ndarray()):
assert a.grad[i][1] == 3


@test_utils.test(arch=archs_support_ndarray_ad)
def test_ad_set_loss_grad():
x = ti.ndarray(dtype=ti.f32, shape=(), needs_grad=True)
loss = ti.ndarray(dtype=ti.f32, shape=(), needs_grad=True)

@ti.kernel
def eval_x(x: ti.types.ndarray()):
x[None] = 1.0

@ti.kernel
def compute_1(x: ti.types.ndarray(), loss: ti.types.ndarray()):
loss[None] = x[None]

@ti.kernel
def compute_2(x: ti.types.ndarray(), loss: ti.types.ndarray()):
loss[None] = 2 * x[None]

@ti.kernel
def compute_3(x: ti.types.ndarray(), loss: ti.types.ndarray()):
loss[None] = 4 * x[None]

eval_x(x)
with ti.ad.Tape(loss=loss):
compute_1(x, loss)
compute_2(x, loss)
compute_3(x, loss)

assert loss[None] == 4
assert x.grad[None] == 4


@pytest.mark.skipif(not has_pytorch(), reason="Pytorch not installed.")
@test_utils.test(arch=archs_support_ndarray_ad)
def test_ad_mixed_with_torch():
Expand Down

0 comments on commit ce694be

Please sign in to comment.