Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong results for forward-mode exp! half of the time #715

Open
baggepinnen opened this issue May 17, 2023 · 5 comments
Open

Wrong results for forward-mode exp! half of the time #715

baggepinnen opened this issue May 17, 2023 · 5 comments
Assignees

Comments

@baggepinnen
Copy link

baggepinnen commented May 17, 2023

If I repeatedly run the example below, I get the wrong result for the gradient through exp! about half of the time.

using LinearAlgebra, ForwardDiff, FiniteDiff, ForwardDiffChainRules

@ForwardDiff_frule LinearAlgebra.exp!(x1::AbstractMatrix{<:ForwardDiff.Dual})

function test_exp(x)
    X = copy(reshape(x, 4, 4))
    X2 = LinearAlgebra.exp!(X)
    sum(X2)
end

for i = 1:20
    x = randn(16)
    X = reshape(x, 4, 4)
    g1 = ForwardDiff.gradient(test_exp, x)
    g2 = FiniteDiff.finite_difference_gradient(test_exp, x)
    @show norm(g1-g2)
end
norm(g1 - g2) = 3.2745988814567806
norm(g1 - g2) = 2.7005934051461515e-9
norm(g1 - g2) = 3.535502368190921
norm(g1 - g2) = 5.376574873194121e-10
norm(g1 - g2) = 2.4158271822718778e-9
norm(g1 - g2) = 3.0885755390647527e-10
norm(g1 - g2) = 4.215282668056846
norm(g1 - g2) = 1.7888448238515218
norm(g1 - g2) = 2.1068558951714456e-10
norm(g1 - g2) = 8.090031857043094
norm(g1 - g2) = 5.8514613833452644
norm(g1 - g2) = 5.859275463330073e-10
norm(g1 - g2) = 3.3486620002856527e-10
norm(g1 - g2) = 1.1628716126438234
norm(g1 - g2) = 2.72443511328846
norm(g1 - g2) = 1.5771975088961793e-10
norm(g1 - g2) = 1.2073055237629486e-9
norm(g1 - g2) = 8.255800634241801
norm(g1 - g2) = 2.2459662479919337e-10
norm(g1 - g2) = 4.433466845638335

Also reported in ThummeTo/ForwardDiffChainRules.jl#14

@baggepinnen
Copy link
Author

baggepinnen commented May 17, 2023

This only appears to be a problem for the non-symmetric version of exp!. When I create a symmetric matrix X = X'X and switch to FiniteDifferences.jl for more accurate testing, it works just fine. The non-symmetric input matrix is still problematic though

using LinearAlgebra, ForwardDiff, ForwardDiffChainRules, FiniteDifferences
@ForwardDiff_frule LinearAlgebra.exp!(x1::AbstractMatrix{<:ForwardDiff.Dual})
function test_exp(x)
    X = copy(reshape(x, 4, 4))
    X2 = LinearAlgebra.exp!(X)
    sum(X2)
end

for i = 1:20
    X = randn(4,4)
    X = X'X
    x = vec(X)
    g1 = ForwardDiff.gradient(test_exp, x)
    g2 = FiniteDifferences.grad(central_fdm(5, 1), test_exp, x)[1]
    @show norm(g1-g2)
end

I've also tested the reverse rule using Zygote and there is no problem in reverse

@oxinabox
Copy link
Member

@sethaxen do you think you might have time to look into this?

@sethaxen
Copy link
Member

Yes, I can look into this.

@baggepinnen I'm not familiar with ForwardDiffChainRules. Can you provide an MWE that exhibits the observed failure using just ChainRules?

Using our own testing machinery, I am unable to observe any failures on 1000x the number of random matrices:

julia> using ChainRules, ChainRulesTestUtils, LinearAlgebra, Random, Test

julia> Random.seed!(42);

julia> @testset "exp!" begin
           Xs = (randn(4, 4) for _ in 1:20_000)
           @testset for X in Xs
               test_frule(LinearAlgebra.exp!, X)
           end
       end;
Test Summary: |  Pass  Total   Time
exp!          | 80000  80000  13.2s

Btw, in a fresh environment, your example errors on my machine in the for loop with:

ERROR: MethodError: no method matching iterate(::Nothing)

Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen})
   @ Base range.jl:880
  iterate(::Union{LinRange, StepRangeLen}, ::Integer)
   @ Base range.jl:880
  iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}}
   @ Base dict.jl:698
  ...

Stacktrace:
 [1] indexed_iterate(I::Nothing, i::Int64)
   @ Base ./tuple.jl:91
 [2] exp!(x1::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}})
   @ Main ~/.julia/packages/ForwardDiffChainRules/2Xt9G/src/ForwardDiffChainRules.jl:81
 [3] test_exp(x::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}})
   @ Main ./REPL[4]:3
 [4] chunk_mode_gradient(f::typeof(test_exp), x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}}})
   @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:123
 [5] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}}}, ::Val{true})
   @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:21
 [6] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}}})
   @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:17
 [7] gradient(f::Function, x::Vector{Float64})
   @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:17
 [8] top-level scope
   @ ./REPL[5]:4

@baggepinnen
Copy link
Author

I think the problem is related to how ForwardDiffChainRules deals with (doesn't deal with) the fact that exp! mutates its input argument, by adding a call to copy on the input argument before each invokation of the frule I get the correct results. This is probably an issue with ForwardDiffChainRules then.

@sethaxen
Copy link
Member

Sounds right. This line calls frule repeatedly on the same primals, so it assumes the function is nonmutating: https://github.com/ThummeTo/ForwardDiffChainRules.jl/blob/d70301a28f61250c3168446c4b147b195ceee117/src/ForwardDiffChainRules.jl#L88

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants