-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move dropout
to NNlib
#2150
Move dropout
to NNlib
#2150
Conversation
src/layers/normalise.jl
Outdated
function Dropout(p::Real; dims=:, rng = default_rng_value()) | ||
0 ≤ p ≤ 1 || throw(ArgumentError("Dropout expexts 0 ≤ p ≤ 1")) | ||
Dropout(p, dims, nothing, rng) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since Dropout
is mutable, you can conceivably do d = Dropout(0.0); d.p = 0.1
. Is this supported?
If not, then the constructor could do this:
p == 0 && return identity
p == 1 && return zero
That would mean that some model construction function which happens to have p==0
will do the right thing, without imposing a branch etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know how to feel about this. @darsnack might be able to give a more definite answer because we ran into this question (when is it appropriate for layer constructors to return a different type) a couple times during the Metalhead rework.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6c84a6c adds this for Dropout(0)
, which you cannot mutate, but not Dropout(0.0)
. Not very sure that's desirable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We discussed this during the ML community call. The decision is to keep things simple at the level of Flux/NNlib, i.e. we should return Dropout
and not optimize it away. These kinds of optimization could be part of a tool that cuts and stitches the model back together with optimized parts (e.g. in Fluxperimental.jl).
src/deprecations.jl
Outdated
if active | ||
NNlib.dropout(rng, x, p; dims) | ||
else | ||
Base.depwarn("Flux.dropout(...; active=false) is deprecated. Please branch outside the function, or call dropout(x, 0) if you must.", :dropout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made this to preserve the active
keyword. Not entirely sure whether use of that was supported outside Flux.
The exported function is this one, not the NNlib one, and I think it lacks a docstring at present.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added it to FluxML/NNlib.jl#452 because backends need it, but to my knowledge none do for dropout so this seems fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You think it might be OK to just go without the deprecation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure. A quick JuliaHub search didn't turn up anything, but I've enabled downstream tests in case we want to check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now 0e396a6 removes this entirely. At least to see if anything changes.
Downstream:
|
Metalhead load failure is known and can be ignored. Adding a better error message is tricky because For GeometricFlux, I missed that the NNlib PR renamed the function. A courtesy deprecation seems fine here. |
Codecov ReportBase: 86.36% // Head: 83.61% // Decreases project coverage by
📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more Additional details and impacted files@@ Coverage Diff @@
## master #2150 +/- ##
==========================================
- Coverage 86.36% 83.61% -2.76%
==========================================
Files 19 19
Lines 1482 1459 -23
==========================================
- Hits 1280 1220 -60
- Misses 202 239 +37
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
We discussed this on call and resolved the outstanding default RNG issue (i.e. leave unchanged). I commented on other issues with the decision on call. Other than those discussions, this PR appears ready. |
Ok. In its present state this PR doesn't change the RNG handling. Looking briefly again, I wonder if we should remove the branch Edit, some quick times: julia> using Flux
julia> model = Chain(Dense(10=>100), Dropout(0.5)); data = rand(Float32, 10, 100);
julia> @btime $model($data);
min 6.667 μs, mean 15.595 μs (4 allocations, 78.22 KiB) # tagged
min 6.917 μs, mean 18.080 μs (4 allocations, 78.22 KiB) # PR yesterday
min 7.014 μs, mean 15.138 μs (4 allocations, 78.22 KiB) # without branch
julia> @btime gradient(m -> sum(m($data)), $model);
min 79.792 μs, mean 115.690 μs (176 allocations, 251.69 KiB) # tagged
min 49.500 μs, mean 97.422 μs (110 allocations, 209.89 KiB) # PR yesterday
min 43.625 μs, mean 73.894 μs (70 allocations, 209.02 KiB) # without branch
julia> trainmode!(model);
julia> @btime $model($data);
min 26.500 μs, mean 48.083 μs (17 allocations, 156.67 KiB)
min 15.042 μs, mean 31.965 μs (6 allocations, 117.33 KiB)
min 14.875 μs, mean 35.244 μs (6 allocations, 117.33 KiB) |
Co-authored-by: Carlo Lucibello <[email protected]>
One more thought. Should we make test mode visible? It's a bit weird that the model has invisible state. Perhaps we should add a keyword to the constructor, and print this:
With more effort, we could make |
I agree, and a keyword seems the best option to me. |
We could make the |
Yes. Is |
I recall doing a quick survey of this at some point, but not where I posted the results. PyTorch uses |
Thanks for digging. Maybe my vote is to leave this for now, and do it later to BatchNorm too. |
p::F | ||
dims::D | ||
active::Union{Bool, Nothing} | ||
rng::R | ||
end | ||
Dropout(p, dims, active) = Dropout(p, dims, active, default_rng_value()) | ||
Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng_value()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is intentional but the error checking seems to only apply to the keyword based constructor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's the only "public" one. I have no idea why we have this 3-arg constructor Dropout(p, dims, active)
, e.g. Functors will use the 4-arg one. Maybe it was in case someone was relying on it from before the rng
field was added?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's my recollection.
Co-authored-by: Kyle Daruwalla <[email protected]>
Maybe this is ready? |
Needs FluxML/NNlib.jl#454