diff --git a/python/taichi/ad/_ad.py b/python/taichi/ad/_ad.py index 6ffe3d654dbba..a2647ee06101a 100644 --- a/python/taichi/ad/_ad.py +++ b/python/taichi/ad/_ad.py @@ -212,6 +212,7 @@ def __enter__(self): "Gradients of loss are not allocated, please set needs_grad=True for all ndarrays that are required by autodiff." ) self.loss.fill(0.0) + self.loss.grad.fill(1.0) else: import torch # pylint: disable=C0415 @@ -224,6 +225,12 @@ def __enter__(self): with torch.no_grad(): self.loss.fill_(0.0) + if self.loss.grad is None: + self.loss.grad = torch.ones_like(self.loss) + else: + with torch.no_grad(): + self.loss.grad.fill_(1.0) + # Attach the context manager to runtime self.runtime.target_tape = self return self @@ -256,16 +263,6 @@ def grad(self): # since we insert write_int and write_float kernels to self.calls # e.g. x[None] = 0.0, this func has no grad attribute if hasattr(func, "grad"): - if isinstance(self.loss, (Field, Ndarray)): - self.loss.grad.fill(1.0) - else: - import torch # pylint: disable=C0415 - - if self.loss.grad is None: - self.loss.grad = torch.ones_like(self.loss) - else: - with torch.no_grad(): - self.loss.grad.fill_(1.0) func.grad(*args) self.gradient_evaluated = True diff --git a/tests/python/test_ad_basics.py b/tests/python/test_ad_basics.py index f59ec4f22b0af..29a8a53bd4ec1 100644 --- a/tests/python/test_ad_basics.py +++ b/tests/python/test_ad_basics.py @@ -529,3 +529,34 @@ def func(): assert b.grad[None] == 0.0 assert c.grad[None] == 0.0 assert d.grad[None] == 0.0 + + +@test_utils.test() +def test_ad_set_loss_grad(): + x = ti.field(dtype=ti.f32, shape=(), needs_grad=True) + loss = ti.field(dtype=ti.f32, shape=(), needs_grad=True) + + @ti.kernel + def eval_x(x: ti.template()): + x[None] = 1.0 + + @ti.kernel + def compute_1(x: ti.template(), loss: ti.template()): + loss[None] = x[None] + + @ti.kernel + def compute_2(x: ti.template(), loss: ti.template()): + loss[None] = 2 * x[None] + + @ti.kernel + def compute_3(x: ti.template(), loss: ti.template()): + loss[None] = 4 * x[None] + + eval_x(x) + with ti.ad.Tape(loss=loss): + compute_1(x, loss) + compute_2(x, loss) + compute_3(x, loss) + + assert loss[None] == 4 + assert x.grad[None] == 4 \ No newline at end of file diff --git a/tests/python/test_ad_ndarray.py b/tests/python/test_ad_ndarray.py index 544454102aec6..ca6e3bb7cfbb2 100644 --- a/tests/python/test_ad_ndarray.py +++ b/tests/python/test_ad_ndarray.py @@ -1229,6 +1229,37 @@ def compute_sum(a: ti.types.ndarray(), p: ti.types.ndarray()): assert a.grad[i][1] == 3 +@test_utils.test(arch=archs_support_ndarray_ad) +def test_ad_set_loss_grad(): + x = ti.ndarray(dtype=ti.f32, shape=(), needs_grad=True) + loss = ti.ndarray(dtype=ti.f32, shape=(), needs_grad=True) + + @ti.kernel + def eval_x(x: ti.types.ndarray()): + x[None] = 1.0 + + @ti.kernel + def compute_1(x: ti.types.ndarray(), loss: ti.types.ndarray()): + loss[None] = x[None] + + @ti.kernel + def compute_2(x: ti.types.ndarray(), loss: ti.types.ndarray()): + loss[None] = 2 * x[None] + + @ti.kernel + def compute_3(x: ti.types.ndarray(), loss: ti.types.ndarray()): + loss[None] = 4 * x[None] + + eval_x(x) + with ti.ad.Tape(loss=loss): + compute_1(x, loss) + compute_2(x, loss) + compute_3(x, loss) + + assert loss[None] == 4 + assert x.grad[None] == 4 + + @pytest.mark.skipif(not has_pytorch(), reason="Pytorch not installed.") @test_utils.test(arch=archs_support_ndarray_ad) def test_ad_mixed_with_torch():