diff --git a/src/activations.jl b/src/activations.jl index d9fc131e6..c7ee13826 100644 --- a/src/activations.jl +++ b/src/activations.jl @@ -339,6 +339,17 @@ end const gelu_λ = √(2 / π) const gelu_2λ = √(8 / π) +function deriv_gelu(x) + α = oftf(x, 0.044715) + α2 = oftf(x, 0.08943) + λλ = oftf(x, gelu_2λ) + x2 = x * x + t = muladd(x2, α, one(x)) + Ω = sigmoid_fast(λλ * x * t) + dσ = conj(Ω * (1 - Ω)) + muladd(dσ * λλ * muladd(x2, α2, t), x, Ω) +end + """ swish(x) = x * σ(x) @@ -853,7 +864,7 @@ UNARY_ACTS = [ # f, dfdx (:relu6, :((Ω>0) & (Ω<6))), # rrelu is random, can't write a rule. (:elu, :(deriv_elu(Ω))), - # gelu + (:gelu, :(deriv_gelu(x))), (:swish, :(Ω + sigmoid_fast(x) * (1 - Ω))), (:hardswish, :(deriv_hardswish(x))), # lisht