Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Truncated normal initialisation for weights #1877

Merged
merged 12 commits into from
Feb 19, 2022
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Flux
# Zero Flux Given

using Base: tail
using Statistics, Random, LinearAlgebra
using Statistics, Random, LinearAlgebra, SpecialFunctions
using Zygote, MacroTools, ProgressLogging, Reexport
using MacroTools: @forward
@reexport using NNlib
Expand Down
64 changes: 53 additions & 11 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ epseltype(x) = eps(float(eltype(x)))
Create an instance of the RNG most appropriate for `x`.
The current defaults are:
- `x isa AbstractArray`
- Julia version is < 1.7: `Random.GLOBAL_RNG`
- Julia version is < 1.7: `rng_from_array()`
theabhirath marked this conversation as resolved.
Show resolved Hide resolved
- Julia version is >= 1.7: `Random.default_rng()`
- `x isa CuArray`: `CUDA.default_rng()`
When `x` is unspecified, it is assumed to be a `AbstractArray`.
Expand Down Expand Up @@ -81,7 +81,7 @@ julia> Flux.glorot_uniform(2, 3)
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
"""
glorot_uniform(rng::AbstractRNG, dims...) = (rand(rng, Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
glorot_uniform(dims...) = glorot_uniform(Random.GLOBAL_RNG, dims...)
glorot_uniform(dims...) = glorot_uniform(rng_from_array(), dims...)
glorot_uniform(rng::AbstractRNG) = (dims...) -> glorot_uniform(rng, dims...)

"""
Expand Down Expand Up @@ -114,7 +114,7 @@ julia> Flux.glorot_normal(3, 2)
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
"""
glorot_normal(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
glorot_normal(dims...) = glorot_normal(Random.GLOBAL_RNG, dims...)
glorot_normal(dims...) = glorot_normal(rng_from_array(), dims...)
glorot_normal(rng::AbstractRNG) = (dims...) -> glorot_normal(rng, dims...)

"""
Expand Down Expand Up @@ -151,7 +151,7 @@ function kaiming_uniform(rng::AbstractRNG, dims...; gain = √2)
return (rand(rng, Float32, dims...) .- 0.5f0) .* 2bound
end

kaiming_uniform(dims...; kwargs...) = kaiming_uniform(Random.GLOBAL_RNG, dims...; kwargs...)
kaiming_uniform(dims...; kwargs...) = kaiming_uniform(rng_from_array(), dims...; kwargs...)
kaiming_uniform(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...)

"""
Expand Down Expand Up @@ -188,9 +188,50 @@ function kaiming_normal(rng::AbstractRNG, dims...; gain = √2f0)
return randn(rng, Float32, dims...) .* std
end

kaiming_normal(dims...; kwargs...) = kaiming_normal(Random.GLOBAL_RNG, dims...; kwargs...)
kaiming_normal(dims...; kwargs...) = kaiming_normal(rng_from_array(), dims...; kwargs...)
kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...)

"""
truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2., hi = 2.)

Return an `Array` of size `dims` where each element is drawn from a truncated normal distribution.
The values are generated by using a truncated uniform distribution and then using the inverse CDF
for the normal distribution. The method used for generating the random values works best when
`lo ≤ mean ≤ hi`.
theabhirath marked this conversation as resolved.
Show resolved Hide resolved

# Examples
```jldoctest; setup = :(using Random; Random.seed!(0))
julia> Flux.truncated_normal(3, 2)
3×2 Matrix{Float32}:
-0.0340547 -1.35207
-0.22757 -0.793773
-1.75771 1.01801
```

# References
[1] Burkardt, John. "The Truncated Normal Distribution"
[PDF](https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf).
Department of Scientific Computing website.
"""
function truncated_normal(rng::AbstractRNG, dims...; mean = 0, std = 1, lo = -2, hi = 2)
norm_cdf(x) = 0.5 * (1 + erf(x/√2))
if (mean < lo - 2 * std) || (mean > hi + 2 * std)
@warn "Mean is more than 2 std from [a, b] in truncated_normal. The distribution of values may be incorrect." maxlog=1
theabhirath marked this conversation as resolved.
Show resolved Hide resolved
end
l = norm_cdf((lo - mean) / std)
u = norm_cdf((hi - mean) / std)
xs = rand(rng, Float32, dims...)
broadcast!(xs, xs) do x
x = x * 2(u - l) + (2l - 1)
x = erfinv(x)
x = clamp.(x * std * √2f0 + mean, lo, hi)
theabhirath marked this conversation as resolved.
Show resolved Hide resolved
end
return xs
end

truncated_normal(dims...; kwargs...) = truncated_normal(rng_from_array(), dims...; kwargs...)
truncated_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...)

"""
orthogonal([rng=GLOBAL_RNG], dims...; gain = 1)

Expand Down Expand Up @@ -232,6 +273,7 @@ true
* sparse initialization: [`sparse_init`](@ref Flux.sparse_init)

# References

[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120

"""
Expand All @@ -254,7 +296,7 @@ function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...)
return reshape(orthogonal(rng, rows, cols; kwargs...), dims)
end

orthogonal(dims::Integer...; kwargs...) = orthogonal(Random.GLOBAL_RNG, dims...; kwargs...)
orthogonal(dims::Integer...; kwargs...) = orthogonal(rng_from_array(), dims...; kwargs...)
orthogonal(rng::AbstractRNG; init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...)

"""
Expand Down Expand Up @@ -298,7 +340,7 @@ function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01)
return mapslices(shuffle, sparse_array, dims=1)
end

sparse_init(dims...; kwargs...) = sparse_init(Random.GLOBAL_RNG, dims...; kwargs...)
sparse_init(dims...; kwargs...) = sparse_init(rng_from_array(), dims...; kwargs...)
sparse_init(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...)

"""
Expand Down Expand Up @@ -382,7 +424,7 @@ function identity_init(dims...; gain=1, shift=0)
end

identity_init(::AbstractRNG, dims...; kwargs...) = identity_init(dims...; kwargs...)
identity_init(; init_kwargs...) = identity_init(Random.GLOBAL_RNG; init_kwargs...)
identity_init(; init_kwargs...) = identity_init(rng_from_array(); init_kwargs...)
identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...)

ones32(dims...) = Base.ones(Float32, dims...)
Expand Down Expand Up @@ -437,8 +479,8 @@ end

Flatten a model's parameters into a single weight vector.

julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
julia> m = Chain(Dense(10, 5, std), Dense(5, 2), softmax)
Chain(Dense(10, 5, std), Dense(5, 2), softmax)

julia> θ, re = destructure(m);

Expand All @@ -451,7 +493,7 @@ The second return value `re` allows you to reconstruct the original network afte
modifications to the weight vector (for example, with a hypernetwork).

julia> re(θ .* 2)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Chain(Dense(10, 5, std), Dense(5, 2), softmax)
theabhirath marked this conversation as resolved.
Show resolved Hide resolved
"""
function destructure(m)
xs = Zygote.Buffer([])
Expand Down
21 changes: 19 additions & 2 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
using Flux
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
kaiming_normal, kaiming_uniform, orthogonal,
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal,
sparse_init, stack, unstack, Zeros, batch, unbatch,
unsqueeze
unsqueeze, params
using StatsBase: var, std
using Statistics, LinearAlgebra
using Random
using Test

Expand Down Expand Up @@ -146,6 +147,22 @@ end
end
end

@testset "truncated_normal" begin
size = (100, 100, 100)
for (μ, σ, lo, hi) in [(0., 1, -2, 2), (0, 1, -4., 4)]
v = truncated_normal(size; mean = μ, std = σ, lo, hi)
@test isapprox(mean(v), μ; atol = 1f-2)
@test isapprox(minimum(v), lo; atol = 1f-2)
@test isapprox(maximum(v), hi; atol = 1f-2)
@test eltype(v) == Float32
end
for (μ, σ, lo, hi) in [(6, 2, -100., 100), (7., 10, -100, 100)]
v = truncated_normal(size...; mean = μ, std = σ, lo, hi)
@test isapprox(std(v), σ; atol = 1f-2)
@test eltype(v) == Float32
end
end

@testset "partial_application" begin
big = 1e9

Expand Down