diff --git a/src/NNlib.jl b/src/NNlib.jl index c6212554e..da0d24025 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -2,7 +2,7 @@ module NNlib using Requires, Libdl -export σ, sigmoid, relu, leakyrelu, elu, swish, selu, softplus, softsign, logσ, logsigmoid, +export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ, logsigmoid, softmax, logsoftmax, maxpool, meanpool include("numeric.jl") diff --git a/src/activation.jl b/src/activation.jl index 9f45b25d4..b5cc8017d 100644 --- a/src/activation.jl +++ b/src/activation.jl @@ -66,6 +66,20 @@ You can also specify the coefficient explicitly, e.g. `elu(x, 1)`. """ elu(x, α = one(x)) = ifelse(x ≥ 0, x/1, α * (exp(x) - one(x))) +""" + gelu(x) = 0.5x*(1 + tanh(√(2/π)*(x + 0.044715x^3))) + +[Gaussian Error Linear Unit](https://arxiv.org/pdf/1606.08415.pdf) +activation function. +""" +function gelu(x) + λ = oftype(x/1, √(2/π)) + α = oftype(x/1, 0.044715) + h = oftype(x/1, 0.5) + h * x * (one(x) + tanh(λ * (x + α * x^3))) +end + + """ swish(x) = x * σ(x) diff --git a/test/activation.jl b/test/activation.jl index d3400a257..9222721e7 100644 --- a/test/activation.jl +++ b/test/activation.jl @@ -1,4 +1,4 @@ -ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, swish, selu, softplus, softsign]; +ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign]; function test_value_float_precision_preserving(a) @testset "$(a): " begin @@ -42,6 +42,7 @@ end @test relu(0.0) == 0.0 @test leakyrelu(0.0) == 0.0 @test elu(0.0) == 0.0 + @test gelu(0.0) == 0.0 @test swish(0.0) == 0.0 @test softplus(0.0) ≈ log(2.0) @test softsign(0.0) == 0.0 @@ -51,6 +52,7 @@ end @test relu(1.0) == 1.0 @test leakyrelu(1.0) == 1.0 @test elu(1.0) == 1.0 + @test gelu(1.0) == 0.8411919906082768 @test swish(1.0) == 1.0 / (1.0 + exp(-1.0)) @test softplus(1.0) ≈ log(exp(1.0) + 1.0) @test softsign(1.0) == 0.5 @@ -60,6 +62,7 @@ end @test relu(-1.0) == 0.0 @test leakyrelu(-1.0) == -0.01 @test elu(-1.0) == exp(-1.0) - 1.0 + @test gelu(-1.0) == -0.15880800939172324 @test swish(-1.0) == -1.0 / (1.0 + exp(1.0)) @test softplus(-1.0) ≈ log(exp(-1.0) + 1.0) @test softsign(-1.0) == -0.5