Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move dropout to NNlib #2150

Merged
merged 11 commits into from
Feb 1, 2023
Merged

Move dropout to NNlib #2150

merged 11 commits into from
Feb 1, 2023

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jan 3, 2023

Needs FluxML/NNlib.jl#454

julia> model = Chain(Dense(10=>100), Dropout(0.5)); data = rand(Float32, 10, 100);

julia> @btime $model($data);
  min 4.536 μs, mean 11.995 μs (4 allocations, 78.22 KiB)
  min 4.692 μs, mean 14.468 μs (4 allocations, 78.22 KiB)

julia> @btime gradient(m -> sum(m($data)), $model);
  min 45.750 μs, mean 77.092 μs (176 allocations, 251.69 KiB)  # before
  min 32.125 μs, mean 55.435 μs (110 allocations, 209.89 KiB)  # after

julia> trainmode!(model);

julia> @btime $model($data);
  min 14.500 μs, mean 25.182 μs (17 allocations, 156.67 KiB)
  min 10.375 μs, mean 24.783 μs (6 allocations, 117.33 KiB)   # after

Comment on lines 70 to 72
function Dropout(p::Real; dims=:, rng = default_rng_value())
0 ≤ p ≤ 1 || throw(ArgumentError("Dropout expexts 0 ≤ p ≤ 1"))
Dropout(p, dims, nothing, rng)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Dropout is mutable, you can conceivably do d = Dropout(0.0); d.p = 0.1. Is this supported?

If not, then the constructor could do this:

p == 0 && return identity
p == 1 && return zero

That would mean that some model construction function which happens to have p==0 will do the right thing, without imposing a branch etc.

Copy link
Member

@ToucheSir ToucheSir Jan 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how to feel about this. @darsnack might be able to give a more definite answer because we ran into this question (when is it appropriate for layer constructors to return a different type) a couple times during the Metalhead rework.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6c84a6c adds this for Dropout(0), which you cannot mutate, but not Dropout(0.0). Not very sure that's desirable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed this during the ML community call. The decision is to keep things simple at the level of Flux/NNlib, i.e. we should return Dropout and not optimize it away. These kinds of optimization could be part of a tool that cuts and stitches the model back together with optimized parts (e.g. in Fluxperimental.jl).

if active
NNlib.dropout(rng, x, p; dims)
else
Base.depwarn("Flux.dropout(...; active=false) is deprecated. Please branch outside the function, or call dropout(x, 0) if you must.", :dropout)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this to preserve the active keyword. Not entirely sure whether use of that was supported outside Flux.

The exported function is this one, not the NNlib one, and I think it lacks a docstring at present.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it to FluxML/NNlib.jl#452 because backends need it, but to my knowledge none do for dropout so this seems fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You think it might be OK to just go without the deprecation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure. A quick JuliaHub search didn't turn up anything, but I've enabled downstream tests in case we want to check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now 0e396a6 removes this entirely. At least to see if anything changes.

@mcabbott mcabbott closed this Jan 5, 2023
@mcabbott mcabbott reopened this Jan 5, 2023
@mcabbott
Copy link
Member Author

mcabbott commented Jan 7, 2023

Downstream:

  • AtomicGraphNets unrelated
  • Metalhead loadmodel!(dst::Tuple{Chain{ "but the structures do not match.", unclear where the mismatch is & whether it's related. (Can that have a better error message? IIRC it's recursive, and could tell you what inner layer failed, rather than the whole outermost model.)
  • GeometricFlux fails as its ∇indexed_softmax says "UndefVarError: within_grad not defined" from here. Maybe Add within_gradient NNlib.jl#434 should have deprecated that or maybe it was obviously internal.

@ToucheSir
Copy link
Member

Metalhead load failure is known and can be ignored. Adding a better error message is tricky because loadmodel! is general-purpose, but that's not my area of expertise.

For GeometricFlux, I missed that the NNlib PR renamed the function. A courtesy deprecation seems fine here.

src/utils.jl Outdated Show resolved Hide resolved
@codecov-commenter
Copy link

codecov-commenter commented Jan 8, 2023

Codecov Report

Base: 86.36% // Head: 83.61% // Decreases project coverage by -2.76% ⚠️

Coverage data is based on head (fc9855b) compared to base (7997174).
Patch coverage: 75.00% of modified lines in pull request are covered.

📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2150      +/-   ##
==========================================
- Coverage   86.36%   83.61%   -2.76%     
==========================================
  Files          19       19              
  Lines        1482     1459      -23     
==========================================
- Hits         1280     1220      -60     
- Misses        202      239      +37     
Impacted Files Coverage Δ
src/deprecations.jl 46.37% <ø> (ø)
src/layers/normalise.jl 86.71% <75.00%> (-2.11%) ⬇️
src/cuda/cudnn.jl 0.00% <0.00%> (-90.91%) ⬇️
src/functor.jl 44.44% <0.00%> (-42.13%) ⬇️
src/layers/conv.jl 87.93% <0.00%> (-0.07%) ⬇️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@darsnack
Copy link
Member

We discussed this on call and resolved the outstanding default RNG issue (i.e. leave unchanged). I commented on other issues with the decision on call. Other than those discussions, this PR appears ready.

@mcabbott
Copy link
Member Author

mcabbott commented Jan 27, 2023

Ok. In its present state this PR doesn't change the RNG handling.

Looking briefly again, I wonder if we should remove the branch if _isactive(a, x) && a.p != 0 here (which is visible to Zygote) in favour of dropout(a.rng, x, a.p * _isactive(a, x); dims=a.dims). FluxML/NNlib.jl#462 made me more worried about branches.

Edit, some quick times:

julia> using Flux

julia> model = Chain(Dense(10=>100), Dropout(0.5)); data = rand(Float32, 10, 100);

julia> @btime $model($data);
  min 6.667 μs, mean 15.595 μs (4 allocations, 78.22 KiB)  # tagged
  min 6.917 μs, mean 18.080 μs (4 allocations, 78.22 KiB)  # PR yesterday
  min 7.014 μs, mean 15.138 μs (4 allocations, 78.22 KiB)  # without branch

julia> @btime gradient(m -> sum(m($data)), $model);
  min 79.792 μs, mean 115.690 μs (176 allocations, 251.69 KiB) # tagged
  min 49.500 μs, mean 97.422 μs (110 allocations, 209.89 KiB)  # PR yesterday
  min 43.625 μs, mean 73.894 μs (70 allocations, 209.02 KiB)   # without branch

julia> trainmode!(model);

julia> @btime $model($data);
  min 26.500 μs, mean 48.083 μs (17 allocations, 156.67 KiB)
  min 15.042 μs, mean 31.965 μs (6 allocations, 117.33 KiB)
  min 14.875 μs, mean 35.244 μs (6 allocations, 117.33 KiB)

@mcabbott
Copy link
Member Author

One more thought. Should we make test mode visible? It's a bit weird that the model has invisible state. Perhaps we should add a keyword to the constructor, and print this:

julia> Flux.trainmode!(m)
Chain(  
  Dense(2 => 3),                        # 9 parameters
  Dropout(0.4, active=true),
)

With more effort, we could make show check _isactive and write say # train mode, 9 parameters but not sure that's better.

@darsnack
Copy link
Member

Perhaps we should add a keyword to the constructor

I agree, and a keyword seems the best option to me.

@ToucheSir
Copy link
Member

We could make the active keyword behave like {train,test}mode!, where users can pass :auto or nothing. That would also help with display.

@mcabbott
Copy link
Member Author

Yes. Is active the right name? This is only used internally at present. Maybe it should follow {train,test}mode! and be called train=true for the active state?

@ToucheSir
Copy link
Member

I recall doing a quick survey of this at some point, but not where I posted the results. PyTorch uses training as inherited from torch.nn.Module. TF relies on a training kwarg at call time which the framework automagically sets. Flax has a deterministic field which it expects users to set themselves and provides no automatic functionality for.

@mcabbott
Copy link
Member Author

Thanks for digging. Maybe my vote is to leave this for now, and do it later to BatchNorm too.

src/layers/normalise.jl Outdated Show resolved Hide resolved
src/layers/normalise.jl Outdated Show resolved Hide resolved
p::F
dims::D
active::Union{Bool, Nothing}
rng::R
end
Dropout(p, dims, active) = Dropout(p, dims, active, default_rng_value())
Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng_value())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is intentional but the error checking seems to only apply to the keyword based constructor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's the only "public" one. I have no idea why we have this 3-arg constructor Dropout(p, dims, active), e.g. Functors will use the 4-arg one. Maybe it was in case someone was relying on it from before the rng field was added?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's my recollection.

src/layers/normalise.jl Outdated Show resolved Hide resolved
src/layers/normalise.jl Outdated Show resolved Hide resolved
@mcabbott
Copy link
Member Author

Maybe this is ready?

@mcabbott mcabbott merged commit e33de0c into FluxML:master Feb 1, 2023
@mcabbott mcabbott deleted the dropout branch February 1, 2023 01:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants