-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Different RNG used during first gradient call on Julia 1.9 #1351
Comments
I assume the change to the primal does not occur outside of a |
That's right. The first call to the primal, as well as subsequent calls, all agree with Zygote's result for Zygote's subsequent calls (i.e. out = |
But if I make five calls to the primal without reseting the seed, I get what Zygote's first call gives us! julia> using Distributions, Random
julia> function f(x)
out = rand(Normal(x, x))
@show out
return out
end
f (generic function with 1 method)
julia> Random.seed!(123)
TaskLocalRNG()
julia> f(.5)
out = 0.17713466394801164
0.17713466394801164
julia> f(.5)
out = -0.23162568944446071
-0.23162568944446071
julia> f(.5)
out = -0.3118018727930403
-0.3118018727930403
julia> f(.5)
out = 0.3911674466082269
0.3911674466082269
julia> f(.5)
out = 0.7461228432625914
0.7461228432625914 |
I am able to reproduce similar behaviour without julia> using Zygote, Random
julia> function f(x)
#=
The sqrt actually makes this not equivalent to the previous sampling procedure.
But with it, we can reproduce the same bug.
=#
out = sqrt(x) * randn() + x
@show out
return out
end
f (generic function with 1 method)
julia> Random.seed!(123)
TaskLocalRNG()
julia> Zygote.gradient(f, .5) # Take the gradient for the first time
out = 1.193657477360091
(1.693657477360091,)
julia> Random.seed!(123)
TaskLocalRNG()
julia> Zygote.gradient(f, .5) # Take the gradient subsequent times
out = 0.04339946293513103
(0.5433994629351311,)
julia> Random.seed!(123)
TaskLocalRNG()
julia> Zygote.gradient(f, .5)
out = 0.04339946293513103
(0.5433994629351311,) In this case, the RNG state of Zygote when evaluating the gradient for the first time seems to be equivalent to what the RNG state on the sixth primal call would be without resetting seed. |
Is the compiler possibly pulling values from the RNG on 1.9? I know there were some sorting algorithm changes made internally recently. |
Good guess! I ran a bisect and the bug is caused by JuliaLang/julia#45222. |
I've run into a very strange bug on Julia 1.9, where Zygote is giving a different result (for both primal and gradient) when taking the gradient of a random function from
Distributions.jl
for the first time. (The same seed is used in all cases). This bug led to an error when I tried to run a package's test suite with 1.9.Here is the MWE:
The version info where the above occured is:
and my environment was
This does not occur on Julia 1.8.1 for me, with the same environment. I also haven't been able to reproduce it without
Distributions.jl
, which muddles the pot further.The text was updated successfully, but these errors were encountered: