Skip to content

Commit

Permalink
Improve docs for initialisation (#1912)
Browse files Browse the repository at this point in the history
* add some methods

* init documentation

* fixup

* note for identity_init

* keyword bug

* test that Float64 keywords don't promote

* branch once in orthogonal

* can't use ;;; on 1.6
  • Loading branch information
mcabbott authored Mar 27, 2022
1 parent dc6f286 commit 57beb23
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 184 deletions.
41 changes: 31 additions & 10 deletions docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,58 @@
Flux provides utility functions which can be used to initialize your layers
or to regularly execute callback functions.

## Layer Initialization
## Layer Initialisation

These are primarily useful if you are planning to write your own layers.
Flux initializes convolutional layers and recurrent cells with `glorot_uniform`
by default.
To change the default on an applicable layer, pass the desired function with the
`init` keyword. For example:
Flux initialises convolutional layers and recurrent cells with `glorot_uniform` by default.
Most layers accept a function as an `init` keyword, which replaces this default. For example:

```jldoctest; setup = :(using Flux)
julia> conv = Conv((3, 3), 1 => 8, relu; init=Flux.glorot_normal)
Conv((3, 3), 1 => 8, relu) # 80 parameters
julia> conv = Conv((3, 3), 3 => 2, relu; init=Flux.glorot_normal)
Conv((3, 3), 3 => 2, relu) # 56 parameters
julia> conv.bias
2-element Vector{Float32}:
0.0
0.0
```

Note that `init` creates the weight array, but not the bias vector.

Many of the initialisation functions accept keywords such as `gain`,
and a random number generator. To make it easy to pass these to layers,
there are methods which return a function:

```jldoctest; setup = :(using Flux, Random)
julia> Dense(4 => 5, tanh; init=Flux.glorot_uniform(gain=2))
Dense(4 => 5, tanh) # 25 parameters
julia> Dense(4 => 5, tanh; init=Flux.randn32(MersenneTwister(1)))
Dense(4 => 5, tanh) # 25 parameters
```

```@docs
Flux.glorot_uniform
Flux.glorot_normal
Flux.kaiming_uniform
Flux.kaiming_normal
Flux.truncated_normal
Flux.orthogonal
Flux.sparse_init
Flux.identity_init
Flux.ones32
Flux.rand32
```

## Changing the type of model parameters

The default `eltype` for models is `Float32` since models are often trained/run on GPUs.
The `eltype` of model `m` can be changed to `Float64` by `f64(m)`:

```@docs
Flux.f64
Flux.f32
```

The default `eltype` for models is `Float32` since models are often trained/run on GPUs. The `eltype` of model `m` can be changed to `Float64` by `f64(m)`, or to `Float32` by `f32(m)`.

## Model Building

Flux provides some utility functions to help you generate models in an automated fashion.
Expand Down
6 changes: 4 additions & 2 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,16 @@ paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)
"""
f32(m)
Convert the `eltype` of model's parameters to `Float32`.
Converts the `eltype` of model's parameters to `Float32` (which is Flux's default).
Recurses into structs marked with [`@functor`](@ref).
"""
f32(m) = paramtype(Float32, m)

"""
f64(m)
Convert the `eltype` of model's parameters to `Float64`.
Converts the `eltype` of model's parameters to `Float64`.
Recurses into structs marked with [`@functor`](@ref).
"""
f64(m) = paramtype(Float64, m)

Expand Down
Loading

0 comments on commit 57beb23

Please sign in to comment.