From 6e23e61be35a5b443821f1dd37a9aec161856fe9 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Wed, 25 Dec 2024 22:49:10 +0200 Subject: [PATCH 1/3] =?UTF-8?q?Unthunk=20each=20element=20in=20=E2=88=87ea?= =?UTF-8?q?chslice?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rulesets/Base/indexing.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 61216bda2..c6623822e 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -262,20 +262,22 @@ end # Using Val(dim) here is worth a factor of 2 in this, on Julia 1.8- # @btime rrule(eachcol, $([1 2; 3 4]))[2]($([[10, 20], [30, 40]])) function ∇eachslice(dys_raw, x::AbstractArray, vd::Val{dim}) where {dim} - dys = unthunk(dys_raw) + dys = unthunk.(unthunk(dys_raw)) i1 = findfirst(dy -> dy isa AbstractArray, dys) if i1 === nothing # all slices are Zero! return _zero_fill!(similar(x, float(eltype(x)), axes(x))) end + T = Base.promote_eltype(dys...) # The whole point of this gradient is that we can allocate one `dx` array: dx = similar(x, T, axes(x)) for i in axes(x, dim) slice = selectdim(dx, dim, i) - if dys[i] isa AbstractZero + dy = dys[i] + if dy isa AbstractZero _zero_fill!(slice) # Avoids this: copyto!([1,2,3], ZeroTangent()) == [0,2,3] else - copyto!(slice, dys[i]) + copyto!(slice, dy) end end return ProjectTo(x)(dx) From 672698d6064249ccb03100c50ad6af9bdd3529ed Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Mon, 30 Dec 2024 00:35:18 +0200 Subject: [PATCH 2/3] Add test --- test/rulesets/Base/indexing.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index f80a37048..d947567e5 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -261,4 +261,17 @@ end Val(3); check_inferred=(VERSION >= v"1.7"), ) + + # eachslice: Make sure pulling back an array of thunks unthunks them and does not return all zeros. + x = ones(Float32, 3) + Δ = ones(Float32, 1) + _, norm_back = ChainRules.rrule(norm, x) + dx = norm_back(Δ)[2] + @test dx isa AbstractThunk + + x = ones(Float32, 3, 1) + _, eachcol_back = ChainRules.rrule(eachcol, x) + Δ2 = [dx] + dx2 = eachcol_back(Δ2)[2] + @test all(dx2 .≉ 0f0) end From f9754c1c990e79474af51b0056fc6652380846a0 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Mon, 30 Dec 2024 00:45:32 +0200 Subject: [PATCH 3/3] Bump to 1.72.2 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 39bda8294..f1ac45605 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.72.1" +version = "1.72.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"