Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Computation of higher order derivatives for recurrent models results in strange errors #1593

Closed
simonmandlik opened this issue May 10, 2021 · 2 comments

Comments

@simonmandlik
Copy link
Contributor

I've been experimenting with computing higher order derivatives for recurrent models. For example, the following computes the gradient (with respect to model parameters) of the gradient of model's output (with respect to input data):

using Flux

function f(i)
    m = RNN(2, 2)
    xs = [randn(Float32, 2, 2) for _ in 1:i]
    df() = gradient(b -> mapreduce(sum, +, m.(b)), xs)[1]
    gradient(() -> mapreduce(sum, +, df()), Flux.params(m))
end

On my machine, for smaller values of i (1...15) it fails on Mutating arrays is not supported:

julia> f(15)
ERROR: Mutating arrays is not supported
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#403#404")(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/array.jl:58
  [3] (::Zygote.var"#2259#back#405"{Zygote.var"#403#404"})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] (::Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.var"#2259#back#405"{Zygote.var"#403#404"}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:194
  [5] (::Zygote.var"#1689#back#182"{Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.var"#2259#back#405"{Zygote.var"#403#404"}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [6] Pullback
    @ ~/.julia/packages/Zygote/6HN9x/src/lib/array.jl:38 [inlined]
  [7] (::typeof(∂(λ)))(Δ::Tuple{Vector{Union{Nothing, Matrix{Float32}}}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
  [9] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Vector{Union{Nothing, Matrix{Float32}}}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./reduce.jl:406 [inlined]
 [11] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Nothing, Nothing, Vector{Matrix{Float32}}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [12] Pullback
    @ ./reducedim.jl:318 [inlined]
 [13] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Nothing, Nothing, Vector{Matrix{Float32}}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [14] Pullback
    @ ./reducedim.jl:310 [inlined]
 [15] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Matrix{Float32}}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [16] Pullback
    @ ./reducedim.jl:310 [inlined]
 [17] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Nothing, Vector{Matrix{Float32}}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [18] Pullback
    @ ./REPL[13]:4 [inlined]
 [19] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Vector{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [20] Pullback
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41 [inlined]
 [21] (::typeof(∂(λ)))(Δ::Tuple{Vector{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [22] Pullback
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59 [inlined]
 [23] (::typeof(∂(gradient)))(Δ::Tuple{Vector{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [24] Pullback
    @ ./REPL[13]:4 [inlined]
 [25] (::typeof(∂(λ)))(Δ::Vector{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [26] Pullback
    @ ./REPL[13]:5 [inlined]
 [27] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [28] (::Zygote.var"#69#70"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:252
 [29] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [30] f(i::Int64)
    @ Main ./REPL[13]:5
 [31] top-level scope
    @ show.jl:955
 [32] top-level scope
    @ ~/.julia/packages/CUDA/k52QH/src/initialization.jl:81

which is kind of expected and can be remedied e.g. like this https://github.com/FluxML/Zygote.jl/pull/944/files.

For i >= 16, the following error happens:

julia> f(16)
ERROR: Can't differentiate loopinfo expression
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] rrule(#unused#::typeof(error), 667::String)
    @ ChainRules ~/.julia/packages/ChainRules/DIzwo/src/rulesets/Base/nondiff.jl:122
  [3] chain_rrule
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:89 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0 [inlined]
  [5] _pullback(ctx::Zygote.Context, f::typeof(error), args::String)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:9
  [6] macro expansion
    @ ./simdloop.jl:79 [inlined]
  [7] _pullback
    @ ./reduce.jl:243 [inlined]
  [8] _pullback(ctx::Zygote.Context, f::typeof(∂(mapreduce_impl)), args::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [9] _pullback
    @ ./reduce.jl:257 [inlined]
 [10] _pullback(ctx::Zygote.Context, f::typeof(∂(mapreduce_impl)), args::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [11] _pullback
    @ ./reduce.jl:415 [inlined]
 [12] _pullback(ctx::Zygote.Context, f::typeof(∂(_mapreduce)), args::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [13] _pullback
    @ ./reducedim.jl:318 [inlined]
 [14] _pullback(ctx::Zygote.Context, f::typeof(∂(_mapreduce_dim)), args::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [15] _pullback
    @ ./reducedim.jl:310 [inlined]
 [16] _pullback(ctx::Zygote.Context, f::typeof(∂(#mapreduce#672)), args::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [17] _pullback
    @ ./reducedim.jl:310 [inlined]
 [18] _pullback(ctx::Zygote.Context, f::typeof(∂(mapreduce)), args::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [19] _pullback
    @ ./REPL[13]:4 [inlined]
 [20] _pullback(ctx::Zygote.Context, f::typeof(∂(λ)), args::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [21] _pullback
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41 [inlined]
 [22] _pullback(ctx::Zygote.Context, f::Zygote.var"#41#42"{typeof(∂(λ))}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [23] _pullback
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59 [inlined]
 [24] _pullback(::Zygote.Context, ::typeof(gradient), ::var"#23#27"{Flux.Recur{Flux.RNNCell{typeof(tanh), Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}}, ::Vector{Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [25] _pullback
    @ ./REPL[13]:4 [inlined]
 [26] _pullback(::Zygote.Context, ::var"#df#26"{Vector{Matrix{Float32}}, Flux.Recur{Flux.RNNCell{typeof(tanh), Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [27] _pullback
    @ ./REPL[13]:5 [inlined]
 [28] _pullback(::Zygote.Context, ::var"#24#28"{var"#df#26"{Vector{Matrix{Float32}}, Flux.Recur{Flux.RNNCell{typeof(tanh), Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}}})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [29] pullback(f::Function, ps::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:247
 [30] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:58
 [31] f(i::Int64)
    @ Main ./REPL[13]:5
 [32] top-level scope
    @ show.jl:955
 [33] top-level scope
    @ ~/.julia/packages/CUDA/k52QH/src/initialization.jl:81

Is there any fix to that and is Flux framework for recurrent models built with nested AD in mind?

@simonmandlik
Copy link
Contributor Author

This may be an issue in Zygote FluxML/Zygote.jl#897

@simonmandlik
Copy link
Contributor Author

There are multiple issues regarding this problem in Zygote:

FluxML/Zygote.jl#704
FluxML/Zygote.jl#157
FluxML/Zygote.jl#897
FluxML/Zygote.jl#229

Closing, because it is a Zygote issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants