diff --git a/Project.toml b/Project.toml index 2bd9c0c4..c256935f 100644 --- a/Project.toml +++ b/Project.toml @@ -8,11 +8,21 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +[weakdeps] +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" + +[extensions] +SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions" + [compat] ADTypes = "1" Compat = "3,4" DocStringExtensions = "0.9" +Requires = "1.3" SparseArrays = "1" +SpecialFunctions = "2.4" julia = "1.6" diff --git a/ext/SparseConnectivityTracerSpecialFunctionsExt/SparseConnectivityTracerSpecialFunctionsExt.jl b/ext/SparseConnectivityTracerSpecialFunctionsExt/SparseConnectivityTracerSpecialFunctionsExt.jl new file mode 100644 index 00000000..b220b013 --- /dev/null +++ b/ext/SparseConnectivityTracerSpecialFunctionsExt/SparseConnectivityTracerSpecialFunctionsExt.jl @@ -0,0 +1,115 @@ +module SparseConnectivityTracerSpecialFunctionsExt + +if isdefined(Base, :get_extension) + import SparseConnectivityTracer as SCT + using SpecialFunctions +else + import ..SparseConnectivityTracer as SCT + using ..SpecialFunctions +end + +#= +Complex functions are ignored. +Functions with more than 2 arguments are ignored. +Functions with integer arguments are ignored. +=# + +## 1-to-1 + +ops_1_to_1_s = ( + # Gamma Function + gamma, + loggamma, + digamma, + invdigamma, + trigamma, + # Exponential and Trigonometric Integrals + expinti, + sinint, + cosint, + # Error functions, Dawson's and Fresnel Integrals + erf, + erfc, + erfcinv, + erfcx, + logerfc, + erfinv, + # Airy and Related Functions + airyai, + airyaiprime, + airybi, + airybiprime, + airyaix, + airyaiprimex, + airybix, + airybiprimex, + # Bessel Functions + besselj0, + besselj1, + bessely0, + bessely1, + jinc, + # Elliptic Integrals + ellipk, + ellipe, +) + +for op in ops_1_to_1_s + T = typeof(op) + @eval SCT.is_influence_zero_global(::$T) = false + @eval SCT.is_firstder_zero_global(::$T) = false + @eval SCT.is_seconder_zero_global(::$T) = false +end + +ops_1_to_1 = ops_1_to_1_s + +## 2-to-1 + +ops_2_to_1_ssc = ( + # Gamma Function + gamma, + loggamma, + beta, + logbeta, + # Exponential and Trigonometric Integrals + expint, + expintx, + # Error functions, Dawson's and Fresnel Integrals + erf, + # Bessel Functions + besselj, + besseljx, + sphericalbesselj, + bessely, + besselyx, + sphericalbessely, + besseli, + besselix, + besselk, + besselkx, +) + +for op in ops_2_to_1_ssc + T = typeof(op) + @eval SCT.is_influence_arg1_zero_global(::$T) = false + @eval SCT.is_influence_arg2_zero_global(::$T) = false + @eval SCT.is_firstder_arg1_zero_global(::$T) = false + @eval SCT.is_seconder_arg1_zero_global(::$T) = false + @eval SCT.is_firstder_arg2_zero_global(::$T) = false + @eval SCT.is_seconder_arg2_zero_global(::$T) = false + @eval SCT.is_crossder_zero_global(::$T) = false +end + +ops_2_to_1 = ops_2_to_1_ssc + +## Lists + +SCT.list_operators_1_to_1(::Val{:SpecialFunctions}) = ops_1_to_1 +SCT.list_operators_2_to_1(::Val{:SpecialFunctions}) = ops_2_to_1 +SCT.list_operators_1_to_2(::Val{:SpecialFunctions}) = () + +## Overloads + +eval(SCT.overload_all(:SpecialFunctions)) + +end diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index eef8a5a7..19fa9d8a 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -1,5 +1,7 @@ module SparseConnectivityTracer +const SCT = SparseConnectivityTracer + using ADTypes: ADTypes using Compat: Returns import SparseArrays: sparse @@ -7,6 +9,10 @@ import Random: rand, AbstractRNG, SamplerType using DocStringExtensions +if !isdefined(Base, :get_extension) + using Requires +end + include("settypes/duplicatevector.jl") include("settypes/recursiveset.jl") include("settypes/sortedvector.jl") @@ -15,10 +21,13 @@ include("tracers.jl") include("exceptions.jl") include("conversion.jl") include("operators.jl") + include("overload_connectivity.jl") include("overload_gradient.jl") include("overload_hessian.jl") include("overload_dual.jl") +include("overload_all.jl") + include("pattern.jl") include("adtypes.jl") @@ -30,4 +39,12 @@ export hessian_pattern, local_hessian_pattern export TracerSparsityDetector export TracerLocalSparsityDetector +function __init__() + @static if !isdefined(Base, :get_extension) + @require SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" include( + "../ext/SparseConnectivityTracerSpecialFunctionsExt/SparseConnectivityTracerSpecialFunctionsExt.jl", + ) + end +end + end # module diff --git a/src/operators.jl b/src/operators.jl index 50b2e56c..33c950d7 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -641,3 +641,7 @@ ops_1_to_2 = union( ops_1_to_2_zz, ) #! format: on + +list_operators_1_to_1(::Val{:Base}) = ops_1_to_1 +list_operators_2_to_1(::Val{:Base}) = ops_2_to_1 +list_operators_1_to_2(::Val{:Base}) = ops_1_to_2 diff --git a/src/overload_all.jl b/src/overload_all.jl new file mode 100644 index 00000000..545dc358 --- /dev/null +++ b/src/overload_all.jl @@ -0,0 +1,30 @@ +function overload_all(M) + exprs_1_to_1 = [ + quote + $(overload_connectivity_1_to_1(M, op)) + $(overload_gradient_1_to_1(M, op)) + $(overload_hessian_1_to_1(M, op)) + end for op in nameof.(list_operators_1_to_1(Val(M))) + ] + exprs_2_to_1 = [ + quote + $(overload_connectivity_2_to_1(M, op)) + $(overload_gradient_2_to_1(M, op)) + $(overload_hessian_2_to_1(M, op)) + end for op in nameof.(list_operators_2_to_1(Val(M))) + ] + exprs_1_to_2 = [ + quote + $(overload_connectivity_1_to_2(M, op)) + $(overload_gradient_1_to_2(M, op)) + $(overload_hessian_1_to_2(M, op)) + end for op in nameof.(list_operators_1_to_2(Val(M))) + ] + return quote + $(exprs_1_to_1...) + $(exprs_2_to_1...) + $(exprs_1_to_2...) + end +end + +eval(overload_all(:Base)) diff --git a/src/overload_connectivity.jl b/src/overload_connectivity.jl index f902cca9..289a248a 100644 --- a/src/overload_connectivity.jl +++ b/src/overload_connectivity.jl @@ -10,16 +10,19 @@ function connectivity_tracer_1_to_1( end end -function overload_connectivity_1_to_1(m::Module, fn::Function) - ms, fns = nameof(m), nameof(fn) - @eval function $ms.$fns(t::T) where {T<:ConnectivityTracer} - return connectivity_tracer_1_to_1(t, is_influence_zero_global($ms.$fns)) - end - @eval function $ms.$fns(d::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} - x = primal(d) - p_out = $ms.$fns(x) - t_out = connectivity_tracer_1_to_1(tracer(d), is_influence_zero_local($ms.$fns, x)) - return Dual(p_out, t_out) +function overload_connectivity_1_to_1(M, op) + return quote + function $M.$op(t::T) where {T<:$SCT.ConnectivityTracer} + return $SCT.connectivity_tracer_1_to_1(t, $SCT.is_influence_zero_global($M.$op)) + end + function $M.$op(d::D) where {P,T<:$SCT.ConnectivityTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + p_out = $M.$op(x) + t_out = $SCT.connectivity_tracer_1_to_1( + $SCT.tracer(d), $SCT.is_influence_zero_local($M.$op, x) + ) + return $SCT.Dual(p_out, t_out) + end end end @@ -43,51 +46,60 @@ function connectivity_tracer_2_to_1( end end -function overload_connectivity_2_to_1(m::Module, fn::Function) - ms, fns = nameof(m), nameof(fn) - @eval function $ms.$fns(tx::T, ty::T) where {T<:ConnectivityTracer} - return connectivity_tracer_2_to_1( - tx, - ty, - is_influence_arg1_zero_global($ms.$fns), - is_influence_arg2_zero_global($ms.$fns), - ) - end - @eval function $ms.$fns(dx::D, dy::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} - x = primal(dx) - y = primal(dy) - p_out = $ms.$fns(x, y) - t_out = connectivity_tracer_2_to_1( - tracer(dx), - tracer(dy), - is_influence_arg1_zero_local($ms.$fns, x, y), - is_influence_arg2_zero_local($ms.$fns, x, y), - ) - return Dual(p_out, t_out) - end +function overload_connectivity_2_to_1(M, op) + return quote + function $M.$op(tx::T, ty::T) where {T<:$SCT.ConnectivityTracer} + return $SCT.connectivity_tracer_2_to_1( + tx, + ty, + $SCT.is_influence_arg1_zero_global($M.$op), + $SCT.is_influence_arg2_zero_global($M.$op), + ) + end + function $M.$op(dx::D, dy::D) where {P,T<:$SCT.ConnectivityTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + y = $SCT.primal(dy) + p_out = $M.$op(x, y) + t_out = $SCT.connectivity_tracer_2_to_1( + $SCT.tracer(dx), + $SCT.tracer(dy), + $SCT.is_influence_arg1_zero_local($M.$op, x, y), + $SCT.is_influence_arg2_zero_local($M.$op, x, y), + ) + return $SCT.Dual(p_out, t_out) + end - @eval function $ms.$fns(tx::ConnectivityTracer, ::Number) - return connectivity_tracer_1_to_1(tx, is_influence_arg1_zero_global($fns)) - end - @eval function $ms.$fns(dx::D, y::Number) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} - x = primal(dx) - p_out = $ms.$fns(x, y) - t_out = connectivity_tracer_1_to_1( - tracer(dx), is_influence_arg1_zero_local($ms.$fns, x, y) - ) - return Dual(p_out, t_out) - end + function $M.$op(tx::$SCT.ConnectivityTracer, ::Number) + return $SCT.connectivity_tracer_1_to_1( + tx, $SCT.is_influence_arg1_zero_global($M.$op) + ) + end + function $M.$op( + dx::D, y::Number + ) where {P,T<:$SCT.ConnectivityTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + p_out = $M.$op(x, y) + t_out = $SCT.connectivity_tracer_1_to_1( + $SCT.tracer(dx), $SCT.is_influence_arg1_zero_local($M.$op, x, y) + ) + return $SCT.Dual(p_out, t_out) + end - @eval function $ms.$fns(::Number, ty::ConnectivityTracer) - return connectivity_tracer_1_to_1(ty, is_influence_arg2_zero_global($fns)) - end - @eval function $ms.$fns(x::Number, dy::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} - y = primal(dy) - p_out = $ms.$fns(x, y) - t_out = connectivity_tracer_1_to_1( - tracer(dy), is_influence_arg2_zero_local($ms.$fns, x, y) - ) - return Dual(p_out, t_out) + function $M.$op(::Number, ty::$SCT.ConnectivityTracer) + return $SCT.connectivity_tracer_1_to_1( + ty, $SCT.is_influence_arg2_zero_global($M.$op) + ) + end + function $M.$op( + x::Number, dy::D + ) where {P,T<:$SCT.ConnectivityTracer,D<:$SCT.Dual{P,T}} + y = $SCT.primal(dy) + p_out = $M.$op(x, y) + t_out = $SCT.connectivity_tracer_1_to_1( + $SCT.tracer(dy), $SCT.is_influence_arg2_zero_local($M.$op, x, y) + ) + return $SCT.Dual(p_out, t_out) + end end end @@ -101,42 +113,29 @@ function connectivity_tracer_1_to_2( return (t1, t2) end -function overload_connectivity_1_to_2(m::Module, fn::Function) - ms, fns = nameof(m), nameof(fn) - @eval function $ms.$fns(t::ConnectivityTracer) - return connectivity_tracer_1_to_2( - t, - is_influence_out1_zero_global($ms.$fns), - is_influence_out2_zero_global($ms.$fns), - ) - end +function overload_connectivity_1_to_2(M, op) + return quote + function $M.$op(t::$SCT.ConnectivityTracer) + return $SCT.connectivity_tracer_1_to_2( + t, + $SCT.is_influence_out1_zero_global($M.$op), + $SCT.is_influence_out2_zero_global($M.$op), + ) + end - @eval function $ms.$fns(d::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} - x = primal(d) - p1_out, p2_out = $ms.$fns(x) - t1_out, t2_out = connectivity_tracer_1_to_2( - t, - is_influence_out1_zero_local($ms.$fns, x), - is_influence_out2_zero_local($ms.$fns, x), - ) - return (Dual(p1_out, t1_out), Dual(p2_out, t2_out)) + function $M.$op(d::D) where {P,T<:$SCT.ConnectivityTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + p1_out, p2_out = $M.$op(x) + t1_out, t2_out = $SCT.connectivity_tracer_1_to_2( + t, + $SCT.is_influence_out1_zero_local($M.$op, x), + $SCT.is_influence_out2_zero_local($M.$op, x), + ) + return ($SCT.Dual(p1_out, t1_out), $SCT.Dual(p2_out, t2_out)) + end end end -## Actual overloads - -for op in ops_1_to_1 - overload_connectivity_1_to_1(Base, op) -end - -for op in ops_2_to_1 - overload_connectivity_2_to_1(Base, op) -end - -for op in ops_1_to_2 - overload_connectivity_1_to_2(Base, op) -end - ## Special cases ## Exponent (requires extra types) diff --git a/src/overload_gradient.jl b/src/overload_gradient.jl index ddfb107a..697b5aab 100644 --- a/src/overload_gradient.jl +++ b/src/overload_gradient.jl @@ -8,16 +8,19 @@ function gradient_tracer_1_to_1(t::T, is_firstder_zero::Bool) where {T<:Gradient end end -function overload_gradient_1_to_1(m::Module, fn::Function) - ms, fns = nameof(m), nameof(fn) - @eval function $ms.$fns(t::GradientTracer) - return gradient_tracer_1_to_1(t, is_firstder_zero_global($ms.$fns)) - end - @eval function $ms.$fns(d::D) where {P,T<:GradientTracer,D<:Dual{P,T}} - x = primal(d) - p_out = $ms.$fns(x) - t_out = gradient_tracer_1_to_1(tracer(d), is_firstder_zero_local($fns, x)) - return Dual(p_out, t_out) +function overload_gradient_1_to_1(M, op) + return quote + function $M.$op(t::$SCT.GradientTracer) + return $SCT.gradient_tracer_1_to_1(t, $SCT.is_firstder_zero_global($M.$op)) + end + function $M.$op(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + p_out = $M.$op(x) + t_out = $SCT.gradient_tracer_1_to_1( + $SCT.tracer(d), $SCT.is_firstder_zero_local($op, x) + ) + return $SCT.Dual(p_out, t_out) + end end end @@ -41,51 +44,56 @@ function gradient_tracer_2_to_1( end end -function overload_gradient_2_to_1(m::Module, fn::Function) - ms, fns = nameof(m), nameof(fn) - @eval function $ms.$fns(tx::T, ty::T) where {T<:GradientTracer} - return gradient_tracer_2_to_1( - tx, - ty, - is_firstder_arg1_zero_global($ms.$fns), - is_firstder_arg2_zero_global($ms.$fns), - ) - end - @eval function $ms.$fns(dx::D, dy::D) where {P,T<:GradientTracer,D<:Dual{P,T}} - x = primal(dx) - y = primal(dy) - p_out = $ms.$fns(x, y) - t_out = gradient_tracer_2_to_1( - tracer(dx), - tracer(dy), - is_firstder_arg1_zero_local($ms.$fns, x, y), - is_firstder_arg2_zero_local($ms.$fns, x, y), - ) - return Dual(p_out, t_out) - end +function overload_gradient_2_to_1(M, op) + return quote + function $M.$op(tx::T, ty::T) where {T<:$SCT.GradientTracer} + return $SCT.gradient_tracer_2_to_1( + tx, + ty, + $SCT.is_firstder_arg1_zero_global($M.$op), + $SCT.is_firstder_arg2_zero_global($M.$op), + ) + end + function $M.$op(dx::D, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + y = $SCT.primal(dy) + p_out = $M.$op(x, y) + t_out = $SCT.gradient_tracer_2_to_1( + $SCT.tracer(dx), + $SCT.tracer(dy), + $SCT.is_firstder_arg1_zero_local($M.$op, x, y), + $SCT.is_firstder_arg2_zero_local($M.$op, x, y), + ) + return $SCT.Dual(p_out, t_out) + end - @eval function $ms.$fns(tx::GradientTracer, ::Number) - return gradient_tracer_1_to_1(tx, is_firstder_arg1_zero_global($ms.$fns)) - end - @eval function $ms.$fns(dx::D, y::Number) where {P,T<:GradientTracer,D<:Dual{P,T}} - x = primal(dx) - p_out = $ms.$fns(x, y) - t_out = gradient_tracer_1_to_1( - tracer(dx), is_firstder_arg1_zero_local($ms.$fns, x, y) - ) - return Dual(p_out, t_out) - end + function $M.$op(tx::$SCT.GradientTracer, ::Number) + return $SCT.gradient_tracer_1_to_1( + tx, $SCT.is_firstder_arg1_zero_global($M.$op) + ) + end + function $M.$op(dx::D, y::Number) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + p_out = $M.$op(x, y) + t_out = $SCT.gradient_tracer_1_to_1( + $SCT.tracer(dx), $SCT.is_firstder_arg1_zero_local($M.$op, x, y) + ) + return $SCT.Dual(p_out, t_out) + end - @eval function $ms.$fns(::Number, ty::GradientTracer) - return gradient_tracer_1_to_1(ty, is_firstder_arg2_zero_global($ms.$fns)) - end - @eval function $ms.$fns(x::Number, dy::D) where {P,T<:GradientTracer,D<:Dual{P,T}} - y = primal(dy) - p_out = $ms.$fns(x, y) - t_out = gradient_tracer_1_to_1( - tracer(dy), is_firstder_arg2_zero_local($ms.$fns, x, y) - ) - return Dual(p_out, t_out) + function $M.$op(::Number, ty::$SCT.GradientTracer) + return $SCT.gradient_tracer_1_to_1( + ty, $SCT.is_firstder_arg2_zero_global($M.$op) + ) + end + function $M.$op(x::Number, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + y = $SCT.primal(dy) + p_out = $M.$op(x, y) + t_out = $SCT.gradient_tracer_1_to_1( + $SCT.tracer(dy), $SCT.is_firstder_arg2_zero_local($M.$op, x, y) + ) + return $SCT.Dual(p_out, t_out) + end end end @@ -99,42 +107,29 @@ function gradient_tracer_1_to_2( return (t1, t2) end -function overload_gradient_1_to_2(m::Module, fn::Function) - ms, fns = nameof(m), nameof(fn) - @eval function $ms.$fns(t::GradientTracer) - return gradient_tracer_1_to_2( - t, - is_firstder_out1_zero_global($ms.$fns), - is_firstder_out2_zero_global($ms.$fns), - ) - end +function overload_gradient_1_to_2(M, op) + return quote + function $M.$op(t::$SCT.GradientTracer) + return $SCT.gradient_tracer_1_to_2( + t, + $SCT.is_firstder_out1_zero_global($M.$op), + $SCT.is_firstder_out2_zero_global($M.$op), + ) + end - @eval function $ms.$fns(d::D) where {P,T<:GradientTracer,D<:Dual{P,T}} - x = primal(d) - p1_out, p2_out = $ms.$fns(x) - t1_out, t2_out = gradient_tracer_1_to_2( - tracer(d), - is_firstder_out1_zero_local($ms.$fns, x), - is_firstder_out2_zero_local($ms.$fns, x), - ) - return (Dual(p1_out, t1_out), Dual(p2_out, t2_out)) # TODO: this was wrong, add test + function $M.$op(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + p1_out, p2_out = $M.$op(x) + t1_out, t2_out = $SCT.gradient_tracer_1_to_2( + $SCT.tracer(d), + $SCT.is_firstder_out1_zero_local($M.$op, x), + $SCT.is_firstder_out2_zero_local($M.$op, x), + ) + return ($SCT.Dual(p1_out, t1_out), $SCT.Dual(p2_out, t2_out)) # TODO: this was wrong, add test + end end end -## Actual overloads - -for op in ops_1_to_1 - overload_gradient_1_to_1(Base, op) -end - -for op in ops_2_to_1 - overload_gradient_2_to_1(Base, op) -end - -for op in ops_1_to_2 - overload_gradient_1_to_2(Base, op) -end - ## Special cases ## Exponent (requires extra types) diff --git a/src/overload_hessian.jl b/src/overload_hessian.jl index 9e985dec..7270f532 100644 --- a/src/overload_hessian.jl +++ b/src/overload_hessian.jl @@ -18,22 +18,25 @@ function hessian_tracer_1_to_1( end end -function overload_hessian_1_to_1(m::Module, fn::Function) - ms, fns = nameof(m), nameof(fn) - @eval function $ms.$fns(t::HessianTracer) - return hessian_tracer_1_to_1( - t, is_firstder_zero_global($ms.$fns), is_seconder_zero_global($ms.$fns) - ) - end - @eval function $ms.$fns(d::D) where {P,T<:HessianTracer,D<:Dual{P,T}} - x = primal(d) - p_out = $ms.$fns(x) - t_out = hessian_tracer_1_to_1( - tracer(d), - is_firstder_zero_local($ms.$fns, x), - is_seconder_zero_local($ms.$fns, x), - ) - return Dual(p_out, t_out) +function overload_hessian_1_to_1(M, op) + return quote + function $M.$op(t::$SCT.HessianTracer) + return $SCT.hessian_tracer_1_to_1( + t, + $SCT.is_firstder_zero_global($M.$op), + $SCT.is_seconder_zero_global($M.$op), + ) + end + function $M.$op(d::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + p_out = $M.$op(x) + t_out = $SCT.hessian_tracer_1_to_1( + $SCT.tracer(d), + $SCT.is_firstder_zero_local($M.$op, x), + $SCT.is_seconder_zero_local($M.$op, x), + ) + return $SCT.Dual(p_out, t_out) + end end end @@ -70,69 +73,70 @@ function hessian_tracer_2_to_1( return T(grad, hess) end -function overload_hessian_2_to_1(m::Module, fn::Function) - ms, fns = nameof(m), nameof(fn) - @eval function $ms.$fns(tx::T, ty::T) where {T<:HessianTracer} - return hessian_tracer_2_to_1( - tx, - ty, - is_firstder_arg1_zero_global($ms.$fns), - is_seconder_arg1_zero_global($ms.$fns), - is_firstder_arg2_zero_global($ms.$fns), - is_seconder_arg2_zero_global($ms.$fns), - is_crossder_zero_global($ms.$fns), - ) - end - @eval function $ms.$fns(dx::D, dy::D) where {P,T<:HessianTracer,D<:Dual{P,T}} - x = primal(dx) - y = primal(dy) - p_out = $ms.$fns(x, y) - t_out = hessian_tracer_2_to_1( - tracer(dx), - tracer(dy), - is_firstder_arg1_zero_local($ms.$fns, x, y), - is_seconder_arg1_zero_local($ms.$fns, x, y), - is_firstder_arg2_zero_local($ms.$fns, x, y), - is_seconder_arg2_zero_local($ms.$fns, x, y), - is_crossder_zero_local($ms.$fns, x, y), - ) - return Dual(p_out, t_out) - end +function overload_hessian_2_to_1(M, op) + return quote + function $M.$op(tx::T, ty::T) where {T<:$SCT.HessianTracer} + return $SCT.hessian_tracer_2_to_1( + tx, + ty, + $SCT.is_firstder_arg1_zero_global($M.$op), + $SCT.is_seconder_arg1_zero_global($M.$op), + $SCT.is_firstder_arg2_zero_global($M.$op), + $SCT.is_seconder_arg2_zero_global($M.$op), + $SCT.is_crossder_zero_global($M.$op), + ) + end + function $M.$op(dx::D, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + y = $SCT.primal(dy) + p_out = $M.$op(x, y) + t_out = $SCT.hessian_tracer_2_to_1( + $SCT.tracer(dx), + $SCT.tracer(dy), + $SCT.is_firstder_arg1_zero_local($M.$op, x, y), + $SCT.is_seconder_arg1_zero_local($M.$op, x, y), + $SCT.is_firstder_arg2_zero_local($M.$op, x, y), + $SCT.is_seconder_arg2_zero_local($M.$op, x, y), + $SCT.is_crossder_zero_local($M.$op, x, y), + ) + return $SCT.Dual(p_out, t_out) + end - @eval function $ms.$fns(tx::HessianTracer, y::Number) - return hessian_tracer_1_to_1( - tx, - is_firstder_arg1_zero_global($ms.$fns), - is_seconder_arg1_zero_global($ms.$fns), - ) - end - @eval function $ms.$fns(x::Number, ty::HessianTracer) - return hessian_tracer_1_to_1( - ty, - is_firstder_arg2_zero_global($ms.$fns), - is_seconder_arg2_zero_global($ms.$fns), - ) - end + function $M.$op(tx::$SCT.HessianTracer, y::Number) + return $SCT.hessian_tracer_1_to_1( + tx, + $SCT.is_firstder_arg1_zero_global($M.$op), + $SCT.is_seconder_arg1_zero_global($M.$op), + ) + end + function $M.$op(x::Number, ty::$SCT.HessianTracer) + return $SCT.hessian_tracer_1_to_1( + ty, + $SCT.is_firstder_arg2_zero_global($M.$op), + $SCT.is_seconder_arg2_zero_global($M.$op), + ) + end - @eval function $ms.$fns(dx::D, y::Number) where {P,T<:HessianTracer,D<:Dual{P,T}} - x = primal(dx) - p_out = $ms.$fns(x, y) - t_out = hessian_tracer_1_to_1( - tracer(dx), - is_firstder_arg1_zero_local($ms.$fns, x, y), - is_seconder_arg1_zero_local($ms.$fns, x, y), - ) - return Dual(p_out, t_out) - end - @eval function $ms.$fns(x::Number, dy::D) where {P,T<:HessianTracer,D<:Dual{P,T}} - y = primal(dy) - p_out = $ms.$fns(x, y) - t_out = hessian_tracer_1_to_1( - tracer(dy), - is_firstder_arg2_zero_local($ms.$fns, x, y), - is_seconder_arg2_zero_local($ms.$fns, x, y), - ) - return Dual(p_out, t_out) + function $M.$op(dx::D, y::Number) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + p_out = $M.$op(x, y) + t_out = $SCT.hessian_tracer_1_to_1( + $SCT.tracer(dx), + $SCT.is_firstder_arg1_zero_local($M.$op, x, y), + $SCT.is_seconder_arg1_zero_local($M.$op, x, y), + ) + return $SCT.Dual(p_out, t_out) + end + function $M.$op(x::Number, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + y = $SCT.primal(dy) + p_out = $M.$op(x, y) + t_out = $SCT.hessian_tracer_1_to_1( + $SCT.tracer(dy), + $SCT.is_firstder_arg2_zero_local($M.$op, x, y), + $SCT.is_seconder_arg2_zero_local($M.$op, x, y), + ) + return $SCT.Dual(p_out, t_out) + end end end @@ -150,46 +154,33 @@ function hessian_tracer_1_to_2( return (t1, t2) end -function overload_hessian_1_to_2(m::Module, fn::Function) - ms, fns = nameof(m), nameof(fn) - @eval function $ms.$fns(t::HessianTracer) - return hessian_tracer_1_to_2( - t, - is_firstder_out1_zero_global($ms.$fns), - is_seconder_out1_zero_global($ms.$fns), - is_firstder_out2_zero_global($ms.$fns), - is_seconder_out2_zero_global($ms.$fns), - ) - end +function overload_hessian_1_to_2(M, op) + return quote + function $M.$op(t::$SCT.HessianTracer) + return $SCT.hessian_tracer_1_to_2( + t, + $SCT.is_firstder_out1_zero_global($M.$op), + $SCT.is_seconder_out1_zero_global($M.$op), + $SCT.is_firstder_out2_zero_global($M.$op), + $SCT.is_seconder_out2_zero_global($M.$op), + ) + end - @eval function $ms.$fns(d::D) where {P,T<:HessianTracer,D<:Dual{P,T}} - x = primal(d) - p1_out, p2_out = $ms.$fns(x) - t1_out, t2_out = hessian_tracer_1_to_2( - d, - is_firstder_out1_zero_local($ms.$fns, x), - is_seconder_out1_zero_local($ms.$fns, x), - is_firstder_out2_zero_local($ms.$fns, x), - is_seconder_out2_zero_local($ms.$fns, x), - ) - return (Dual(p1_out, t1_out), Dual(p2_out, t2_out)) + function $M.$op(d::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + p1_out, p2_out = $M.$op(x) + t1_out, t2_out = $SCT.hessian_tracer_1_to_2( + d, + $SCT.is_firstder_out1_zero_local($M.$op, x), + $SCT.is_seconder_out1_zero_local($M.$op, x), + $SCT.is_firstder_out2_zero_local($M.$op, x), + $SCT.is_seconder_out2_zero_local($M.$op, x), + ) + return ($SCT.Dual(p1_out, t1_out), $SCT.Dual(p2_out, t2_out)) + end end end -## Actual overloads - -for op in ops_1_to_1 - overload_hessian_1_to_1(Base, op) -end - -for op in ops_2_to_1 - overload_hessian_2_to_1(Base, op) -end - -for op in ops_1_to_2 - overload_hessian_1_to_2(Base, op) -end - ## Special cases ## Exponent (requires extra types) diff --git a/test/Project.toml b/test/Project.toml index 60cb4362..2c75e169 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,4 +12,5 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/classification.jl b/test/classification.jl index b68e6076..29436b89 100644 --- a/test/classification.jl +++ b/test/classification.jl @@ -1,10 +1,10 @@ using SparseConnectivityTracer: - ops_1_to_1, + list_operators_1_to_1, is_firstder_zero_global, is_seconder_zero_global, is_firstder_zero_local, is_seconder_zero_local, - ops_2_to_1, + list_operators_2_to_1, is_firstder_arg1_zero_global, is_seconder_arg1_zero_global, is_firstder_arg2_zero_global, @@ -15,7 +15,7 @@ using SparseConnectivityTracer: is_firstder_arg2_zero_local, is_seconder_arg2_zero_local, is_crossder_zero_local, - ops_1_to_2, + list_operators_1_to_2, is_firstder_out1_zero_global, is_seconder_out1_zero_global, is_firstder_out2_zero_global, @@ -24,6 +24,7 @@ using SparseConnectivityTracer: is_firstder_out1_zero_local, is_firstder_out2_zero_local, is_seconder_out2_zero_local +using SpecialFunctions: SpecialFunctions using Test using ForwardDiff: derivative, gradient, hessian @@ -82,11 +83,13 @@ function correct_classification_1_to_1(op, x; atol) end @testset verbose = true "1-to-1" begin - @testset "$op" for op in ops_1_to_1 - @test all( - correct_classification_1_to_1(op, random_input(op); atol=DEFAULT_ATOL) for - _ in 1:DEFAULT_TRIALS - ) + @testset "$m" for m in (Base, SpecialFunctions) + @testset "$op" for op in list_operators_1_to_1(Val(Symbol(m))) + @test all( + correct_classification_1_to_1(op, random_input(op); atol=DEFAULT_ATOL) for + _ in 1:DEFAULT_TRIALS + ) + end end end; @@ -122,12 +125,14 @@ function correct_classification_2_to_1(op, x, y; atol) end @testset verbose = true "2-to-1" begin - @testset "$op" for op in ops_2_to_1 - @test all( - correct_classification_2_to_1( - op, random_first_input(op), random_second_input(op); atol=DEFAULT_ATOL - ) for _ in 1:DEFAULT_TRIALS - ) + @testset "$m" for m in (Base, SpecialFunctions) + @testset "$op" for op in list_operators_2_to_1(Val(Symbol(m))) + @test all( + correct_classification_2_to_1( + op, random_first_input(op), random_second_input(op); atol=DEFAULT_ATOL + ) for _ in 1:DEFAULT_TRIALS + ) + end end end; @@ -159,10 +164,12 @@ function correct_classification_1_to_2(op, x; atol) end @testset verbose = true "1-to-2" begin - @testset "$op" for op in ops_1_to_2 - @test all( - correct_classification_1_to_2(op, random_input(op); atol=DEFAULT_ATOL) for - _ in 1:DEFAULT_TRIALS - ) + @testset "$m" for m in (Base, SpecialFunctions) + @testset "$op" for op in list_operators_1_to_2(Val(Symbol(m))) + @test all( + correct_classification_1_to_2(op, random_input(op); atol=DEFAULT_ATOL) for + _ in 1:DEFAULT_TRIALS + ) + end end end; diff --git a/test/first_order.jl b/test/first_order.jl index d55882bf..0355dba5 100644 --- a/test/first_order.jl +++ b/test/first_order.jl @@ -4,6 +4,7 @@ using SparseConnectivityTracer: using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector using ADTypes: jacobian_sparsity using LinearAlgebra: det, dot, logdet +using SpecialFunctions: erf, beta using Test const FIRST_ORDER_SET_TYPES = ( @@ -34,6 +35,10 @@ const FIRST_ORDER_SET_TYPES = ( @test connectivity_pattern(x -> ℯ^x, 1, G) ≈ [1;;] @test connectivity_pattern(x -> round(x, RoundNearestTiesUp), 1, G) ≈ [1;;] + # SpecialFunctions + @test connectivity_pattern(x -> erf(x[1]), rand(2), G) == [1 0] + @test connectivity_pattern(x -> beta(x[1], x[2]), rand(3), G) == [1 1 0] + ## Error handling when applying non-dual tracers to "local" functions with control flow @test_throws MissingPrimalError connectivity_pattern( x -> ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]), [1 2 3 4], G @@ -78,6 +83,10 @@ end # Linear Algebra @test jacobian_sparsity(x -> dot(x[1:2], x[4:5]), rand(5), method) == [1 1 0 1 1] + # SpecialFunctions + @test jacobian_sparsity(x -> erf(x[1]), rand(2), method) == [1 0] + @test jacobian_sparsity(x -> beta(x[1], x[2]), rand(3), method) == [1 1 0] + ## Error handling when applying non-dual tracers to "local" functions with control flow @test_throws MissingPrimalError jacobian_sparsity( x -> x[1] > x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0], method diff --git a/test/runtests.jl b/test/runtests.jl index af50cfce..f974ea4a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,7 @@ DocMeta.setdocmeta!( SparseConnectivityTracer; ambiguities=false, deps_compat=(ignore=[:Random, :SparseArrays], check_extras=false), + stale_deps=(ignore=[:Requires],), persistent_tasks=false, ) end diff --git a/test/second_order.jl b/test/second_order.jl index b199ca95..1f7028f4 100644 --- a/test/second_order.jl +++ b/test/second_order.jl @@ -3,6 +3,7 @@ using SparseConnectivityTracer: HessianTracer, MissingPrimalError, tracer, trace_input, empty using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector using ADTypes: hessian_sparsity +using SpecialFunctions: erf, beta using Test const SECOND_ORDER_SET_TYPES = ( @@ -145,6 +146,17 @@ const SECOND_ORDER_SET_TYPES = ( 0 1 0 0 1 ] + # SpecialFunctions + @test hessian_sparsity(x -> erf(x[1]), rand(2), method) == [ + 1 0 + 0 0 + ] + @test hessian_sparsity(x -> beta(x[1], x[2]), rand(3), method) == [ + 1 1 0 + 1 1 0 + 0 0 0 + ] + ## Error handling when applying non-dual tracers to "local" functions with control flow f2(x) = ifelse(x[2] < x[3], x[1] * x[2], x[3] * x[4]) @test_throws MissingPrimalError hessian_sparsity(f2, [1 3 2 4], method)