Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Allow fusing activation into normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 19, 2024
1 parent eaebfbe commit 8271ed8
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 28 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.13"
version = "0.3.14"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
6 changes: 3 additions & 3 deletions ext/LuxLibTrackercuDNNExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ const TR_BNParamType = Union{

function LuxLib.batchnorm(
x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType, bias::TR_BNParamType,
running_mean::TR_BNParamType, running_var::TR_BNParamType;
momentum::Real, training::Val, epsilon::Real)
running_mean::TR_BNParamType, running_var::TR_BNParamType,
σ::F=identity; momentum::Real, training::Val, epsilon::Real) where {F}
rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training)
# NOTE: The following returns a tracked tuple so we can't do `first` on it
x_ = LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1]
return x_, (; running_mean=rm, running_var=rv)
return LuxLib.fast_activation!!(σ, x_), (; running_mean=rm, running_var=rv)
end

for RM in (:TrackedVector, :Nothing, :AbstractVector),
Expand Down
6 changes: 3 additions & 3 deletions ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ const CUDNN_BN_ARRAY_TYPE = Union{
const BNParamType = Union{Nothing, CuVector{<:Union{Float32, Float64}}}

function LuxLib.batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType,
running_mean::BNParamType, running_var::BNParamType;
momentum::Real, training::Val, epsilon::Real)
running_mean::BNParamType, running_var::BNParamType, σ::F=identity;
momentum::Real, training::Val, epsilon::Real) where F
rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training)
x_ = first(LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training))
return x_, (; running_mean=rm, running_var=rv)
return LuxLib.fast_activation!!(σ, x_), (; running_mean=rm, running_var=rv)
end

@inline function LuxLib.batchnorm_cudnn(
Expand Down
3 changes: 3 additions & 0 deletions src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ include("impl/groupnorm.jl")
include("impl/normalization.jl")
include("impl/fused_dense.jl")
include("impl/fused_conv.jl")
include("impl/fast_activation.jl")

# User Facing
include("api/batchnorm.jl")
Expand All @@ -38,8 +39,10 @@ include("api/instancenorm.jl")
include("api/layernorm.jl")
include("api/dense.jl")
include("api/conv.jl")
include("api/fast_activation.jl")

export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout
export fused_dense_bias_activation, fused_conv_bias_activation
export fast_activation!!

end
10 changes: 6 additions & 4 deletions src/api/batchnorm.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@doc doc"""
batchnorm(x, scale, bias, running_mean, running_var; momentum, epsilon, training)
batchnorm(x, scale, bias, running_mean, running_var, σ=identity; momentum, epsilon,
training)
Batch Normalization. For details see [1].
Expand All @@ -14,6 +15,7 @@ accordingly.
- `bias`: Bias factor (``\beta``) (can be `nothing`)
- `running_mean`: Running mean (can be `nothing`)
- `running_var`: Running variance (can be `nothing`)
- `σ`: Activation function (default: `identity`)
## Keyword Arguments
Expand Down Expand Up @@ -41,11 +43,11 @@ fallback is used which is not highly optimized.
function batchnorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector},
bias::Union{Nothing, <:AbstractVector},
running_mean::Union{Nothing, <:AbstractVector},
running_var::Union{Nothing, <:AbstractVector};
momentum::Real, training::Val, epsilon::Real) where {N}
running_var::Union{Nothing, <:AbstractVector}, σ::F=identity;
momentum::Real, training::Val, epsilon::Real) where {F, N}
x_, xm, xv = _normalization(x, _drop_forwarddiff_partials(running_mean),
_drop_forwarddiff_partials(running_var), scale, bias,
_get_batchnorm_reduce_dims(x), training, momentum, epsilon)
_get_batchnorm_reduce_dims(x), training, momentum, epsilon, σ)
stats = (; running_mean=_drop_forwarddiff_partials(xm),
running_var=_drop_forwarddiff_partials(xv))
return (x_, stats)
Expand Down
26 changes: 26 additions & 0 deletions src/api/fast_activation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
fast_activation!!(σ::F, x) where {F}
Compute `σ.(x)` with the best possible implementation available. If it is possible to
rewrite `x` in-place, it does so. If `x` is an immutable array, it falls back to the
generic implementation.
!!! note
This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be
done by the user if needed.
## Arguments
- `σ`: Activation function
- `x`: Input array
## Returns
- Output Array with the same size as `x`
"""
@inline function fast_activation!!::F, x::AbstractArray) where {F}
σ === identity && return x
ArrayInterface.can_setindex(x) && __fast_activation_impl!(σ, x)
return σ.(x)
end
3 changes: 1 addition & 2 deletions src/api/layernorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,5 @@ end

function layernorm(x::AbstractArray, ::Nothing, ::Nothing; dims, epsilon)
_mean = mean(x; dims)
rstd = 1 ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon)
return (x .- _mean) .* rstd
return (x .- _mean) ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon)
end
44 changes: 44 additions & 0 deletions src/impl/fast_activation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Specialized Implementation based off NNlib._fast_broadcast with added logic from
# ArrayInterface
# If we enter here, we already know that we can setindex into the array
@inline function __fast_activation_impl!::F, x::AbstractArray) where {F}
if ArrayInterface.fast_scalar_indexing(x)
bc = Broadcast.instantiate(Broadcast.broadcasted(σ, x))
@simd ivdep for I in eachindex(bc)
@inbounds x[I] = bc[I]
end
else
@. x = σ(x)
end
return x
end

function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
::typeof(__fast_activation_impl!), σ::F, x::AbstractArray{T}) where {F, T}
σ === identity && return x, @closure->(CRC.NoTangent(), CRC.NoTangent(), Δ))

if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber}))
__fast_activation_impl!(σ, x)
∇__fast_activation_impl_no_cached = @closure Δ -> begin
∂x = only_derivative.(x, σ, NotaNumber()) .* CRC.unthunk(Δ)
return CRC.NoTangent(), CRC.NoTangent(), ∂x
end
return x, ∇__fast_activation_impl_no_cached
end

if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T}))
y = @. σ(x)
∇__fast_activation_impl_cached_crc = @closure Δ -> begin
∂z = only_derivative.(y, σ, x) .* CRC.unthunk(Δ)
return CRC.NoTangent(), CRC.NoTangent(), ∂z
end
return z, ∇__fast_activation_impl_cached_crc
end

y, pb_f = CRC.rrule_via_ad(cfg, broadcast, σ, x)
∇__fast_activation_impl_cached = @closure Δ -> begin
_, _, ∂x = pb_f(Δ)
return CRC.NoTangent(), CRC.NoTangent(), ∂x
end
return y, ∇__fast_activation_impl_cached
end
2 changes: 1 addition & 1 deletion src/impl/fused_conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
end

# This implementation is different from `conv_bias_act` in that it defines the proper rrules
# and fuses operations into a single kernel if it is possible. Unfortinately there are
# and fuses operations into a single kernel if it is possible. Unfortunately there are
# certain configurations where CUDNN allows caching intermediates, but we don't do that rn.

@inline function __fused_conv_bias_activation_impl(
Expand Down
25 changes: 13 additions & 12 deletions src/impl/normalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,37 +41,38 @@ end
return Expr(:block, calls...)
end

@generated function _affine_normalize(x::AbstractArray, xmean::ST, xvar::ST,
scale::A, bias::A, epsilon::Real) where {ST, A}
@generated function _affine_normalize(act::F, x::AbstractArray, xmean::ST, xvar::ST,
scale::A, bias::A, epsilon::Real) where {F, ST, A}
if A != Nothing
return quote
x_norm = (x .- xmean) ./ sqrt.(xvar .+ epsilon)
return scale .* x_norm .+ bias
x_norm = @.((x - xmean)/sqrt(xvar + epsilon))
return @. act(scale * x_norm + bias)
end
else
return :(return (x .- xmean) ./ sqrt.(xvar .+ epsilon))
return :(return @. act((x - xmean) / sqrt(xvar + epsilon)))
end
end

function _normalization_impl(x::AbstractArray, running_mean::R, running_var::R,
scale::A, bias::A, r::Val{reduce_dims}, training::Val,
momentum::Union{Real, Nothing}, epsilon::Real) where {R, A, reduce_dims}
function _normalization_impl(
x::AbstractArray, running_mean::R, running_var::R, scale::A, bias::A,
r::Val{reduce_dims}, training::Val, momentum::Union{Real, Nothing},
epsilon::Real, act::F=identity) where {R, A, reduce_dims, F}
_stats = _get_batch_statistics(x, running_mean, running_var, r, training, momentum)
(batchmean, batchvar), (running_mean, running_var) = _stats
x_norm = _affine_normalize(x, batchmean, batchvar, scale, bias, epsilon)
x_norm = _affine_normalize(act, x, batchmean, batchvar, scale, bias, epsilon)
return (x_norm, running_mean, running_var)
end

function _normalization(x::AbstractArray, running_mean::Union{Nothing, <:AbstractVector},
running_var::Union{Nothing, <:AbstractVector},
scale::Union{Nothing, <:AbstractVector},
bias::Union{Nothing, <:AbstractVector}, reduce_dims::Val,
training::Val, momentum::Union{Real, Nothing}, epsilon::Real)
bias::Union{Nothing, <:AbstractVector}, reduce_dims::Val, training::Val,
momentum::Union{Real, Nothing}, epsilon::Real, act::F=identity) where {F}
rm_ = _reshape_into_proper_shape(running_mean, x)
rv_ = _reshape_into_proper_shape(running_var, x)
s_ = _reshape_into_proper_shape(scale, x)
b_ = _reshape_into_proper_shape(bias, x)
x_, rm, rv = _normalization_impl(
x, rm_, rv_, s_, b_, reduce_dims, training, momentum, epsilon)
x, rm_, rv_, s_, b_, reduce_dims, training, momentum, epsilon, act)
return x_, _vec(rm), _vec(rv)
end
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ end
b::Union{Nothing, <:AbstractArray}) where {F, Tw, Tx}
if b === nothing
Ty = promote_type(Tw, Tx)
Tact = Core.Compiler.return_type(act, Tuple{Ty})
Tact = Core.Compiler._return_type(act, Tuple{Ty})
return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty
end
Ty = promote_type(Tw, Tx, eltype(b))
Tact = Core.Compiler.return_type(act, Tuple{Ty})
Tact = Core.Compiler._return_type(act, Tuple{Ty})
return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty
end

Expand Down

0 comments on commit 8271ed8

Please sign in to comment.