Skip to content

Commit

Permalink
Merge pull request #738 from JuliaDiff/ox/setfield
Browse files Browse the repository at this point in the history
Add rules needed for mutation
  • Loading branch information
oxinabox authored Feb 1, 2024
2 parents 32bf53d + 810c633 commit 354ecbd
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[compat]
Adapt = "3.4.0, 4"
ChainRulesCore = "1.15.3"
ChainRulesCore = "1.20"
ChainRulesTestUtils = "1.5"
Compat = "3.46, 4.2"
Distributed = "1"
Expand Down
8 changes: 7 additions & 1 deletion src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,13 @@ _instantiate_zeros(ẋs::AbstractArray{<:AbstractArray}, xs) = ẋs
#####

function frule((_, ẏ, ẋ), ::typeof(copyto!), y::AbstractArray, x)
return copyto!(y, x), copyto!(ẏ, ẋ)
ifisa AbstractZero
# it's allowed to have an imutable zero tangent for ẏ as long as ẋ is zero
@assert iszero(ẋ)
else
copyto!(ẏ, ẋ)
end
return copyto!(y, x), ẏ
end

function frule((_, ẏ, _, ẋ), ::typeof(copyto!), y::AbstractArray, i::Integer, x, js::Integer...)
Expand Down
8 changes: 8 additions & 0 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ function rrule(::typeof(one), x)
return (one(x), one_pullback)
end


function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setfield!), obj, field, x)
ȯbj::MutableTangent
y = setfield!(obj, field, x)
= setproperty!(ȯbj, field, ẋ)
return y, ẏ
end

# `adjoint`

frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz')
Expand Down
9 changes: 9 additions & 0 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
mutable struct MDemo
x::Float64
end

@testset "base.jl" begin
@testset "zero/one" begin
for f in [zero, one]
Expand All @@ -18,6 +22,11 @@
end
end
end

@testset "setfield!" begin
test_frule(setfield!, MDemo(3.5) MutableTangent{MDemo}(; x=2.0), :x, 5.0)
test_frule(setfield!, MDemo(3.5) MutableTangent{MDemo}(; x=2.0), 1, 5.0)
end

@testset "Trig" begin
@testset "Basics" for x = (Float64(π)-0.01, Complex(π, π/2))
Expand Down

0 comments on commit 354ecbd

Please sign in to comment.