Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add kaiming initialization #138

Merged
merged 5 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/api/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Lux.gpu
```@docs
Lux.glorot_normal
Lux.glorot_uniform
Lux.kaiming_normal
Lux.kaiming_uniform
Lux.ones32
Lux.zeros32
```
Expand Down
2 changes: 2 additions & 0 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ ChainRulesCore.@non_differentiable _get_reshape_dims(::Any, ::Any)
ChainRulesCore.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any)
ChainRulesCore.@non_differentiable glorot_normal(::Any...)
ChainRulesCore.@non_differentiable glorot_uniform(::Any...)
ChainRulesCore.@non_differentiable kaiming_normal(::Any...)
ChainRulesCore.@non_differentiable kaiming_uniform(::Any...)
ChainRulesCore.@non_differentiable check_use_cuda()
ChainRulesCore.@non_differentiable istraining(::Any)
ChainRulesCore.@non_differentiable _get_norm_except_dims(::Any, ::Any)
Expand Down
34 changes: 34 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,40 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1)
return randn(rng, Float32, dims...) .* std
end

"""
kaiming_uniform(rng::AbstractRNG, size...; gain = √2f0)

Return an `Array{Float32}` of the given `size` containing random numbers drawn from a
uniform distribution on the interval `[-x, x]`, where `x = gain * sqrt(3/fan_in)`.

# References

[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on
imagenet classification." _Proceedings of the IEEE international conference on computer
vision_. 2015.
"""
function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0)
bound = Float32(√3.0f0 * gain / sqrt(first(_nfan(dims...))))
return (rand(rng, Float32, dims...) .- 0.5f0) .* 2bound
end

"""
kaiming_normal(rng::AbstractRNG, size...; gain = √2f0)

Return an `Array{Float32}` of the given `size` containing random numbers taken from a normal
distribution standard deviation `gain / sqrt(fan_in)`

# References

[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on
imagenet classification." _Proceedings of the IEEE international conference on computer
vision_. 2015.
"""
function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real=√2.0f0)
std = Float32(gain / sqrt(first(_nfan(dims...))))
return randn(rng, Float32, dims...) .* std
end

# PRNG Handling
"""
replicate(rng::AbstractRNG)
Expand Down
18 changes: 18 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Lux, ComponentArrays, CUDA, Functors, ReverseDiff, Random, Optimisers, Zygote, Test
using Statistics: std

include("test_utils.jl")

Expand Down Expand Up @@ -29,6 +30,23 @@ end
end
end

@testset "kaiming" begin
# kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)]
# and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out)
for (n_in, n_out) in [(100, 100), (100, 400)]
v = Lux.kaiming_uniform(rng, n_in, n_out)
σ2 = sqrt(6 / n_out)
@test -1σ2 < minimum(v) < -0.9σ2
@test 0.9σ2 < maximum(v) < 1σ2

v = Lux.kaiming_normal(rng, n_in, n_out)
σ2 = sqrt(2 / n_out)
@test 0.9σ2 < std(v) < 1.1σ2
end
@test eltype(Lux.kaiming_uniform(rng, 3, 4; gain=1.5)) == Float32
@test eltype(Lux.kaiming_normal(rng, 3, 4; gain=1.5)) == Float32
end

@testset "istraining" begin
@test Lux.istraining(Val(true))
@test !Lux.istraining(Val(false))
Expand Down