Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different RNG used during first gradient call on Julia 1.9 #1351

Open
gaurav-arya opened this issue Jan 10, 2023 · 6 comments
Open

Different RNG used during first gradient call on Julia 1.9 #1351

gaurav-arya opened this issue Jan 10, 2023 · 6 comments

Comments

@gaurav-arya
Copy link

gaurav-arya commented Jan 10, 2023

I've run into a very strange bug on Julia 1.9, where Zygote is giving a different result (for both primal and gradient) when taking the gradient of a random function from Distributions.jl for the first time. (The same seed is used in all cases). This bug led to an error when I tried to run a package's test suite with 1.9.

Here is the MWE:

julia> using Zygote, Distributions, Random

julia> function f(x) 
           out = rand(Normal(x, x))
           @show out
           return out
       end
f (generic function with 1 method)

julia> Random.seed!(123)
TaskLocalRNG()

julia> Zygote.gradient(f, .5) # Take the gradient for the first time
out = 0.7461228432625914
(1.4922456865251827,)

julia> Random.seed!(123)
TaskLocalRNG()

julia> Zygote.gradient(f, .5) # Take the gradient subsequent times
out = 0.17713466394801164
(0.3542693278960233,)

julia> Random.seed!(123)
TaskLocalRNG()

julia> Zygote.gradient(f, .5)
out = 0.17713466394801164
(0.3542693278960233,)

The version info where the above occured is:

julia> versioninfo()
Julia Version 1.9.0-beta2
Commit 7daffeecb8c (2022-12-29 07:45 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × 11th Gen Intel(R) Core(TM) i7-1165G7 @ 2.80GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, tigerlake)
  Threads: 1 on 8 virtual cores

and my environment was

  [31c24e10] Distributions v0.25.79
  [e88e6eb3] Zygote v0.6.52

This does not occur on Julia 1.8.1 for me, with the same environment. I also haven't been able to reproduce it without Distributions.jl, which muddles the pot further.

@gaurav-arya gaurav-arya changed the title Different RNG used during compile time on Julia 1.9 Different RNG used during gradient compile time on Julia 1.9 Jan 10, 2023
@darsnack
Copy link
Member

I assume the change to the primal does not occur outside of a gradient call (i.e. just f(0.5))?

@gaurav-arya
Copy link
Author

That's right. The first call to the primal, as well as subsequent calls, all agree with Zygote's result for Zygote's subsequent calls (i.e. out = 0.177...)

@gaurav-arya
Copy link
Author

gaurav-arya commented Jan 10, 2023

But if I make five calls to the primal without reseting the seed, I get what Zygote's first call gives us!

julia> using Distributions, Random

julia> function f(x) 
           out = rand(Normal(x, x))
           @show out
           return out
       end
f (generic function with 1 method)

julia> Random.seed!(123)
TaskLocalRNG()

julia> f(.5)
out = 0.17713466394801164
0.17713466394801164

julia> f(.5)
out = -0.23162568944446071
-0.23162568944446071

julia> f(.5)
out = -0.3118018727930403
-0.3118018727930403

julia> f(.5)
out = 0.3911674466082269
0.3911674466082269

julia> f(.5)
out = 0.7461228432625914
0.7461228432625914

@gaurav-arya
Copy link
Author

gaurav-arya commented Jan 10, 2023

I am able to reproduce similar behaviour without Distributions:

julia> using Zygote, Random

julia> function f(x) 
           #=
           The sqrt actually makes this not equivalent to the previous sampling procedure. 
           But with it, we can reproduce the same bug.
           =#
           out = sqrt(x) * randn() + x
           @show out
           return out
       end
f (generic function with 1 method)

julia> Random.seed!(123)
TaskLocalRNG()

julia> Zygote.gradient(f, .5) # Take the gradient for the first time
out = 1.193657477360091
(1.693657477360091,)

julia> Random.seed!(123)
TaskLocalRNG()

julia> Zygote.gradient(f, .5) # Take the gradient subsequent times
out = 0.04339946293513103
(0.5433994629351311,)

julia> Random.seed!(123)
TaskLocalRNG()

julia> Zygote.gradient(f, .5)
out = 0.04339946293513103
(0.5433994629351311,)

In this case, the RNG state of Zygote when evaluating the gradient for the first time seems to be equivalent to what the RNG state on the sixth primal call would be without resetting seed.

@gaurav-arya gaurav-arya changed the title Different RNG used during gradient compile time on Julia 1.9 Different RNG used during first gradient call on Julia 1.9 Jan 10, 2023
@ToucheSir
Copy link
Member

Is the compiler possibly pulling values from the RNG on 1.9? I know there were some sorting algorithm changes made internally recently.

@gaurav-arya
Copy link
Author

Good guess! I ran a bisect and the bug is caused by JuliaLang/julia#45222.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants