Skip to content

Commit

Permalink
Merge #1496
Browse files Browse the repository at this point in the history
1496: Add Orthogonal initialization feature. r=DhairyaLGandhi a=SomTambe

As per issue #1431 I have added the Orthogonal matrix initialization feature. 

I will add the tests gradually. Just wondering what they can be.

### PR Checklist

- [x] Tests are added
- [x] Entry in NEWS.md
- [x] Documentation, if applicable
- [ ] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: SomTambe <[email protected]>
Co-authored-by: Som Tambe <[email protected]>
  • Loading branch information
3 people authored Feb 11, 2021
2 parents 3bc42f2 + 8f2e4ed commit 4c53672
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 1 deletion.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## v0.12.0

* Add [Orthogonal Matrix initialization](https://github.com/FluxML/Flux.jl/pull/1496) as described in [Exact solutions to the nonlinear dynamics of learning in deep linear neural networks](https://arxiv.org/abs/1312.6120).
* Added [Focal Loss function](https://github.com/FluxML/Flux.jl/pull/1489) to Losses module
* The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405).
* Dense and Conv layers no longer perform [implicit type conversion](https://github.com/FluxML/Flux.jl/pull/1394).
Expand Down
1 change: 1 addition & 0 deletions docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Flux.glorot_uniform
Flux.glorot_normal
Flux.kaiming_uniform
Flux.kaiming_normal
Flux.orthogonal
Flux.sparse_init
```

Expand Down
66 changes: 66 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,72 @@ end
kaiming_normal(dims...; kwargs...) = kaiming_normal(Random.GLOBAL_RNG, dims...; kwargs...)
kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...)

"""
orthogonal([rng=GLOBAL_RNG], dims...; gain = 1)
Return an `Array` of size `dims` which is a (semi) orthogonal matrix, as described in [1].
The input must have at least 2 dimensions.
For `length(dims) > 2`, a `prod(dims[1:(end - 1)])` by `dims[end]` orthogonal matrix
is computed before reshaping it to the original dimensions.
# Examples
```jldoctest; setup = :(using LinearAlgebra)
julia> W = Flux.orthogonal(5, 7);
julia> summary(W)
"5×7 Array{Float32,2}"
julia> W * W' ≈ I(5)
true
julia> W2 = Flux.orthogonal(7, 5);
julia> W2 * W2' ≈ I(7)
false
julia> W2' * W2 ≈ I(5)
true
julia> W3 = Flux.orthogonal(3, 3, 2, 4);
julia> transpose(reshape(W3, :, 4)) * reshape(W3, :, 4) ≈ I(4)
true
```
# See also
* kaiming initialization using normal distribution: [`kaiming_normal`](@ref Flux.kaiming_normal)
* kaiming initialization using uniform distribution: [`kaiming_uniform`](@ref Flux.kaiming_uniform)
* glorot initialization using normal distribution: [`glorot_normal`](@ref Flux.glorot_normal)
* glorot initialization using uniform distribution: [`glorot_uniform`](@ref Flux.glorot_uniform)
* sparse initialization: [`sparse_init`](@ref Flux.sparse_init)
# References
[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120
"""
function orthogonal(rng::AbstractRNG, rows::Integer, cols::Integer; gain = 1)
mat = rows > cols ? randn(rng, Float32, rows, cols) : randn(rng, Float32, cols, rows)

Q, R = LinearAlgebra.qr(mat)
Q = Array(Q) * sign.(LinearAlgebra.Diagonal(R))
if rows < cols
Q = transpose(Q)
end

return gain * Q
end

function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...)
dims = (d1, ds...)
rows = prod(dims[1:end-1])
cols = dims[end]
return reshape(orthogonal(rng, rows, cols; kwargs...), dims)
end

orthogonal(dims::Integer...; kwargs...) = orthogonal(Random.GLOBAL_RNG, dims...; kwargs...)
orthogonal(rng::AbstractRNG; init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...)

"""
sparse_init([rng=GLOBAL_RNG], dims...; sparsity, std = 0.01)
Expand Down
17 changes: 16 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Flux
using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, sparse_init, stack, unstack, Zeros
using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, orthogonal, sparse_init, stack, unstack, Zeros
using StatsBase: var, std
using Random
using Test
Expand Down Expand Up @@ -96,6 +96,21 @@ end
end
end

@testset "orthogonal" begin
# A matrix of dim = (m,n) with m > n should produce a QR decomposition. In the other case, the transpose should be taken to compute the QR decomposition.
for (rows,cols) in [(5,3),(3,5)]
v = orthogonal(rows, cols)
rows < cols ? (@test v * v' I(rows)) : (@test v' * v I(cols))
end
for mat in [(3,4,5),(2,2,5)]
v = orthogonal(mat...)
cols = mat[end]
rows = div(prod(mat),cols)
v = reshape(v, (rows,cols))
rows < cols ? (@test v * v' I(rows)) : (@test v' * v I(cols))
end
end

@testset "sparse_init" begin
# sparse_init should yield an error for non 2-d dimensions
# sparse_init should yield no zero elements if sparsity < 0
Expand Down

0 comments on commit 4c53672

Please sign in to comment.