-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dropout
#454
Add dropout
#454
Conversation
One trick for GPU is to save the RNG state and lazily generate the mask in a kernel instead of materializing it (see https://triton-lang.org/master/getting-started/tutorials/04-low-memory-dropout.html for an implementation). This works because the default RNG is counter-based and doesn't have any dependence on previous states. Wikipedia tells me Xoshiro does not fall into this category, but we'd likely be doing a linear traversal on CPU and allocations are cheap(er) anyhow. All this to say that as long as things aren't slower on GPU, we can always optimize them separately in NNlibCUDA. |
Ah that's an idea I did not consider, would probably pay on CPU too. Copying Xoshiro is certainly fast and appears to work. It's just some bits... where do you read that it has other state?
What I'm not sure of is whether A quick attempt is not quite as fast, but close, and saves memory.
Edit: Here's a test with CUDA, which seems to reliably repeat on small sizes, but fail on large arrays. I presume that means it's dividing the work up in a non-deterministic way. (Maybe there are ways to control that?)
|
I was going off https://en.wikipedia.org/wiki/Counter-based_random_number_generator_(CBRNG)#Background. It's not that it has other state, but that you can't get the Nth random value without first calculating some previous N-x values to advance the internal state. Whereas for the Philox RNG CUDA.jl uses, you just set the counter to however many steps forward you want and it gives you a value.
This was my concern as well. We could always use dispatch to have a fallback path for unknown RNGs that materializes the mask. |
src/dropout.jl
Outdated
similar(A, T, ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))) | ||
end | ||
rand!(rng, keep) | ||
Y = @. (keep>p) * A * val |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the mask idea doesn't work out, another idea for making this type stable while still allowing short-circuiting on 0 and 1 is to make a lazy mask wrapper which only creates the mask if necessary. Depending on how devious we want to be, we can even use promote_op
or a custom mask_type(xtype::Type{A}, ...) where A <: AbstractArray
to determine what type the wrapped mask should be (since we can't determine it in the short-circuiting case).
Ah but that's fine I think. We don't need to skip into the future. We only need that the two copies used forward and backward produce the same, starting from whatever we happen to start.
Oh man, I'd be quite keen not to have to maintain two paths. (Already the no-gradient and gradient versions differ a bit here...) |
Yes, the skipping is more relevant on GPU. IIRC there's a way to split a On divergent paths, I don't think it's too too bad. The dispatch split point could be on what Flux calls |
I think we should do this, it's a step forward. The copy-the-RNG story does work, and save memory on the GPU. But needs some constructor methods, and perhaps consultation with experts about safety. And copying MerseinTwister on 1.6 is pretty slow... we should kick the can down the road I think. No chance that changes the interface. This saves memory & time over what's now in Flux, and will let us remove that. It |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No objections from me. All the aforementioned tricks can be done as additive or at least non-breaking changes in future PRs.
""" | ||
_rng_from_array(x) | ||
|
||
Return the random number generator most appropriate for `x`: | ||
`CUDA.default_rng()` for `CuArray`, else `Random.default_rng()` | ||
""" | ||
_rng_from_array(::AbstractArray) = Random.default_rng() | ||
|
||
@non_differentiable _rng_from_array(::Any) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason not to copy https://github.com/FluxML/Flux.jl/blob/ee78ce3cefb027228413f8edace3c0385139d786/src/utils.jl#L36-L49 wholesale (minus the CUDA overload)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just simplifying a bit, I couldn't figure out why there were so many functions. Why make a different choice on 1.6
julia> using Random
julia> Random.default_rng()
MersenneTwister(0x9687b6121c4ccb062f473c9c3c8bccc6)
julia> Random.GLOBAL_RNG
Random._GLOBAL_RNG()
julia> VERSION
v"1.6.0"
compared to master:
julia> using Random
julia> Random.default_rng()
TaskLocalRNG()
julia> Random.GLOBAL_RNG
Random._GLOBAL_RNG()
julia> VERSION
v"1.10.0-DEV.204"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't remember now, but based on FluxML/Flux.jl#1849 (comment) it might've been related to thread safety?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cthulhu tells me that rand(...)
uses default_rng()
on 1.6 as well and it returns a thread-local RNG, so maybe this was much ado about nothing. cc @darsnack if I've missed something though, and I think this function can be public like the Flux one.
This adds a
dropout
function.At least on CPU, this should be about twice as fast as the one in Flux, and use less memory. There are some Flux PRs, but the code is messy, and I got lost & started from scratch.
It's not quite the fastest variant. The gradient can be faster if we believe that every zero in the output comes from the mask, never from the input. That may not be a terrible assumption -- if you don't have a
relu
before it, zeros are rare; if you do, then the gradient associated to a zero will be discarded anyway. But... perhaps it's best to be correct.I can put a gist of these variations somewhere. But local benchmarks are:
Edit: timing things on the GPU, at first the
dims=:
case seemed slow, but now it's not...