Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move dropout to NNlib #2150

Merged
merged 11 commits into from
Feb 1, 2023
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ ChainRulesCore = "1.12"
Functors = "0.3, 0.4"
MLUtils = "0.2, 0.3.1, 0.4"
MacroTools = "0.5"
NNlib = "0.8.14"
NNlibCUDA = "0.2.4"
NNlib = "0.8.15"
NNlibCUDA = "0.2.6"
OneHotArrays = "0.1, 0.2"
Optimisers = "0.2.12"
ProgressLogging = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ LayerNorm
InstanceNorm
GroupNorm
Flux.normalise
Flux.dropout
NNlib.dropout
```

### Test vs. Train
Expand Down
1 change: 1 addition & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ function update!(opt::Optimise.AbstractOptimiser, ::Params, grads::Union{Tuple,
""")
end


# v0.14 deprecations

# Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc:
Expand Down
124 changes: 49 additions & 75 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
@@ -1,111 +1,85 @@

# Internal function, used only for layers defined in this file.
_isactive(m, x) = isnothing(m.active) ? NNlib.within_gradient(x) : m.active

_dropout_shape(s, ::Colon) = size(s)
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)

_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)

"""
dropout([rng = rng_from_array(x)], x, p; dims=:, active=true)

The dropout function. If `active` is `true`,
for each input, either sets that input to `0` (with probability
`p`) or scales it by `1 / (1 - p)`. `dims` specifies the unbroadcasted dimensions,
e.g. `dims=1` applies dropout along columns and `dims=2` along rows.
If `active` is `false`, it just returns the input `x`.

Specify `rng` for custom RNGs instead of the default RNG.
Note that custom RNGs are only supported on the CPU.

Warning: when using this function, you have to manually manage the activation
state. Usually in fact, dropout is used while training
but is deactivated in the inference phase. This can be
automatically managed using the [`Dropout`](@ref) layer instead of the
`dropout` function.

The [`Dropout`](@ref) layer is what you should use in most scenarios.
"""
function dropout(rng, x, p; dims=:, active::Bool=true)
active || return x
y = dropout_mask(rng, x, p, dims=dims)
return x .* y
end
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)

dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
dropout_mask(rng, x::CuArray, p; kwargs...) =
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
function _dropout_mask(rng, x, p; dims=:)
realfptype = float(real(eltype(x)))
y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims)))
y .= _dropout_kernel.(y, p, 1 - p)
return y
end

# TODO move this to NNlib
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)

"""
Dropout(p; dims=:, rng = default_rng_value())
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

Dropout layer.
Layer implementing [dropout](https://arxiv.org/abs/1207.0580) with the given probability.
This is used as a regularisation, i.e. to reduce overfitting.

While training, for each input, this layer either sets that input to `0` (with probability
`p`) or scales it by `1 / (1 - p)`. To apply dropout along certain dimension(s), specify the
`dims` keyword. e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input
(also called 2D dropout). This is used as a regularisation, i.e. it reduces overfitting during
training.
While training, it sets each input to `0` (with probability `p`)
or else scales it by `1 / (1 - p)`, using the [`NNlib.dropout`](@ref) function.
While testing, it has no effect.

In the forward pass, this layer applies the [`Flux.dropout`](@ref) function. See that for more
details.
By defaul the mode will switch automatically, but it can also
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
be controlled manually via [`Flux.testmode!`](@ref).

Specify `rng` to use a custom RNG instead of the default.
Custom RNGs are only supported on the CPU.
By default every input is treated independently. The `dims` keyword
instead takes a random choice only along that dimension.
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
For example `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input
(also called 2D dropout).

Does nothing to the input once [`Flux.testmode!`](@ref) is `true`.
Keyword `rng` lets you specify a custom random number generator.
(Only supported on the CPU.)

# Examples
```jldoctest
julia> m = Chain(Dense(1 => 1), Dropout(1));
```julia
julia> m = Chain(Dense(ones(3,2)), Dropout(0.4))
Chain(
Dense(2 => 3), # 9 parameters
Dropout(0.4),
)

julia> Flux.trainmode!(m);
julia> m(ones(2, 7)) # test mode, no effect
3×7 Matrix{Float64}:
2.0 2.0 2.0 2.0 2.0 2.0 2.0
2.0 2.0 2.0 2.0 2.0 2.0 2.0
2.0 2.0 2.0 2.0 2.0 2.0 2.0

julia> y = m([1]);
julia> Flux.trainmode!(m); # would happen within gradient

julia> y == [0]
true
julia> m(ones(2, 7))
3×7 Matrix{Float64}:
0.0 0.0 3.33333 0.0 0.0 0.0 0.0
3.33333 0.0 3.33333 0.0 3.33333 0.0 3.33333
3.33333 3.33333 0.0 3.33333 0.0 0.0 3.33333

julia> m = Chain(Dense(1000 => 1000), Dropout(0.5));
julia> y = m(ones(2, 10_000));

julia> Flux.trainmode!(m);
julia> using Statistics

julia> y = m(ones(1000));
julia> mean(y) # is about 2.0, as for test mode
1.9892222222222182

julia> isapprox(count(==(0), y) / length(y), 0.5, atol=0.1)
true
julia> mean(iszero, y) # is about 0.4
0.40323333333333333
```
"""
mutable struct Dropout{F,D,R<:AbstractRNG}
mutable struct Dropout{F<:Real,D,R<:AbstractRNG}
p::F
dims::D
active::Union{Bool, Nothing}
rng::R
end
Dropout(p, dims, active) = Dropout(p, dims, active, default_rng_value())
Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng_value())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is intentional but the error checking seems to only apply to the keyword based constructor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's the only "public" one. I have no idea why we have this 3-arg constructor Dropout(p, dims, active), e.g. Functors will use the 4-arg one. Maybe it was in case someone was relying on it from before the rng field was added?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's my recollection.


function Dropout(p; dims=:, rng = default_rng_value())
@assert 0 ≤ p ≤ 1
function Dropout(p::Real; dims=:, rng = default_rng_value())
0 ≤ p ≤ 1 || throw(ArgumentError("Dropout expexts 0 ≤ p ≤ 1, got p = $p"))
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
if p isa Integer # Dropout(0)
return p==0 ? identity : zero
end
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
Dropout(p, dims, nothing, rng)
end

@functor Dropout
trainable(a::Dropout) = (;)

function (a::Dropout)(x)
_isactive(a, x) || return x
return dropout(a.rng, x, a.p; dims=a.dims, active=true)
if _isactive(a, x) && a.p != 0
dropout(a.rng, x, a.p; dims=a.dims)
else
x
end
end

testmode!(m::Dropout, mode=true) =
Expand Down
6 changes: 3 additions & 3 deletions test/layers/normalisation.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Flux, Test, Statistics
using Flux, Test, Statistics, Random
using Zygote: pullback, ForwardDiff

evalwgrad(f, x...) = pullback(f, x...)[1]
Expand Down Expand Up @@ -56,10 +56,10 @@ evalwgrad(f, x...) = pullback(f, x...)[1]
y = m(x)
@test count(a->a == 0, y) > 50

y = Flux.dropout(values(rng_kwargs)..., x, 0.9, active=true)
y = Flux.dropout(values(rng_kwargs)..., x, 0.9) # , active=true)
@test count(a->a == 0, y) > 50

y = Flux.dropout(values(rng_kwargs)..., x, 0.9, active=false)
y = Flux.dropout(values(rng_kwargs)..., x, 0.9 * 0) # , active=false)
@test count(a->a == 0, y) == 0

# CPU RNGs map onto CPU ok
Expand Down