From f1e056c110fd149e9e75ac460e7230606054ab3f Mon Sep 17 00:00:00 2001 From: Mingrui Zhang <33411325+erizmr@users.noreply.github.com> Date: Thu, 17 Nov 2022 11:14:58 +0800 Subject: [PATCH] [autodiff] Clear adjoint after global store (#6579) To avoid invalid grad accumulation, adjoints of fields on left-hand sides of assignments (GlobalStoreStmt) need to be reset to zero after the corresponding adjoint assignments. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- python/taichi/ad/_ad.py | 1 + taichi/transforms/auto_diff.cpp | 6 ++++++ tests/python/test_ad_basics.py | 8 ++++---- tests/python/test_loop_grad.py | 7 ++++--- tests/python/test_offline_cache.py | 4 ++-- 5 files changed, 17 insertions(+), 9 deletions(-) diff --git a/python/taichi/ad/_ad.py b/python/taichi/ad/_ad.py index 3e99da1681d3c..13ef5eb5c2c2e 100644 --- a/python/taichi/ad/_ad.py +++ b/python/taichi/ad/_ad.py @@ -235,6 +235,7 @@ def grad(self): # since we insert write_int and write_float kernels to self.calls # e.g. x[None] = 0.0, this func has no grad attribute if hasattr(func, 'grad'): + self.loss.grad.fill(1.0) func.grad(*args) self.gradient_evaluated = True diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index ba03036ab3b74..d1205aa2b9db2 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -1061,6 +1061,12 @@ class MakeAdjoint : public ADTransform { adjoint_ptr, stmt->dest->as()->offset); } accumulate(stmt->val, insert(adjoint_ptr)); + + // Clear the gradient after accumulation finished. + auto zero = insert( + TypedConstant(adjoint_ptr->ret_type.ptr_removed(), 0)); + insert(adjoint_ptr, zero); + stmt->parent->erase(stmt); } diff --git a/tests/python/test_ad_basics.py b/tests/python/test_ad_basics.py index 1400f3074d9a6..fa0a72bac5a14 100644 --- a/tests/python/test_ad_basics.py +++ b/tests/python/test_ad_basics.py @@ -504,7 +504,7 @@ def func(): with ti.ad.Tape(loss=e): func() assert x.grad[None] == 120.0 - assert a.grad[None] == 120.0 - assert b.grad[None] == 60.0 - assert c.grad[None] == 20.0 - assert d.grad[None] == 5.0 + assert a.grad[None] == 0.0 + assert b.grad[None] == 0.0 + assert c.grad[None] == 0.0 + assert d.grad[None] == 0.0 diff --git a/tests/python/test_loop_grad.py b/tests/python/test_loop_grad.py index 11482fb39384c..d8074e2fdd47f 100644 --- a/tests/python/test_loop_grad.py +++ b/tests/python/test_loop_grad.py @@ -27,9 +27,10 @@ def func(): func.grad() for k in range(n): - for i in range(m): - assert x[k, i] == 2**i * k - assert x.grad[k, i] == 2**(m - 1 - i) + # The grad of fields on left-hand sides of assignments (GlobalStoreStmt) need to be reset to zero after the corresponding adjoint assignments. + # Therefore, only the grad of the element with index 0 at second dimension is preserved here. + assert x[k, 0] == 2**0 * k + assert x.grad[k, 0] == 2**(m - 1 - 0) @test_utils.test(exclude=[ti.vulkan, ti.dx11]) diff --git a/tests/python/test_offline_cache.py b/tests/python/test_offline_cache.py index a7ae97d337d6d..5ce79f26bc5b2 100644 --- a/tests/python/test_offline_cache.py +++ b/tests/python/test_offline_cache.py @@ -332,12 +332,12 @@ def compute_y(): enable_fallback=False, **current_thread_ext_options()) assert added_files(curr_arch) == expected_num_cache_files( - curr_arch, [1] * 8) + curr_arch, [1] * 9) helper() ti.reset() assert added_files(curr_arch) == expected_num_cache_files( - curr_arch, [1] * 8) + curr_arch, [1] * 9) @pytest.mark.parametrize('curr_arch', supported_archs_offline_cache)