Skip to content

Commit

Permalink
Add SpecialFunctions extension (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored May 22, 2024
1 parent 2ea9b06 commit 57def1e
Show file tree
Hide file tree
Showing 13 changed files with 493 additions and 302 deletions.
10 changes: 10 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,21 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions"

[compat]
ADTypes = "1"
Compat = "3,4"
DocStringExtensions = "0.9"
Requires = "1.3"
SparseArrays = "1"
SpecialFunctions = "2.4"
julia = "1.6"
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
module SparseConnectivityTracerSpecialFunctionsExt

if isdefined(Base, :get_extension)
import SparseConnectivityTracer as SCT
using SpecialFunctions
else
import ..SparseConnectivityTracer as SCT
using ..SpecialFunctions
end

#=
Complex functions are ignored.
Functions with more than 2 arguments are ignored.
Functions with integer arguments are ignored.
=#

## 1-to-1

ops_1_to_1_s = (
# Gamma Function
gamma,
loggamma,
digamma,
invdigamma,
trigamma,
# Exponential and Trigonometric Integrals
expinti,
sinint,
cosint,
# Error functions, Dawson's and Fresnel Integrals
erf,
erfc,
erfcinv,
erfcx,
logerfc,
erfinv,
# Airy and Related Functions
airyai,
airyaiprime,
airybi,
airybiprime,
airyaix,
airyaiprimex,
airybix,
airybiprimex,
# Bessel Functions
besselj0,
besselj1,
bessely0,
bessely1,
jinc,
# Elliptic Integrals
ellipk,
ellipe,
)

for op in ops_1_to_1_s
T = typeof(op)
@eval SCT.is_influence_zero_global(::$T) = false
@eval SCT.is_firstder_zero_global(::$T) = false
@eval SCT.is_seconder_zero_global(::$T) = false
end

ops_1_to_1 = ops_1_to_1_s

## 2-to-1

ops_2_to_1_ssc = (
# Gamma Function
gamma,
loggamma,
beta,
logbeta,
# Exponential and Trigonometric Integrals
expint,
expintx,
# Error functions, Dawson's and Fresnel Integrals
erf,
# Bessel Functions
besselj,
besseljx,
sphericalbesselj,
bessely,
besselyx,
sphericalbessely,
besseli,
besselix,
besselk,
besselkx,
)

for op in ops_2_to_1_ssc
T = typeof(op)
@eval SCT.is_influence_arg1_zero_global(::$T) = false
@eval SCT.is_influence_arg2_zero_global(::$T) = false
@eval SCT.is_firstder_arg1_zero_global(::$T) = false
@eval SCT.is_seconder_arg1_zero_global(::$T) = false
@eval SCT.is_firstder_arg2_zero_global(::$T) = false
@eval SCT.is_seconder_arg2_zero_global(::$T) = false
@eval SCT.is_crossder_zero_global(::$T) = false
end

ops_2_to_1 = ops_2_to_1_ssc

## Lists

SCT.list_operators_1_to_1(::Val{:SpecialFunctions}) = ops_1_to_1
SCT.list_operators_2_to_1(::Val{:SpecialFunctions}) = ops_2_to_1
SCT.list_operators_1_to_2(::Val{:SpecialFunctions}) = ()

## Overloads

eval(SCT.overload_all(:SpecialFunctions))

end
17 changes: 17 additions & 0 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
module SparseConnectivityTracer

const SCT = SparseConnectivityTracer

using ADTypes: ADTypes
using Compat: Returns
import SparseArrays: sparse
import Random: rand, AbstractRNG, SamplerType

using DocStringExtensions

if !isdefined(Base, :get_extension)
using Requires
end

include("settypes/duplicatevector.jl")
include("settypes/recursiveset.jl")
include("settypes/sortedvector.jl")
Expand All @@ -15,10 +21,13 @@ include("tracers.jl")
include("exceptions.jl")
include("conversion.jl")
include("operators.jl")

include("overload_connectivity.jl")
include("overload_gradient.jl")
include("overload_hessian.jl")
include("overload_dual.jl")
include("overload_all.jl")

include("pattern.jl")
include("adtypes.jl")

Expand All @@ -30,4 +39,12 @@ export hessian_pattern, local_hessian_pattern
export TracerSparsityDetector
export TracerLocalSparsityDetector

function __init__()
@static if !isdefined(Base, :get_extension)
@require SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" include(
"../ext/SparseConnectivityTracerSpecialFunctionsExt/SparseConnectivityTracerSpecialFunctionsExt.jl",
)
end
end

end # module
4 changes: 4 additions & 0 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -641,3 +641,7 @@ ops_1_to_2 = union(
ops_1_to_2_zz,
)
#! format: on

list_operators_1_to_1(::Val{:Base}) = ops_1_to_1
list_operators_2_to_1(::Val{:Base}) = ops_2_to_1
list_operators_1_to_2(::Val{:Base}) = ops_1_to_2
30 changes: 30 additions & 0 deletions src/overload_all.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
function overload_all(M)
exprs_1_to_1 = [
quote
$(overload_connectivity_1_to_1(M, op))
$(overload_gradient_1_to_1(M, op))
$(overload_hessian_1_to_1(M, op))
end for op in nameof.(list_operators_1_to_1(Val(M)))
]
exprs_2_to_1 = [
quote
$(overload_connectivity_2_to_1(M, op))
$(overload_gradient_2_to_1(M, op))
$(overload_hessian_2_to_1(M, op))
end for op in nameof.(list_operators_2_to_1(Val(M)))
]
exprs_1_to_2 = [
quote
$(overload_connectivity_1_to_2(M, op))
$(overload_gradient_1_to_2(M, op))
$(overload_hessian_1_to_2(M, op))
end for op in nameof.(list_operators_1_to_2(Val(M)))
]
return quote
$(exprs_1_to_1...)
$(exprs_2_to_1...)
$(exprs_1_to_2...)
end
end

eval(overload_all(:Base))
Loading

0 comments on commit 57def1e

Please sign in to comment.