Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote error: UndefVarError: S not defined #1578

Closed
jonathan-laurent opened this issue Apr 17, 2021 · 25 comments
Closed

Zygote error: UndefVarError: S not defined #1578

jonathan-laurent opened this issue Apr 17, 2021 · 25 comments

Comments

@jonathan-laurent
Copy link

I am trying to update AlphaZero.jl so that it works with Flux v0.12 but I am stuck on the following Zygote error:

ERROR: UndefVarError: S not defined

The bug happens on both [email protected] and Flux#master.

Replication instructions

To replicate, you can run the following using Julia 1.6:

git clone --branch flux-0.12 https://github.com/jonathan-laurent/AlphaZero.jl.git
cd AlphaZero.jl
julia --project -e 'import Pkg; Pkg.instantiate()'
julia --project -e 'using AlphaZero; Scripts.test_grad_updates("connect-four")'

Full stacktrace

ERROR: UndefVarError: S not defined
Stacktrace:
  [1] (typeof(∂(λ)))(x::Tuple{typeof(∂(getproperty))})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface.jl:19
  [2] _pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
  [3] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::Type{Type{T}}, ::Val{:name})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
  [4] _pullback
    @ ./promotion.jl:87 [inlined]
  [5] _pullback(::Zygote.Context, ::typeof(typejoin), ::Type{CUDA.CuArray{Float32, 4}}, ::Type{CUDA.CuArray{Float32, 2}})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
  [6] _pullback
    @ ./promotion.jl:149 [inlined]
  [7] _pullback
    @ ./array.jl:748 [inlined]
  [8] _pullback(::Zygote.Context, ::typeof(Base.push_widen), ::Vector{CUDA.CuArray{Float32, 4}}, ::CUDA.CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
  [9] _pullback
    @ ./array.jl:767 [inlined]
 [10] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{CUDA.CuArray{Float32, 4}}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, AlphaZero.FluxLib.var"#27#28"}}, ::Tuple{Int64, Base.Generator{Vector{CUDA.CuArray{Float32, 4}}, typeof(identity)}, Int64})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [11] _pullback
    @ ./array.jl:743 [inlined]
 [12] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Any}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, AlphaZero.FluxLib.var"#27#28"}})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [13] _pullback
    @ ./array.jl:652 [inlined]
 [14] _pullback(::Zygote.Context, ::typeof(Base._collect), ::UnitRange{Int64}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, AlphaZero.FluxLib.var"#27#28"}}, ::Base.EltypeUnknown, ::Base.SizeUnknown)
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [15] _pullback
    @ ./array.jl:602 [inlined]
 [16] _pullback(ctx::Zygote.Context, f::typeof(collect), args::Base.Iterators.Flatten{Base.Generator{Vector{Any}, AlphaZero.FluxLib.var"#27#28"}})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [17] _pullback
    @ ~/AlphaZero.jl/src/networks/flux.jl:117 [inlined]
 [18] _pullback(ctx::Zygote.Context, f::typeof(AlphaZero.Network.regularized_params), args::ResNet)
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [19] _pullback
    @ ~/AlphaZero.jl/src/learning.jl:75 [inlined]
 [20] _pullback(::Zygote.Context, ::typeof(AlphaZero.losses), ::ResNet, ::LearningParams, ::Float32, ::Float32, ::Tuple{CUDA.CuArray{Float32, 2}, CUDA.CuArray{Float32, 4}, CUDA.CuArray{Float32, 2}, CUDA.CuArray{Float32, 2}, CUDA.CuArray{Float32, 2}})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [21] _pullback
    @ ~/AlphaZero.jl/src/learning.jl:122 [inlined]
 [22] _pullback(::Zygote.Context, ::AlphaZero.var"#L#110"{AlphaZero.Trainer}, ::CUDA.CuArray{Float32, 2}, ::CUDA.CuArray{Float32, 4}, ::CUDA.CuArray{Float32, 2}, ::CUDA.CuArray{Float32, 2}, ::CUDA.CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [23] adjoint
    @ ~/.julia/packages/Zygote/CgsVi/src/lib/lib.jl:188 [inlined]
 [24] _pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [25] _pullback
    @ ~/AlphaZero.jl/src/networks/flux.jl:82 [inlined]
 [26] _pullback(::Zygote.Context, ::AlphaZero.FluxLib.var"#1#2"{AlphaZero.var"#L#110"{AlphaZero.Trainer}, Tuple{CUDA.CuArray{Float32, 2}, CUDA.CuArray{Float32, 4}, CUDA.CuArray{Float32, 2}, CUDA.CuArray{Float32, 2}, CUDA.CuArray{Float32, 2}}})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [27] pullback(f::Function, ps::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface.jl:247
 [28] lossgrads(f::Function, args::Zygote.Params)
    @ AlphaZero.FluxLib ~/AlphaZero.jl/src/networks/flux.jl:72
 [29] train!(callback::AlphaZero.var"#109#111"{Vector{Float32}}, nn::ResNet, opt::Adam, loss::Function, data::Base.Iterators.Take{Base.Iterators.Stateful{Base.Iterators.Flatten{Base.Generator{Base.Iterators.Repeated{Nothing}, AlphaZero.Util.var"#12#13"{AlphaZero.var"#106#108"{ResNet}, Tuple{Matrix{Float32}, Array{Float32, 4}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}, Int64, Bool}}}, Tuple{NTuple{5, Any}, Tuple{Nothing, Base.Generator{Vector{Tuple{Matrix{Float32}, Array{Float32, 4}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, AlphaZero.Util.var"#9#11"{AlphaZero.var"#106#108"{ResNet}}}, Int64}}}}, n::Int64)
    @ AlphaZero.FluxLib ~/AlphaZero.jl/src/networks/flux.jl:81
 [30] batch_updates!(tr::AlphaZero.Trainer, n::Int64)
    @ AlphaZero ~/AlphaZero.jl/src/learning.jl:125
 [31] macro expansion
    @ ./timing.jl:356 [inlined]
 [32] learning_step!(env::Env{AlphaZero.Examples.ConnectFour.GameSpec, ResNet, NamedTuple{(:board, :curplayer), Tuple{StaticArrays.SMatrix{7, 6, UInt8, 42}, UInt8}}}, handler::Session{Env{AlphaZero.Examples.ConnectFour.GameSpec, ResNet, NamedTuple{(:board, :curplayer), Tuple{StaticArrays.SMatrix{7, 6, UInt8, 42}, UInt8}}}})
    @ AlphaZero ~/AlphaZero.jl/src/training.jl:224
 [33] test_grad_updates(exp::Experiment; num_games::Int64)
    @ AlphaZero.Scripts ~/AlphaZero.jl/src/scripts/test_grad_updates.jl:17
 [34] test_grad_updates
    @ ~/AlphaZero.jl/src/scripts/test_grad_updates.jl:10 [inlined]
 [35] #test_grad_updates#21
    @ ~/AlphaZero.jl/src/scripts/scripts.jl:57 [inlined]
 [36] test_grad_updates(s::String)
    @ AlphaZero.Scripts ~/AlphaZero.jl/src/scripts/scripts.jl:57
 [37] top-level scope
    @ REPL[2]:1
@DhairyaLGandhi
Copy link
Member

Interesting, thanks for pointing this out. Next step might be to check with older Zygote releases to see where this came up. I suspected the push related changes caused it to happen, but I didn't find any stray Ss in there.

@darsnack
Copy link
Member

darsnack commented Apr 17, 2021

The source of the issue is this line.

@DhairyaLGandhi
Copy link
Member

julia> d = Chain(Dense(3,3), Dense(3,3))
Chain(Dense(3, 3), Dense(3, 3))

julia> x = rand(Float32, 3,4);

julia> gs = gradient(Flux.params(d)) do
         ds = Flux.modules(d)
         sum(l(x) for l in ds) |> sum
       end
Grads(...)

@darsnack
Copy link
Member

Flux.modules has @nograd defined. I think the generator expression might the problem.

@DhairyaLGandhi
Copy link
Member

There's a generator in the above expression as well.

@darsnack
Copy link
Member

julia> regularized_params_(l::Flux.Dense) = [l.W]
regularized_params_ (generic function with 1 method)

julia> regularized_params_(l) = []
regularized_params_ (generic function with 2 methods)

julia> regularized_params_(l::Flux.Conv) = [l.weight]
regularized_params_ (generic function with 3 methods)

julia> d = Chain(Conv((3, 3), 1 => 2), Dense(3,3), Dense(3,3))
Chain(Conv((3, 3), 1=>2), Dense(3, 3), Dense(3, 3))

julia> gs = gradient(Flux.params(d)) do
           ws = [w for l in Flux.modules(d) for w in regularized_params_(l)]
           sum(sum(w) for w in ws)
       end
ERROR: UndefVarError: S not defined
Stacktrace:
  [1] (typeof((λ)))(x::Tuple{typeof(∂(getproperty))})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:19
  [2] _pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
  [3] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::Type{Type{T}}, ::Val{:name})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [4] _pullback
    @ ./promotion.jl:87 [inlined]
  [5] _pullback(::Zygote.Context, ::typeof(typejoin), ::Type{Array{Float32, 4}}, ::Type{Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [6] _pullback
    @ ./promotion.jl:149 [inlined]
  [7] _pullback
    @ ./array.jl:748 [inlined]
  [8] _pullback(::Zygote.Context, ::typeof(Base.push_widen), ::Vector{Array{Float32, 4}}, ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [9] _pullback
    @ ./array.jl:767 [inlined]
 [10] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Array{Float32, 4}}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, var"#40#42"}}, ::Tuple{Int64, Base.Generator{Vector{Array{Float32, 4}}, typeof(identity)}, Int64})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [11] _pullback
    @ ./array.jl:743 [inlined]
 [12] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Any}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, var"#40#42"}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [13] _pullback
    @ ./array.jl:652 [inlined]
 [14] _pullback(::Zygote.Context, ::typeof(Base._collect), ::UnitRange{Int64}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, var"#40#42"}}, ::Base.EltypeUnknown, ::Base.SizeUnknown)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [15] _pullback
    @ ./array.jl:602 [inlined]
 [16] _pullback(ctx::Zygote.Context, f::typeof(collect), args::Base.Iterators.Flatten{Base.Generator{Vector{Any}, var"#40#42"}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [17] _pullback
    @ ./REPL[474]:2 [inlined]
 [18] _pullback(::Zygote.Context, ::var"#39#41")
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [19] pullback(f::Function, ps::Params)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:247
 [20] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:58
 [21] top-level scope
    @ REPL[474]:1

@darsnack
Copy link
Member

It's more subtle too. You need a Conv + Dense:

julia> d = Chain(Dense(3, 3), Dense(3, 3))
Chain(Dense(3, 3), Dense(3, 3))

julia> gs = gradient(Flux.params(d)) do
           ws = [w for l in Flux.modules(d) for w in regularized_params_(l)]
           sum(sum(w) for w in ws)
       end
ERROR: Compiling Tuple{Base.var"##depwarn#864", Bool, typeof(Base.depwarn), String, Symbol}: try/catch is not supported.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:121
  [3] #Primal#20
    @ ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:202 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:315
  [5] _lookup_grad(T::Type)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/emit.jl:101
  [6] #s2993#1177
    @ ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:34 [inlined]
  [7] var"#s2993#1177"(T::Any, j::Any, Δ::Any)
    @ Zygote ./none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:571
  [9] Pullback
    @ ./deprecated.jl:80 [inlined]
 [10] (::typeof((depwarn)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/Flux/qp1gc/src/deprecations.jl:13 [inlined]
 [12] (::typeof((getproperty)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
 [14] Pullback
    @ ./REPL[457]:1 [inlined]
 [15] (::typeof((regularized_params_)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [16] Pullback
    @ ./none:0 [inlined]
 [17] (::typeof((#44)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [18] Pullback
    @ ./generator.jl:47 [inlined]
 [19] (::typeof((iterate)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [20] Pullback
    @ ./iterators.jl:1093 [inlined]
 [21] (::typeof((iterate)))(Δ::Tuple{Nothing, Tuple{Nothing, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [22] Pullback
    @ ./array.jl:761 [inlined]
 [23] (::typeof((grow_to!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [24] Pullback
    @ ./array.jl:743 [inlined]
 [25] (::typeof((grow_to!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [26] Pullback
    @ ./array.jl:652 [inlined]
 [27] (::typeof((_collect)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [28] Pullback
    @ ./array.jl:602 [inlined]
 [29] (::typeof((collect)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [30] Pullback
    @ ./REPL[476]:2 [inlined]
 [31] (::typeof((#43)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [32] (::Zygote.var"#69#70"{Params, typeof((#43)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:252
 [33] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:59
 [34] top-level scope
    @ REPL[476]:1

I think if you look at the stack trace, you'll see a call to typejoin because the type of the Conv and Dense weights are not the same type.

Even smaller reproducer without Flux.modules:

julia> gs = gradient(Flux.params(d)) do
           ws = [w for l in d for w in regularized_params_(l)]
           sum(sum(w) for w in ws)
       end
ERROR: UndefVarError: S not defined
Stacktrace:
  [1] (typeof((λ)))(x::Tuple{typeof(∂(getproperty))})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:19
  [2] _pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
  [3] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::Type{Type{T}}, ::Val{:name})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [4] _pullback
    @ ./promotion.jl:87 [inlined]
  [5] _pullback(::Zygote.Context, ::typeof(typejoin), ::Type{Array{Float32, 4}}, ::Type{Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [6] _pullback
    @ ./promotion.jl:149 [inlined]
  [7] _pullback
    @ ./array.jl:748 [inlined]
  [8] _pullback(::Zygote.Context, ::typeof(Base.push_widen), ::Vector{Array{Float32, 4}}, ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [9] _pullback
    @ ./array.jl:767 [inlined]
 [10] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Array{Float32, 4}}, ::Base.Iterators.Flatten{Base.Generator{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, var"#64#66"}}, ::Tuple{Int64, Base.Generator{Vector{Array{Float32, 4}}, typeof(identity)}, Int64})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [11] _pullback
    @ ./array.jl:743 [inlined]
 [12] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Any}, ::Base.Iterators.Flatten{Base.Generator{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, var"#64#66"}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [13] _pullback
    @ ./array.jl:652 [inlined]
 [14] _pullback(::Zygote.Context, ::typeof(Base._collect), ::UnitRange{Int64}, ::Base.Iterators.Flatten{Base.Generator{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, var"#64#66"}}, ::Base.EltypeUnknown, ::Base.SizeUnknown)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [15] _pullback
    @ ./array.jl:602 [inlined]
 [16] _pullback(ctx::Zygote.Context, f::typeof(collect), args::Base.Iterators.Flatten{Base.Generator{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, var"#64#66"}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [17] _pullback
    @ ./REPL[484]:2 [inlined]
 [18] _pullback(::Zygote.Context, ::var"#63#65")
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [19] pullback(f::Function, ps::Params)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:247
 [20] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:58
 [21] top-level scope
    @ REPL[484]:1

@darsnack
Copy link
Member

darsnack commented Apr 17, 2021

I found that I needed different types for the elements in the generator, and I need nested generators (to force the call to collect which in turn would trigger the typejoin). For example, the following gets a different error:

julia> gs = gradient(Flux.params(d)) do
           sum(sum(w) for m in d for w in regularized_params_(m))
       end
ERROR: Compiling Tuple{Base.var"##depwarn#864", Bool, typeof(Base.depwarn), String, Symbol}: try/catch is not supported.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:121
  [3] #Primal#20
    @ ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:202 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:315
  [5] _lookup_grad(T::Type)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/emit.jl:101
  [6] #s2993#1177
    @ ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:34 [inlined]
  [7] var"#s2993#1177"(T::Any, j::Any, Δ::Any)
    @ Zygote ./none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:571
  [9] Pullback
    @ ./deprecated.jl:80 [inlined]
 [10] (::typeof((depwarn)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/Flux/qp1gc/src/deprecations.jl:13 [inlined]
 [12] (::typeof((getproperty)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
 [14] Pullback
    @ ./REPL[457]:1 [inlined]
 [15] (::typeof((regularized_params_)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [16] Pullback
    @ ./none:0 [inlined]
 [17] (::typeof((#60)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [18] Pullback
    @ ./reduce.jl:93 [inlined]
 [19] (::typeof((Base.MappingRF{var"#60#62", Base.FlatteningRF{Base.BottomRF{typeof(Base.add_sum)}}}(var"#60#62"(), Base.FlatteningRF{Base.BottomRF{typeof(Base.add_sum)}}(Base.BottomRF{typeof(Base.add_sum)}(Base.add_sum))))))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [20] Pullback
    @ ./reduce.jl:62 [inlined]
 [21] (::typeof((_foldl_impl)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [22] Pullback
    @ ./reduce.jl:48 [inlined]
 [23] (::typeof((foldl_impl)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [24] Pullback
    @ ./reduce.jl:44 [inlined]
 [25] (::typeof((mapfoldl_impl)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [26] Pullback (repeats 2 times)
    @ ./reduce.jl:160 [inlined]
 [27] (::typeof((mapfoldl)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [28] Pullback
    @ ./reduce.jl:287 [inlined]
 [29] (::typeof((#mapreduce#218)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [30] Pullback
    @ ./reduce.jl:287 [inlined]
 [31] (::typeof((mapreduce)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [32] Pullback
    @ ./reduce.jl:501 [inlined]
 [33] (::typeof((#sum#221)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [34] Pullback
    @ ./reduce.jl:501 [inlined]
 [35] (::typeof((sum)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [36] Pullback
    @ ./reduce.jl:528 [inlined]
 [37] (::typeof((#sum#222)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [38] Pullback
    @ ./reduce.jl:528 [inlined]
 [39] (::typeof((sum)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [40] Pullback
    @ ./REPL[483]:2 [inlined]
 [41] (::typeof((#59)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [42] (::Zygote.var"#69#70"{Params, typeof((#59)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:252
 [43] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:59
 [44] top-level scope
    @ REPL[483]:1

@darsnack
Copy link
Member

Okay finally got a MWE:

julia> gs = gradient(params(d)) do
           x = typejoin(Array{Float32, 4}, Array{Float32, 2})
           return 1
       end
ERROR: UndefVarError: S not defined
Stacktrace:
  [1] (typeof((λ)))(x::Tuple{typeof(∂(getproperty))})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:19
  [2] _pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
  [3] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::Type{Type{T}}, ::Val{:name})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [4] _pullback
    @ ./promotion.jl:87 [inlined]
  [5] _pullback(::Zygote.Context, ::typeof(typejoin), ::Type{Array{Float32, 4}}, ::Type{Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [6] _pullback
    @ ./REPL[486]:2 [inlined]
  [7] _pullback(::Zygote.Context, ::var"#67#68")
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [8] pullback(f::Function, ps::Params)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:247
  [9] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:58
 [10] top-level scope
    @ REPL[486]:1

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Apr 17, 2021

Just opened FluxML/Zygote.jl#947

Looks like we came to the same conclusion lol.

@darsnack added you as a co-author (hope you don't mind)

@jonathan-laurent
Copy link
Author

I updated the Manifest to use FluxML/Zygote.jl#947 but now I am seeing a different error.
(I updated the flux-0.12 branch so you can still use the replication instructions above).

ERROR: Compiling Tuple{Base.var"##depwarn#868", Bool, typeof(Base.depwarn), String, Symbol}: try/catch is not supported.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/reverse.jl:121
  [3] #Primal#20
    @ ~/.julia/packages/Zygote/iSZne/src/compiler/reverse.jl:202 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/reverse.jl:315
  [5] _lookup_grad(T::Type)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/emit.jl:101
  [6] #s2996#1179
    @ ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:34 [inlined]
  [7] var"#s2996#1179"(T::Any, j::Any, Δ::Any)
    @ Zygote ./none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:571
  [9] Pullback
    @ ./deprecated.jl:80 [inlined]
 [10] (::typeof(∂(depwarn)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/Flux/qp1gc/src/deprecations.jl:13 [inlined]
 [12] (::typeof(∂(getproperty)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
 [14] Pullback
    @ ~/AlphaZero.jl/src/networks/flux.jl:113 [inlined]
 [15] (::typeof(∂(regularized_params_)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [16] Pullback
    @ ./none:0 [inlined]
 [17] (::typeof(∂(#5)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [18] Pullback
    @ ./generator.jl:47 [inlined]
 [19] (::typeof(∂(iterate)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [20] Pullback
    @ ./iterators.jl:1097 [inlined]
 [21] (::typeof(∂(iterate)))(Δ::Tuple{Nothing, Tuple{Nothing, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [22] Pullback
    @ ./array.jl:770 [inlined]
 [23] (::typeof(∂(grow_to!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [24] Pullback
    @ ./array.jl:768 [inlined]
 [25] (::typeof(∂(grow_to!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [26] Pullback
    @ ./array.jl:743 [inlined]
 [27] (::typeof(∂(grow_to!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [28] Pullback
    @ ./array.jl:652 [inlined]
 [29] (::typeof(∂(_collect)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [30] Pullback
    @ ./array.jl:602 [inlined]
 [31] (::typeof(∂(collect)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [32] Pullback
    @ ~/AlphaZero.jl/src/networks/flux.jl:117 [inlined]
 [33] (::typeof(∂(regularized_params)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [34] Pullback
    @ ~/AlphaZero.jl/src/learning.jl:75 [inlined]
 [35] (::typeof(∂(losses)))(Δ::Tuple{Float32, Nothing, Nothing, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [36] Pullback
    @ ~/AlphaZero.jl/src/learning.jl:122 [inlined]
 [37] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [38] (::Zygote.var"#178#179"{Tuple{NTuple{5, Nothing}}, typeof(∂(λ))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/lib/lib.jl:194
 [39] #1686#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [40] Pullback
    @ ~/AlphaZero.jl/src/networks/flux.jl:82 [inlined]
 [41] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [42] (::Zygote.var"#69#70"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface.jl:252
 [43] lossgrads(f::Function, args::Zygote.Params)
    @ AlphaZero.FluxLib ~/AlphaZero.jl/src/networks/flux.jl:73
 [44] train!(callback::AlphaZero.var"#109#111"{Vector{Float32}}, nn::ResNet, opt::Adam, loss::Function, data::Base.Iterators.Take{Base.Iterators.Stateful{Base.Iterators.Flatten{Base.Generator{Base.Iterators.Repeated{Nothing}, AlphaZero.Util.var"#12#13"{AlphaZero.var"#106#108"{ResNet}, Tuple{Matrix{Float32}, Array{Float32, 4}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}, Int64, Bool}}}, Tuple{NTuple{5, Any}, Tuple{Nothing, Base.Generator{Vector{Tuple{Matrix{Float32}, Array{Float32, 4}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, AlphaZero.Util.var"#9#11"{AlphaZero.var"#106#108"{ResNet}}}, Int64}}}}, n::Int64)
    @ AlphaZero.FluxLib ~/AlphaZero.jl/src/networks/flux.jl:81
 [45] batch_updates!(tr::AlphaZero.Trainer, n::Int64)
    @ AlphaZero ~/AlphaZero.jl/src/learning.jl:125
 [46] macro expansion
    @ ./timing.jl:356 [inlined]
 [47] learning_step!(env::Env{AlphaZero.Examples.ConnectFour.GameSpec, ResNet, NamedTuple{(:board, :curplayer), Tuple{StaticArrays.SMatrix{7, 6, UInt8, 42}, UInt8}}}, handler::Session{Env{AlphaZero.Examples.ConnectFour.GameSpec, ResNet, NamedTuple{(:board, :curplayer), Tuple{StaticArrays.SMatrix{7, 6, UInt8, 42}, UInt8}}}})
    @ AlphaZero ~/AlphaZero.jl/src/training.jl:224
 [48] test_grad_updates(exp::Experiment; num_games::Int64)
    @ AlphaZero.Scripts ~/AlphaZero.jl/src/scripts/test_grad_updates.jl:17
 [49] test_grad_updates
    @ ~/AlphaZero.jl/src/scripts/test_grad_updates.jl:10 [inlined]
 [50] #test_grad_updates#21
    @ ~/AlphaZero.jl/src/scripts/scripts.jl:57 [inlined]
 [51] test_grad_updates(s::String)
    @ AlphaZero.Scripts ~/AlphaZero.jl/src/scripts/scripts.jl:57

@DhairyaLGandhi
Copy link
Member

I addressed that in JuliaDiff/ChainRules.jl#398.

@DhairyaLGandhi
Copy link
Member

Should be closed in JuliaDiff/ChainRules.jl#398

@jonathan-laurent
Copy link
Author

jonathan-laurent commented Apr 17, 2021

Now it works if I define this:

function Network.regularized_params(net::FluxNetwork)
  return (w for l in Flux.modules(net) for w in regularized_params_(l))
end

Just so you know, I still have an error if I use this definition instead (returning an array instead of a generator):

function Network.regularized_params(net::FluxNetwork)
  return [w for l in Flux.modules(net) for w in regularized_params_(l)]
end

I am not sure whether or not this should be considered a bug. If not, how should I modify the second example to make it work? Would I need to use a @nograd macro somewhere?

ERROR: Mutating arrays is not supported
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#405#406")(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/lib/array.jl:61
  [3] (::Zygote.var"#2266#back#407"{Zygote.var"#405#406"})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ./array.jl:977 [inlined]
  [5] (::typeof(∂(append!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./array.jl:753 [inlined]
  [7] (::typeof(∂(push_widen)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [8] Pullback
    @ ./array.jl:767 [inlined]
  [9] (::typeof(∂(grow_to!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./array.jl:743 [inlined]
 [11] (::typeof(∂(grow_to!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [12] Pullback
    @ ./array.jl:652 [inlined]
 [13] (::typeof(∂(_collect)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [14] Pullback
    @ ./array.jl:602 [inlined]
 [15] (::typeof(∂(collect)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/AlphaZero.jl/src/networks/flux.jl:117 [inlined]
 [17] (::typeof(∂(regularized_params)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/AlphaZero.jl/src/learning.jl:75 [inlined]
 [19] (::typeof(∂(losses)))(Δ::Tuple{Float32, Nothing, Nothing, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [20] Pullback
    @ ~/AlphaZero.jl/src/learning.jl:122 [inlined]
 [21] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [22] (::Zygote.var"#178#179"{Tuple{NTuple{5, Nothing}}, typeof(∂(λ))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/lib/lib.jl:194
 [23] #1686#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [24] Pullback
    @ ~/AlphaZero.jl/src/networks/flux.jl:82 [inlined]
 [25] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [26] (::Zygote.var"#69#70"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:252
 [27] lossgrads(f::Function, args::Zygote.Params)
    @ AlphaZero.FluxLib ~/AlphaZero.jl/src/networks/flux.jl:73
 [28] train!(callback::AlphaZero.var"#109#111"{Vector{Float32}}, nn::ResNet, opt::Adam, loss::Function, data::Base.Iterators.Take{Base.Iterators.Stateful{Base.Iterators.Flatten{Base.Generator{Base.Iterators.Repeated{Nothing}, AlphaZero.Util.var"#12#13"{AlphaZero.var"#106#108"{ResNet}, Tuple{Matrix{Float32}, Array{Float32, 4}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}, Int64, Bool}}}, Tuple{NTuple{5, Any}, Tuple{Nothing, Base.Generator{Vector{Tuple{Matrix{Float32}, Array{Float32, 4}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, AlphaZero.Util.var"#9#11"{AlphaZero.var"#106#108"{ResNet}}}, Int64}}}}, n::Int64)
    @ AlphaZero.FluxLib ~/AlphaZero.jl/src/networks/flux.jl:81
 [29] batch_updates!(tr::AlphaZero.Trainer, n::Int64)
    @ AlphaZero ~/AlphaZero.jl/src/learning.jl:125
 [30] macro expansion
    @ ./timing.jl:356 [inlined]
 [31] learning_step!(env::Env{AlphaZero.Examples.ConnectFour.GameSpec, ResNet, NamedTuple{(:board, :curplayer), Tuple{StaticArrays.SMatrix{7, 6, UInt8, 42}, UInt8}}}, handler::Session{Env{AlphaZero.Examples.ConnectFour.GameSpec, ResNet, NamedTuple{(:board, :curplayer), Tuple{StaticArrays.SMatrix{7, 6, UInt8, 42}, UInt8}}}})
    @ AlphaZero ~/AlphaZero.jl/src/training.jl:224
 [32] test_grad_updates(exp::Experiment; num_games::Int64)
    @ AlphaZero.Scripts ~/AlphaZero.jl/src/scripts/test_grad_updates.jl:17
 [33] test_grad_updates
    @ ~/AlphaZero.jl/src/scripts/test_grad_updates.jl:10 [inlined]
 [34] #test_grad_updates#21
    @ ~/AlphaZero.jl/src/scripts/scripts.jl:57 [inlined]
 [35] test_grad_updates(s::String)
    @ AlphaZero.Scripts ~/AlphaZero.jl/src/scripts/scripts.jl:57
 [36] top-level scope
    @ REPL[3]:1

@DhairyaLGandhi
Copy link
Member

Interesting is this with the chain rules update?

@jonathan-laurent
Copy link
Author

Yes, this is with the chain rule update.

@DhairyaLGandhi
Copy link
Member

So this is in fact doing array mutation internally, so I'm comfortable to say that Zygote is indeed showing the expected behavior, however, nested generators are something that we had working before. For now I'd return a literal generator and look into the code gen separately

@jonathan-laurent
Copy link
Author

Interesting update:

This issue is about making sure that code like this works with Flux:

gs = gradient(Flux.params(d)) do
    ds = Flux.modules(d)
    sum(l(x) for l in ds) |> sum
end

However, the code above appears to be very slow. Indeed, I achieved a 3x speedup (and a tenfold reduction in the number of GPU allocs) in the backprop phase of AlphaZero.jl by essentially replacing the code above by:

ds = Flux.modules(d)
gs = gradient(Flux.params(d)) do
    sum(l(x) for l in ds) |> sum
end

For more details, see the exact commit: : jonathan-laurent/AlphaZero.jl@b8bac93

Here is the output of CUDA.@time before and after I made the change in AlphaZero.jl:

BEFORE:
12.463815 seconds (44.18 M CPU allocations: 1.117 GiB, 4.67% gc time) (401.32 k GPU allocations: 314.618 GiB, 18.44% gc time of which 34.41% spent allocating)

AFTER:
4.198313 seconds (11.60 M CPU allocations: 451.503 MiB, 7.84% gc time) (45.81 k GPU allocations: 314.617 GiB, 8.46% gc time of which 38.23% spent allocating)

Although it is understandable for the second version to be faster, I definitely did not expect such a gap. Note that this is possibly the result of a recent regression as I think I would have noticed this before otherwise.

@darsnack
Copy link
Member

darsnack commented Apr 26, 2021

Note that this is possibly the result of a recent regression as I think I would have noticed this before otherwise.

How recent? Flux.modules is new (as of v0.12) and uses a very different implementation (based on Functors.jl) than the for-loop it replaces in AlphaZero.jl. (see edit) So, I am not surprised that they don't perform the same, but it's good to know that it is so slow. We definitely want to make the first code example performant, cause that's how I'd expect Flux.modules to be used most of the time.

Unless by recent you mean within the last couple weeks (i.e. comparing Flux.modules vs. Flux.modules w/ some updates). In which case, Zygote would be the first place to look.

Maybe a good cross-check would be to replace Flux.modules with some other generator?


EDIT:

Actually they are very similar.

@jonathan-laurent
Copy link
Author

Before Flux v0.12, I was relying on my own implementation of modules.

See this commit: jonathan-laurent/AlphaZero.jl@7bbb2cb#diff-837cea9edf0b0d7507b695b63c300d02574ec50c6b3227ee6fe4701544b3e97a

@ToucheSir
Copy link
Member

Could it be that fcollect uses a vector as a cache and that's slower in AD? Not sure what else could be a factor.

@DhairyaLGandhi
Copy link
Member

It has been @nogradd already.

@ToucheSir
Copy link
Member

Ah right, disregard that then.

I played around with the examples above on CPU and GPU. Not sure why using modules inside the gradient context has such a crushing performance impact on CPU-side allocations given that modules is @nograd.

d = Chain((Dense(32, 32) for i in 1:5)...)
x = rand(Float32, 32, 8)

f1(m, x) = gradient(Flux.params(m)) do
    ds = Flux.modules(m)
    sum(l(x) for l in ds) |> sum
end

function f2(m, x)
    ds = Flux.modules(m)
    gradient(Flux.params(m)) do
        sum(l(x) for l in ds) |> sum
    end
end

julia> @timev f1(d, x)
 24.901731 seconds (63.67 M allocations: 3.610 GiB, 4.22% gc time)
elapsed time (ns): 24901730841
gc time (ns):      1051987239
bytes allocated:   3875757116
pool allocs:       63656612
non-pool GC allocs:15360
malloc() calls:    88
realloc() calls:   15
GC pauses:         53
full collections:  2

julia> @timev f2(d, x)
  1.086023 seconds (2.15 M allocations: 119.429 MiB, 2.75% gc time, 98.94% compilation time)
elapsed time (ns): 1086023325
gc time (ns):      29910632
bytes allocated:   125229937
pool allocs:       2152312
non-pool GC allocs:161
GC pauses:         3

julia> @btime f2($d, $x)
  415.283 μs (2140 allocations: 210.67 KiB)

julia> @btime f1($d, $x)
  414.055 μs (2146 allocations: 214.48 KiB)

@mcabbott
Copy link
Member

FWIW, running the above example:

julia> @timev f1(d, x)
 12.528920 seconds (61.88 M allocations: 3.262 GiB, 5.23% gc time, 99.97% compilation time)
elapsed time (ns): 12528919792
gc time (ns):      655876707
bytes allocated:   3502653566
pool allocs:       61852547
non-pool GC allocs:23328
malloc() calls:    103
realloc() calls:   20
GC pauses:         42
full collections:  1
Grads(...)

julia> @timev f2(d, x)
  0.456313 seconds (2.03 M allocations: 103.419 MiB, 5.52% gc time, 99.67% compilation time)
elapsed time (ns): 456313042
gc time (ns):      25197167
bytes allocated:   108442549
pool allocs:       2025083
non-pool GC allocs:271
GC pauses:         2
Grads(...)

julia> @btime f2($d, $x)
  min 171.500 μs, mean 209.663 μs (1631 allocations, 309.03 KiB. GC mean 11.82%)
Grads(...)

julia> @btime f1($d, $x)
  min 173.167 μs, mean 209.863 μs (1636 allocations, 312.83 KiB. GC mean 11.27%)
Grads(...)

And after re-starting, running them in the opposite order:

julia> @timev f2(d, x)
 12.708308 seconds (61.94 M allocations: 3.266 GiB, 5.74% gc time, 99.98% compilation time)
elapsed time (ns): 12708308041
gc time (ns):      729664123
bytes allocated:   3506520558
pool allocs:       61916466
non-pool GC allocs:23354
malloc() calls:    104
realloc() calls:   20
GC pauses:         39
full collections:  2
Grads(...)

julia> @timev f1(d, x)
  0.447610 seconds (1.98 M allocations: 100.910 MiB, 2.59% gc time, 99.69% compilation time)
elapsed time (ns): 447609750
gc time (ns):      11576334
bytes allocated:   105811871
pool allocs:       1980090
non-pool GC allocs:250
GC pauses:         1
Grads(...)

(@v1.8) pkg> st Flux Zygote
Status `~/.julia/environments/v1.8/Project.toml`
  [587475ba] Flux v0.12.8
  [e88e6eb3] Zygote v0.6.32

So there's 12s of startup time on whichever is first, but no obvious effect of where Flux.modules(m) is called.

@ToucheSir
Copy link
Member

With the benefit of hindsight, the first run is basically measuring any AD compilation overhead, while the second can benefit from caching. It also looks like the original issue is resolved, so worth opening a new one if anything persists.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants