Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

feat: use sleefpirates #101

Merged
merged 9 commits into from
Jul 25, 2024
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.33"
version = "0.3.34"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand All @@ -18,6 +18,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"
Expand Down Expand Up @@ -62,6 +63,7 @@ Random = "1.10"
ReTestItems = "1.23.1"
Reexport = "1"
ReverseDiff = "1.15"
SLEEFPirates = "0.6.43"
StableRNGs = "1"
StaticArrays = "1.9"
StaticArraysCore = "1.4.3"
Expand Down
4 changes: 3 additions & 1 deletion src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ using Random: Random, AbstractRNG, rand!
using Reexport: @reexport
using StaticArraysCore: StaticArraysCore, StaticVector
using Statistics: Statistics, mean, var
using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter
using SLEEFPirates: SLEEFPirates
using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter, unrolled_mapreduce

@reexport using NNlib

const CRC = ChainRulesCore
const KA = KernelAbstractions

include("utils.jl")
include("patches.jl")

# User Facing
include("api/activation.jl")
Expand Down
10 changes: 9 additions & 1 deletion src/api/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ generic implementation.
This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be
done by the user if needed.

!!! tip

Certain activation functions are replaced with specialized implementations from
[SLEEFPirates.jl](https://github.com/JuliaSIMD/SLEEFPirates.jl) for FP32. This might
lead to faster performance but can cause slight decrease in accuracy (in the floating
point limit).

## Arguments

- `σ`: Activation function
Expand All @@ -20,7 +27,8 @@ generic implementation.
- Output Array with the same size as `x`
"""
function fast_activation!!(σ::F, x::AbstractArray) where {F}
return _fast_activation!!(__is_immutable_array_or_dual_val((x,)), σ, x)
return _fast_activation!!(
__is_immutable_array_or_dual_val((x,)), select_fastest_activation(σ, x), x)
end

function _fast_activation!!(::Val{true}, σ::F, x::AbstractArray) where {F}
Expand Down
3 changes: 2 additions & 1 deletion src/api/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ function batchnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector
running_var::Optional{<:AbstractVector}, training::Val, σ::F=identity,
momentum::Real=0.1f0, epsilon::Real=__default_epsilon(x)) where {F, N}
x_, xm, xv = _normalization(x, __value(running_mean), __value(running_var), scale, bias,
_get_batchnorm_reduce_dims(x), training, momentum, epsilon, σ)
_get_batchnorm_reduce_dims(x), training, momentum, epsilon,
select_fastest_activation(σ, x, scale, bias, running_mean, running_var))
return (x_, (; running_mean=__value(xm), running_var=__value(xv)))
end

Expand Down
4 changes: 2 additions & 2 deletions src/api/bias_activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ See also [`bias_activation!!`](@ref), [`fast_activation!!`](@ref).
"""
function bias_activation(σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F}
_bias_act_check(x, bias)
return __bias_activation_impl(σ, x, bias)
return __bias_activation_impl(select_fastest_activation(σ, x, bias), x, bias)
end

"""
Expand All @@ -30,7 +30,7 @@ See also [`bias_activation`](@ref), [`fast_activation!!`](@ref).
function bias_activation!!(
σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F}
_bias_act_check(x, bias)
return __bias_activation_impl!!(σ, x, bias)
return __bias_activation_impl!!(select_fastest_activation(σ, x, bias), x, bias)
end

_bias_act_check(x, b) = nothing
Expand Down
3 changes: 2 additions & 1 deletion src/api/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ function fused_conv_bias_activation(
b::AbstractArray{<:Number, N}, cdims::ConvDims) where {F, N}
__depwarn("Passing `bias` as a N-D array is deprecated, pass it as a Vector instead",
:fused_conv_bias_activation)
return fused_conv_bias_activation(σ, weight, x, _vec(b), cdims)
return fused_conv_bias_activation(
select_fastest_activation(σ, weight, x, b), weight, x, _vec(b), cdims)
end

function fused_conv_bias_activation(
Expand Down
4 changes: 2 additions & 2 deletions src/api/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ multiple operations.
"""
function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix,
b::Optional{<:AbstractVector}) where {F}
return fused_dense_bias_activation(
σ, __is_immutable_array_or_dual_val((weight, x, b)), weight, x, b)
return fused_dense_bias_activation(select_fastest_activation(σ, weight, x, b),
__is_immutable_array_or_dual_val((weight, x, b)), weight, x, b)
end

for (check, fop) in (
Expand Down
3 changes: 2 additions & 1 deletion src/api/groupnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector

sz = size(x)
x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N])
x_ = _groupnorm_impl(x_reshaped, scale, bias, _get_groupnorm_reduce_dims(x), epsilon, σ)
x_ = _groupnorm_impl(x_reshaped, scale, bias, _get_groupnorm_reduce_dims(x), epsilon,
select_fastest_activation(σ, x, scale, bias, x_reshaped))

return reshape(x_, sz)
end
Expand Down
5 changes: 3 additions & 2 deletions src/api/instancenorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ function instancenorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVec
σ::F=identity, epsilon::Real=__default_epsilon(x)) where {N, F}
_test_valid_instancenorm_arguments(x)

x_, xm, xv = _normalization(x, nothing, nothing, scale, bias,
_get_instancenorm_reduce_dims(x), training, nothing, epsilon, σ)
x_, xm, xv = _normalization(
x, nothing, nothing, scale, bias, _get_instancenorm_reduce_dims(x),
training, nothing, epsilon, select_fastest_activation(σ, x, scale, bias))

return x_, (; running_mean=xm, running_var=xv)
end
Expand Down
3 changes: 2 additions & 1 deletion src/api/layernorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@ function layernorm(
bias::Optional{<:AbstractArray{<:Number, N}}, σ::F=identity,
dims=Colon(), epsilon::Real=__default_epsilon(x)) where {N, F}
μ, σ² = fast_mean_var(x; dims, corrected=false)
return _affine_normalize(σ, x, μ, σ², scale, bias, epsilon)
return _affine_normalize(
select_fastest_activation(σ, x, scale, bias), x, μ, σ², scale, bias, epsilon)
end
108 changes: 108 additions & 0 deletions src/impl/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,111 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!

return CRC.rrule_via_ad(cfg, _fast_activation, σ, x)
end

# Specialized functions that use SLEEFPirates.jl to speed up the activation functions
sigmoid_fast_sleefpirates(x::Number) = SLEEFPirates.sigmoid_fast(x)

softplus_sleefpirates(x::Number) = SLEEFPirates.softplus(x)

logsigmoid_sleefpirates(x::Number) = -softplus_sleefpirates(-x)

gelu_sleefpirates(x::Number) = SLEEFPirates.gelu(x)

const gelu_λ = √(2 / π)
const gelu_2λ = √(8 / π)

function ∂gelu_sleefpirates(x::Number)
α = oftype(x, 0.044715)
α2 = oftype(x, 0.08943)
λλ = oftype(x, gelu_2λ)
x2 = Base.FastMath.mul_fast(x, x)
t = muladd(x2, α, one(x))
Ω = sigmoid_fast_sleefpirates(λλ * x * t)
dσ = conj(Ω * (1 - Ω))
return muladd(dσ * λλ * muladd(x2, α2, t), x, Ω)
end

swish_sleefpirates(x::Number) = Base.FastMath.mul_fast(x, sigmoid_fast_sleefpirates(x))

lisht_sleefpirates(x::Number) = Base.FastMath.mul_fast(x, tanh_fast_sleefpirates(x))

tanh_sleefpirates(x::Number) = SLEEFPirates.tanh(x)

tanh_fast_sleefpirates(x::Number) = SLEEFPirates.tanh_fast(x)

for (f, dfdx) in [
#! format: off
(:sigmoid_fast_sleefpirates, :(conj(Base.FastMath.mul_fast(Ω, Base.FastMath.sub_fast(1, Ω))))),
(:softplus_sleefpirates, :(sigmoid_fast_sleefpirates(x))),
(:logsigmoid_sleefpirates, :(sigmoid_fast_sleefpirates(-x))),
(:gelu_sleefpirates, :(∂gelu_sleefpirates(x))),
(:swish_sleefpirates, :(Base.FastMath.add_fast(Ω, Base.FastMath.mul_fast(sigmoid_fast_sleefpirates(x), Base.FastMath.sub_fast(1, Ω))))),
(:tanh_sleefpirates, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))),
(:tanh_fast_sleefpirates, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω)))))
#! format: on
]
@eval CRC.@scalar_rule($f(x), $dfdx)

pullback = Symbol(:broadcasted_, f, :_pullback)
@eval function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof($f),
x::Union{Numeric, Broadcast.Broadcasted})
Ω = $f.(x)
function $pullback(dΩ)
x_thunk = CRC.InplaceableThunk(
dx -> @.(dx+=dΩ * $dfdx), CRC.@thunk @.(dΩ*$dfdx))
return ∂∅, ∂∅, x_thunk
end
return Ω, $pullback
end
end

# Enzyme works for all of these except `gelu`.
# See https://github.com/EnzymeAD/Enzyme.jl/issues/1671
function EnzymeRules.augmented_primal(
cfg::EnzymeRules.ConfigWidth{1}, func::EnzymeCore.Const{typeof(gelu_sleefpirates)},
::Type{<:EnzymeCore.Active}, x::EnzymeCore.Active{<:Number})
primal = EnzymeRules.needs_primal(cfg) ? func.val(x.val) : nothing
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
end

function EnzymeRules.reverse(
cfg::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu_sleefpirates)},
dret::EnzymeCore.Active, ::Nothing, x::EnzymeCore.Active{<:Number})
return (dret.val * ∂gelu_sleefpirates(x.val),)
end

# Convert to SLEEFPirates.jl
function select_fastest_activation(f::F, xs...) where {F}
return select_fastest_activation(
f, internal_operation_mode(xs), unrolled_mapreduce(__eltype, promote_type, xs))
end

select_fastest_activation(f::F, ::AbstractInternalArrayOpMode, ::Type{T}) where {F, T} = f
function select_fastest_activation(f::F, ::LoopedArrayOp, ::Type{T}) where {F, T}
return sleefpirates_activation(f, T)
end

CRC.@non_differentiable select_fastest_activation(::Any...)
EnzymeRules.inactive_noinl(::typeof(select_fastest_activation), ::Any...) = nothing

sleefpirates_activation(f::F, ::Type{T}) where {F, T} = f
sleefpirates_activation(f::F, ::Type{Float32}) where {F} = sleefpirates_activation(f)

for (fbase, ffast) in [
#! format: off
(NNlib.sigmoid_fast, sigmoid_fast_sleefpirates),
(NNlib.softplus, softplus_sleefpirates),
(NNlib.logsigmoid, logsigmoid_sleefpirates),
(NNlib.gelu, gelu_sleefpirates),
(NNlib.swish, swish_sleefpirates),
(NNlib.lisht, lisht_sleefpirates),
(Base.tanh, tanh_sleefpirates),
(NNlib.tanh_fast, tanh_fast_sleefpirates)
#! format: on
]
@eval sleefpirates_activation(::typeof($fbase)) = $ffast
end
sleefpirates_activation(f::F) where {F} = f

CRC.@non_differentiable sleefpirates_activation(::Any...)
EnzymeRules.inactive_noinl(::typeof(sleefpirates_activation), ::Any...) = nothing
70 changes: 70 additions & 0 deletions src/patches.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# This is type-piracy but needed to fix a blocking issue. TODO: upstream to NNlib
# Enzyme causes a "active variables passed by value to jl_new_task are not yet supported"
# warning without this patch.
function EnzymeRules.augmented_primal(
cfg::EnzymeRules.ConfigWidth, func::EnzymeCore.Const{typeof(NNlib.batched_mul!)},
::Type{RT}, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}},
A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}},
B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT}
if typeof(C) <: EnzymeCore.Duplicated || typeof(C) <: EnzymeCore.BatchDuplicated
func.val(C.val, A.val, B.val)
end

primal = EnzymeRules.needs_primal(cfg) ? C.val : nothing
shadow = EnzymeRules.needs_shadow(cfg) ? C.dval : nothing

cache_A = (EnzymeRules.overwritten(cfg)[3] &&
!(typeof(C) <: EnzymeCore.Const) &&
!(typeof(B) <: EnzymeCore.Const)) ? copy(A.val) : nothing
cache_B = (EnzymeRules.overwritten(cfg)[3] &&
!(typeof(C) <: EnzymeCore.Const) &&
!(typeof(A) <: EnzymeCore.Const)) ? copy(B.val) : nothing

return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B))
end

function EnzymeRules.reverse(
cfg::EnzymeRules.ConfigWidth, func::EnzymeCore.Const{typeof(NNlib.batched_mul!)},
::Type{RT}, cache, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}},
A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}},
B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT}
cache_A, cache_B = cache

if !(typeof(B) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const)
if !EnzymeRules.overwritten(cfg)[3]
cache_A = A.val
end
end

if !(typeof(A) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const)
if !EnzymeRules.overwritten(cfg)[3]
cache_B = B.val
end
end

dCs = C.dval
dAs = (typeof(A) <: EnzymeCore.Const) ? dCs : A.dval
dBs = (typeof(B) <: EnzymeCore.Const) ? dCs : B.dval

if EnzymeRules.width(cfg) == 1
dCs = (dCs,)
dAs = (dAs,)
dBs = (dBs,)
end

for (dC, dA, dB) in zip(dCs, dAs, dBs)
if !(typeof(C) <: EnzymeCore.Const) && dC !== C.val
if !(typeof(A) <: EnzymeCore.Const) && dA !== A.val
NNlib.batched_mul!(dA, dC, NNlib.batched_adjoint(B.val), true, true)
end

if !(typeof(B) <: EnzymeCore.Const) && dB !== B.val
NNlib.batched_mul!(dB, NNlib.batched_adjoint(A.val), dC, true, true)
end

dC .= 0
end
end

return ntuple(Returns(nothing), 3)
end
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const Optional{T} = Union{Nothing, T}

const Numeric = Union{AbstractArray{<:T}, T} where {T <: Number}
const ∂∅ = NoTangent()

# Bias Gradient -- can't be used inside gradient rules
Expand Down
51 changes: 51 additions & 0 deletions test/common_ops/activation_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
@testitem "Activation Functions" tags=[:common_ops] setup=[SharedTestSetup] begin
rng = StableRNG(1234)

apply_act(f::F, x) where {F} = sum(abs2, f.(x))
apply_act_fast(f::F, x) where {F} = sum(abs2, fast_activation!!(f, copy(x)))

@testset "$mode" for (mode, aType, on_gpu) in MODES
@testset "$f: $T" for f in [identity, relu, sigmoid, sigmoid_fast, softplus,
logsigmoid, gelu, swish, lisht, tanh, tanh_fast],
T in [Float16, Float32, Float64]

x = rand(rng, T, 4, 3) |> aType

y1 = apply_act(f, x)
y2 = apply_act_fast(f, x)

fp16 = T == Float16
atol = fp16 ? 1.0f-1 : 1.0f-3
rtol = fp16 ? 1.0f-1 : 1.0f-3

@test y1≈y2 atol=atol rtol=rtol
@test eltype(y1) == T

@test @inferred(apply_act(f, x)) isa Any
@test @inferred(apply_act_fast(f, x)) isa Any

@jet apply_act_fast(f, x)

@test @inferred(Zygote.gradient(apply_act, f, x)) isa Any

@eval @test_gradients apply_act $f $x gpu_testing=$on_gpu atol=$atol rtol=$rtol skip_finite_differences=$fp16

∂x1 = Zygote.gradient(apply_act, f, x)[2]
∂x2 = Zygote.gradient(apply_act_fast, f, x)[2]

@test ∂x1≈∂x2 atol=atol rtol=rtol

if !on_gpu
∂x1_enz = Enzyme.make_zero(x)
Enzyme.autodiff(
Reverse, apply_act, Active, Const(f), Duplicated(x, ∂x1_enz))
@test ∂x1≈∂x1_enz atol=atol rtol=rtol

∂x2_enz = Enzyme.make_zero(x)
Enzyme.autodiff(
Reverse, apply_act_fast, Active, Const(f), Duplicated(x, ∂x2_enz))
@test ∂x2≈∂x2_enz atol=atol rtol=rtol
end
end
end
end
Loading
Loading