-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error in pullback construction #368
Comments
I have an example where this error shows up, but it's a couple of hundred lines long. I tried to reduce it but no luck... If I replace Zygote with Tracker then it works. Should I post my example here? |
Is it possible to start from that larger example, and gradually remove code until it's smaller? |
I was able to reduce my example down to this (now it's only 100 lines long): # Listen, Attend and Spell: arxiv.org/abs/1508.01211
using Flux
using Flux: @functor, reset!
using LinearAlgebra
mutable struct State{M <: AbstractMatrix{<:Real}}
context :: M # last attention context
decoding :: M # last decoder state
prediction :: M # last prediction
# reset values
context₀ :: M
decoding₀ :: M
prediction₀ :: M
end
@functor State
function State(dim_c::Integer, dim_d::Integer, dim_p::Integer)
context₀ = param(zeros(Float32, dim_c, 1))
decoding₀ = param(zeros(Float32, dim_d, 1))
prediction₀ = param(zeros(Float32, dim_p, 1))
return State(context₀, decoding₀, prediction₀, context₀, decoding₀, prediction₀)
end
function Flux.reset!(s::State)
s.context = s.context₀
s.decoding = s.decoding₀
s.prediction = s.prediction₀
return nothing
end
struct LAS{V, E, Dϕ, Dψ, L, C}
state :: State{V} # current state of the model
listen :: E # encoder function
attention_ϕ :: Dϕ # attention context function
attention_ψ :: Dψ # attention context function
spell :: L # RNN decoder
infer :: C # character distribution inference function
end
@functor LAS
function LAS()
state = State(8, 8, 4)
listen = RNN(5, 8)
attention_ϕ = Dense(8, 8)
attention_ψ = Dense(8, 8)
spell = RNN(20, 8)
infer = Chain(Dense(16, 4), logsoftmax)
las = LAS(state, listen, attention_ϕ, attention_ψ, spell, infer)
return las
end
function (m::LAS)(xs::AbstractVector{<:AbstractMatrix})::AbstractVector{<:AbstractMatrix{<:Real}}
batch_size = size(first(xs), 2)
# compute input encoding
hs = m.listen.(xs)
# concatenate sequence of D×N matrices into ssingle D×N×T 3-dimdimensional array
Hs = cat(hs...; dims=3)
# precompute ψ(H)
ψHs = m.attention_ψ.(hs)
# compute inital decoder state for a batch
O = zeros(Float32, size(m.state.decoding, 1), batch_size)
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context]) .+ O
dim_out = size(m.state.prediction, 1)
ŷs = map(1:length(xs)) do _
# compute ϕ(sᵢ)
ϕSᵢᵀ = m.attention_ϕ(m.state.decoding)'
# compute attention context
Eᵢs = diag.(Ref(ϕSᵢᵀ) .* ψHs)
# αᵢs = softmax(hcat(Eᵢs...)')
αᵢs = softmax(vcat(Eᵢs'...))
# compute attention context, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
m.state.context = dropdims(sum(reshape(αᵢs, 1, batch_size, :) .* Hs; dims=3); dims=3)
# predict probability distribution over character alphabet
m.state.prediction = m.infer([m.state.decoding; m.state.context])
# compute decoder state
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
return m.state.prediction
end
reset!(m)
return ŷs
end
function Flux.reset!(m::LAS)
reset!(m.state)
reset!(m.listen)
reset!(m.spell)
return nothing
end
las = LAS()
function loss(xs::AbstractVector{<:AbstractMatrix{<:Real}}, indexes::AbstractVector{<:AbstractVector{<:Integer}})::Real
ŷs = las(xs)
l = -sum(sum.(getindex.(ŷs, indexes)))
return l
end
xs = [rand(Float32, 5,7) for _ ∈ 1:3]
ys = [rand(1:28, 7) for _ ∈ 1:3] # 4*7
las(xs)
loss(xs, ys)
θ = params(las)
l, pb = Flux.Zygote.pullback(θ) do
loss(xs, ys)
end |
Further reduced to: # Listen, Attend and Spell: arxiv.org/abs/1508.01211
using Flux
using Flux: @functor, reset!
using LinearAlgebra
mutable struct State{M <: AbstractMatrix{<:Real}}
context :: M # last attention context
decoding :: M # last decoder state
prediction :: M # last prediction
# reset values
context₀ :: M
decoding₀ :: M
prediction₀ :: M
end
@functor State
function State(dim_c::Integer, dim_d::Integer, dim_p::Integer)
context₀ = param(zeros(Float32, dim_c, 1))
decoding₀ = param(zeros(Float32, dim_d, 1))
prediction₀ = param(zeros(Float32, dim_p, 1))
return State(context₀, decoding₀, prediction₀, context₀, decoding₀, prediction₀)
end
function Flux.reset!(s::State)
s.context = s.context₀
s.decoding = s.decoding₀
s.prediction = s.prediction₀
return nothing
end
struct LAS{V, E, L, C}
state :: State{V} # current state of the model
listen :: E # encoder function
spell :: L # RNN decoder
infer :: C # character distribution inference function
end
@functor LAS
function LAS()
state = State(8, 8, 4)
listen = RNN(5, 8)
spell = RNN(20, 8)
infer = Chain(Dense(16, 4), logsoftmax)
return LAS(state, listen, spell, infer)
end
function (m::LAS)(xs::AbstractVector{<:AbstractMatrix})::AbstractVector{<:AbstractMatrix{<:Real}}
batch_size = size(first(xs), 2)
# compute input encoding
hs = m.listen.(xs)
# concatenate sequence of D×N matrices into ssingle D×N×T 3-dimdimensional array
Hs = cat(hs...; dims=3)
# compute inital decoder state for a batch
O = zeros(Float32, size(m.state.decoding, 1), batch_size)
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context]) .+ O
dim_out = size(m.state.prediction, 1)
ŷs = map(1:length(xs)) do _
m.state.context = dropdims(sum(Hs; dims=3); dims=3)
# predict probability distribution over character alphabet
m.state.prediction = m.infer([m.state.decoding; m.state.context])
# compute decoder state
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
return m.state.prediction
end
reset!(m)
return ŷs
end
function Flux.reset!(m::LAS)
reset!(m.state)
reset!(m.listen)
reset!(m.spell)
return nothing
end
las = LAS()
function loss(xs::AbstractVector{<:AbstractMatrix{<:Real}}, indexes::AbstractVector{<:AbstractVector{<:Integer}})::Real
ŷs = las(xs)
l = -sum(sum.(getindex.(ŷs, indexes)))
return l
end
xs = [rand(Float32, 5,7) for _ ∈ 1:3]
ys = [rand(1:28, 7) for _ ∈ 1:3] # 4*7
las(xs)
loss(xs, ys)
θ = params(las)
l, pb = Flux.Zygote.pullback(θ) do
loss(xs, ys)
end |
After investigating this the whole day, seems like the error is happening here: Zygote.jl/src/compiler/interface.jl Line 96 in 16b9aea
where the cx is a Zygote.Context(nothing, nothing) and f is a closure defined as f = () -> loss(xs, ys) . _pullback(cx, f) calls generated function defined here:Zygote.jl/src/compiler/interface2.jl Line 6 in 16b9aea
and debugging any further is above my abilities. Sorry. I really hope this is helpful in resolving this nasty bug. Would be great if this can get some attention - it's what is preventing me from using Zygote at the moment. Thanks a lot. |
This is a major pain point for us at the moment. @MikeInnes I would love to prepare PR that fixes this if you could provide some guidance? I'm hoping we can iron out this bug before the next release of Flux. |
Somewhat smaller example: julia> using Zygote: pullback
julia> function foo()
Complex{<:Real}
end
foo (generic function with 2 methods)
julia> pullback(foo)
ERROR: UndefVarError: S not defined (the problem in the original code is the assertion on return type on the |
Thank you for fixing this! I can confirm that we don't get this error now. We're now seeing a different error, but I'll file a separate issue for it. Edit: the error we are getting is #198. |
We occasionally get errors when constructing pullbacks in the reverse pass, with
S not defined
. (The type isPullback{T,S}
, whereT
is a function signature andS
is the type of the data we're storing; we callPullback{T}(data)
and letS
be inferred). I've seen this before but don't currently have a test case for it.The text was updated successfully, but these errors were encountered: