From 8271ed8cddb5929212652fb0193568e62388d14c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 19 Apr 2024 15:28:32 -0400 Subject: [PATCH] Allow fusing activation into normalization --- Project.toml | 2 +- ext/LuxLibTrackercuDNNExt.jl | 6 ++-- ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl | 6 ++-- src/LuxLib.jl | 3 ++ src/api/batchnorm.jl | 10 ++++--- src/api/fast_activation.jl | 26 ++++++++++++++++ src/api/layernorm.jl | 3 +- src/impl/fast_activation.jl | 44 ++++++++++++++++++++++++++++ src/impl/fused_conv.jl | 2 +- src/impl/normalization.jl | 25 ++++++++-------- src/utils.jl | 4 +-- 11 files changed, 103 insertions(+), 28 deletions(-) create mode 100644 src/api/fast_activation.jl create mode 100644 src/impl/fast_activation.jl diff --git a/Project.toml b/Project.toml index 42b06efe..bac45980 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.13" +version = "0.3.14" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/ext/LuxLibTrackercuDNNExt.jl b/ext/LuxLibTrackercuDNNExt.jl index 1694ef8e..5c7ca026 100644 --- a/ext/LuxLibTrackercuDNNExt.jl +++ b/ext/LuxLibTrackercuDNNExt.jl @@ -17,12 +17,12 @@ const TR_BNParamType = Union{ function LuxLib.batchnorm( x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType, bias::TR_BNParamType, - running_mean::TR_BNParamType, running_var::TR_BNParamType; - momentum::Real, training::Val, epsilon::Real) + running_mean::TR_BNParamType, running_var::TR_BNParamType, + σ::F=identity; momentum::Real, training::Val, epsilon::Real) where {F} rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) # NOTE: The following returns a tracked tuple so we can't do `first` on it x_ = LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1] - return x_, (; running_mean=rm, running_var=rv) + return LuxLib.fast_activation!!(σ, x_), (; running_mean=rm, running_var=rv) end for RM in (:TrackedVector, :Nothing, :AbstractVector), diff --git a/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl index 044929ea..fda82360 100644 --- a/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl +++ b/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -20,11 +20,11 @@ const CUDNN_BN_ARRAY_TYPE = Union{ const BNParamType = Union{Nothing, CuVector{<:Union{Float32, Float64}}} function LuxLib.batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType, - running_mean::BNParamType, running_var::BNParamType; - momentum::Real, training::Val, epsilon::Real) + running_mean::BNParamType, running_var::BNParamType, σ::F=identity; + momentum::Real, training::Val, epsilon::Real) where F rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training) x_ = first(LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)) - return x_, (; running_mean=rm, running_var=rv) + return LuxLib.fast_activation!!(σ, x_), (; running_mean=rm, running_var=rv) end @inline function LuxLib.batchnorm_cudnn( diff --git a/src/LuxLib.jl b/src/LuxLib.jl index 8f132648..8eadfffa 100644 --- a/src/LuxLib.jl +++ b/src/LuxLib.jl @@ -29,6 +29,7 @@ include("impl/groupnorm.jl") include("impl/normalization.jl") include("impl/fused_dense.jl") include("impl/fused_conv.jl") +include("impl/fast_activation.jl") # User Facing include("api/batchnorm.jl") @@ -38,8 +39,10 @@ include("api/instancenorm.jl") include("api/layernorm.jl") include("api/dense.jl") include("api/conv.jl") +include("api/fast_activation.jl") export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout export fused_dense_bias_activation, fused_conv_bias_activation +export fast_activation!! end diff --git a/src/api/batchnorm.jl b/src/api/batchnorm.jl index 2161b56f..73f8b01a 100644 --- a/src/api/batchnorm.jl +++ b/src/api/batchnorm.jl @@ -1,5 +1,6 @@ @doc doc""" - batchnorm(x, scale, bias, running_mean, running_var; momentum, epsilon, training) + batchnorm(x, scale, bias, running_mean, running_var, σ=identity; momentum, epsilon, + training) Batch Normalization. For details see [1]. @@ -14,6 +15,7 @@ accordingly. - `bias`: Bias factor (``\beta``) (can be `nothing`) - `running_mean`: Running mean (can be `nothing`) - `running_var`: Running variance (can be `nothing`) + - `σ`: Activation function (default: `identity`) ## Keyword Arguments @@ -41,11 +43,11 @@ fallback is used which is not highly optimized. function batchnorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector}, bias::Union{Nothing, <:AbstractVector}, running_mean::Union{Nothing, <:AbstractVector}, - running_var::Union{Nothing, <:AbstractVector}; - momentum::Real, training::Val, epsilon::Real) where {N} + running_var::Union{Nothing, <:AbstractVector}, σ::F=identity; + momentum::Real, training::Val, epsilon::Real) where {F, N} x_, xm, xv = _normalization(x, _drop_forwarddiff_partials(running_mean), _drop_forwarddiff_partials(running_var), scale, bias, - _get_batchnorm_reduce_dims(x), training, momentum, epsilon) + _get_batchnorm_reduce_dims(x), training, momentum, epsilon, σ) stats = (; running_mean=_drop_forwarddiff_partials(xm), running_var=_drop_forwarddiff_partials(xv)) return (x_, stats) diff --git a/src/api/fast_activation.jl b/src/api/fast_activation.jl new file mode 100644 index 00000000..232e9dbb --- /dev/null +++ b/src/api/fast_activation.jl @@ -0,0 +1,26 @@ +""" + fast_activation!!(σ::F, x) where {F} + +Compute `σ.(x)` with the best possible implementation available. If it is possible to +rewrite `x` in-place, it does so. If `x` is an immutable array, it falls back to the +generic implementation. + +!!! note + + This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be + done by the user if needed. + +## Arguments + + - `σ`: Activation function + - `x`: Input array + +## Returns + + - Output Array with the same size as `x` +""" +@inline function fast_activation!!(σ::F, x::AbstractArray) where {F} + σ === identity && return x + ArrayInterface.can_setindex(x) && __fast_activation_impl!(σ, x) + return σ.(x) +end diff --git a/src/api/layernorm.jl b/src/api/layernorm.jl index 3cc25e93..22adaf99 100644 --- a/src/api/layernorm.jl +++ b/src/api/layernorm.jl @@ -37,6 +37,5 @@ end function layernorm(x::AbstractArray, ::Nothing, ::Nothing; dims, epsilon) _mean = mean(x; dims) - rstd = 1 ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon) - return (x .- _mean) .* rstd + return (x .- _mean) ./ (std(x; dims, mean=_mean, corrected=false) .+ epsilon) end diff --git a/src/impl/fast_activation.jl b/src/impl/fast_activation.jl new file mode 100644 index 00000000..ba170922 --- /dev/null +++ b/src/impl/fast_activation.jl @@ -0,0 +1,44 @@ +# Specialized Implementation based off NNlib._fast_broadcast with added logic from +# ArrayInterface +# If we enter here, we already know that we can setindex into the array +@inline function __fast_activation_impl!(σ::F, x::AbstractArray) where {F} + if ArrayInterface.fast_scalar_indexing(x) + bc = Broadcast.instantiate(Broadcast.broadcasted(σ, x)) + @simd ivdep for I in eachindex(bc) + @inbounds x[I] = bc[I] + end + else + @. x = σ(x) + end + return x +end + +function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, + ::typeof(__fast_activation_impl!), σ::F, x::AbstractArray{T}) where {F, T} + σ === identity && return x, @closure(Δ->(CRC.NoTangent(), CRC.NoTangent(), Δ)) + + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) + __fast_activation_impl!(σ, x) + ∇__fast_activation_impl_no_cached = @closure Δ -> begin + ∂x = only_derivative.(x, σ, NotaNumber()) .* CRC.unthunk(Δ) + return CRC.NoTangent(), CRC.NoTangent(), ∂x + end + return x, ∇__fast_activation_impl_no_cached + end + + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) + y = @. σ(x) + ∇__fast_activation_impl_cached_crc = @closure Δ -> begin + ∂z = only_derivative.(y, σ, x) .* CRC.unthunk(Δ) + return CRC.NoTangent(), CRC.NoTangent(), ∂z + end + return z, ∇__fast_activation_impl_cached_crc + end + + y, pb_f = CRC.rrule_via_ad(cfg, broadcast, σ, x) + ∇__fast_activation_impl_cached = @closure Δ -> begin + _, _, ∂x = pb_f(Δ) + return CRC.NoTangent(), CRC.NoTangent(), ∂x + end + return y, ∇__fast_activation_impl_cached +end diff --git a/src/impl/fused_conv.jl b/src/impl/fused_conv.jl index 6746b465..d861474f 100644 --- a/src/impl/fused_conv.jl +++ b/src/impl/fused_conv.jl @@ -5,7 +5,7 @@ end # This implementation is different from `conv_bias_act` in that it defines the proper rrules -# and fuses operations into a single kernel if it is possible. Unfortinately there are +# and fuses operations into a single kernel if it is possible. Unfortunately there are # certain configurations where CUDNN allows caching intermediates, but we don't do that rn. @inline function __fused_conv_bias_activation_impl( diff --git a/src/impl/normalization.jl b/src/impl/normalization.jl index 8a8ee48b..693baa12 100644 --- a/src/impl/normalization.jl +++ b/src/impl/normalization.jl @@ -41,37 +41,38 @@ end return Expr(:block, calls...) end -@generated function _affine_normalize(x::AbstractArray, xmean::ST, xvar::ST, - scale::A, bias::A, epsilon::Real) where {ST, A} +@generated function _affine_normalize(act::F, x::AbstractArray, xmean::ST, xvar::ST, + scale::A, bias::A, epsilon::Real) where {F, ST, A} if A != Nothing return quote - x_norm = (x .- xmean) ./ sqrt.(xvar .+ epsilon) - return scale .* x_norm .+ bias + x_norm = @.((x - xmean)/sqrt(xvar + epsilon)) + return @. act(scale * x_norm + bias) end else - return :(return (x .- xmean) ./ sqrt.(xvar .+ epsilon)) + return :(return @. act((x - xmean) / sqrt(xvar + epsilon))) end end -function _normalization_impl(x::AbstractArray, running_mean::R, running_var::R, - scale::A, bias::A, r::Val{reduce_dims}, training::Val, - momentum::Union{Real, Nothing}, epsilon::Real) where {R, A, reduce_dims} +function _normalization_impl( + x::AbstractArray, running_mean::R, running_var::R, scale::A, bias::A, + r::Val{reduce_dims}, training::Val, momentum::Union{Real, Nothing}, + epsilon::Real, act::F=identity) where {R, A, reduce_dims, F} _stats = _get_batch_statistics(x, running_mean, running_var, r, training, momentum) (batchmean, batchvar), (running_mean, running_var) = _stats - x_norm = _affine_normalize(x, batchmean, batchvar, scale, bias, epsilon) + x_norm = _affine_normalize(act, x, batchmean, batchvar, scale, bias, epsilon) return (x_norm, running_mean, running_var) end function _normalization(x::AbstractArray, running_mean::Union{Nothing, <:AbstractVector}, running_var::Union{Nothing, <:AbstractVector}, scale::Union{Nothing, <:AbstractVector}, - bias::Union{Nothing, <:AbstractVector}, reduce_dims::Val, - training::Val, momentum::Union{Real, Nothing}, epsilon::Real) + bias::Union{Nothing, <:AbstractVector}, reduce_dims::Val, training::Val, + momentum::Union{Real, Nothing}, epsilon::Real, act::F=identity) where {F} rm_ = _reshape_into_proper_shape(running_mean, x) rv_ = _reshape_into_proper_shape(running_var, x) s_ = _reshape_into_proper_shape(scale, x) b_ = _reshape_into_proper_shape(bias, x) x_, rm, rv = _normalization_impl( - x, rm_, rv_, s_, b_, reduce_dims, training, momentum, epsilon) + x, rm_, rv_, s_, b_, reduce_dims, training, momentum, epsilon, act) return x_, _vec(rm), _vec(rv) end diff --git a/src/utils.jl b/src/utils.jl index 66f58fee..84f10362 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -113,11 +113,11 @@ end b::Union{Nothing, <:AbstractArray}) where {F, Tw, Tx} if b === nothing Ty = promote_type(Tw, Tx) - Tact = Core.Compiler.return_type(act, Tuple{Ty}) + Tact = Core.Compiler._return_type(act, Tuple{Ty}) return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty end Ty = promote_type(Tw, Tx, eltype(b)) - Tact = Core.Compiler.return_type(act, Tuple{Ty}) + Tact = Core.Compiler._return_type(act, Tuple{Ty}) return isconcretetype(Tact) ? promote_type(Ty, Tact) : Ty end