Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in pullback construction #368

Closed
MikeInnes opened this issue Oct 9, 2019 · 8 comments
Closed

Error in pullback construction #368

MikeInnes opened this issue Oct 9, 2019 · 8 comments
Labels
bug Something isn't working

Comments

@MikeInnes
Copy link
Member

We occasionally get errors when constructing pullbacks in the reverse pass, with S not defined. (The type is Pullback{T,S}, where T is a function signature and S is the type of the data we're storing; we call Pullback{T}(data) and let S be inferred). I've seen this before but don't currently have a test case for it.

@AzamatB
Copy link
Contributor

AzamatB commented Oct 10, 2019

I have an example where this error shows up, but it's a couple of hundred lines long. I tried to reduce it but no luck... If I replace Zygote with Tracker then it works. Should I post my example here?

@MikeInnes
Copy link
Member Author

Is it possible to start from that larger example, and gradually remove code until it's smaller?

@AzamatB
Copy link
Contributor

AzamatB commented Oct 10, 2019

I was able to reduce my example down to this (now it's only 100 lines long):

# Listen, Attend and Spell: arxiv.org/abs/1508.01211
using Flux
using Flux: @functor, reset!
using LinearAlgebra

mutable struct State{M <: AbstractMatrix{<:Real}}
   context     :: M   # last attention context
   decoding    :: M   # last decoder state
   prediction  :: M   # last prediction
   # reset values
   context₀    :: M
   decoding₀   :: M
   prediction₀ :: M
end

@functor State

function State(dim_c::Integer, dim_d::Integer, dim_p::Integer)
   context₀    = param(zeros(Float32, dim_c, 1))
   decoding₀   = param(zeros(Float32, dim_d, 1))
   prediction₀ = param(zeros(Float32, dim_p, 1))
   return State(context₀, decoding₀, prediction₀, context₀, decoding₀, prediction₀)
end

function Flux.reset!(s::State)
   s.context    = s.context₀
   s.decoding   = s.decoding₀
   s.prediction = s.prediction₀
   return nothing
end

struct LAS{V, E, Dϕ, Dψ, L, C}
   state       :: State{V} # current state of the model
   listen      :: E   # encoder function
   attention_ϕ :: Dϕ  # attention context function
   attention_ψ :: Dψ  # attention context function
   spell       :: L   # RNN decoder
   infer       :: C   # character distribution inference function
end

@functor LAS

function LAS()
   state       = State(8, 8, 4)
   listen      = RNN(5, 8)
   attention_ϕ = Dense(8, 8)
   attention_ψ = Dense(8, 8)
   spell       = RNN(20, 8)
   infer       = Chain(Dense(16, 4), logsoftmax)
   las = LAS(state, listen, attention_ϕ, attention_ψ, spell, infer)
   return las
end

function (m::LAS)(xs::AbstractVector{<:AbstractMatrix})::AbstractVector{<:AbstractMatrix{<:Real}}
   batch_size = size(first(xs), 2)
   # compute input encoding
   hs = m.listen.(xs)
   # concatenate sequence of D×N matrices into ssingle D×N×T 3-dimdimensional array
   Hs = cat(hs...; dims=3)
   # precompute ψ(H)
   ψHs = m.attention_ψ.(hs)
   # compute inital decoder state for a batch
   O = zeros(Float32, size(m.state.decoding, 1), batch_size)
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context]) .+ O
   dim_out = size(m.state.prediction, 1)

   ŷs = map(1:length(xs)) do _
      # compute ϕ(sᵢ)
      ϕSᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute attention context
      Eᵢs = diag.(Ref(ϕSᵢᵀ) .* ψHs)
      # αᵢs = softmax(hcat(Eᵢs...)')
      αᵢs = softmax(vcat(Eᵢs'...))
      # compute attention context, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      m.state.context = dropdims(sum(reshape(αᵢs, 1, batch_size, :) .* Hs; dims=3); dims=3)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end

function Flux.reset!(m::LAS)
   reset!(m.state)
   reset!(m.listen)
   reset!(m.spell)
   return nothing
end

las = LAS()

function loss(xs::AbstractVector{<:AbstractMatrix{<:Real}}, indexes::AbstractVector{<:AbstractVector{<:Integer}})::Real
   ŷs = las(xs)
   l = -sum(sum.(getindex.(ŷs, indexes)))
   return l
end

xs = [rand(Float32, 5,7) for _  1:3]
ys = [rand(1:28, 7) for _  1:3] # 4*7

las(xs)
loss(xs, ys)

θ = params(las)

l, pb = Flux.Zygote.pullback(θ) do
   loss(xs, ys)
end

@AzamatB
Copy link
Contributor

AzamatB commented Oct 10, 2019

Further reduced to:

# Listen, Attend and Spell: arxiv.org/abs/1508.01211
using Flux
using Flux: @functor, reset!
using LinearAlgebra

mutable struct State{M <: AbstractMatrix{<:Real}}
   context     :: M   # last attention context
   decoding    :: M   # last decoder state
   prediction  :: M   # last prediction
   # reset values
   context₀    :: M
   decoding₀   :: M
   prediction₀ :: M
end

@functor State

function State(dim_c::Integer, dim_d::Integer, dim_p::Integer)
   context₀    = param(zeros(Float32, dim_c, 1))
   decoding₀   = param(zeros(Float32, dim_d, 1))
   prediction₀ = param(zeros(Float32, dim_p, 1))
   return State(context₀, decoding₀, prediction₀, context₀, decoding₀, prediction₀)
end

function Flux.reset!(s::State)
   s.context    = s.context₀
   s.decoding   = s.decoding₀
   s.prediction = s.prediction₀
   return nothing
end

struct LAS{V, E, L, C}
   state  :: State{V} # current state of the model
   listen :: E   # encoder function
   spell  :: L   # RNN decoder
   infer  :: C   # character distribution inference function
end

@functor LAS

function LAS()
   state  = State(8, 8, 4)
   listen = RNN(5, 8)
   spell  = RNN(20, 8)
   infer  = Chain(Dense(16, 4), logsoftmax)
   return LAS(state, listen, spell, infer)
end

function (m::LAS)(xs::AbstractVector{<:AbstractMatrix})::AbstractVector{<:AbstractMatrix{<:Real}}
   batch_size = size(first(xs), 2)
   # compute input encoding
   hs = m.listen.(xs)
   # concatenate sequence of D×N matrices into ssingle D×N×T 3-dimdimensional array
   Hs = cat(hs...; dims=3)
   # compute inital decoder state for a batch
   O = zeros(Float32, size(m.state.decoding, 1), batch_size)
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context]) .+ O
   dim_out = size(m.state.prediction, 1)

   ŷs = map(1:length(xs)) do _
      m.state.context = dropdims(sum(Hs; dims=3); dims=3)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end

function Flux.reset!(m::LAS)
   reset!(m.state)
   reset!(m.listen)
   reset!(m.spell)
   return nothing
end

las = LAS()

function loss(xs::AbstractVector{<:AbstractMatrix{<:Real}}, indexes::AbstractVector{<:AbstractVector{<:Integer}})::Real
   ŷs = las(xs)
   l = -sum(sum.(getindex.(ŷs, indexes)))
   return l
end

xs = [rand(Float32, 5,7) for _  1:3]
ys = [rand(1:28, 7) for _  1:3] # 4*7

las(xs)
loss(xs, ys)

θ = params(las)

l, pb = Flux.Zygote.pullback(θ) do
   loss(xs, ys)
end

@AzamatB
Copy link
Contributor

AzamatB commented Oct 10, 2019

After investigating this the whole day, seems like the error is happening here:

y, back = _pullback(cx, f)

where the cx is a Zygote.Context(nothing, nothing) and f is a closure defined as f = () -> loss(xs, ys). _pullback(cx, f) calls generated function defined here:
@generated function _pullback(ctx::AContext, f, args...)

and debugging any further is above my abilities. Sorry. I really hope this is helpful in resolving this nasty bug.

Would be great if this can get some attention - it's what is preventing me from using Zygote at the moment. Thanks a lot.

@AzamatB
Copy link
Contributor

AzamatB commented Oct 15, 2019

This is a major pain point for us at the moment. @MikeInnes I would love to prepare PR that fixes this if you could provide some guidance?

I'm hoping we can iron out this bug before the next release of Flux.

@MikeInnes
Copy link
Member Author

Somewhat smaller example:

julia> using Zygote: pullback

julia> function foo()
          Complex{<:Real}
       end
foo (generic function with 2 methods)

julia> pullback(foo)
ERROR: UndefVarError: S not defined

(the problem in the original code is the assertion on return type on the (m::LAS) definition; you can just remove that to make it work, although an unrelated error shows up if you do).

@AzamatB
Copy link
Contributor

AzamatB commented Oct 16, 2019

Thank you for fixing this! I can confirm that we don't get this error now. We're now seeing a different error, but I'll file a separate issue for it.

Edit: the error we are getting is #198.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants