Skip to content

Commit

Permalink
Add ForwardDiff package extension (#200)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill authored Oct 1, 2024
1 parent f70f532 commit bb7d365
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 12 deletions.
12 changes: 8 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,30 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
SparseConnectivityTracerDataInterpolationsExt = "DataInterpolations"
SparseConnectivityTracerForwardDiffExt = "ForwardDiff"
SparseConnectivityTracerLogExpFunctionsExt = "LogExpFunctions"
SparseConnectivityTracerNaNMathExt = "NaNMath"
SparseConnectivityTracerNNlibExt = "NNlib"
SparseConnectivityTracerNaNMathExt = "NaNMath"
SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions"

[compat]
ADTypes = "1"
DataInterpolations = "6.2"
DocStringExtensions = "0.9"
FillArrays = "1"
ForwardDiff = "0.10"
LinearAlgebra = "<0.0.1, 1"
LogExpFunctions = "0.3.28"
NaNMath = "1"
NNlib = "0.8, 0.9"
NaNMath = "1"
Random = "<0.0.1, 1"
Requires = "1.3"
SparseArrays = "<0.0.1, 1"
Expand All @@ -43,7 +46,8 @@ julia = "1.6"

[extras]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
14 changes: 14 additions & 0 deletions ext/SparseConnectivityTracerForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module SparseConnectivityTracerForwardDiffExt

if isdefined(Base, :get_extension)
import SparseConnectivityTracer as SCT
using ForwardDiff: ForwardDiff
else
import ..SparseConnectivityTracer as SCT
using ..ForwardDiff: ForwardDiff
end

# Overload 2-to-1 functions on ForwardDiff.Dual
eval(SCT.generate_code_2_to_1_typed(:Base, SCT.ops_2_to_1, ForwardDiff.Dual))

end # module
13 changes: 8 additions & 5 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,21 @@ export jacobian_sparsity, hessian_sparsity

function __init__()
@static if !isdefined(Base, :get_extension)
@require SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" include(
"../ext/SparseConnectivityTracerSpecialFunctionsExt.jl"
)
@require NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" include(
"../ext/SparseConnectivityTracerNNlibExt.jl"
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(
"../ext/SparseConnectivityTracerForwardDiffExt.jl"
)
@require LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" include(
"../ext/SparseConnectivityTracerLogExpFunctionsExt.jl"
)
@require NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" include(
"../ext/SparseConnectivityTracerNaNMathExt.jl"
)
@require NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" include(
"../ext/SparseConnectivityTracerNNlibExt.jl"
)
@require SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" include(
"../ext/SparseConnectivityTracerSpecialFunctionsExt.jl"
)
# NOTE: SparseConnectivityTracerDataInterpolationsExt is not loaded on Julia <1.10
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/overloads/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ for d in dims
g = Symbol("generate_code_gradient_", d)
h = Symbol("generate_code_hessian_", d)

@eval function $f(M::Symbol, f)
@eval function $f(M::Symbol, f::Function)
expr_g = $g(M, f)
expr_h = $h(M, f)
return Expr(:block, expr_g, expr_h)
Expand All @@ -28,7 +28,7 @@ for d in dims
end

# Overloads of 2-argument functions on arbitrary types
function generate_code_2_to_1_typed(M::Symbol, f, Z::Type)
function generate_code_2_to_1_typed(M::Symbol, f::Function, Z::Type)
expr_g = generate_code_gradient_2_to_1_typed(M, f, Z)
expr_h = generate_code_hessian_2_to_1_typed(M, f, Z)
return Expr(:block, expr_g, expr_h)
Expand Down
20 changes: 20 additions & 0 deletions test/ext/test_ForwardDiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using SparseConnectivityTracer
using ForwardDiff: ForwardDiff

using Test

d = ForwardDiff.Dual{ForwardDiff.Tag{*,Float64}}(1.2, 3.4)
@testset "$D" for D in (TracerSparsityDetector, TracerLocalSparsityDetector)
detector = D()
# Testing on multiplication ensures that methods from Base have been overloaded,
# Since this would otherwise throw an ambiguity error:
# https://github.com/adrhill/SparseConnectivityTracer.jl/issues/196
@testset "Jacobian" begin
@test jacobian_sparsity(x -> x * d, 1.0, detector) [1;;]
@test jacobian_sparsity(x -> d * x, 1.0, detector) [1;;]
end
@testset "Hessian" begin
@test hessian_sparsity(x -> x * d, 1.0, detector) [0;;]
@test hessian_sparsity(x -> d * x, 1.0, detector) [0;;]
end
end
1 change: 1 addition & 0 deletions test/linting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using JET: JET
using ExplicitImports: ExplicitImports

# Load package extensions so they get tested by ExplicitImports.jl
using ForwardDiff: ForwardDiff
using DataInterpolations: DataInterpolations
using NaNMath: NaNMath
using NNlib: NNlib
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core")
if GROUP in ("Core", "All")
@info "Testing package extensions..."
@testset verbose = true "Package extensions" begin
for ext in (:NNlib, :SpecialFunctions, :LogExpFunctions, :NaNMath)
for ext in (:ForwardDiff, :LogExpFunctions, :NaNMath, :NNlib, :SpecialFunctions)
@testset "$ext" begin
@info "...$ext"
include("ext/test_$ext.jl")
Expand Down

0 comments on commit bb7d365

Please sign in to comment.