Skip to content

Commit

Permalink
simplify default_rng etc
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jan 7, 2023
1 parent 4ab93b3 commit 9e61486
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 43 deletions.
1 change: 1 addition & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Flux

using Base: tail
using LinearAlgebra, Statistics, Random # standard lib
using Random: default_rng
using MacroTools, Reexport, ProgressLogging, SpecialFunctions
using MacroTools: @forward

Expand Down
15 changes: 2 additions & 13 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ Base.@deprecate_binding ADADelta AdaDelta
# Remove sub-module Data, while making sure Flux.Data.DataLoader keeps working
Base.@deprecate_binding Data Flux false "Sub-module Flux.Data has been removed. The only thing it contained may be accessed as Flux.DataLoader"

@deprecate rng_from_array() default_rng_value()

function istraining()
Base.depwarn("Flux.istraining() is deprecated, use NNlib.within_gradient(x) instead", :istraining)
false
Expand Down Expand Up @@ -185,17 +183,8 @@ function update!(opt::Optimise.AbstractOptimiser, ::Params, grads::Union{Tuple,
""")
end


function dropout(rng, x, p; dims=:, active::Bool=true)
if active
NNlib.dropout(rng, x, p; dims)
else
Base.depwarn("Flux.dropout(...; active=false) is deprecated. Please branch outside the function, or call dropout(x, 0) if you must.", :dropout)
return x
end
end
dropout(x, p; kwargs...) = dropout(NNlib._rng_from_array(x), x, p; kwargs...)

@deprecate rng_from_array() default_rng_value()
@deprecate default_rng_value() Random.default_rng()

# v0.14 deprecations

Expand Down
12 changes: 6 additions & 6 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
_isactive(m, x) = isnothing(m.active) ? NNlib.within_gradient(x) : m.active

"""
Dropout(p; dims=:, rng = default_rng_value())
Dropout(p; dims=:, rng = default_rng())
Layer implementing [dropout](https://arxiv.org/abs/1207.0580) with the given probability.
This is used as a regularisation, i.e. to reduce overfitting.
Expand Down Expand Up @@ -61,9 +61,9 @@ mutable struct Dropout{F<:Real,D,R<:AbstractRNG}
active::Union{Bool, Nothing}
rng::R
end
Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng_value())
Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng())

function Dropout(p::Real; dims=:, rng = default_rng_value())
function Dropout(p::Real; dims=:, rng = default_rng())
0 p 1 || throw(ArgumentError("Dropout expexts 0 ≤ p ≤ 1, got p = $p"))
if p isa Integer # Dropout(0)
return p==0 ? identity : zero
Expand Down Expand Up @@ -92,7 +92,7 @@ function Base.show(io::IO, d::Dropout)
end

"""
AlphaDropout(p; rng = default_rng_value())
AlphaDropout(p; rng = default_rng())
A dropout layer. Used in
[Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
Expand Down Expand Up @@ -126,8 +126,8 @@ mutable struct AlphaDropout{F,R<:AbstractRNG}
new{typeof(p), typeof(rng)}(p, active, rng)
end
end
AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value())
AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng)
AlphaDropout(p, active) = AlphaDropout(p, active, default_rng())
AlphaDropout(p; rng = default_rng()) = AlphaDropout(p, nothing, rng)

@functor AlphaDropout
trainable(a::AlphaDropout) = (;)
Expand Down
28 changes: 4 additions & 24 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,12 @@ epseltype(x) = eps(float(eltype(x)))
"""
rng_from_array([x])
Create an instance of the RNG most appropriate for `x`.
The current defaults are:
- `x isa CuArray`: `CUDA.default_rng()`, else:
- `x isa AbstractArray`, or no `x` provided:
- Julia version is < 1.7: `Random.GLOBAL_RNG`
- Julia version is >= 1.7: `Random.default_rng()`
"""
rng_from_array(::AbstractArray) = default_rng_value()
rng_from_array(::CuArray) = CUDA.default_rng()

@non_differentiable rng_from_array(::Any)

if VERSION >= v"1.7"
default_rng_value() = Random.default_rng()
else
default_rng_value() = Random.GLOBAL_RNG
end

Create an instance of the RNG most appropriate for array `x`.
If `x isa CuArray` then this is `CUDA.default_rng()`,
otherwise `Random.default_rng()`.
"""
default_rng_value()
rng_from_array(x::AbstractArray) = NNlib._rng_from_array(x)

Create an instance of the default RNG depending on Julia's version.
- Julia version is < 1.7: `Random.GLOBAL_RNG`
- Julia version is >= 1.7: `Random.default_rng()`
"""
default_rng_value

"""
glorot_uniform([rng = default_rng_value()], size...; gain = 1) -> Array
Expand Down

0 comments on commit 9e61486

Please sign in to comment.