Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Truncated normal initialisation for weights #1877

Merged
merged 12 commits into from
Feb 19, 2022
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ been removed in favour of MLDatasets.jl.
* `Dropout` gained improved compatibility with Int and Complex arrays and is now twice-differentiable.
* Many utily functions and the `DataLoader` are [now provided by MLUtils.jl](https://github.com/FluxML/Flux.jl/pull/1874).
* The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`.
* Added truncated normal initialisation of weights.
theabhirath marked this conversation as resolved.
Show resolved Hide resolved

## v0.12.10
* `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Flux
# Zero Flux Given

using Base: tail
using Statistics, Random, LinearAlgebra
using Statistics, Random, LinearAlgebra, SpecialFunctions
using Zygote, MacroTools, ProgressLogging, Reexport
using MacroTools: @forward
@reexport using NNlib
Expand Down
68 changes: 59 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ julia> Flux.glorot_uniform(2, 3)
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
"""
glorot_uniform(rng::AbstractRNG, dims...) = (rand(rng, Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
glorot_uniform(dims...) = glorot_uniform(Random.GLOBAL_RNG, dims...)
glorot_uniform(dims...) = glorot_uniform(rng_from_array(), dims...)
glorot_uniform(rng::AbstractRNG) = (dims...) -> glorot_uniform(rng, dims...)

"""
Expand Down Expand Up @@ -114,7 +114,7 @@ julia> Flux.glorot_normal(3, 2)
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
"""
glorot_normal(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
glorot_normal(dims...) = glorot_normal(Random.GLOBAL_RNG, dims...)
glorot_normal(dims...) = glorot_normal(rng_from_array(), dims...)
glorot_normal(rng::AbstractRNG) = (dims...) -> glorot_normal(rng, dims...)

"""
Expand Down Expand Up @@ -151,7 +151,7 @@ function kaiming_uniform(rng::AbstractRNG, dims...; gain = √2)
return (rand(rng, Float32, dims...) .- 0.5f0) .* 2bound
end

kaiming_uniform(dims...; kwargs...) = kaiming_uniform(Random.GLOBAL_RNG, dims...; kwargs...)
kaiming_uniform(dims...; kwargs...) = kaiming_uniform(rng_from_array(), dims...; kwargs...)
kaiming_uniform(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...)

"""
Expand Down Expand Up @@ -188,9 +188,58 @@ function kaiming_normal(rng::AbstractRNG, dims...; gain = √2f0)
return randn(rng, Float32, dims...) .* std
end

kaiming_normal(dims...; kwargs...) = kaiming_normal(Random.GLOBAL_RNG, dims...; kwargs...)
kaiming_normal(dims...; kwargs...) = kaiming_normal(rng_from_array(), dims...; kwargs...)
kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...)

"""
truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2, hi = 2)

Return an `Array{Float32}` of size `dims` where each element is drawn from a truncated normal distribution.
The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(dims...))`.

The values are generated by sampling a Uniform(0, 1) (`rand()`) and then
applying the inverse CDF of the truncated normal distribution
(see the references for more info).
This method works best when `lo ≤ mean ≤ hi`.

# Examples
```jldoctest
julia> using Statistics

julia> Flux.truncated_normal(3, 4) |> summary
"3×4 Matrix{Float32}"

julia> round.(extrema(Flux.truncated_normal(10^6)); digits=3)
(-2.0f0, 2.0f0)

julia> round(std(Flux.truncated_normal(10^6; lo = -100, hi = 100)))
1.0f0
```

# References
[1] Burkardt, John. "The Truncated Normal Distribution"
[PDF](https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf).
Department of Scientific Computing website.
"""
function truncated_normal(rng::AbstractRNG, dims...; mean = 0, std = 1, lo = -2, hi = 2)
norm_cdf(x) = 0.5 * (1 + erf(x/√2))
if (mean < lo - 2 * std) || (mean > hi + 2 * std)
@warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1
end
l = norm_cdf((lo - mean) / std)
u = norm_cdf((hi - mean) / std)
xs = rand(rng, Float32, dims...)
broadcast!(xs, xs) do x
x = x * 2(u - l) + (2l - 1)
x = erfinv(x)
x = clamp(x * std * √2 + mean, lo, hi)
end
return xs
end

truncated_normal(dims...; kwargs...) = truncated_normal(rng_from_array(), dims...; kwargs...)
truncated_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...)

"""
orthogonal([rng=GLOBAL_RNG], dims...; gain = 1)

Expand Down Expand Up @@ -232,6 +281,7 @@ true
* sparse initialization: [`sparse_init`](@ref Flux.sparse_init)

# References

[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120

"""
Expand All @@ -254,7 +304,7 @@ function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...)
return reshape(orthogonal(rng, rows, cols; kwargs...), dims)
end

orthogonal(dims::Integer...; kwargs...) = orthogonal(Random.GLOBAL_RNG, dims...; kwargs...)
orthogonal(dims::Integer...; kwargs...) = orthogonal(rng_from_array(), dims...; kwargs...)
orthogonal(rng::AbstractRNG; init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...)

"""
Expand Down Expand Up @@ -298,7 +348,7 @@ function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01)
return mapslices(shuffle, sparse_array, dims=1)
end

sparse_init(dims...; kwargs...) = sparse_init(Random.GLOBAL_RNG, dims...; kwargs...)
sparse_init(dims...; kwargs...) = sparse_init(rng_from_array(), dims...; kwargs...)
sparse_init(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...)

"""
Expand Down Expand Up @@ -382,7 +432,7 @@ function identity_init(dims...; gain=1, shift=0)
end

identity_init(::AbstractRNG, dims...; kwargs...) = identity_init(dims...; kwargs...)
identity_init(; init_kwargs...) = identity_init(Random.GLOBAL_RNG; init_kwargs...)
identity_init(; init_kwargs...) = identity_init(rng_from_array(); init_kwargs...)
identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...)

ones32(dims...) = Base.ones(Float32, dims...)
Expand Down Expand Up @@ -437,8 +487,8 @@ end

Flatten a model's parameters into a single weight vector.

julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
julia> m = Chain(Dense(10, 5, std), Dense(5, 2), softmax)
Chain(Dense(10, 5, std), Dense(5, 2), softmax)

julia> θ, re = destructure(m);

Expand Down
21 changes: 19 additions & 2 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
using Flux
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
kaiming_normal, kaiming_uniform, orthogonal,
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal,
sparse_init, stack, unstack, Zeros, batch, unbatch,
unsqueeze
unsqueeze, params
using StatsBase: var, std
using Statistics, LinearAlgebra
using Random
using Test

Expand Down Expand Up @@ -146,6 +147,22 @@ end
end
end

@testset "truncated_normal" begin
size = (100, 100, 100)
for (μ, σ, lo, hi) in [(0., 1, -2, 2), (0, 1, -4., 4)]
v = truncated_normal(size; mean = μ, std = σ, lo, hi)
@test isapprox(mean(v), μ; atol = 1f-1)
@test isapprox(minimum(v), lo; atol = 1f-1)
@test isapprox(maximum(v), hi; atol = 1f-1)
@test eltype(v) == Float32
end
for (μ, σ, lo, hi) in [(6, 2, -100., 100), (7., 10, -100, 100)]
v = truncated_normal(size...; mean = μ, std = σ, lo, hi)
@test isapprox(std(v), σ; atol = 1f-1)
@test eltype(v) == Float32
end
end

@testset "partial_application" begin
big = 1e9

Expand Down