Skip to content

Commit

Permalink
Update normalization implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 10, 2022
1 parent f7ba291 commit ccb9b72
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 158 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.4.8"
version = "0.4.9"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -10,6 +10,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand Down
59 changes: 58 additions & 1 deletion src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ ChainRulesCore.@non_differentiable replicate(::Any)
ChainRulesCore.@non_differentiable update_statistics(::Any, ::Any, ::Any, ::Any, ::Any,
::Any, ::Any)
ChainRulesCore.@non_differentiable generate_dropout_mask(::Any, ::Any, ::Any, ::Any)
ChainRulesCore.@non_differentiable _get_reshape_dims(::Any, ::Any)
ChainRulesCore.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any)
ChainRulesCore.@non_differentiable glorot_normal(::Any...)
ChainRulesCore.@non_differentiable glorot_uniform(::Any...)
Expand Down Expand Up @@ -31,7 +32,63 @@ function ChainRulesCore.rrule(::typeof(dropout), rng::AbstractRNG, x::AbstractAr
return (y, mask, rng), dropout_pullback
end

# Activation Rrules
# Utilities

function ChainRulesCore.rrule(::typeof(_reshape_into_proper_shape), ::Nothing, y)
function _reshape_into_proper_shape_pullback(dx)
return NoTangent(), NoTangent(), NoTangent()
end
return nothing, _reshape_into_proper_shape_pullback
end

function ChainRulesCore.rrule(::typeof(_reshape_into_proper_shape), x, y)
res = _reshape_into_proper_shape(x, y)
function _reshape_into_proper_shape_pullback(dx)
return NoTangent(), reshape(dx, size(x)), NoTangent()
end
return res, _reshape_into_proper_shape_pullback
end

function ChainRulesCore.rrule(::typeof(merge), nt1::NamedTuple{F1},
nt2::NamedTuple{F2}) where {F1, F2}
y = merge(nt1, nt2)
function merge_pullback(dy)
dnt1 = NamedTuple((f1 => (f1 in F2 ? NoTangent() : getproperty(dy, f1))
for f1 in F1))
dnt2 = NamedTuple((f2 => getproperty(dy, f2) for f2 in F2))
return (NoTangent(), dnt1, dnt2)
end
return y, merge_pullback
end

function ChainRulesCore.rrule(::typeof(vec), x::AbstractMatrix)
y = vec(x)
vec_pullback(dy) = NoTangent(), reshape(dy, size(x))
return y, vec_pullback
end

function ChainRulesCore.rrule(::typeof(convert), T::DataType, x::AbstractMatrix)
y = convert(T, x)
function convert_pullback(dy)
if dy isa NoTangent || dy isa ZeroTangent
dx = dy
else
dx = convert(typeof(x), dy)
end
return NoTangent(), NoTangent(), dx
end
return y, convert_pullback
end

function ChainRulesCore.rrule(::typeof(collect), v::Vector)
y = collect(v)
function collect_pullback(dy)
return NoTangent(), dy
end
return y, collect_pullback
end

# Activation rrules
function ChainRulesCore.rrule(::typeof(applyactivation), f::cudnnValidActivationTypes,
x::CuArray{T}) where {T <: CUDNNFloat}
mode = getCUDNNActivationMode(f)
Expand Down
25 changes: 15 additions & 10 deletions src/layers/normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ Use [`Lux.testmode`](@ref) during inference.
m = Chain(Dense(784 => 64), BatchNorm(64, relu), Dense(64 => 10), BatchNorm(10))
```
!!! warning
Passing a batch size of 1, during training will result in NaNs.
See also [`GroupNorm`](@ref)
"""
struct BatchNorm{affine, track_stats, F1, F2, F3, N} <:
Expand All @@ -90,9 +94,13 @@ function BatchNorm(chs::Int, activation=identity; init_bias=zeros32, init_scale=
end

function initialparameters(rng::AbstractRNG, l::BatchNorm{affine}) where {affine}
return affine ? (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) :
NamedTuple()
if affine
return (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs))
else
return (scale=nothing, bias=nothing)
end
end

function initialstates(rng::AbstractRNG,
l::BatchNorm{affine, track_stats}) where {affine, track_stats}
return if track_stats
Expand All @@ -109,9 +117,6 @@ function statelength(l::BatchNorm{affine, track_stats}) where {affine, track_sta
end

function (BN::BatchNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N}
@assert size(x, N - 1) == BN.chs
@assert !istraining(st)||size(x, N) > 1 "During `training`, `BatchNorm` can't handle Batch Size == 1"

x_normalized, xmean, xvar = normalization(x, st.running_mean, st.running_var, ps.scale,
ps.bias, BN.activation,
collect([1:(N - 2); N]), st.training,
Expand Down Expand Up @@ -278,8 +283,11 @@ function GroupNorm(chs::Integer, groups::Integer, activation=identity; init_bias
end

function initialparameters(rng::AbstractRNG, l::GroupNorm{affine}) where {affine}
return affine ? (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) :
NamedTuple()
if affine
return (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs))
else
return (scale=nothing, bias=nothing)
end
end

function initialstates(rng::AbstractRNG,
Expand All @@ -300,9 +308,6 @@ end

function (GN::GroupNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N}
sz = size(x)
@assert N > 2
@assert sz[N - 1] == GN.chs

x_ = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ GN.groups, GN.groups, sz[N])

x_normalized, xmean, xvar = normalization(x_, st.running_mean, st.running_var, ps.scale,
Expand Down
23 changes: 12 additions & 11 deletions src/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ Performs BatchNorm/GroupNorm/InstanceNorm based on input configuration
running_var::Union{Nothing, AbstractVector{T}},
scale::Union{Nothing, AbstractVector{T}},
bias::Union{Nothing, AbstractVector{T}}, activation,
reduce_dims, t::Val, momentum::T=T(0.1), epsilon::T=T(1e-5);
kwargs...) where {T, N}
reduce_dims, t::Val, momentum::T=T(0.1),
epsilon::T=T(1e-5)) where {T, N}
running_mean_reshaped = _reshape_into_proper_shape(running_mean, x)
running_var_reshaped = _reshape_into_proper_shape(running_var, x)
scale_reshaped = _reshape_into_proper_shape(scale, x)
Expand All @@ -44,15 +44,16 @@ Performs BatchNorm/GroupNorm/InstanceNorm based on input configuration
scale_reshaped,
bias_reshaped, activation,
reduce_dims, t, momentum,
epsilon; kwargs...)
epsilon)
return x_norm, _safe_vec(running_mean_), _safe_vec(running_var_)
end

@generated function normalization_forward(x::AbstractArray{T, N}, running_mean::RM,
running_var::RV, scale::S, bias::B, activation::A,
reduce_dims, ::Val{training},
momentum::T=T(0.1f0), epsilon::T=T(1.0f-5);
kwargs...) where {RM, RV, S, B, T, N, A, training}
momentum::T=T(0.1f0),
epsilon::T=T(1.0f-5)) where {RM, RV, S, B, T, N,
A, training}
calls = []
if !training
if RM == Nothing
Expand All @@ -79,16 +80,16 @@ end

expr = if S != Nothing
if A == typeof(identity)
:(result = @. scale * (x - batchmean) / sqrt(batchvar + epsilon) + bias)
:(result = scale .* (x .- batchmean) ./ sqrt.(batchvar .+ epsilon) .+ bias)
else
:(result = @. activation(scale * (x - batchmean) / sqrt(batchvar + epsilon) +
bias))
:(result = activation.(scale .* (x .- batchmean) ./
sqrt.(batchvar .+ epsilon) .+ bias))
end
else
if A == typeof(identity)
:(result = @. (x - batchmean) / sqrt(batchvar + epsilon))
:(result = (x .- batchmean) ./ sqrt.(batchvar .+ epsilon))
else
:(result = @. activation((x - batchmean) / sqrt(batchvar + epsilon)))
:(result = activation.((x .- batchmean) ./ sqrt.(batchvar .+ epsilon)))
end
end
push!(calls, expr)
Expand All @@ -115,7 +116,7 @@ end
@inline function generate_dropout_mask(rng::AbstractRNG, x, p, q; dims=:)
realfptype = float(real(eltype(x)))
y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims)))
@. y = _dropout_kernel(y, p, q)
y .= _dropout_kernel.(y, p, q)
return y
end

Expand Down
13 changes: 13 additions & 0 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,16 @@ end

@test gs_x_1 == gs_x_2
end

@testset "_reshape_into_proper_shape" begin
x = randn(rng, Float32, 3, 2)
y = randn(rng, Float32, 2, 2, 6, 2)

@test size(Lux._reshape_into_proper_shape(x, y)) == (1, 1, 6, 1)
@inferred Lux._reshape_into_proper_shape(x, y)

gs_1 = Zygote.gradient(x -> sum(Lux._reshape_into_proper_shape(x, y)), x)[1]
gs_2 = Zygote.gradient(x -> sum(reshape(x, (1, 1, 6, 1))), x)[1]

@test gs_1 == gs_2
end
Loading

0 comments on commit ccb9b72

Please sign in to comment.