-
-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zygote error: UndefVarError: S not defined #1578
Comments
Interesting, thanks for pointing this out. Next step might be to check with older Zygote releases to see where this came up. I suspected the |
The source of the issue is this line. |
julia> d = Chain(Dense(3,3), Dense(3,3))
Chain(Dense(3, 3), Dense(3, 3))
julia> x = rand(Float32, 3,4);
julia> gs = gradient(Flux.params(d)) do
ds = Flux.modules(d)
sum(l(x) for l in ds) |> sum
end
Grads(...) |
|
There's a generator in the above expression as well. |
julia> regularized_params_(l::Flux.Dense) = [l.W]
regularized_params_ (generic function with 1 method)
julia> regularized_params_(l) = []
regularized_params_ (generic function with 2 methods)
julia> regularized_params_(l::Flux.Conv) = [l.weight]
regularized_params_ (generic function with 3 methods)
julia> d = Chain(Conv((3, 3), 1 => 2), Dense(3,3), Dense(3,3))
Chain(Conv((3, 3), 1=>2), Dense(3, 3), Dense(3, 3))
julia> gs = gradient(Flux.params(d)) do
ws = [w for l in Flux.modules(d) for w in regularized_params_(l)]
sum(sum(w) for w in ws)
end
ERROR: UndefVarError: S not defined
Stacktrace:
[1] (typeof(∂(λ)))(x::Tuple{typeof(∂(getproperty))})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:19
[2] _pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
[3] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::Type{Type{T}}, ::Val{:name})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[4] _pullback
@ ./promotion.jl:87 [inlined]
[5] _pullback(::Zygote.Context, ::typeof(typejoin), ::Type{Array{Float32, 4}}, ::Type{Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[6] _pullback
@ ./promotion.jl:149 [inlined]
[7] _pullback
@ ./array.jl:748 [inlined]
[8] _pullback(::Zygote.Context, ::typeof(Base.push_widen), ::Vector{Array{Float32, 4}}, ::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[9] _pullback
@ ./array.jl:767 [inlined]
[10] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Array{Float32, 4}}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, var"#40#42"}}, ::Tuple{Int64, Base.Generator{Vector{Array{Float32, 4}}, typeof(identity)}, Int64})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[11] _pullback
@ ./array.jl:743 [inlined]
[12] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Any}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, var"#40#42"}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[13] _pullback
@ ./array.jl:652 [inlined]
[14] _pullback(::Zygote.Context, ::typeof(Base._collect), ::UnitRange{Int64}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, var"#40#42"}}, ::Base.EltypeUnknown, ::Base.SizeUnknown)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[15] _pullback
@ ./array.jl:602 [inlined]
[16] _pullback(ctx::Zygote.Context, f::typeof(collect), args::Base.Iterators.Flatten{Base.Generator{Vector{Any}, var"#40#42"}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[17] _pullback
@ ./REPL[474]:2 [inlined]
[18] _pullback(::Zygote.Context, ::var"#39#41")
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[19] pullback(f::Function, ps::Params)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:247
[20] gradient(f::Function, args::Params)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:58
[21] top-level scope
@ REPL[474]:1 |
It's more subtle too. You need a julia> d = Chain(Dense(3, 3), Dense(3, 3))
Chain(Dense(3, 3), Dense(3, 3))
julia> gs = gradient(Flux.params(d)) do
ws = [w for l in Flux.modules(d) for w in regularized_params_(l)]
sum(sum(w) for w in ws)
end
ERROR: Compiling Tuple{Base.var"##depwarn#864", Bool, typeof(Base.depwarn), String, Symbol}: try/catch is not supported.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] instrument(ir::IRTools.Inner.IR)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:121
[3] #Primal#20
@ ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:202 [inlined]
[4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:315
[5] _lookup_grad(T::Type)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/emit.jl:101
[6] #s2993#1177
@ ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:34 [inlined]
[7] var"#s2993#1177"(T::Any, j::Any, Δ::Any)
@ Zygote ./none:0
[8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any, N} where N)
@ Core ./boot.jl:571
[9] Pullback
@ ./deprecated.jl:80 [inlined]
[10] (::typeof(∂(depwarn)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[11] Pullback
@ ~/.julia/packages/Flux/qp1gc/src/deprecations.jl:13 [inlined]
[12] (::typeof(∂(getproperty)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[13] Pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
[14] Pullback
@ ./REPL[457]:1 [inlined]
[15] (::typeof(∂(regularized_params_)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[16] Pullback
@ ./none:0 [inlined]
[17] (::typeof(∂(#44)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[18] Pullback
@ ./generator.jl:47 [inlined]
[19] (::typeof(∂(iterate)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[20] Pullback
@ ./iterators.jl:1093 [inlined]
[21] (::typeof(∂(iterate)))(Δ::Tuple{Nothing, Tuple{Nothing, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[22] Pullback
@ ./array.jl:761 [inlined]
[23] (::typeof(∂(grow_to!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[24] Pullback
@ ./array.jl:743 [inlined]
[25] (::typeof(∂(grow_to!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[26] Pullback
@ ./array.jl:652 [inlined]
[27] (::typeof(∂(_collect)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[28] Pullback
@ ./array.jl:602 [inlined]
[29] (::typeof(∂(collect)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[30] Pullback
@ ./REPL[476]:2 [inlined]
[31] (::typeof(∂(#43)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[32] (::Zygote.var"#69#70"{Params, typeof(∂(#43)), Zygote.Context})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:252
[33] gradient(f::Function, args::Params)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:59
[34] top-level scope
@ REPL[476]:1 I think if you look at the stack trace, you'll see a call to Even smaller reproducer without julia> gs = gradient(Flux.params(d)) do
ws = [w for l in d for w in regularized_params_(l)]
sum(sum(w) for w in ws)
end
ERROR: UndefVarError: S not defined
Stacktrace:
[1] (typeof(∂(λ)))(x::Tuple{typeof(∂(getproperty))})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:19
[2] _pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
[3] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::Type{Type{T}}, ::Val{:name})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[4] _pullback
@ ./promotion.jl:87 [inlined]
[5] _pullback(::Zygote.Context, ::typeof(typejoin), ::Type{Array{Float32, 4}}, ::Type{Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[6] _pullback
@ ./promotion.jl:149 [inlined]
[7] _pullback
@ ./array.jl:748 [inlined]
[8] _pullback(::Zygote.Context, ::typeof(Base.push_widen), ::Vector{Array{Float32, 4}}, ::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[9] _pullback
@ ./array.jl:767 [inlined]
[10] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Array{Float32, 4}}, ::Base.Iterators.Flatten{Base.Generator{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, var"#64#66"}}, ::Tuple{Int64, Base.Generator{Vector{Array{Float32, 4}}, typeof(identity)}, Int64})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[11] _pullback
@ ./array.jl:743 [inlined]
[12] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Any}, ::Base.Iterators.Flatten{Base.Generator{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, var"#64#66"}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[13] _pullback
@ ./array.jl:652 [inlined]
[14] _pullback(::Zygote.Context, ::typeof(Base._collect), ::UnitRange{Int64}, ::Base.Iterators.Flatten{Base.Generator{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, var"#64#66"}}, ::Base.EltypeUnknown, ::Base.SizeUnknown)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[15] _pullback
@ ./array.jl:602 [inlined]
[16] _pullback(ctx::Zygote.Context, f::typeof(collect), args::Base.Iterators.Flatten{Base.Generator{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, var"#64#66"}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[17] _pullback
@ ./REPL[484]:2 [inlined]
[18] _pullback(::Zygote.Context, ::var"#63#65")
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[19] pullback(f::Function, ps::Params)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:247
[20] gradient(f::Function, args::Params)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:58
[21] top-level scope
@ REPL[484]:1 |
I found that I needed different types for the elements in the generator, and I need nested generators (to force the call to julia> gs = gradient(Flux.params(d)) do
sum(sum(w) for m in d for w in regularized_params_(m))
end
ERROR: Compiling Tuple{Base.var"##depwarn#864", Bool, typeof(Base.depwarn), String, Symbol}: try/catch is not supported.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] instrument(ir::IRTools.Inner.IR)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:121
[3] #Primal#20
@ ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:202 [inlined]
[4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:315
[5] _lookup_grad(T::Type)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/emit.jl:101
[6] #s2993#1177
@ ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:34 [inlined]
[7] var"#s2993#1177"(T::Any, j::Any, Δ::Any)
@ Zygote ./none:0
[8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any, N} where N)
@ Core ./boot.jl:571
[9] Pullback
@ ./deprecated.jl:80 [inlined]
[10] (::typeof(∂(depwarn)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[11] Pullback
@ ~/.julia/packages/Flux/qp1gc/src/deprecations.jl:13 [inlined]
[12] (::typeof(∂(getproperty)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[13] Pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
[14] Pullback
@ ./REPL[457]:1 [inlined]
[15] (::typeof(∂(regularized_params_)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[16] Pullback
@ ./none:0 [inlined]
[17] (::typeof(∂(#60)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[18] Pullback
@ ./reduce.jl:93 [inlined]
[19] (::typeof(∂(Base.MappingRF{var"#60#62", Base.FlatteningRF{Base.BottomRF{typeof(Base.add_sum)}}}(var"#60#62"(), Base.FlatteningRF{Base.BottomRF{typeof(Base.add_sum)}}(Base.BottomRF{typeof(Base.add_sum)}(Base.add_sum))))))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[20] Pullback
@ ./reduce.jl:62 [inlined]
[21] (::typeof(∂(_foldl_impl)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[22] Pullback
@ ./reduce.jl:48 [inlined]
[23] (::typeof(∂(foldl_impl)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[24] Pullback
@ ./reduce.jl:44 [inlined]
[25] (::typeof(∂(mapfoldl_impl)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[26] Pullback (repeats 2 times)
@ ./reduce.jl:160 [inlined]
[27] (::typeof(∂(mapfoldl)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[28] Pullback
@ ./reduce.jl:287 [inlined]
[29] (::typeof(∂(#mapreduce#218)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[30] Pullback
@ ./reduce.jl:287 [inlined]
[31] (::typeof(∂(mapreduce)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[32] Pullback
@ ./reduce.jl:501 [inlined]
[33] (::typeof(∂(#sum#221)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[34] Pullback
@ ./reduce.jl:501 [inlined]
[35] (::typeof(∂(sum)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[36] Pullback
@ ./reduce.jl:528 [inlined]
[37] (::typeof(∂(#sum#222)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[38] Pullback
@ ./reduce.jl:528 [inlined]
[39] (::typeof(∂(sum)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[40] Pullback
@ ./REPL[483]:2 [inlined]
[41] (::typeof(∂(#59)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[42] (::Zygote.var"#69#70"{Params, typeof(∂(#59)), Zygote.Context})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:252
[43] gradient(f::Function, args::Params)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:59
[44] top-level scope
@ REPL[483]:1 |
Okay finally got a MWE: julia> gs = gradient(params(d)) do
x = typejoin(Array{Float32, 4}, Array{Float32, 2})
return 1
end
ERROR: UndefVarError: S not defined
Stacktrace:
[1] (typeof(∂(λ)))(x::Tuple{typeof(∂(getproperty))})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:19
[2] _pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
[3] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::Type{Type{T}}, ::Val{:name})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[4] _pullback
@ ./promotion.jl:87 [inlined]
[5] _pullback(::Zygote.Context, ::typeof(typejoin), ::Type{Array{Float32, 4}}, ::Type{Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[6] _pullback
@ ./REPL[486]:2 [inlined]
[7] _pullback(::Zygote.Context, ::var"#67#68")
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[8] pullback(f::Function, ps::Params)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:247
[9] gradient(f::Function, args::Params)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:58
[10] top-level scope
@ REPL[486]:1 |
Just opened FluxML/Zygote.jl#947 Looks like we came to the same conclusion lol. @darsnack added you as a co-author (hope you don't mind) |
I updated the Manifest to use FluxML/Zygote.jl#947 but now I am seeing a different error.
|
I addressed that in JuliaDiff/ChainRules.jl#398. |
Should be closed in JuliaDiff/ChainRules.jl#398 |
Now it works if I define this: function Network.regularized_params(net::FluxNetwork)
return (w for l in Flux.modules(net) for w in regularized_params_(l))
end Just so you know, I still have an error if I use this definition instead (returning an array instead of a generator): function Network.regularized_params(net::FluxNetwork)
return [w for l in Flux.modules(net) for w in regularized_params_(l)]
end I am not sure whether or not this should be considered a bug. If not, how should I modify the second example to make it work? Would I need to use a
|
Interesting is this with the chain rules update? |
Yes, this is with the chain rule update. |
So this is in fact doing array mutation internally, so I'm comfortable to say that Zygote is indeed showing the expected behavior, however, nested generators are something that we had working before. For now I'd return a literal generator and look into the code gen separately |
Interesting update: This issue is about making sure that code like this works with Flux: gs = gradient(Flux.params(d)) do
ds = Flux.modules(d)
sum(l(x) for l in ds) |> sum
end However, the code above appears to be very slow. Indeed, I achieved a 3x speedup (and a tenfold reduction in the number of GPU allocs) in the backprop phase of AlphaZero.jl by essentially replacing the code above by: ds = Flux.modules(d)
gs = gradient(Flux.params(d)) do
sum(l(x) for l in ds) |> sum
end For more details, see the exact commit: : jonathan-laurent/AlphaZero.jl@b8bac93 Here is the output of
Although it is understandable for the second version to be faster, I definitely did not expect such a gap. Note that this is possibly the result of a recent regression as I think I would have noticed this before otherwise. |
How recent? Unless by recent you mean within the last couple weeks (i.e. comparing Maybe a good cross-check would be to replace EDIT: Actually they are very similar. |
Before Flux v0.12, I was relying on my own implementation of See this commit: jonathan-laurent/AlphaZero.jl@7bbb2cb#diff-837cea9edf0b0d7507b695b63c300d02574ec50c6b3227ee6fe4701544b3e97a |
Could it be that |
It has been |
Ah right, disregard that then. I played around with the examples above on CPU and GPU. Not sure why using d = Chain((Dense(32, 32) for i in 1:5)...)
x = rand(Float32, 32, 8)
f1(m, x) = gradient(Flux.params(m)) do
ds = Flux.modules(m)
sum(l(x) for l in ds) |> sum
end
function f2(m, x)
ds = Flux.modules(m)
gradient(Flux.params(m)) do
sum(l(x) for l in ds) |> sum
end
end
julia> @timev f1(d, x)
24.901731 seconds (63.67 M allocations: 3.610 GiB, 4.22% gc time)
elapsed time (ns): 24901730841
gc time (ns): 1051987239
bytes allocated: 3875757116
pool allocs: 63656612
non-pool GC allocs:15360
malloc() calls: 88
realloc() calls: 15
GC pauses: 53
full collections: 2
julia> @timev f2(d, x)
1.086023 seconds (2.15 M allocations: 119.429 MiB, 2.75% gc time, 98.94% compilation time)
elapsed time (ns): 1086023325
gc time (ns): 29910632
bytes allocated: 125229937
pool allocs: 2152312
non-pool GC allocs:161
GC pauses: 3
julia> @btime f2($d, $x)
415.283 μs (2140 allocations: 210.67 KiB)
julia> @btime f1($d, $x)
414.055 μs (2146 allocations: 214.48 KiB) |
FWIW, running the above example:
And after re-starting, running them in the opposite order:
So there's 12s of startup time on whichever is first, but no obvious effect of where |
With the benefit of hindsight, the first run is basically measuring any AD compilation overhead, while the second can benefit from caching. It also looks like the original issue is resolved, so worth opening a new one if anything persists. |
I am trying to update AlphaZero.jl so that it works with Flux v0.12 but I am stuck on the following Zygote error:
The bug happens on both [email protected] and Flux#master.
Replication instructions
To replicate, you can run the following using Julia 1.6:
Full stacktrace
The text was updated successfully, but these errors were encountered: