-
Notifications
You must be signed in to change notification settings - Fork 4
/
mlp.jl
69 lines (61 loc) · 2.77 KB
/
mlp.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
MLP(in_dims::Integer, hidden_dims::Dims{N}, activation=NNlib.relu; norm_layer=nothing,
dropout_rate::Real=0.0f0, dense_kwargs=(;), norm_kwargs=(;),
last_layer_activation=false) where {N}
Construct a multi-layer perceptron (MLP) with dense layers, optional normalization layers,
and dropout.
## Arguments
- `in_dims`: number of input dimensions
- `hidden_dims`: dimensions of the hidden layers
- `activation`: activation function (stacked after the normalization layer, if present
else after the dense layer)
## Keyword Arguments
- `norm_layer`: $(NORM_LAYER_DOC)
- `dropout_rate`: dropout rate (default: `0.0f0`)
- `dense_kwargs`: keyword arguments for the dense layers
- `norm_kwargs`: keyword arguments for the normalization layers
- `last_layer_activation`: set to `true` to apply the activation function to the last
layer
"""
@concrete struct MLP <: AbstractLuxWrapperLayer{:chain}
chain <: Lux.Chain
end
function MLP(in_dims::Integer, hidden_dims::Dims{N}, activation::F=NNlib.relu;
norm_layer::NF=nothing, dropout_rate::Real=0.0f0, last_layer_activation::Bool=false,
dense_kwargs=(;), norm_kwargs=(;)) where {N, F, NF}
@argcheck N > 0
layers = Vector{AbstractLuxLayer}(undef, N)
for (i, out_dims) in enumerate(hidden_dims)
act = i != N ? activation : (last_layer_activation ? activation : identity)
layers[i] = dense_norm_act_dropout(i, in_dims => out_dims, act, norm_layer,
dropout_rate, dense_kwargs, norm_kwargs)
in_dims = out_dims
end
inner_blocks = NamedTuple{ntuple(i -> Symbol(:block, i), N)}(layers)
return MLP(Lux.Chain(inner_blocks))
end
@concrete struct DenseNormActDropoutBlock <: AbstractLuxWrapperLayer{:block}
block
end
function dense_norm_act_dropout(
i::Integer, (in_dims, out_dims)::Pair{<:Integer, <:Integer}, activation::F,
norm_layer::NF, dropout_rate::Real, dense_kwargs, norm_kwargs) where {F, NF}
if iszero(dropout_rate)
if norm_layer === nothing
return DenseNormActDropoutBlock(Lux.Chain(;
dense=Lux.Dense(in_dims => out_dims, activation; dense_kwargs...)))
end
return DenseNormActDropoutBlock(Lux.Chain(;
dense=Lux.Dense(in_dims => out_dims; dense_kwargs...),
norm=norm_layer(i, out_dims, activation; norm_kwargs...)))
end
if norm_layer === nothing
return DenseNormActDropoutBlock(Lux.Chain(;
dense=Lux.Dense(in_dims => out_dims, activation; dense_kwargs...),
dropout=Lux.Dropout(dropout_rate)))
end
return DenseNormActDropoutBlock(Lux.Chain(;
dense=Lux.Dense(in_dims => out_dims; dense_kwargs...),
norm=norm_layer(i, out_dims, activation; norm_kwargs...),
dropout=Lux.Dropout(dropout_rate)))
end