Skip to content

Commit

Permalink
Add ForwardDiff Extension: Dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 22, 2023
1 parent 6650103 commit 3639971
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 2 deletions.
11 changes: 10 additions & 1 deletion lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.8"
version = "0.1.9"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[extensions]
LuxLibForwardDiffExt = "ForwardDiff"

[compat]
CUDA = "3, 4"
CUDAKernels = "0.3, 0.4"
ChainRulesCore = "1"
ForwardDiff = "0.10"
KernelAbstractions = "0.7, 0.8"
NNlib = "0.8"
NNlibCUDA = "0.2"
Expand Down
10 changes: 10 additions & 0 deletions lib/LuxLib/ext/LuxLibForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module LuxLibForwardDiffExt

isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff)
using LuxLib

function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual})
return ForwardDiff.valtype(eltype(x))
end

end
12 changes: 12 additions & 0 deletions lib/LuxLib/src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@ using ChainRulesCore, CUDA, CUDAKernels, KernelAbstractions, Markdown, NNlib, NN
Random, Statistics
import ChainRulesCore as CRC

# Extensions
if !isdefined(Base, :get_extension)
using Requires
end

function __init__()
@static if !isdefined(Base, :get_extension)
# Handling ForwardDiff
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin include("../ext/LuxLibForwardDiffExt.jl") end
end
end

include("utils.jl")

include("deprecated.jl")
Expand Down
4 changes: 3 additions & 1 deletion lib/LuxLib/src/api/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,10 @@ end

@inline _dropout_kernel(y, p, invp) = y > p ? invp : oftype(y, 0)

@inline _dropout_fptype(x) = float(real(eltype(x)))

@inline function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims)
realfptype = float(real(eltype(x)))
realfptype = _dropout_fptype(x)
y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims)))
y .= _dropout_kernel.(y, p, invp)
return y
Expand Down
2 changes: 2 additions & 0 deletions lib/LuxLib/test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand All @@ -11,6 +12,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
CUDA = "3, 4"
FiniteDifferences = "0.12"
ForwardDiff = "0.10"
JET = "0.4, 0.5, 0.6, 0.7"
SafeTestsets = "0.0.1"
Zygote = "0.6"
Expand Down
13 changes: 13 additions & 0 deletions lib/LuxLib/test/ext/LuxLibForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using LuxLib, ForwardDiff, Random, Test

rng = MersenneTwister(0)

x = randn(rng, Float32, 10, 2)
x_dual = ForwardDiff.Dual.(x)

@test_nowarn dropout(rng, x_dual, 0.5f0, Val(true); dims=:)

x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1]
x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1])

@test isapprox(x_dropout, x_dual_dropout)
2 changes: 2 additions & 0 deletions lib/LuxLib/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ using SafeTestsets, Test
@time @safetestset "GroupNorm" begin include("api/groupnorm.jl") end
@time @safetestset "InstanceNorm" begin include("api/instancenorm.jl") end
@time @safetestset "LayerNorm" begin include("api/layernorm.jl") end

@time @safetestset "ForwardDiff Extension" begin include("ext/LuxLibForwardDiffExt.jl") end
end

0 comments on commit 3639971

Please sign in to comment.