Skip to content

Commit

Permalink
Merge pull request #83 from chengchingwen/master
Browse files Browse the repository at this point in the history
implement gelu
  • Loading branch information
MikeInnes authored Jan 28, 2019
2 parents 085adb7 + 4b4c3a3 commit 09482fb
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module NNlib

using Requires, Libdl

export σ, sigmoid, relu, leakyrelu, elu, swish, selu, softplus, softsign, logσ, logsigmoid,
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ, logsigmoid,
softmax, logsoftmax, maxpool, meanpool

include("numeric.jl")
Expand Down
14 changes: 14 additions & 0 deletions src/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,20 @@ You can also specify the coefficient explicitly, e.g. `elu(x, 1)`.
"""
elu(x, α = one(x)) = ifelse(x 0, x/1, α * (exp(x) - one(x)))

"""
gelu(x) = 0.5x*(1 + tanh(√(2/π)*(x + 0.044715x^3)))
[Gaussian Error Linear Unit](https://arxiv.org/pdf/1606.08415.pdf)
activation function.
"""
function gelu(x)
λ = oftype(x/1, (2/π))
α = oftype(x/1, 0.044715)
h = oftype(x/1, 0.5)
h * x * (one(x) + tanh* (x + α * x^3)))
end


"""
swish(x) = x * σ(x)
Expand Down
5 changes: 4 additions & 1 deletion test/activation.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, swish, selu, softplus, softsign];
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign];

function test_value_float_precision_preserving(a)
@testset "$(a): " begin
Expand Down Expand Up @@ -42,6 +42,7 @@ end
@test relu(0.0) == 0.0
@test leakyrelu(0.0) == 0.0
@test elu(0.0) == 0.0
@test gelu(0.0) == 0.0
@test swish(0.0) == 0.0
@test softplus(0.0) log(2.0)
@test softsign(0.0) == 0.0
Expand All @@ -51,6 +52,7 @@ end
@test relu(1.0) == 1.0
@test leakyrelu(1.0) == 1.0
@test elu(1.0) == 1.0
@test gelu(1.0) == 0.8411919906082768
@test swish(1.0) == 1.0 / (1.0 + exp(-1.0))
@test softplus(1.0) log(exp(1.0) + 1.0)
@test softsign(1.0) == 0.5
Expand All @@ -60,6 +62,7 @@ end
@test relu(-1.0) == 0.0
@test leakyrelu(-1.0) == -0.01
@test elu(-1.0) == exp(-1.0) - 1.0
@test gelu(-1.0) == -0.15880800939172324
@test swish(-1.0) == -1.0 / (1.0 + exp(1.0))
@test softplus(-1.0) log(exp(-1.0) + 1.0)
@test softsign(-1.0) == -0.5
Expand Down

0 comments on commit 09482fb

Please sign in to comment.