Skip to content

Commit

Permalink
implment gelu
Browse files Browse the repository at this point in the history
test for gelu

fix indent
  • Loading branch information
chengchingwen committed Jan 14, 2019
1 parent 085adb7 commit 4b4c3a3
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module NNlib

using Requires, Libdl

export σ, sigmoid, relu, leakyrelu, elu, swish, selu, softplus, softsign, logσ, logsigmoid,
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ, logsigmoid,
softmax, logsoftmax, maxpool, meanpool

include("numeric.jl")
Expand Down
14 changes: 14 additions & 0 deletions src/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,20 @@ You can also specify the coefficient explicitly, e.g. `elu(x, 1)`.
"""
elu(x, α = one(x)) = ifelse(x 0, x/1, α * (exp(x) - one(x)))

"""
gelu(x) = 0.5x*(1 + tanh(√(2/π)*(x + 0.044715x^3)))
[Gaussian Error Linear Unit](https://arxiv.org/pdf/1606.08415.pdf)
activation function.
"""
function gelu(x)
λ = oftype(x/1, (2/π))
α = oftype(x/1, 0.044715)
h = oftype(x/1, 0.5)
h * x * (one(x) + tanh* (x + α * x^3)))
end


"""
swish(x) = x * σ(x)
Expand Down
5 changes: 4 additions & 1 deletion test/activation.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, swish, selu, softplus, softsign];
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign];

function test_value_float_precision_preserving(a)
@testset "$(a): " begin
Expand Down Expand Up @@ -42,6 +42,7 @@ end
@test relu(0.0) == 0.0
@test leakyrelu(0.0) == 0.0
@test elu(0.0) == 0.0
@test gelu(0.0) == 0.0
@test swish(0.0) == 0.0
@test softplus(0.0) log(2.0)
@test softsign(0.0) == 0.0
Expand All @@ -51,6 +52,7 @@ end
@test relu(1.0) == 1.0
@test leakyrelu(1.0) == 1.0
@test elu(1.0) == 1.0
@test gelu(1.0) == 0.8411919906082768
@test swish(1.0) == 1.0 / (1.0 + exp(-1.0))
@test softplus(1.0) log(exp(1.0) + 1.0)
@test softsign(1.0) == 0.5
Expand All @@ -60,6 +62,7 @@ end
@test relu(-1.0) == 0.0
@test leakyrelu(-1.0) == -0.01
@test elu(-1.0) == exp(-1.0) - 1.0
@test gelu(-1.0) == -0.15880800939172324
@test swish(-1.0) == -1.0 / (1.0 + exp(1.0))
@test softplus(-1.0) log(exp(-1.0) + 1.0)
@test softsign(-1.0) == -0.5
Expand Down

0 comments on commit 4b4c3a3

Please sign in to comment.