Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: add sleefpirates for CPU activation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 23, 2024
1 parent d179933 commit 481cee2
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 3 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.31"
version = "0.3.32"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand All @@ -18,6 +18,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"

Expand Down Expand Up @@ -61,6 +62,7 @@ Random = "1.10"
ReTestItems = "1.23.1"
Reexport = "1"
ReverseDiff = "1.15"
SLEEFPirates = "0.6.43"
StableRNGs = "1"
Statistics = "1.10"
Test = "1.10"
Expand Down
1 change: 1 addition & 0 deletions src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, σ, ∇conv_data, ∇conv
using Random: Random, AbstractRNG, rand!
using Reexport: @reexport
using Statistics: Statistics, mean, var
using SLEEFPirates: SLEEFPirates
using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter

@reexport using NNlib
Expand Down
7 changes: 7 additions & 0 deletions src/api/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ generic implementation.
This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be
done by the user if needed.
!!! tip
Certain activation functions are replaced with specialized implementations from
[SLEEFPirates.jl](https://github.com/JuliaSIMD/SLEEFPirates.jl). This might lead to
faster performance but can cause slight decrease in accuracy (in the floating point
limit).
## Arguments
- `σ`: Activation function
Expand Down
28 changes: 27 additions & 1 deletion src/impl/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ end

function _fast_activation!(
::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F}
σ_sleef = __sleefpirates_activation(σ)
@simd ivdep for I in eachindex(y, x)
@inbounds y[I] = σ(x[I])
@inbounds y[I] = σ_sleef(x[I])
end
end
function _fast_activation!(opmode, y::AbstractArray, σ::F, x::AbstractArray) where {F}
Expand Down Expand Up @@ -87,3 +88,28 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!

return CRC.rrule_via_ad(cfg, _fast_activation, σ, x)
end

# Specialized functions that use SLEEFPirates.jl to speed up the activation functions
sigmoid_fast_sleefpirates(x) = SLEEFPirates.sigmoid_fast(x)
softplus_sleefpirates(x) = SLEEFPirates.softplus(x)
logsigmoid_sleefpirates(x) = -softplus_sleefpirates(-x)
elu_sleefpirates(x, α=1) = SLEEFPirates.Elu(α)(x)
gelu_sleefpirates(x) = SLEEFPirates.gelu(x)
swish_sleefpirates(x) = Base.FastMath.mul_fast(x, sigmoid_fast_sleefpirates(x))
lisht_sleefpirates(x) = Base.FastMath.mul_fast(x, tanh_fast_sleefpirates(x))
tanh_sleefpirates(x) = SLEEFPirates.tanh(x)
tanh_fast_sleefpirates(x) = SLEEFPirates.tanh_fast(x)

# Convert to SLEEFPirates.jl
__sleefpirates_activation(f::F, ::Type{T}) where {F, T} = f
__sleefpirates_activation(f::F, ::Type{Float32}) where {F} = __sleefpirates_activation(f)
__sleefpirates_activation(f::F, ::Type{Float64}) where {F} = __sleefpirates_activation(f)

for (fbase, ffast) in ((NNlib.sigmoid_fast, sigmoid_fast_sleefpirates),
(NNlib.softplus, softplus_sleefpirates), (NNlib.logsigmoid, logsigmoid_sleefpirates),
(NNlib.elu, elu_sleefpirates), (NNlib.gelu, gelu_sleefpirates),
(NNlib.swish, swish_sleefpirates), (NNlib.lisht, lisht_sleefpirates),
(Base.tanh, tanh_sleefpirates), (NNlib.tanh_fast, tanh_fast_sleefpirates))
@eval __sleefpirates_activation(::typeof($fbase)) = $ffast
end
__sleefpirates_activation(f::F) where {F} = f
3 changes: 2 additions & 1 deletion src/impl/bias_activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ function __bias_activation_impl!(
opmode = internal_operation_mode((y, x, bias))
bias_ = __reshape_bias_into_xdims(x, bias)
if opmode isa LoopedArrayOp
bc = Broadcast.instantiate(Broadcast.broadcasted +, x, bias_))
σ_sleef = __sleefpirates_activation(σ)
bc = Broadcast.instantiate(Broadcast.broadcasted(σ_sleef +, x, bias_))
@simd ivdep for I in eachindex(bc)
@inbounds y[I] = bc[I]
end
Expand Down

0 comments on commit 481cee2

Please sign in to comment.