Skip to content

Commit

Permalink
Merge pull request #94 from FluxML/sf/overhaul
Browse files Browse the repository at this point in the history
Major overhaul of NNlib
  • Loading branch information
staticfloat authored Mar 28, 2019
2 parents 11f840d + 936e71a commit 60b2b92
Show file tree
Hide file tree
Showing 35 changed files with 3,726 additions and 1,767 deletions.
3 changes: 3 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ notifications:
email: false
git:
depth: 99999999
env:
# Disable test fuzzing for the moment, as we're a little too slow for Travis
- NNLIB_TEST_FUZZING=false

# Submit to Codecov
after_success:
Expand Down
59 changes: 14 additions & 45 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,31 +1,22 @@
# This file is machine-generated - editing it directly is not advised

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ff2595695fc4f14427358ce2593f867085c45dcb"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.2.0"

[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
[[Crayons]]
deps = ["Test"]
git-tree-sha1 = "3017c662a988bcb8a3f43306a793617c6524d476"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "1.0.0"

[[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

Expand All @@ -36,31 +27,14 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[MacroTools]]
deps = ["Compat"]
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.4.4"

[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -71,16 +45,9 @@ git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "0.5.2"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"

[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

Expand All @@ -96,9 +63,11 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[UUIDs]]
deps = ["Random"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[TimerOutputs]]
deps = ["Crayons", "Printf", "Test", "Unicode"]
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"

[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
[deps]
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
3 changes: 1 addition & 2 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
julia 0.7-
julia 1.0
Requires
MacroTools
31 changes: 21 additions & 10 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
module NNlib
using Requires, TimerOutputs

using Requires, Libdl

export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ, logsigmoid,
softmax, logsoftmax, maxpool, meanpool

include("numeric.jl")
# Include APIs
include("dim_helpers.jl")
include("activation.jl")
include("softmax.jl")
include("logsoftmax.jl")
include("linalg.jl")
include("gemm.jl")
include("conv.jl")
include("cubroadcast.jl")
include("pooling.jl")

## Include implementations
include("impl/padding_edges.jl")

# Direct implementations of convolutional and depthwise-convolutional algorithms
include("impl/conv_direct.jl")
include("impl/depthwiseconv_direct.jl")
# im2col implementations of convolutional and depthwise-convolutional algorithms
include("impl/conv_im2col.jl")
include("impl/depthwiseconv_im2col.jl")

# Direct implementations of pooling
include("impl/pooling_direct.jl")

to = TimerOutput()

end # module
end # module NNlib
18 changes: 11 additions & 7 deletions src/activation.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ,
logsigmoid

"""
σ(x) = 1 / (1 + exp(-x))
Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation
function.
"""
σ(x) = one(x) / (one(x) + exp(-x))

const sigmoid = σ

# ForwardDiff numerical stability hack
σ_stable(x) = ifelse(x < -80, zero(x), one(x) / (one(x) + exp(-x)))

σ(x::Float32) = σ_stable(x)

@init @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
σ(x::ForwardDiff.Dual{T,Float32}) where T = σ_stable(x)
end


"""
logσ(x)
Expand All @@ -31,13 +32,13 @@ Return `log(σ(x))` which is computed in a numerically stable way.
-0.0
"""
function logσ(x)
max_v = max(zero(x), -x)
z = exp(-max_v) + exp(-x-max_v)
-(max_v + log(z))
max_v = max(zero(x), -x)
z = exp(-max_v) + exp(-x-max_v)
return -(max_v + log(z))
end

const logsigmoid = logσ


"""
relu(x) = max(0, x)
Expand All @@ -56,6 +57,7 @@ You can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`.
"""
leakyrelu(x, a = oftype(x/1, 0.01)) = max(a*x, x/1)


"""
elu(x, α = 1) =
x > 0 ? x : α * (exp(x) - 1)
Expand All @@ -66,6 +68,7 @@ You can also specify the coefficient explicitly, e.g. `elu(x, 1)`.
"""
elu(x, α = one(x)) = ifelse(x 0, x/1, α * (exp(x) - one(x)))


"""
gelu(x) = 0.5x*(1 + tanh(√(2/π)*(x + 0.044715x^3)))
Expand Down Expand Up @@ -103,6 +106,7 @@ function selu(x)
λ * ifelse(x > 0, x/1, α * (exp(x) - 1))
end


"""
softsign(x) = x / (1 + |x|)
Expand Down
Loading

0 comments on commit 60b2b92

Please sign in to comment.