From 0ff7b71d2f42f7a320bb2534acbe7c3f4e11f486 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 15 Aug 2024 15:05:38 +0200 Subject: [PATCH 01/20] Return only `primal` on non-diff `Dual` methods --- src/overloads/conversion.jl | 30 +++++++++++++++++++----------- src/overloads/gradient_tracer.jl | 12 +++--------- src/overloads/hessian_tracer.jl | 12 +++--------- 3 files changed, 25 insertions(+), 29 deletions(-) diff --git a/src/overloads/conversion.jl b/src/overloads/conversion.jl index ceff378..4509cbe 100644 --- a/src/overloads/conversion.jl +++ b/src/overloads/conversion.jl @@ -1,5 +1,9 @@ #! format: off +##===============# +# AbstractTracer # +#================# + ## Type conversions (non-dual) Base.promote_rule(::Type{T}, ::Type{N}) where {T<:AbstractTracer,N<:Real} = T Base.promote_rule(::Type{N}, ::Type{T}) where {T<:AbstractTracer,N<:Real} = T @@ -24,7 +28,10 @@ Base.floatmin(::Type{T}) where {T<:AbstractTracer} = myempty(T) Base.floatmax(::Type{T}) where {T<:AbstractTracer} = myempty(T) Base.maxintfloat(::Type{T}) where {T<:AbstractTracer} = myempty(T) -## Duals +##======# +# Duals # +#=======# + function Base.promote_rule(::Type{Dual{P1, T}}, ::Type{Dual{P2, T}}) where {P1,P2,T} PP = Base.promote_type(P1, P2) # TODO: possible method call error? return Dual{PP,T} @@ -59,15 +66,16 @@ for T in (:Int, :Integer, :Float64, :Float32) end ## Constants -# These are methods defined on types. Methods on variables are in operators.jl -Base.zero(::Type{D}) where {P,T,D<:Dual{P,T}} = D(zero(P), myempty(T)) -Base.one(::Type{D}) where {P,T,D<:Dual{P,T}} = D(one(P), myempty(T)) -Base.oneunit(::Type{D}) where {P,T,D<:Dual{P,T}} = D(oneunit(P), myempty(T)) -Base.typemin(::Type{D}) where {P,T,D<:Dual{P,T}} = D(typemin(P), myempty(T)) -Base.typemax(::Type{D}) where {P,T,D<:Dual{P,T}} = D(typemax(P), myempty(T)) -Base.eps(::Type{D}) where {P,T,D<:Dual{P,T}} = D(eps(P), myempty(T)) -Base.floatmin(::Type{D}) where {P,T,D<:Dual{P,T}} = D(floatmin(P), myempty(T)) -Base.floatmax(::Type{D}) where {P,T,D<:Dual{P,T}} = D(floatmax(P), myempty(T)) -Base.maxintfloat(::Type{D}) where {P,T,D<:Dual{P,T}} = D(maxintfloat(P), myempty(T)) +# These are methods defined on types. Methods on variables are in operators.jl +# TODO: only return primal on methods on variable +Base.zero(::Type{D}) where {P,T,D<:Dual{P,T}} = zero(P) +Base.one(::Type{D}) where {P,T,D<:Dual{P,T}} = one(P) +Base.oneunit(::Type{D}) where {P,T,D<:Dual{P,T}} = oneunit(P) +Base.typemin(::Type{D}) where {P,T,D<:Dual{P,T}} = typemin(P) +Base.typemax(::Type{D}) where {P,T,D<:Dual{P,T}} = typemax(P) +Base.eps(::Type{D}) where {P,T,D<:Dual{P,T}} = eps(P) +Base.floatmin(::Type{D}) where {P,T,D<:Dual{P,T}} = floatmin(P) +Base.floatmax(::Type{D}) where {P,T,D<:Dual{P,T}} = floatmax(P) +Base.maxintfloat(::Type{D}) where {P,T,D<:Dual{P,T}} = maxintfloat(P) #! format: on diff --git a/src/overloads/gradient_tracer.jl b/src/overloads/gradient_tracer.jl index fdc5b21..02b7e8d 100644 --- a/src/overloads/gradient_tracer.jl +++ b/src/overloads/gradient_tracer.jl @@ -226,17 +226,13 @@ Base.round(::T, ::RoundingMode; kwargs...) where {T<:GradientTracer} = myempty(T function Base.round( d::D, mode::RoundingMode; kwargs... ) where {P,T<:GradientTracer,D<:Dual{P,T}} - p = round(primal(d), mode; kwargs...) - t = myempty(T) - return Dual(p, t) + return round(primal(d), mode; kwargs...) # only return primal end for RR in (Real, Integer, Bool) Base.round(::Type{R}, ::T) where {R<:RR,T<:GradientTracer} = myempty(T) function Base.round(::Type{R}, d::D) where {R<:RR,P,T<:GradientTracer,D<:Dual{P,T}} - p = round(R, primal(d)) - t = myempty(T) - return Dual(p, t) + return round(R, primal(d)) # only return primal end end @@ -245,7 +241,5 @@ Base.rand(::AbstractRNG, ::SamplerType{T}) where {T<:GradientTracer} = myempty(T function Base.rand( rng::AbstractRNG, ::SamplerType{D} ) where {P,T<:GradientTracer,D<:Dual{P,T}} - p = rand(rng, P) - t = myempty(T) - return Dual(p, t) + return rand(rng, P) # only return primal end diff --git a/src/overloads/hessian_tracer.jl b/src/overloads/hessian_tracer.jl index 01f7671..0c853a6 100644 --- a/src/overloads/hessian_tracer.jl +++ b/src/overloads/hessian_tracer.jl @@ -332,17 +332,13 @@ Base.round(t::T, ::RoundingMode; kwargs...) where {T<:HessianTracer} = myempty(T function Base.round( d::D, mode::RoundingMode; kwargs... ) where {P,T<:HessianTracer,D<:Dual{P,T}} - p = round(primal(d), mode; kwargs...) - t = myempty(T) - return Dual(p, t) + return round(primal(d), mode; kwargs...) # only return primal end for RR in (Real, Integer, Bool) Base.round(::Type{R}, ::T) where {R<:RR,T<:HessianTracer} = myempty(T) function Base.round(::Type{R}, d::D) where {R<:RR,P,T<:HessianTracer,D<:Dual{P,T}} - p = round(R, primal(d)) - t = myempty(T) - return Dual(p, t) + return round(R, primal(d)) # only return primal end end @@ -351,7 +347,5 @@ Base.rand(::AbstractRNG, ::SamplerType{T}) where {T<:HessianTracer} = myempty(T) function Base.rand( rng::AbstractRNG, ::SamplerType{D} ) where {P,T<:HessianTracer,D<:Dual{P,T}} - p = rand(rng, P) - t = myempty(T) - return Dual(p, t) + return rand(rng, P) # only return primal end From 43a01e0a005e59db1b83ebc08391b64553fa9d9c Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 15 Aug 2024 15:48:41 +0200 Subject: [PATCH 02/20] Update tests --- test/test_constructors.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/test_constructors.jl b/test/test_constructors.jl index fca3eff..6337db3 100644 --- a/test/test_constructors.jl +++ b/test/test_constructors.jl @@ -28,10 +28,9 @@ function test_constant_functions(::Type{D}) where {P,T,D<:Dual{P,T}} @testset "$f" for f in ( zero, one, oneunit, typemin, typemax, eps, floatmin, floatmax, maxintfloat ) - d = f(D) - @test isa(d, D) - @test isemptytracer(d) - @test primal(d) == f(P) + out = f(D) + @test out isa P + @test out == f(P) end end From 00a72ba5c4cfe2d56406753094e7e794700435a4 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 15 Aug 2024 17:58:39 +0200 Subject: [PATCH 03/20] Remove dead code --- test/runtests.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 4b19a51..2156d4c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,17 +5,12 @@ Pkg.develop(; using SparseConnectivityTracer -using Compat using Test -using ReferenceTests using JuliaFormatter using Aqua using JET using Documenter -using LinearAlgebra -using Random - DocMeta.setdocmeta!( SparseConnectivityTracer, :DocTestSetup, From 5a253d5c4fa507f274c4e4f55ba4de431015ffc1 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 15 Aug 2024 17:58:15 +0200 Subject: [PATCH 04/20] Make comparisons regular operators --- src/operators.jl | 2 ++ src/overloads/dual.jl | 6 ------ src/overloads/ifelse_global.jl | 14 -------------- 3 files changed, 2 insertions(+), 20 deletions(-) diff --git a/src/operators.jl b/src/operators.jl index 5dc97b3..562482c 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -374,6 +374,8 @@ end ops_2_to_1_zzz = ( # division div, fld, fld1, cld, + # comparisons + isequal, isapprox, isless, ==, <, >, <=, >=, ) for op in ops_2_to_1_zzz T = typeof(op) diff --git a/src/overloads/dual.jl b/src/overloads/dual.jl index 87844d7..4c504bd 100644 --- a/src/overloads/dual.jl +++ b/src/overloads/dual.jl @@ -19,11 +19,5 @@ for fn in ( end end -for fn in (:isequal, :isapprox, :isless, :(==), :(<), :(>), :(<=), :(>=)) - @eval Base.$fn(dx::D, dy::D) where {D<:Dual} = $fn(primal(dx), primal(dy)) - @eval Base.$fn(dx::D, y::Real) where {D<:Dual} = $fn(primal(dx), y) - @eval Base.$fn(x::Real, dy::D) where {D<:Dual} = $fn(x, primal(dy)) -end - # In some cases, more specialized methods are needed Base.isless(dx::D, y::AbstractFloat) where {D<:Dual} = isless(primal(dx), y) diff --git a/src/overloads/ifelse_global.jl b/src/overloads/ifelse_global.jl index 54244ac..3af235f 100644 --- a/src/overloads/ifelse_global.jl +++ b/src/overloads/ifelse_global.jl @@ -54,17 +54,3 @@ return ty end end - -# Overload only on AbstractTracer, not Dual -for op in (isequal, isapprox, isless, ==, <, >, <=, >=) - T = typeof(op) - @eval is_der1_arg1_zero_global(::$T) = true - @eval is_der2_arg1_zero_global(::$T) = true - @eval is_der1_arg2_zero_global(::$T) = true - @eval is_der2_arg2_zero_global(::$T) = true - @eval is_der_cross_zero_global(::$T) = true - - op_symb = nameof(op) - SparseConnectivityTracer.eval(overload_gradient_2_to_1(:Base, op_symb)) - SparseConnectivityTracer.eval(overload_hessian_2_to_1(:Base, op_symb)) -end From f89694c23f1f43b8aceb9ba5fa394f6f34d1a57d Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 15 Aug 2024 17:22:40 +0200 Subject: [PATCH 05/20] Fewer overloads --- src/overloads/gradient_tracer.jl | 21 ++++++--------------- src/overloads/hessian_tracer.jl | 21 ++++++--------------- src/overloads/overload_all.jl | 21 --------------------- 3 files changed, 12 insertions(+), 51 deletions(-) diff --git a/src/overloads/gradient_tracer.jl b/src/overloads/gradient_tracer.jl index 02b7e8d..9b6609c 100644 --- a/src/overloads/gradient_tracer.jl +++ b/src/overloads/gradient_tracer.jl @@ -31,16 +31,13 @@ end function overload_gradient_1_to_1(M, op) SCT = SparseConnectivityTracer return quote + ## GradientTracer function $M.$op(t::$SCT.GradientTracer) is_der1_zero = $SCT.is_der1_zero_global($M.$op) return $SCT.gradient_tracer_1_to_1(t, is_der1_zero) end - end -end -function overload_gradient_1_to_1_dual(M, op) - SCT = SparseConnectivityTracer - return quote + ## Dual function $M.$op(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} x = $SCT.primal(d) p_out = $M.$op(x) @@ -102,6 +99,7 @@ end function overload_gradient_2_to_1(M, op) SCT = SparseConnectivityTracer return quote + ## GradientTracer function $M.$op(tx::T, ty::T) where {T<:$SCT.GradientTracer} is_der1_arg1_zero = $SCT.is_der1_arg1_zero_global($M.$op) is_der1_arg2_zero = $SCT.is_der1_arg2_zero_global($M.$op) @@ -117,12 +115,8 @@ function overload_gradient_2_to_1(M, op) is_der1_arg2_zero = $SCT.is_der1_arg2_zero_global($M.$op) return $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero) end - end -end -function overload_gradient_2_to_1_dual(M, op) - SCT = SparseConnectivityTracer - return quote + ## Dual function $M.$op(dx::D, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} x = $SCT.primal(dx) y = $SCT.primal(dy) @@ -177,17 +171,14 @@ end function overload_gradient_1_to_2(M, op) SCT = SparseConnectivityTracer return quote + ## GradientTracer function $M.$op(t::$SCT.GradientTracer) is_der1_out1_zero = $SCT.is_der1_out1_zero_global($M.$op) is_der1_out2_zero = $SCT.is_der1_out2_zero_global($M.$op) return $SCT.gradient_tracer_1_to_2(t, is_der1_out1_zero, is_der1_out2_zero) end - end -end -function overload_gradient_1_to_2_dual(M, op) - SCT = SparseConnectivityTracer - return quote + ## Dual function $M.$op(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} x = $SCT.primal(d) p_out1, p_out2 = $M.$op(x) diff --git a/src/overloads/hessian_tracer.jl b/src/overloads/hessian_tracer.jl index 0c853a6..feae676 100644 --- a/src/overloads/hessian_tracer.jl +++ b/src/overloads/hessian_tracer.jl @@ -58,17 +58,14 @@ end function overload_hessian_1_to_1(M, op) SCT = SparseConnectivityTracer return quote + ## HessianTracer function $M.$op(t::$SCT.HessianTracer) is_der1_zero = $SCT.is_der1_zero_global($M.$op) is_der2_zero = $SCT.is_der2_zero_global($M.$op) return $SCT.hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero) end - end -end -function overload_hessian_1_to_1_dual(M, op) - SCT = SparseConnectivityTracer - return quote + ## Dual function $M.$op(d::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} x = $SCT.primal(d) p_out = $M.$op(x) @@ -166,6 +163,7 @@ end function overload_hessian_2_to_1(M, op) SCT = SparseConnectivityTracer return quote + ## HessianTracer function $M.$op(tx::T, ty::T) where {T<:$SCT.HessianTracer} is_der1_arg1_zero = $SCT.is_der1_arg1_zero_global($M.$op) is_der2_arg1_zero = $SCT.is_der2_arg1_zero_global($M.$op) @@ -194,12 +192,8 @@ function overload_hessian_2_to_1(M, op) is_der2_arg2_zero = $SCT.is_der2_arg2_zero_global($M.$op) return $SCT.hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_der2_arg2_zero) end - end -end -function overload_hessian_2_to_1_dual(M, op) - SCT = SparseConnectivityTracer - return quote + ## Dual function $M.$op(dx::D, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} x = $SCT.primal(dx) y = $SCT.primal(dy) @@ -269,6 +263,7 @@ end function overload_hessian_1_to_2(M, op) SCT = SparseConnectivityTracer return quote + ## HessianTracer function $M.$op(t::$SCT.HessianTracer) is_der1_out1_zero = $SCT.is_der1_out1_zero_global($M.$op) is_der2_out1_zero = $SCT.is_der2_out1_zero_global($M.$op) @@ -282,12 +277,8 @@ function overload_hessian_1_to_2(M, op) is_der2_out2_zero, ) end - end -end -function overload_hessian_1_to_2_dual(M, op) - SCT = SparseConnectivityTracer - return quote + ## Dual function $M.$op(d::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} x = $SCT.primal(d) p_out1, p_out2 = $M.$op(x) diff --git a/src/overloads/overload_all.jl b/src/overloads/overload_all.jl index 737fed2..5639fab 100644 --- a/src/overloads/overload_all.jl +++ b/src/overloads/overload_all.jl @@ -5,43 +5,22 @@ function overload_all(M) $(overload_hessian_1_to_1(M, op)) end for op in nameof.(list_operators_1_to_1(Val(M))) ] - exprs_1_to_1_dual = [ - quote - $(overload_gradient_1_to_1_dual(M, op)) - $(overload_hessian_1_to_1_dual(M, op)) - end for op in nameof.(list_operators_1_to_1(Val(M))) - ] exprs_2_to_1 = [ quote $(overload_gradient_2_to_1(M, op)) $(overload_hessian_2_to_1(M, op)) end for op in nameof.(list_operators_2_to_1(Val(M))) ] - exprs_2_to_1_dual = [ - quote - $(overload_gradient_2_to_1_dual(M, op)) - $(overload_hessian_2_to_1_dual(M, op)) - end for op in nameof.(list_operators_2_to_1(Val(M))) - ] exprs_1_to_2 = [ quote $(overload_gradient_1_to_2(M, op)) $(overload_hessian_1_to_2(M, op)) end for op in nameof.(list_operators_1_to_2(Val(M))) ] - exprs_1_to_2_dual = [ - quote - $(overload_gradient_1_to_2_dual(M, op)) - $(overload_hessian_1_to_2_dual(M, op)) - end for op in nameof.(list_operators_1_to_2(Val(M))) - ] return quote $(exprs_1_to_1...) - $(exprs_1_to_1_dual...) $(exprs_2_to_1...) - $(exprs_2_to_1_dual...) $(exprs_1_to_2...) - $(exprs_1_to_2_dual...) end end From 0fb2e50c755dcf253b4af95624b549220d1465d3 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 15 Aug 2024 18:05:49 +0200 Subject: [PATCH 06/20] Revert random sampling of `Dual` --- src/overloads/gradient_tracer.jl | 4 +++- src/overloads/hessian_tracer.jl | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/overloads/gradient_tracer.jl b/src/overloads/gradient_tracer.jl index 9b6609c..6253dab 100644 --- a/src/overloads/gradient_tracer.jl +++ b/src/overloads/gradient_tracer.jl @@ -232,5 +232,7 @@ Base.rand(::AbstractRNG, ::SamplerType{T}) where {T<:GradientTracer} = myempty(T function Base.rand( rng::AbstractRNG, ::SamplerType{D} ) where {P,T<:GradientTracer,D<:Dual{P,T}} - return rand(rng, P) # only return primal + p = rand(rng, P) + t = myempty(T) + return Dual(p, t) end diff --git a/src/overloads/hessian_tracer.jl b/src/overloads/hessian_tracer.jl index feae676..3604ed9 100644 --- a/src/overloads/hessian_tracer.jl +++ b/src/overloads/hessian_tracer.jl @@ -338,5 +338,7 @@ Base.rand(::AbstractRNG, ::SamplerType{T}) where {T<:HessianTracer} = myempty(T) function Base.rand( rng::AbstractRNG, ::SamplerType{D} ) where {P,T<:HessianTracer,D<:Dual{P,T}} - return rand(rng, P) # only return primal + p = rand(rng, P) + t = myempty(T) + return Dual(p, t) end From f3f037d012513bf8bcedc10188f7040189dfd77d Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 15 Aug 2024 19:11:17 +0200 Subject: [PATCH 07/20] Add more named testsets --- test/test_gradient.jl | 123 ++++++++------- test/test_hessian.jl | 348 ++++++++++++++++++++++-------------------- 2 files changed, 253 insertions(+), 218 deletions(-) diff --git a/test/test_gradient.jl b/test/test_gradient.jl index 0b4abc3..86cb30f 100644 --- a/test/test_gradient.jl +++ b/test/test_gradient.jl @@ -41,77 +41,96 @@ NNLIB_ACTIVATIONS = union(NNLIB_ACTIVATIONS_S, NNLIB_ACTIVATIONS_F) REAL_TYPES = (Float64, Int, Bool, UInt8, Float16, Rational{Int}) +# These exists to be able to quickly run tests in the REPL. +# NOTE: J gets overwritten inside the testsets. +method = TracerSparsityDetector() +J(f, x) = jacobian_sparsity(f, x, method) + @testset "Jacobian Global" begin @testset "$P" for P in GRADIENT_PATTERNS T = GradientTracer{P} method = TracerSparsityDetector(; gradient_tracer_type=T) J(f, x) = jacobian_sparsity(f, x, method) - f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])] - @test J(f, rand(3)) == [1 0 0; 1 1 0; 0 0 1] - @test J(identity, rand()) ≈ [1;;] - @test J(Returns(1), 1) ≈ [0;;] + @testset "Trivial examples" begin + f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])] + @test J(f, rand(3)) == [1 0 0; 1 1 0; 0 0 1] + @test J(identity, rand()) ≈ [1;;] + @test J(Returns(1), 1) ≈ [0;;] + end # Test GradientTracer on functions with zero derivatives - x = rand(2) - g(x) = [x[1] * x[2], ceil(x[1] * x[2]), x[1] * round(x[2])] - @test J(g, x) == [1 1; 0 0; 1 0] - @test J(!, true) ≈ [0;;] + @testset "Zero derivatives" begin + x = rand(2) + g(x) = [x[1] * x[2], ceil(x[1] * x[2]), x[1] * round(x[2])] + @test J(g, x) == [1 1; 0 0; 1 0] + @test J(!, true) ≈ [0;;] + end # Code coverage - @test J(x -> [sincos(x)...], 1) ≈ [1; 1] - @test J(typemax, 1) ≈ [0;;] - @test J(x -> x^(2//3), 1) ≈ [1;;] - @test J(x -> (2//3)^x, 1) ≈ [1;;] - @test J(x -> x^ℯ, 1) ≈ [1;;] - @test J(x -> ℯ^x, 1) ≈ [1;;] - @test J(x -> 0, 1) ≈ [0;;] - - # Test special cases on empty tracer - @test J(x -> zero(x)^(2//3), 1) ≈ [0;;] - @test J(x -> (2//3)^zero(x), 1) ≈ [0;;] - @test J(x -> zero(x)^ℯ, 1) ≈ [0;;] - @test J(x -> ℯ^zero(x), 1) ≈ [0;;] + @testset "Miscellaneous" begin + @test J(x -> [sincos(x)...], 1) ≈ [1; 1] + @test J(typemax, 1) ≈ [0;;] + @test J(x -> x^(2//3), 1) ≈ [1;;] + @test J(x -> (2//3)^x, 1) ≈ [1;;] + @test J(x -> x^ℯ, 1) ≈ [1;;] + @test J(x -> ℯ^x, 1) ≈ [1;;] + @test J(x -> 0, 1) ≈ [0;;] + + # Test special cases on empty tracer + @test J(x -> zero(x)^(2//3), 1) ≈ [0;;] + @test J(x -> (2//3)^zero(x), 1) ≈ [0;;] + @test J(x -> zero(x)^ℯ, 1) ≈ [0;;] + @test J(x -> ℯ^zero(x), 1) ≈ [0;;] + end # Conversions - @testset "Conversion to $T" for T in REAL_TYPES - @test J(x -> convert(T, x), 1.0) ≈ [1;;] + @testset "Conversion" begin + @testset "to $T" for T in REAL_TYPES + @test J(x -> convert(T, x), 1.0) ≈ [1;;] + end end - # Round - @test J(round, 1.1) ≈ [0;;] - @test J(x -> round(Int, x), 1.1) ≈ [0;;] - @test J(x -> round(Bool, x), 1.1) ≈ [0;;] - @test J(x -> round(Float16, x), 1.1) ≈ [0;;] - @test J(x -> round(x, RoundNearestTiesAway), 1.1) ≈ [0;;] - @test J(x -> round(x; digits=3, base=2), 1.1) ≈ [0;;] + @testset "Round" begin + @test J(round, 1.1) ≈ [0;;] + @test J(x -> round(Int, x), 1.1) ≈ [0;;] + @test J(x -> round(Bool, x), 1.1) ≈ [0;;] + @test J(x -> round(Float16, x), 1.1) ≈ [0;;] + @test J(x -> round(x, RoundNearestTiesAway), 1.1) ≈ [0;;] + @test J(x -> round(x; digits=3, base=2), 1.1) ≈ [0;;] + end - # Random - @test J(x -> rand(typeof(x)), 1) ≈ [0;;] - @test J(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;] + @testset "Random" begin + @test J(x -> rand(typeof(x)), 1) ≈ [0;;] + @test J(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;] + end - # Linear Algebra - @test J(x -> dot(x[1:2], x[4:5]), rand(5)) == [1 1 0 1 1] + @testset "LinearAlgebra" begin + @test J(x -> dot(x[1:2], x[4:5]), rand(5)) == [1 1 0 1 1] + end - # SpecialFunctions extension - @test J(x -> erf(x[1]), rand(2)) == [1 0] - @test J(x -> beta(x[1], x[2]), rand(3)) == [1 1 0] + @testset "SpecialFunctions extension" begin + @test J(x -> erf(x[1]), rand(2)) == [1 0] + @test J(x -> beta(x[1], x[2]), rand(3)) == [1 1 0] + end # Missing primal errors - @testset "MissingPrimalError on $f" for f in ( - iseven, - isfinite, - isinf, - isinteger, - ismissing, - isnan, - isnothing, - isodd, - isone, - isreal, - iszero, - ) - @test_throws MissingPrimalError J(f, rand()) + @testset "MissingPrimalError" begin + @testset "$f" for f in ( + iseven, + isfinite, + isinf, + isinteger, + ismissing, + isnan, + isnothing, + isodd, + isone, + isreal, + iszero, + ) + @test_throws MissingPrimalError J(f, rand()) + end end # NNlib extension diff --git a/test/test_hessian.jl b/test/test_hessian.jl index 9fb0d5d..ff677e9 100644 --- a/test/test_hessian.jl +++ b/test/test_hessian.jl @@ -26,121 +26,152 @@ D = Dual{Int,T} method = TracerSparsityDetector(; hessian_tracer_type=T) H(f, x) = hessian_sparsity(f, x, method) - @test H(identity, rand()) ≈ [0;;] - @test H(sqrt, rand()) ≈ [1;;] + @testset "Trivial examples" begin + @test H(identity, rand()) ≈ [0;;] + @test H(sqrt, rand()) ≈ [1;;] - @test H(x -> 1 * x, rand()) ≈ [0;;] - @test H(x -> x * 1, rand()) ≈ [0;;] + @test H(x -> 1 * x, rand()) ≈ [0;;] + @test H(x -> x * 1, rand()) ≈ [0;;] + end # Code coverage - @test H(sign, 1) ≈ [0;;] - @test H(typemax, 1) ≈ [0;;] - @test H(x -> x^(2//3), 1) ≈ [1;;] - @test H(x -> (2//3)^x, 1) ≈ [1;;] - @test H(x -> x^ℯ, 1) ≈ [1;;] - @test H(x -> ℯ^x, 1) ≈ [1;;] - @test H(x -> 0, 1) ≈ [0;;] + @testset "Miscellaneous" begin + @test H(sign, 1) ≈ [0;;] + @test H(typemax, 1) ≈ [0;;] + @test H(x -> x^(2//3), 1) ≈ [1;;] + @test H(x -> (2//3)^x, 1) ≈ [1;;] + @test H(x -> x^ℯ, 1) ≈ [1;;] + @test H(x -> ℯ^x, 1) ≈ [1;;] + @test H(x -> 0, 1) ≈ [0;;] + end # Conversions - @testset "Conversion to $T" for T in REAL_TYPES - @test H(x -> convert(T, x), 1.0) ≈ [0;;] - @test H(x -> convert(T, x^2), 1.0) ≈ [1;;] - @test H(x -> convert(T, x)^2, 1.0) ≈ [1;;] + @testset "Conversion" begin + @testset "to $T" for T in REAL_TYPES + @test H(x -> convert(T, x), 1.0) ≈ [0;;] + @test H(x -> convert(T, x^2), 1.0) ≈ [1;;] + @test H(x -> convert(T, x)^2, 1.0) ≈ [1;;] + end end - # Round - @test H(round, 1.1) ≈ [0;;] - @test H(x -> round(Int, x), 1.1) ≈ [0;;] - @test H(x -> round(Bool, x), 1.1) ≈ [0;;] - @test H(x -> round(Float16, x), 1.1) ≈ [0;;] - @test H(x -> round(x, RoundNearestTiesAway), 1.1) ≈ [0;;] - @test H(x -> round(x; digits=3, base=2), 1.1) ≈ [0;;] - - # Random - @test H(x -> rand(typeof(x)), 1) ≈ [0;;] - @test H(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;] - - @test H(x -> x[1] / x[2] + x[3] / 1 + 1 / x[4], rand(4)) == [ - 0 1 0 0 - 1 1 0 0 - 0 0 0 0 - 0 0 0 1 - ] - - @test H(x -> x[1] * x[2] + x[3] * 1 + 1 * x[4], rand(4)) == [ - 0 1 0 0 - 1 0 0 0 - 0 0 0 0 - 0 0 0 0 - ] - - @test H(x -> (x[1] * x[2]) * (x[3] * x[4]), rand(4)) == [ - 0 1 1 1 - 1 0 1 1 - 1 1 0 1 - 1 1 1 0 - ] - - @test H(x -> (x[1] + x[2]) * (x[3] + x[4]), rand(4)) == [ - 0 0 1 1 - 0 0 1 1 - 1 1 0 0 - 1 1 0 0 - ] - - @test H(x -> (x[1] + x[2] + x[3] + x[4])^2, rand(4)) == [ - 1 1 1 1 - 1 1 1 1 - 1 1 1 1 - 1 1 1 1 - ] - - @test H(x -> 1 / (x[1] + x[2] + x[3] + x[4]), rand(4)) == [ - 1 1 1 1 - 1 1 1 1 - 1 1 1 1 - 1 1 1 1 - ] + @testset "Round" begin + @test H(round, 1.1) ≈ [0;;] + @test H(x -> round(Int, x), 1.1) ≈ [0;;] + @test H(x -> round(Bool, x), 1.1) ≈ [0;;] + @test H(x -> round(Float16, x), 1.1) ≈ [0;;] + @test H(x -> round(x, RoundNearestTiesAway), 1.1) ≈ [0;;] + @test H(x -> round(x; digits=3, base=2), 1.1) ≈ [0;;] + end - @test H(x -> (x[1] - x[2]) + (x[3] - 1) + (1 - x[4]), rand(4)) == [ - 0 0 0 0 - 0 0 0 0 - 0 0 0 0 - 0 0 0 0 - ] + @testset "Random" begin + @test H(x -> rand(typeof(x)), 1) ≈ [0;;] + @test H(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;] + end - h = H(x -> copysign(x[1] * x[2], x[3] * x[4]), rand(4)) - if Bool(shared(T)) - @test h == [ + @testset "Basic operators" begin + @test H(x -> x[1] / x[2] + x[3] / 1 + 1 / x[4], rand(4)) == [ 0 1 0 0 - 1 0 0 0 + 1 1 0 0 + 0 0 0 0 0 0 0 1 - 0 0 1 0 ] - else - @test h == [ + + @test H(x -> x[1] * x[2] + x[3] * 1 + 1 * x[4], rand(4)) == [ 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 ] - end - h = H(x -> div(x[1] * x[2], x[3] * x[4]), rand(4)) - if Bool(shared(T)) - @test Matrix(h) == [ - 0 1 0 0 - 1 0 0 0 - 0 0 0 1 - 0 0 1 0 + @test H(x -> (x[1] * x[2]) * (x[3] * x[4]), rand(4)) == [ + 0 1 1 1 + 1 0 1 1 + 1 1 0 1 + 1 1 1 0 ] - else - @test h == [ + + @test H(x -> (x[1] + x[2]) * (x[3] + x[4]), rand(4)) == [ + 0 0 1 1 + 0 0 1 1 + 1 1 0 0 + 1 1 0 0 + ] + + @test H(x -> (x[1] + x[2] + x[3] + x[4])^2, rand(4)) == [ + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + ] + + @test H(x -> 1 / (x[1] + x[2] + x[3] + x[4]), rand(4)) == [ + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + ] + + @test H(x -> (x[1] - x[2]) + (x[3] - 1) + (1 - x[4]), rand(4)) == [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ] + + x = rand(5) + foo(x) = x[1] + x[2] * x[3] + 1 / x[4] + 1 * x[5] + @test H(foo, x) == [ + 0 0 0 0 0 + 0 0 1 0 0 + 0 1 0 0 0 + 0 0 0 1 0 + 0 0 0 0 0 + ] + + bar(x) = foo(x) + x[2]^x[5] + @test H(bar, x) == [ + 0 0 0 0 0 + 0 1 1 0 1 + 0 1 0 0 0 + 0 0 0 1 0 + 0 1 0 0 1 + ] + end + + @testset "Zero derivatives" begin + h = H(x -> copysign(x[1] * x[2], x[3] * x[4]), rand(4)) + if Bool(shared(T)) + @test h == [ + 0 1 0 0 + 1 0 0 0 + 0 0 0 1 + 0 0 1 0 + ] + else + @test h == [ + 0 1 0 0 + 1 0 0 0 + 0 0 0 0 + 0 0 0 0 + ] + end + + h = H(x -> div(x[1] * x[2], x[3] * x[4]), rand(4)) + if Bool(shared(T)) + @test Matrix(h) == [ + 0 1 0 0 + 1 0 0 0 + 0 0 0 1 + 0 0 1 0 + ] + else + @test h == [ + 0 0 0 0 + 0 0 0 0 + 0 0 0 0 + 0 0 0 0 + ] + end end @test H(x -> sum(sincosd(x)), 1.0) ≈ [1;;] @@ -152,25 +183,6 @@ D = Dual{Int,T} 0 0 1 1 ] - x = rand(5) - foo(x) = x[1] + x[2] * x[3] + 1 / x[4] + 1 * x[5] - @test H(foo, x) == [ - 0 0 0 0 0 - 0 0 1 0 0 - 0 1 0 0 0 - 0 0 0 1 0 - 0 0 0 0 0 - ] - - bar(x) = foo(x) + x[2]^x[5] - @test H(bar, x) == [ - 0 0 0 0 0 - 0 1 1 0 1 - 0 1 0 0 0 - 0 0 0 1 0 - 0 1 0 0 1 - ] - # Shared Hessian function dead_end(x) z = x[1] * x[2] @@ -194,68 +206,72 @@ D = Dual{Int,T} end # Missing primal errors - @testset "MissingPrimalError on $f" for f in ( - iseven, - isfinite, - isinf, - isinteger, - ismissing, - isnan, - isnothing, - isodd, - isone, - isreal, - iszero, - ) - @test_throws MissingPrimalError H(f, rand()) + @testset "MissingPrimalError" begin + @testset "$f" for f in ( + iseven, + isfinite, + isinf, + isinteger, + ismissing, + isnan, + isnothing, + isodd, + isone, + isreal, + iszero, + ) + @test_throws MissingPrimalError H(f, rand()) + end end # ifelse and comparisons - if VERSION >= v"1.8" - @test H(x -> ifelse(x[1], x[1]^x[2], x[3] * x[4]), rand(4)) == [ - 1 1 0 0 - 1 1 0 0 - 0 0 0 1 - 0 0 1 0 + @testset "ifelse and comparisons" begin + if VERSION >= v"1.8" + @test H(x -> ifelse(x[1], x[1]^x[2], x[3] * x[4]), rand(4)) == [ + 1 1 0 0 + 1 1 0 0 + 0 0 0 1 + 0 0 1 0 + ] + + @test H(x -> ifelse(x[1], x[1]^x[2], 1.0), rand(4)) == [ + 1 1 0 0 + 1 1 0 0 + 0 0 0 0 + 0 0 0 0 + ] + + @test H(x -> ifelse(x[1], 1.0, x[3] * x[4]), rand(4)) == [ + 0 0 0 0 + 0 0 0 0 + 0 0 0 1 + 0 0 1 0 + ] + end + + function f_ampgo07(x) + return (x[1] <= 0) * convert(eltype(x), Inf) + + sin(x[1]) + + sin(10//3 * x[1]) + + log(abs(x[1])) - 84//100 * x[1] + 3 + end + @test H(f_ampgo07, [1.0]) ≈ [1;;] + + # Error handling when applying non-dual tracers to "local" functions with control flow + # TypeError: non-boolean (SparseConnectivityTracer.GradientTracer{BitSet}) used in boolean context + @test_throws TypeError H(x -> x[1] > x[2] ? x[1]^x[2] : x[3] * x[4], rand(4)) + + # SpecialFunctions + @test H(x -> erf(x[1]), rand(2)) == [ + 1 0 + 0 0 ] - - @test H(x -> ifelse(x[1], x[1]^x[2], 1.0), rand(4)) == [ - 1 1 0 0 - 1 1 0 0 - 0 0 0 0 - 0 0 0 0 - ] - - @test H(x -> ifelse(x[1], 1.0, x[3] * x[4]), rand(4)) == [ - 0 0 0 0 - 0 0 0 0 - 0 0 0 1 - 0 0 1 0 + @test H(x -> beta(x[1], x[2]), rand(3)) == [ + 1 1 0 + 1 1 0 + 0 0 0 ] end - - function f_ampgo07(x) - return (x[1] <= 0) * convert(eltype(x), Inf) + - sin(x[1]) + - sin(10//3 * x[1]) + - log(abs(x[1])) - 84//100 * x[1] + 3 - end - @test H(f_ampgo07, [1.0]) ≈ [1;;] - - # Error handling when applying non-dual tracers to "local" functions with control flow - # TypeError: non-boolean (SparseConnectivityTracer.GradientTracer{BitSet}) used in boolean context - @test_throws TypeError H(x -> x[1] > x[2] ? x[1]^x[2] : x[3] * x[4], rand(4)) - - # SpecialFunctions - @test H(x -> erf(x[1]), rand(2)) == [ - 1 0 - 0 0 - ] - @test H(x -> beta(x[1], x[2]), rand(3)) == [ - 1 1 0 - 1 1 0 - 0 0 0 - ] yield() end end From 06b7d7287c6ab2b3eabc6598560a8afe0218ac8c Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 15 Aug 2024 19:15:23 +0200 Subject: [PATCH 08/20] Moduleless overloads --- src/overloads/gradient_tracer.jl | 9 +++-- src/overloads/hessian_tracer.jl | 9 +++-- src/overloads/overload_all.jl | 59 ++++++++++++++++++-------------- 3 files changed, 45 insertions(+), 32 deletions(-) diff --git a/src/overloads/gradient_tracer.jl b/src/overloads/gradient_tracer.jl index 6253dab..7cb78c4 100644 --- a/src/overloads/gradient_tracer.jl +++ b/src/overloads/gradient_tracer.jl @@ -28,7 +28,8 @@ function gradient_tracer_1_to_1_inner( end end -function overload_gradient_1_to_1(M, op) +function overload_gradient_1_to_1(op) + M = parentmodule(op) SCT = SparseConnectivityTracer return quote ## GradientTracer @@ -96,7 +97,8 @@ function gradient_tracer_2_to_1_inner( end end -function overload_gradient_2_to_1(M, op) +function overload_gradient_2_to_1(op) + M = parentmodule(op) SCT = SparseConnectivityTracer return quote ## GradientTracer @@ -168,7 +170,8 @@ end end end -function overload_gradient_1_to_2(M, op) +function overload_gradient_1_to_2(op) + M = parentmodule(op) SCT = SparseConnectivityTracer return quote ## GradientTracer diff --git a/src/overloads/hessian_tracer.jl b/src/overloads/hessian_tracer.jl index 3604ed9..4909df0 100644 --- a/src/overloads/hessian_tracer.jl +++ b/src/overloads/hessian_tracer.jl @@ -55,7 +55,8 @@ function hessian_tracer_1_to_1_inner( return P(g_out, h_out) # return pattern end -function overload_hessian_1_to_1(M, op) +function overload_hessian_1_to_1(op) + M = parentmodule(op) SCT = SparseConnectivityTracer return quote ## HessianTracer @@ -160,7 +161,8 @@ function hessian_tracer_2_to_1_inner( return P(g_out, h_out) # return pattern end -function overload_hessian_2_to_1(M, op) +function overload_hessian_2_to_1(op) + M = parentmodule(op) SCT = SparseConnectivityTracer return quote ## HessianTracer @@ -260,7 +262,8 @@ end end end -function overload_hessian_1_to_2(M, op) +function overload_hessian_1_to_2(op) + M = parentmodule(op) SCT = SparseConnectivityTracer return quote ## HessianTracer diff --git a/src/overloads/overload_all.jl b/src/overloads/overload_all.jl index 5639fab..73fd8ab 100644 --- a/src/overloads/overload_all.jl +++ b/src/overloads/overload_all.jl @@ -1,27 +1,34 @@ -function overload_all(M) - exprs_1_to_1 = [ - quote - $(overload_gradient_1_to_1(M, op)) - $(overload_hessian_1_to_1(M, op)) - end for op in nameof.(list_operators_1_to_1(Val(M))) - ] - exprs_2_to_1 = [ - quote - $(overload_gradient_2_to_1(M, op)) - $(overload_hessian_2_to_1(M, op)) - end for op in nameof.(list_operators_2_to_1(Val(M))) - ] - exprs_1_to_2 = [ - quote - $(overload_gradient_1_to_2(M, op)) - $(overload_hessian_1_to_2(M, op)) - end for op in nameof.(list_operators_1_to_2(Val(M))) - ] - return quote - $(exprs_1_to_1...) - $(exprs_2_to_1...) - $(exprs_1_to_2...) - end -end +# for overload in ( +# overload_gradient_1_to_1, +# overload_gradient_2_to_1, +# overload_gradient_1_to_2, +# overload_hessian_1_to_1, +# overload_hessian_2_to_1, +# overload_hessian_1_to_2, +# ) +# @eval function $overload(ops::Tuple) +# for op in ops +# $overload(op) +# end +# end +# end + +# overload_gradient_1_to_1(ops_1_to_1) +# overload_gradient_2_to_1(ops_2_to_1) +# overload_gradient_1_to_2(ops_1_to_2) +# overload_hessian_1_to_1(ops_1_to_1) +# overload_hessian_2_to_1(ops_2_to_1) +# overload_hessian_1_to_2(ops_1_to_2) -eval(overload_all(:Base)) +for op in ops_1_to_1 + overload_gradient_1_to_1(op) + overload_hessian_1_to_1(op) +end +for op in ops_2_to_1 + overload_gradient_2_to_1(op) + overload_hessian_2_to_1(op) +end +for op in ops_1_to_2 + overload_gradient_1_to_2(op) + overload_hessian_1_to_2(op) +end From 4bcd08c187dad0a0b81dd178d60531cdbd0d5668 Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 15 Aug 2024 19:18:19 +0200 Subject: [PATCH 09/20] Use `@eval` instead of quotes --- src/overloads/gradient_tracer.jl | 6 +++--- src/overloads/hessian_tracer.jl | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/overloads/gradient_tracer.jl b/src/overloads/gradient_tracer.jl index 7cb78c4..3a98cb9 100644 --- a/src/overloads/gradient_tracer.jl +++ b/src/overloads/gradient_tracer.jl @@ -31,7 +31,7 @@ end function overload_gradient_1_to_1(op) M = parentmodule(op) SCT = SparseConnectivityTracer - return quote + @eval begin ## GradientTracer function $M.$op(t::$SCT.GradientTracer) is_der1_zero = $SCT.is_der1_zero_global($M.$op) @@ -100,7 +100,7 @@ end function overload_gradient_2_to_1(op) M = parentmodule(op) SCT = SparseConnectivityTracer - return quote + @eval begin ## GradientTracer function $M.$op(tx::T, ty::T) where {T<:$SCT.GradientTracer} is_der1_arg1_zero = $SCT.is_der1_arg1_zero_global($M.$op) @@ -173,7 +173,7 @@ end function overload_gradient_1_to_2(op) M = parentmodule(op) SCT = SparseConnectivityTracer - return quote + @eval begin ## GradientTracer function $M.$op(t::$SCT.GradientTracer) is_der1_out1_zero = $SCT.is_der1_out1_zero_global($M.$op) diff --git a/src/overloads/hessian_tracer.jl b/src/overloads/hessian_tracer.jl index 4909df0..2ddb7ae 100644 --- a/src/overloads/hessian_tracer.jl +++ b/src/overloads/hessian_tracer.jl @@ -58,7 +58,7 @@ end function overload_hessian_1_to_1(op) M = parentmodule(op) SCT = SparseConnectivityTracer - return quote + @eval begin ## HessianTracer function $M.$op(t::$SCT.HessianTracer) is_der1_zero = $SCT.is_der1_zero_global($M.$op) @@ -164,7 +164,7 @@ end function overload_hessian_2_to_1(op) M = parentmodule(op) SCT = SparseConnectivityTracer - return quote + @eval begin ## HessianTracer function $M.$op(tx::T, ty::T) where {T<:$SCT.HessianTracer} is_der1_arg1_zero = $SCT.is_der1_arg1_zero_global($M.$op) @@ -265,7 +265,7 @@ end function overload_hessian_1_to_2(op) M = parentmodule(op) SCT = SparseConnectivityTracer - return quote + @eval begin ## HessianTracer function $M.$op(t::$SCT.HessianTracer) is_der1_out1_zero = $SCT.is_der1_out1_zero_global($M.$op) From 11953aba5f0582326f7e04a80c6f60c301f20977 Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 16 Aug 2024 15:34:58 +0200 Subject: [PATCH 10/20] More named testsets --- test/test_gradient.jl | 299 +++++++++++++++++++------------------ test/test_hessian.jl | 333 ++++++++++++++++++++++-------------------- 2 files changed, 328 insertions(+), 304 deletions(-) diff --git a/test/test_gradient.jl b/test/test_gradient.jl index 86cb30f..9b9c11e 100644 --- a/test/test_gradient.jl +++ b/test/test_gradient.jl @@ -114,7 +114,6 @@ J(f, x) = jacobian_sparsity(f, x, method) @test J(x -> beta(x[1], x[2]), rand(3)) == [1 1 0] end - # Missing primal errors @testset "MissingPrimalError" begin @testset "$f" for f in ( iseven, @@ -133,31 +132,34 @@ J(f, x) = jacobian_sparsity(f, x, method) end end - # NNlib extension - for f in NNLIB_ACTIVATIONS - @test J(f, 1) ≈ [1;;] + @testset "NNlib" begin + @testset "$f" for f in NNLIB_ACTIVATIONS + @test J(f, 1) ≈ [1;;] + end end - # ifelse and comparisons - if VERSION >= v"1.8" - @test J(x -> ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]), [1 2 3 4]) == - [1 1 1 1] - @test J(x -> ifelse(x[2] < x[3], x[1] + x[2], 1.0), [1 2 3 4]) == [1 1 0 0] - @test J(x -> ifelse(x[2] < x[3], 1.0, x[3] * x[4]), [1 2 3 4]) == [0 0 1 1] - end + @testset "ifelse and comparisons" begin + if VERSION >= v"1.8" + @test J(x -> ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]), [1 2 3 4]) == + [1 1 1 1] + @test J(x -> ifelse(x[2] < x[3], x[1] + x[2], 1.0), [1 2 3 4]) == [1 1 0 0] + @test J(x -> ifelse(x[2] < x[3], 1.0, x[3] * x[4]), [1 2 3 4]) == [0 0 1 1] + end - function f_ampgo07(x) - return (x[1] <= 0) * convert(eltype(x), Inf) + - sin(x[1]) + - sin(10//3 * x[1]) + - log(abs(x[1])) - 84//100 * x[1] + 3 - end - @test J(f_ampgo07, [1.0]) ≈ [1;;] + function f_ampgo07(x) + return (x[1] <= 0) * convert(eltype(x), Inf) + + sin(x[1]) + + sin(10//3 * x[1]) + + log(abs(x[1])) - 84//100 * x[1] + 3 + end + @test J(f_ampgo07, [1.0]) ≈ [1;;] - ## Error handling when applying non-dual tracers to "local" functions with control flow - # TypeError: non-boolean (SparseConnectivityTracer.GradientTracer{BitSet}) used in boolean context - @test_throws TypeError J(x -> x[1] > x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0]) == - [0 0 1 1;] + # Error handling when applying non-dual tracers to "local" functions with control flow + # TypeError: non-boolean (SparseConnectivityTracer.GradientTracer{BitSet}) used in boolean context + @test_throws TypeError J( + x -> x[1] > x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0] + ) == [0 0 1 1;] + end yield() end end @@ -168,136 +170,147 @@ end method = TracerLocalSparsityDetector(; gradient_tracer_type=T) J(f, x) = jacobian_sparsity(f, x, method) - # Multiplication - @test J(x -> x[1] * x[2], [1.0, 1.0]) == [1 1;] - @test J(x -> x[1] * x[2], [1.0, 0.0]) == [0 1;] - @test J(x -> x[1] * x[2], [0.0, 1.0]) == [1 0;] - @test J(x -> x[1] * x[2], [0.0, 0.0]) == [0 0;] - - # Division - @test J(x -> x[1] / x[2], [1.0, 1.0]) == [1 1;] - @test J(x -> x[1] / x[2], [0.0, 0.0]) == [1 0;] - - # Maximum - @test J(x -> max(x[1], x[2]), [1.0, 2.0]) == [0 1;] - @test J(x -> max(x[1], x[2]), [2.0, 1.0]) == [1 0;] - @test J(x -> max(x[1], x[2]), [1.0, 1.0]) == [1 1;] + @testset "Trivial examples" begin - # Minimum - @test J(x -> min(x[1], x[2]), [1.0, 2.0]) == [1 0;] - @test J(x -> min(x[1], x[2]), [2.0, 1.0]) == [0 1;] - @test J(x -> min(x[1], x[2]), [1.0, 1.0]) == [1 1;] + # Multiplication + @test J(x -> x[1] * x[2], [1.0, 1.0]) == [1 1;] + @test J(x -> x[1] * x[2], [1.0, 0.0]) == [0 1;] + @test J(x -> x[1] * x[2], [0.0, 1.0]) == [1 0;] + @test J(x -> x[1] * x[2], [0.0, 0.0]) == [0 0;] + + # Division + @test J(x -> x[1] / x[2], [1.0, 1.0]) == [1 1;] + @test J(x -> x[1] / x[2], [0.0, 0.0]) == [1 0;] + + # Maximum + @test J(x -> max(x[1], x[2]), [1.0, 2.0]) == [0 1;] + @test J(x -> max(x[1], x[2]), [2.0, 1.0]) == [1 0;] + @test J(x -> max(x[1], x[2]), [1.0, 1.0]) == [1 1;] + + # Minimum + @test J(x -> min(x[1], x[2]), [1.0, 2.0]) == [1 0;] + @test J(x -> min(x[1], x[2]), [2.0, 1.0]) == [0 1;] + @test J(x -> min(x[1], x[2]), [1.0, 1.0]) == [1 1;] + end # Comparisons - @test J(x -> x[1] > x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0]) == [0 0 0 1;] - @test J(x -> x[1] > x[2] ? x[3] : x[4], [2.0, 1.0, 3.0, 4.0]) == [0 0 1 0;] - @test J(x -> x[1] < x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0]) == [0 0 1 0;] - @test J(x -> x[1] < x[2] ? x[3] : x[4], [2.0, 1.0, 3.0, 4.0]) == [0 0 0 1;] - - @test J(x -> x[1] >= x[2] ? x[1] : x[2], [1.0, 2.0]) == [0 1;] - @test J(x -> x[1] >= x[2] ? x[1] : x[2], [2.0, 1.0]) == [1 0;] - @test J(x -> x[1] >= x[2] ? x[1] : x[2], [1.0, 1.0]) == [1 0;] - - @test J(x -> x[1] >= x[2] ? x[1] : x[2], [1.0, 2.0]) == [0 1;] - @test J(x -> x[1] >= x[2] ? x[1] : x[2], [2.0, 1.0]) == [1 0;] - @test J(x -> x[1] >= x[2] ? x[1] : x[2], [1.0, 1.0]) == [1 0;] - - @test J(x -> x[1] <= x[2] ? x[1] : x[2], [1.0, 2.0]) == [1 0;] - @test J(x -> x[1] <= x[2] ? x[1] : x[2], [2.0, 1.0]) == [0 1;] - @test J(x -> x[1] <= x[2] ? x[1] : x[2], [1.0, 1.0]) == [1 0;] - - @test J(x -> x[1] == x[2] ? x[1] : x[2], [1.0, 2.0]) == [0 1;] - @test J(x -> x[1] == x[2] ? x[1] : x[2], [2.0, 1.0]) == [0 1;] - @test J(x -> x[1] == x[2] ? x[1] : x[2], [1.0, 1.0]) == [1 0;] - - @testset "Comparison with $T" for T in REAL_TYPES - _1 = oneunit(T) - @test J(x -> x[1] > _1 ? x[1] : x[2], [0.0, 2.0]) == [0 1;] - @test J(x -> x[1] > _1 ? x[1] : x[2], [2.0, 0.0]) == [1 0;] - @test J(x -> x[1] >= _1 ? x[1] : x[2], [0.0, 2.0]) == [0 1;] - @test J(x -> x[1] >= _1 ? x[1] : x[2], [2.0, 0.0]) == [1 0;] - @test J(x -> x[1] < _1 ? x[1] : x[2], [0.0, 2.0]) == [1 0;] - @test J(x -> x[1] < _1 ? x[1] : x[2], [2.0, 0.0]) == [0 1;] - @test J(x -> isless(x[1], _1) ? x[1] : x[2], [0.0, 2.0]) == [1 0;] - @test J(x -> isless(x[1], _1) ? x[1] : x[2], [2.0, 0.0]) == [0 1;] - @test J(x -> x[1] <= _1 ? x[1] : x[2], [0.0, 2.0]) == [1 0;] - @test J(x -> x[1] <= _1 ? x[1] : x[2], [2.0, 0.0]) == [0 1;] - @test J(x -> _1 > x[2] ? x[1] : x[2], [0.0, 2.0]) == [0 1;] - @test J(x -> _1 > x[2] ? x[1] : x[2], [2.0, 0.0]) == [1 0;] - @test J(x -> _1 >= x[2] ? x[1] : x[2], [0.0, 2.0]) == [0 1;] - @test J(x -> _1 >= x[2] ? x[1] : x[2], [2.0, 0.0]) == [1 0;] - @test J(x -> _1 < x[2] ? x[1] : x[2], [0.0, 2.0]) == [1 0;] - @test J(x -> _1 < x[2] ? x[1] : x[2], [2.0, 0.0]) == [0 1;] - @test J(x -> _1 <= x[2] ? x[1] : x[2], [0.0, 2.0]) == [1 0;] - @test J(x -> _1 <= x[2] ? x[1] : x[2], [2.0, 0.0]) == [0 1;] + @testset "Comparisons" begin + @test J(x -> x[1] > x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0]) == [0 0 0 1;] + @test J(x -> x[1] > x[2] ? x[3] : x[4], [2.0, 1.0, 3.0, 4.0]) == [0 0 1 0;] + @test J(x -> x[1] < x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0]) == [0 0 1 0;] + @test J(x -> x[1] < x[2] ? x[3] : x[4], [2.0, 1.0, 3.0, 4.0]) == [0 0 0 1;] + + @test J(x -> x[1] >= x[2] ? x[1] : x[2], [1.0, 2.0]) == [0 1;] + @test J(x -> x[1] >= x[2] ? x[1] : x[2], [2.0, 1.0]) == [1 0;] + @test J(x -> x[1] >= x[2] ? x[1] : x[2], [1.0, 1.0]) == [1 0;] + + @test J(x -> x[1] >= x[2] ? x[1] : x[2], [1.0, 2.0]) == [0 1;] + @test J(x -> x[1] >= x[2] ? x[1] : x[2], [2.0, 1.0]) == [1 0;] + @test J(x -> x[1] >= x[2] ? x[1] : x[2], [1.0, 1.0]) == [1 0;] + + @test J(x -> x[1] <= x[2] ? x[1] : x[2], [1.0, 2.0]) == [1 0;] + @test J(x -> x[1] <= x[2] ? x[1] : x[2], [2.0, 1.0]) == [0 1;] + @test J(x -> x[1] <= x[2] ? x[1] : x[2], [1.0, 1.0]) == [1 0;] + + @test J(x -> x[1] == x[2] ? x[1] : x[2], [1.0, 2.0]) == [0 1;] + @test J(x -> x[1] == x[2] ? x[1] : x[2], [2.0, 1.0]) == [0 1;] + @test J(x -> x[1] == x[2] ? x[1] : x[2], [1.0, 1.0]) == [1 0;] + + @testset "Comparison with $T" for T in REAL_TYPES + _1 = oneunit(T) + @test J(x -> x[1] > _1 ? x[1] : x[2], [0.0, 2.0]) == [0 1;] + @test J(x -> x[1] > _1 ? x[1] : x[2], [2.0, 0.0]) == [1 0;] + @test J(x -> x[1] >= _1 ? x[1] : x[2], [0.0, 2.0]) == [0 1;] + @test J(x -> x[1] >= _1 ? x[1] : x[2], [2.0, 0.0]) == [1 0;] + @test J(x -> x[1] < _1 ? x[1] : x[2], [0.0, 2.0]) == [1 0;] + @test J(x -> x[1] < _1 ? x[1] : x[2], [2.0, 0.0]) == [0 1;] + @test J(x -> isless(x[1], _1) ? x[1] : x[2], [0.0, 2.0]) == [1 0;] + @test J(x -> isless(x[1], _1) ? x[1] : x[2], [2.0, 0.0]) == [0 1;] + @test J(x -> x[1] <= _1 ? x[1] : x[2], [0.0, 2.0]) == [1 0;] + @test J(x -> x[1] <= _1 ? x[1] : x[2], [2.0, 0.0]) == [0 1;] + @test J(x -> _1 > x[2] ? x[1] : x[2], [0.0, 2.0]) == [0 1;] + @test J(x -> _1 > x[2] ? x[1] : x[2], [2.0, 0.0]) == [1 0;] + @test J(x -> _1 >= x[2] ? x[1] : x[2], [0.0, 2.0]) == [0 1;] + @test J(x -> _1 >= x[2] ? x[1] : x[2], [2.0, 0.0]) == [1 0;] + @test J(x -> _1 < x[2] ? x[1] : x[2], [0.0, 2.0]) == [1 0;] + @test J(x -> _1 < x[2] ? x[1] : x[2], [2.0, 0.0]) == [0 1;] + @test J(x -> _1 <= x[2] ? x[1] : x[2], [0.0, 2.0]) == [1 0;] + @test J(x -> _1 <= x[2] ? x[1] : x[2], [2.0, 0.0]) == [0 1;] + end end - # Code coverage - @test J(x -> [sincos(x)...], 1) ≈ [1; 1] - @test J(typemax, 1) ≈ [0;;] - @test J(x -> x^(2//3), 1) ≈ [1;;] - @test J(x -> (2//3)^x, 1) ≈ [1;;] - @test J(x -> x^ℯ, 1) ≈ [1;;] - @test J(x -> ℯ^x, 1) ≈ [1;;] - @test J(x -> 0, 1) ≈ [0;;] + @testset "Miscellaneous" begin + @test J(x -> [sincos(x)...], 1) ≈ [1; 1] + @test J(typemax, 1) ≈ [0;;] + @test J(x -> x^(2//3), 1) ≈ [1;;] + @test J(x -> (2//3)^x, 1) ≈ [1;;] + @test J(x -> x^ℯ, 1) ≈ [1;;] + @test J(x -> ℯ^x, 1) ≈ [1;;] + @test J(x -> 0, 1) ≈ [0;;] + end # Conversions - @testset "Conversion to $T" for T in REAL_TYPES - @test J(x -> convert(T, x), 1.0) ≈ [1;;] + @testset "Conversion" begin + @testset "Conversion to $T" for T in REAL_TYPES + @test J(x -> convert(T, x), 1.0) ≈ [1;;] + end + end + @testset "Round" begin + @test J(round, 1.1) ≈ [0;;] + @test J(x -> round(Int, x), 1.1) ≈ [0;;] + @test J(x -> round(Bool, x), 1.1) ≈ [0;;] + @test J(x -> round(x, RoundNearestTiesAway), 1.1) ≈ [0;;] + @test J(x -> round(x; digits=3, base=2), 1.1) ≈ [0;;] end - # Round - @test J(round, 1.1) ≈ [0;;] - @test J(x -> round(Int, x), 1.1) ≈ [0;;] - @test J(x -> round(Bool, x), 1.1) ≈ [0;;] - @test J(x -> round(x, RoundNearestTiesAway), 1.1) ≈ [0;;] - @test J(x -> round(x; digits=3, base=2), 1.1) ≈ [0;;] - - # Random - @test J(x -> rand(typeof(x)), 1) ≈ [0;;] - @test J(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;] - - # Linear algebra - @test J(logdet, [1.0 -1.0; 2.0 2.0]) == [1 1 1 1] # (#68) - @test J(x -> log(det(x)), [1.0 -1.0; 2.0 2.0]) == [1 1 1 1] - @test J(x -> dot(x[1:2], x[4:5]), [0, 1, 0, 1, 0]) == [1 0 0 0 1] - - # NNlib extension - @test J(NNlib.relu, -1) ≈ [0;;] - @test J(NNlib.relu, 1) ≈ [1;;] - @test J(NNlib.elu, -1) ≈ [1;;] - @test J(NNlib.elu, 1) ≈ [1;;] - @test J(NNlib.celu, -1) ≈ [1;;] - @test J(NNlib.celu, 1) ≈ [1;;] - @test J(NNlib.selu, -1) ≈ [1;;] - @test J(NNlib.selu, 1) ≈ [1;;] - - @test J(NNlib.relu6, -1) ≈ [0;;] - @test J(NNlib.relu6, 1) ≈ [1;;] - @test J(NNlib.relu6, 7) ≈ [0;;] - - @test J(NNlib.trelu, 0.9) ≈ [0;;] - @test J(NNlib.trelu, 1.1) ≈ [1;;] - - @test J(NNlib.swish, -5) ≈ [1;;] - @test J(NNlib.swish, 0) ≈ [1;;] - @test J(NNlib.swish, 5) ≈ [1;;] - - @test J(NNlib.hardswish, -5) ≈ [0;;] - @test J(NNlib.hardswish, 0) ≈ [1;;] - @test J(NNlib.hardswish, 5) ≈ [1;;] - - @test J(NNlib.hardσ, -4) ≈ [0;;] - @test J(NNlib.hardσ, 0) ≈ [1;;] - @test J(NNlib.hardσ, 4) ≈ [0;;] - - @test J(NNlib.hardtanh, -2) ≈ [0;;] - @test J(NNlib.hardtanh, 0) ≈ [1;;] - @test J(NNlib.hardtanh, 2) ≈ [0;;] - - @test J(NNlib.softshrink, -1) ≈ [1;;] - @test J(NNlib.softshrink, 0) ≈ [0;;] - @test J(NNlib.softshrink, 1) ≈ [1;;] + @testset "Random" begin + @test J(x -> rand(typeof(x)), 1) ≈ [0;;] + @test J(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;] + end + + @testset "LinearAlgebra." begin + @test J(logdet, [1.0 -1.0; 2.0 2.0]) == [1 1 1 1] # (#68) + @test J(x -> log(det(x)), [1.0 -1.0; 2.0 2.0]) == [1 1 1 1] + @test J(x -> dot(x[1:2], x[4:5]), [0, 1, 0, 1, 0]) == [1 0 0 0 1] + end + + @testset "NNlib" begin + @test J(NNlib.relu, -1) ≈ [0;;] + @test J(NNlib.relu, 1) ≈ [1;;] + @test J(NNlib.elu, -1) ≈ [1;;] + @test J(NNlib.elu, 1) ≈ [1;;] + @test J(NNlib.celu, -1) ≈ [1;;] + @test J(NNlib.celu, 1) ≈ [1;;] + @test J(NNlib.selu, -1) ≈ [1;;] + @test J(NNlib.selu, 1) ≈ [1;;] + + @test J(NNlib.relu6, -1) ≈ [0;;] + @test J(NNlib.relu6, 1) ≈ [1;;] + @test J(NNlib.relu6, 7) ≈ [0;;] + + @test J(NNlib.trelu, 0.9) ≈ [0;;] + @test J(NNlib.trelu, 1.1) ≈ [1;;] + + @test J(NNlib.swish, -5) ≈ [1;;] + @test J(NNlib.swish, 0) ≈ [1;;] + @test J(NNlib.swish, 5) ≈ [1;;] + + @test J(NNlib.hardswish, -5) ≈ [0;;] + @test J(NNlib.hardswish, 0) ≈ [1;;] + @test J(NNlib.hardswish, 5) ≈ [1;;] + + @test J(NNlib.hardσ, -4) ≈ [0;;] + @test J(NNlib.hardσ, 0) ≈ [1;;] + @test J(NNlib.hardσ, 4) ≈ [0;;] + + @test J(NNlib.hardtanh, -2) ≈ [0;;] + @test J(NNlib.hardtanh, 0) ≈ [1;;] + @test J(NNlib.hardtanh, 2) ≈ [0;;] + + @test J(NNlib.softshrink, -1) ≈ [1;;] + @test J(NNlib.softshrink, 0) ≈ [0;;] + @test J(NNlib.softshrink, 1) ≈ [1;;] + end yield() end end diff --git a/test/test_hessian.jl b/test/test_hessian.jl index ff677e9..38a5fbc 100644 --- a/test/test_hessian.jl +++ b/test/test_hessian.jl @@ -43,6 +43,15 @@ D = Dual{Int,T} @test H(x -> x^ℯ, 1) ≈ [1;;] @test H(x -> ℯ^x, 1) ≈ [1;;] @test H(x -> 0, 1) ≈ [0;;] + + @test H(x -> sum(sincosd(x)), 1.0) ≈ [1;;] + + @test H(x -> sum(diff(x) .^ 3), rand(4)) == [ + 1 1 0 0 + 1 1 1 0 + 0 1 1 1 + 0 0 1 1 + ] end # Conversions @@ -174,35 +183,28 @@ D = Dual{Int,T} end end - @test H(x -> sum(sincosd(x)), 1.0) ≈ [1;;] - - @test H(x -> sum(diff(x) .^ 3), rand(4)) == [ - 1 1 0 0 - 1 1 1 0 - 0 1 1 1 - 0 0 1 1 - ] + @testset "shared Hessian" begin + function dead_end(x) + z = x[1] * x[2] + return x[3] * x[4] + end + h = H(dead_end, rand(4)) - # Shared Hessian - function dead_end(x) - z = x[1] * x[2] - return x[3] * x[4] - end - h = H(dead_end, rand(4)) - if Bool(shared(T)) - @test h == [ - 0 1 0 0 - 1 0 0 0 - 0 0 0 1 - 0 0 1 0 - ] - else - @test h == [ - 0 0 0 0 - 0 0 0 0 - 0 0 0 1 - 0 0 1 0 - ] + if Bool(shared(T)) + @test h == [ + 0 1 0 0 + 1 0 0 0 + 0 0 0 1 + 0 0 1 0 + ] + else + @test h == [ + 0 0 0 0 + 0 0 0 0 + 0 0 0 1 + 0 0 1 0 + ] + end end # Missing primal errors @@ -224,7 +226,6 @@ D = Dual{Int,T} end end - # ifelse and comparisons @testset "ifelse and comparisons" begin if VERSION >= v"1.8" @test H(x -> ifelse(x[1], x[1]^x[2], x[3] * x[4]), rand(4)) == [ @@ -260,8 +261,9 @@ D = Dual{Int,T} # Error handling when applying non-dual tracers to "local" functions with control flow # TypeError: non-boolean (SparseConnectivityTracer.GradientTracer{BitSet}) used in boolean context @test_throws TypeError H(x -> x[1] > x[2] ? x[1]^x[2] : x[3] * x[4], rand(4)) + end - # SpecialFunctions + @testset "SpecialFunctions.jl" begin @test H(x -> erf(x[1]), rand(2)) == [ 1 0 0 0 @@ -282,149 +284,158 @@ end method = TracerLocalSparsityDetector(; hessian_tracer_type=T) H(f, x) = hessian_sparsity(f, x, method) - f1(x) = x[1] + x[2] * x[3] + 1 / x[4] + x[2] * max(x[1], x[5]) - @test H(f1, [1.0 3.0 5.0 1.0 2.0]) == [ - 0 0 0 0 0 - 0 0 1 0 1 - 0 1 0 0 0 - 0 0 0 1 0 - 0 1 0 0 0 - ] - - @test H(f1, [4.0 3.0 5.0 1.0 2.0]) == [ - 0 1 0 0 0 - 1 0 1 0 0 - 0 1 0 0 0 - 0 0 0 1 0 - 0 0 0 0 0 - ] - - f2(x) = ifelse(x[2] < x[3], x[1] * x[2], x[3] * x[4]) - h = H(f2, [1 2 3 4]) - if Bool(shared(T)) - @test h == [ - 0 1 0 0 - 1 0 0 0 - 0 0 0 1 - 0 0 1 0 + @testset "Trivial examples" begin + f1(x) = x[1] + x[2] * x[3] + 1 / x[4] + x[2] * max(x[1], x[5]) + @test H(f1, [1.0 3.0 5.0 1.0 2.0]) == [ + 0 0 0 0 0 + 0 0 1 0 1 + 0 1 0 0 0 + 0 0 0 1 0 + 0 1 0 0 0 ] - else - @test h == [ - 0 1 0 0 - 1 0 0 0 - 0 0 0 0 - 0 0 0 0 + + @test H(f1, [4.0 3.0 5.0 1.0 2.0]) == [ + 0 1 0 0 0 + 1 0 1 0 0 + 0 1 0 0 0 + 0 0 0 1 0 + 0 0 0 0 0 ] + + f2(x) = ifelse(x[2] < x[3], x[1] * x[2], x[3] * x[4]) + h = H(f2, [1 2 3 4]) + if Bool(shared(T)) + @test h == [ + 0 1 0 0 + 1 0 0 0 + 0 0 0 1 + 0 0 1 0 + ] + else + @test h == [ + 0 1 0 0 + 1 0 0 0 + 0 0 0 0 + 0 0 0 0 + ] + end + + h = H(f2, [1 3 2 4]) + if Bool(shared(T)) + @test h == [ + 0 1 0 0 + 1 0 0 0 + 0 0 0 1 + 0 0 1 0 + ] + else + @test h == [ + 0 0 0 0 + 0 0 0 0 + 0 0 0 1 + 0 0 1 0 + ] + end end - h = H(f2, [1 3 2 4]) - if Bool(shared(T)) - @test h == [ - 0 1 0 0 - 1 0 0 0 - 0 0 0 1 - 0 0 1 0 - ] - else - @test h == [ - 0 0 0 0 - 0 0 0 0 - 0 0 0 1 - 0 0 1 0 - ] + @testset "Shared Hessian" begin + function dead_end(x) + z = x[1] * x[2] + return x[3] * x[4] + end + h = H(dead_end, rand(4)) + if Bool(shared(T)) + @test h == [ + 0 1 0 0 + 1 0 0 0 + 0 0 0 1 + 0 0 1 0 + ] + else + @test h == [ + 0 0 0 0 + 0 0 0 0 + 0 0 0 1 + 0 0 1 0 + ] + end end - # Shared Hessian - function dead_end(x) - z = x[1] * x[2] - return x[3] * x[4] + @testset "Miscellaneous" begin + @test H(sign, 1) ≈ [0;;] + @test H(typemax, 1) ≈ [0;;] + @test H(x -> x^(2//3), 1) ≈ [1;;] + @test H(x -> (2//3)^x, 1) ≈ [1;;] + @test H(x -> x^ℯ, 1) ≈ [1;;] + @test H(x -> ℯ^x, 1) ≈ [1;;] + @test H(x -> 0, 1) ≈ [0;;] + + # Test special cases on empty tracer + + @test H(x -> zero(x)^(2//3), 1) ≈ [0;;] + @test H(x -> (2//3)^zero(x), 1) ≈ [0;;] + @test H(x -> zero(x)^ℯ, 1) ≈ [0;;] + @test H(x -> ℯ^zero(x), 1) ≈ [0;;] end - h = H(dead_end, rand(4)) - if Bool(shared(T)) - @test h == [ - 0 1 0 0 - 1 0 0 0 - 0 0 0 1 - 0 0 1 0 - ] - else - @test h == [ - 0 0 0 0 - 0 0 0 0 - 0 0 0 1 - 0 0 1 0 - ] + + @testset "Conversion" begin + @testset "to $T" for T in REAL_TYPES + @test H(x -> convert(T, x), 1.0) ≈ [0;;] + @test H(x -> convert(T, x^2), 1.0) ≈ [1;;] + @test H(x -> convert(T, x)^2, 1.0) ≈ [1;;] + end end - # Code coverage - @test H(sign, 1) ≈ [0;;] - @test H(typemax, 1) ≈ [0;;] - @test H(x -> x^(2//3), 1) ≈ [1;;] - @test H(x -> (2//3)^x, 1) ≈ [1;;] - @test H(x -> x^ℯ, 1) ≈ [1;;] - @test H(x -> ℯ^x, 1) ≈ [1;;] - @test H(x -> 0, 1) ≈ [0;;] + @testset "Round" begin + @test H(round, 1.1) ≈ [0;;] + @test H(x -> round(Int, x), 1.1) ≈ [0;;] + @test H(x -> round(Bool, x), 1.1) ≈ [0;;] + @test H(x -> round(x, RoundNearestTiesAway), 1.1) ≈ [0;;] + @test H(x -> round(x; digits=3, base=2), 1.1) ≈ [0;;] + end - # Conversions - @testset "Conversion to $T" for T in REAL_TYPES - @test H(x -> convert(T, x), 1.0) ≈ [0;;] - @test H(x -> convert(T, x^2), 1.0) ≈ [1;;] - @test H(x -> convert(T, x)^2, 1.0) ≈ [1;;] + @testset "Random" begin + @test H(x -> rand(typeof(x)), 1) ≈ [0;;] + @test H(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;] end - # Round - @test H(round, 1.1) ≈ [0;;] - @test H(x -> round(Int, x), 1.1) ≈ [0;;] - @test H(x -> round(Bool, x), 1.1) ≈ [0;;] - @test H(x -> round(x, RoundNearestTiesAway), 1.1) ≈ [0;;] - @test H(x -> round(x; digits=3, base=2), 1.1) ≈ [0;;] - - # Random - @test H(x -> rand(typeof(x)), 1) ≈ [0;;] - @test H(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;] - - # Test special cases on empty tracer - @test H(x -> zero(x)^(2//3), 1) ≈ [0;;] - @test H(x -> (2//3)^zero(x), 1) ≈ [0;;] - @test H(x -> zero(x)^ℯ, 1) ≈ [0;;] - @test H(x -> ℯ^zero(x), 1) ≈ [0;;] - - # NNlib extension - @test H(NNlib.relu, -1) ≈ [0;;] - @test H(NNlib.relu, 1) ≈ [0;;] - @test H(NNlib.elu, -1) ≈ [1;;] - @test H(NNlib.elu, 1) ≈ [0;;] - @test H(NNlib.celu, -1) ≈ [1;;] - @test H(NNlib.celu, 1) ≈ [0;;] - @test H(NNlib.selu, -1) ≈ [1;;] - @test H(NNlib.selu, 1) ≈ [0;;] - - @test H(NNlib.relu6, -1) ≈ [0;;] - @test H(NNlib.relu6, 1) ≈ [0;;] - @test H(NNlib.relu6, 7) ≈ [0;;] - - @test H(NNlib.trelu, 0.9) ≈ [0;;] - @test H(NNlib.trelu, 1.1) ≈ [0;;] - - @test H(NNlib.swish, -5) ≈ [1;;] - @test H(NNlib.swish, 0) ≈ [1;;] - @test H(NNlib.swish, 5) ≈ [1;;] - - @test H(NNlib.hardswish, -5) ≈ [0;;] - @test H(NNlib.hardswish, 0) ≈ [1;;] - @test H(NNlib.hardswish, 5) ≈ [0;;] - - @test H(NNlib.hardσ, -4) ≈ [0;;] - @test H(NNlib.hardσ, 0) ≈ [0;;] - @test H(NNlib.hardσ, 4) ≈ [0;;] - - @test H(NNlib.hardtanh, -2) ≈ [0;;] - @test H(NNlib.hardtanh, 0) ≈ [0;;] - @test H(NNlib.hardtanh, 2) ≈ [0;;] - - @test H(NNlib.softshrink, -1) ≈ [0;;] - @test H(NNlib.softshrink, 0) ≈ [0;;] - @test H(NNlib.softshrink, 1) ≈ [0;;] + @testset "NNlib" begin + @test H(NNlib.relu, -1) ≈ [0;;] + @test H(NNlib.relu, 1) ≈ [0;;] + @test H(NNlib.elu, -1) ≈ [1;;] + @test H(NNlib.elu, 1) ≈ [0;;] + @test H(NNlib.celu, -1) ≈ [1;;] + @test H(NNlib.celu, 1) ≈ [0;;] + @test H(NNlib.selu, -1) ≈ [1;;] + @test H(NNlib.selu, 1) ≈ [0;;] + + @test H(NNlib.relu6, -1) ≈ [0;;] + @test H(NNlib.relu6, 1) ≈ [0;;] + @test H(NNlib.relu6, 7) ≈ [0;;] + + @test H(NNlib.trelu, 0.9) ≈ [0;;] + @test H(NNlib.trelu, 1.1) ≈ [0;;] + + @test H(NNlib.swish, -5) ≈ [1;;] + @test H(NNlib.swish, 0) ≈ [1;;] + @test H(NNlib.swish, 5) ≈ [1;;] + + @test H(NNlib.hardswish, -5) ≈ [0;;] + @test H(NNlib.hardswish, 0) ≈ [1;;] + @test H(NNlib.hardswish, 5) ≈ [0;;] + + @test H(NNlib.hardσ, -4) ≈ [0;;] + @test H(NNlib.hardσ, 0) ≈ [0;;] + @test H(NNlib.hardσ, 4) ≈ [0;;] + + @test H(NNlib.hardtanh, -2) ≈ [0;;] + @test H(NNlib.hardtanh, 0) ≈ [0;;] + @test H(NNlib.hardtanh, 2) ≈ [0;;] + + @test H(NNlib.softshrink, -1) ≈ [0;;] + @test H(NNlib.softshrink, 0) ≈ [0;;] + @test H(NNlib.softshrink, 1) ≈ [0;;] + end yield() end end From 01e7b3a687f3ff80d0294606d2fbf78af97b3884 Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 16 Aug 2024 15:53:05 +0200 Subject: [PATCH 11/20] Rewrite overloads --- src/SparseConnectivityTracer.jl | 3 +- src/overloads/gradient_tracer.jl | 205 +++++++++++++++----------- src/overloads/hessian_tracer.jl | 239 ++++++++++++++++--------------- src/overloads/utils.jl | 10 ++ 4 files changed, 251 insertions(+), 206 deletions(-) create mode 100644 src/overloads/utils.jl diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 2de22ca..092f12f 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -25,13 +25,14 @@ include("tracers.jl") include("exceptions.jl") include("operators.jl") +include("overloads/utils.jl") include("overloads/conversion.jl") include("overloads/gradient_tracer.jl") include("overloads/hessian_tracer.jl") include("overloads/ifelse_global.jl") include("overloads/dual.jl") -include("overloads/overload_all.jl") include("overloads/arrays.jl") +include("overloads/overload_all.jl") include("interface.jl") include("adtypes.jl") diff --git a/src/overloads/gradient_tracer.jl b/src/overloads/gradient_tracer.jl index 3a98cb9..fb9f2b8 100644 --- a/src/overloads/gradient_tracer.jl +++ b/src/overloads/gradient_tracer.jl @@ -28,28 +28,57 @@ function gradient_tracer_1_to_1_inner( end end -function overload_gradient_1_to_1(op) - M = parentmodule(op) +function overload_gradient_1_to_1(f) + is_der1_zero_g = is_der1_zero_global(f) + SCT = SparseConnectivityTracer - @eval begin - ## GradientTracer - function $M.$op(t::$SCT.GradientTracer) - is_der1_zero = $SCT.is_der1_zero_global($M.$op) - return $SCT.gradient_tracer_1_to_1(t, is_der1_zero) - end + M = nameofrootmodule(f) + fname = nameof(f) + + ## GradientTracer + @eval function $M.$fname(t::$SCT.GradientTracer) + return $SCT.gradient_tracer_1_to_1(t, $is_der1_zero_g) + end + + ## Dual + @eval function $M.$fname(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + p_out = $M.$fname(x) + + t = $SCT.tracer(d) + is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x) + t_out = $SCT.gradient_tracer_1_to_1(t, is_der1_zero) + return $SCT.Dual(p_out, t_out) + end +end + +## Drafty McDraftface +function foo(f) + M = nameofrootmodule(f) + fname = nameof(f) + is_der1_zero = true - ## Dual - function $M.$op(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(d) - p_out = $M.$op(x) + SCT = LinearAlgebra - t = $SCT.tracer(d) - is_der1_zero = $SCT.is_der1_zero_local($op, x) - t_out = $SCT.gradient_tracer_1_to_1(t, is_der1_zero) - return $SCT.Dual(p_out, t_out) + ## GradientTracer + return quote + function $M.$fname(t::$SCT.Diagonal) + return $SCT.gradient_tracer_1_to_1(t, $is_der1_zero) end end end +function bar(f) + M = nameofrootmodule(f) + fname = nameof(f) + is_der1_zero = true + + SCT = LinearAlgebra + + ## GradientTracer + @eval function $M.$fname(t::$SCT.Diagonal) + return $SCT.gradient_tracer_1_to_1(t, $is_der1_zero) + end +end ## 2-to-1 @@ -97,62 +126,65 @@ function gradient_tracer_2_to_1_inner( end end -function overload_gradient_2_to_1(op) - M = parentmodule(op) +function overload_gradient_2_to_1(f) + is_der1_arg1_zero = is_der1_arg1_zero_global(f) + is_der1_arg2_zero = is_der1_arg2_zero_global(f) + SCT = SparseConnectivityTracer - @eval begin - ## GradientTracer - function $M.$op(tx::T, ty::T) where {T<:$SCT.GradientTracer} - is_der1_arg1_zero = $SCT.is_der1_arg1_zero_global($M.$op) - is_der1_arg2_zero = $SCT.is_der1_arg2_zero_global($M.$op) - return $SCT.gradient_tracer_2_to_1(tx, ty, is_der1_arg1_zero, is_der1_arg2_zero) - end + M = nameofrootmodule(f) + fname = nameof(f) - function $M.$op(tx::$SCT.GradientTracer, ::Real) - is_der1_arg1_zero = $SCT.is_der1_arg1_zero_global($M.$op) - return $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero) - end + ## GradientTracer + @eval function $M.$fname(tx::T, ty::T) where {T<:$SCT.GradientTracer} + return $SCT.gradient_tracer_2_to_1(tx, ty, $is_der1_arg1_zero, $is_der1_arg2_zero) + end - function $M.$op(::Real, ty::$SCT.GradientTracer) - is_der1_arg2_zero = $SCT.is_der1_arg2_zero_global($M.$op) - return $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero) - end + @eval function $M.$fname(tx::$SCT.GradientTracer, ::Real) + return $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero) + end - ## Dual - function $M.$op(dx::D, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(dx) - y = $SCT.primal(dy) - p_out = $M.$op(x, y) - - tx = $SCT.tracer(dx) - ty = $SCT.tracer(dy) - is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$op, x, y) - is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$op, x, y) - t_out = $SCT.gradient_tracer_2_to_1( - tx, ty, is_der1_arg1_zero, is_der1_arg2_zero - ) - return $SCT.Dual(p_out, t_out) - end + @eval function $M.$fname(::Real, ty::$SCT.GradientTracer) + return $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero) + end - function $M.$op(dx::D, y::Real) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(dx) - p_out = $M.$op(x, y) + ## Dual + @eval function $M.$fname( + dx::D, dy::D + ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + y = $SCT.primal(dy) + p_out = $M.$fname(x, y) - tx = $SCT.tracer(dx) - is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$op, x, y) - t_out = $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero) - return $SCT.Dual(p_out, t_out) - end + tx = $SCT.tracer(dx) + ty = $SCT.tracer(dy) + is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) + is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) + t_out = $SCT.gradient_tracer_2_to_1(tx, ty, is_der1_arg1_zero, is_der1_arg2_zero) + return $SCT.Dual(p_out, t_out) + end - function $M.$op(x::Real, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} - y = $SCT.primal(dy) - p_out = $M.$op(x, y) + @eval function $M.$fname( + dx::D, y::Real + ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + p_out = $M.$fname(x, y) - ty = $SCT.tracer(dy) - is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$op, x, y) - t_out = $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero) - return $SCT.Dual(p_out, t_out) - end + tx = $SCT.tracer(dx) + is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) + t_out = $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero) + return $SCT.Dual(p_out, t_out) + end + + @eval function $M.$fname( + x::Real, dy::D + ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + y = $SCT.primal(dy) + p_out = $M.$fname(x, y) + + ty = $SCT.tracer(dy) + is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) + t_out = $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero) + return $SCT.Dual(p_out, t_out) end end @@ -170,30 +202,31 @@ end end end -function overload_gradient_1_to_2(op) - M = parentmodule(op) +function overload_gradient_1_to_2(f) + is_der1_out1_zero_g = is_der1_out1_zero_global(f) + is_der1_out2_zero_g = is_der1_out2_zero_global(f) + SCT = SparseConnectivityTracer - @eval begin - ## GradientTracer - function $M.$op(t::$SCT.GradientTracer) - is_der1_out1_zero = $SCT.is_der1_out1_zero_global($M.$op) - is_der1_out2_zero = $SCT.is_der1_out2_zero_global($M.$op) - return $SCT.gradient_tracer_1_to_2(t, is_der1_out1_zero, is_der1_out2_zero) - end + M = nameofrootmodule(f) + fname = nameof(f) - ## Dual - function $M.$op(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(d) - p_out1, p_out2 = $M.$op(x) - - t = $SCT.tracer(d) - is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$op, x) - is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$op, x) - t_out1, t_out2 = $SCT.gradient_tracer_1_to_2( - t, is_der1_out1_zero, is_der1_out2_zero - ) - return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) # TODO: this was wrong, add test - end + ## GradientTracer + @eval function $M.$fname(t::$SCT.GradientTracer) + return $SCT.gradient_tracer_1_to_2(t, $is_der1_out1_zero_g, $is_der1_out2_zero_g) + end + + ## Dual + @eval function $M.$fname(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + p_out1, p_out2 = $M.$fname(x) + + t = $SCT.tracer(d) + is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x) + is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$fname, x) + t_out1, t_out2 = $SCT.gradient_tracer_1_to_2( + t, is_der1_out1_zero, is_der1_out2_zero + ) + return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) # TODO: this was wrong, add test end end diff --git a/src/overloads/hessian_tracer.jl b/src/overloads/hessian_tracer.jl index 2ddb7ae..e8770f6 100644 --- a/src/overloads/hessian_tracer.jl +++ b/src/overloads/hessian_tracer.jl @@ -55,25 +55,28 @@ function hessian_tracer_1_to_1_inner( return P(g_out, h_out) # return pattern end -function overload_hessian_1_to_1(op) - M = parentmodule(op) +function overload_hessian_1_to_1(f) + is_der1_zero_g = is_der1_zero_global(f) + is_der2_zero_g = is_der2_zero_global(f) + SCT = SparseConnectivityTracer + M = nameofrootmodule(f) + fname = nameof(f) + @eval begin ## HessianTracer - function $M.$op(t::$SCT.HessianTracer) - is_der1_zero = $SCT.is_der1_zero_global($M.$op) - is_der2_zero = $SCT.is_der2_zero_global($M.$op) - return $SCT.hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero) + function $M.$fname(t::$SCT.HessianTracer) + return $SCT.hessian_tracer_1_to_1(t, $is_der1_zero_g, $is_der2_zero_g) end ## Dual - function $M.$op(d::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + function $M.$fname(d::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} x = $SCT.primal(d) - p_out = $M.$op(x) + p_out = $M.$fname(x) t = $SCT.tracer(d) - is_der1_zero = $SCT.is_der1_zero_local($M.$op, x) - is_der2_zero = $SCT.is_der2_zero_local($M.$op, x) + is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x) + is_der2_zero = $SCT.is_der2_zero_local($M.$fname, x) t_out = $SCT.hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero) return $SCT.Dual(p_out, t_out) end @@ -161,86 +164,87 @@ function hessian_tracer_2_to_1_inner( return P(g_out, h_out) # return pattern end -function overload_hessian_2_to_1(op) - M = parentmodule(op) - SCT = SparseConnectivityTracer - @eval begin - ## HessianTracer - function $M.$op(tx::T, ty::T) where {T<:$SCT.HessianTracer} - is_der1_arg1_zero = $SCT.is_der1_arg1_zero_global($M.$op) - is_der2_arg1_zero = $SCT.is_der2_arg1_zero_global($M.$op) - is_der1_arg2_zero = $SCT.is_der1_arg2_zero_global($M.$op) - is_der2_arg2_zero = $SCT.is_der2_arg2_zero_global($M.$op) - is_der_cross_zero = $SCT.is_der_cross_zero_global($M.$op) - return $SCT.hessian_tracer_2_to_1( - tx, - ty, - is_der1_arg1_zero, - is_der2_arg1_zero, - is_der1_arg2_zero, - is_der2_arg2_zero, - is_der_cross_zero, - ) - end - - function $M.$op(tx::$SCT.HessianTracer, y::Real) - is_der1_arg1_zero = $SCT.is_der1_arg1_zero_global($M.$op) - is_der2_arg1_zero = $SCT.is_der2_arg1_zero_global($M.$op) - return $SCT.hessian_tracer_1_to_1(tx, is_der1_arg1_zero, is_der2_arg1_zero) - end +function overload_hessian_2_to_1(f) + is_der1_arg1_zero_g = is_der1_arg1_zero_global(f) + is_der2_arg1_zero_g = is_der2_arg1_zero_global(f) + is_der1_arg2_zero_g = is_der1_arg2_zero_global(f) + is_der2_arg2_zero_g = is_der2_arg2_zero_global(f) + is_der_cross_zero_g = is_der_cross_zero_global(f) - function $M.$op(x::Real, ty::$SCT.HessianTracer) - is_der1_arg2_zero = $SCT.is_der1_arg2_zero_global($M.$op) - is_der2_arg2_zero = $SCT.is_der2_arg2_zero_global($M.$op) - return $SCT.hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_der2_arg2_zero) - end + SCT = SparseConnectivityTracer + M = nameofrootmodule(f) + fname = nameof(f) + + ## HessianTracer + @eval function $M.$fname(tx::T, ty::T) where {T<:$SCT.HessianTracer} + return $SCT.hessian_tracer_2_to_1( + tx, + ty, + $is_der1_arg1_zero_g, + $is_der2_arg1_zero_g, + $is_der1_arg2_zero_g, + $is_der2_arg2_zero_g, + $is_der_cross_zero_g, + ) + end - ## Dual - function $M.$op(dx::D, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(dx) - y = $SCT.primal(dy) - p_out = $M.$op(x, y) - - tx = $SCT.tracer(dx) - ty = $SCT.tracer(dy) - is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$op, x, y) - is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$op, x, y) - is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$op, x, y) - is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$op, x, y) - is_der_cross_zero = $SCT.is_der_cross_zero_local($M.$op, x, y) - t_out = $SCT.hessian_tracer_2_to_1( - tx, - ty, - is_der1_arg1_zero, - is_der2_arg1_zero, - is_der1_arg2_zero, - is_der2_arg2_zero, - is_der_cross_zero, - ) - return $SCT.Dual(p_out, t_out) - end + @eval function $M.$fname(tx::$SCT.HessianTracer, y::Real) + return $SCT.hessian_tracer_1_to_1(tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g) + end - function $M.$op(dx::D, y::Real) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(dx) - p_out = $M.$op(x, y) + @eval function $M.$fname(x::Real, ty::$SCT.HessianTracer) + return $SCT.hessian_tracer_1_to_1(ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g) + end - tx = $SCT.tracer(dx) - is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$op, x, y) - is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$op, x, y) - t_out = $SCT.hessian_tracer_1_to_1(tx, is_der1_arg1_zero, is_der2_arg1_zero) - return $SCT.Dual(p_out, t_out) - end + ## Dual + @eval function $M.$fname(dx::D, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + y = $SCT.primal(dy) + p_out = $M.$fname(x, y) + + tx = $SCT.tracer(dx) + ty = $SCT.tracer(dy) + is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) + is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$fname, x, y) + is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) + is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y) + is_der_cross_zero = $SCT.is_der_cross_zero_local($M.$fname, x, y) + t_out = $SCT.hessian_tracer_2_to_1( + tx, + ty, + is_der1_arg1_zero, + is_der2_arg1_zero, + is_der1_arg2_zero, + is_der2_arg2_zero, + is_der_cross_zero, + ) + return $SCT.Dual(p_out, t_out) + end - function $M.$op(x::Real, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} - y = $SCT.primal(dy) - p_out = $M.$op(x, y) + @eval function $M.$fname( + dx::D, y::Real + ) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + p_out = $M.$fname(x, y) + + tx = $SCT.tracer(dx) + is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) + is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$fname, x, y) + t_out = $SCT.hessian_tracer_1_to_1(tx, is_der1_arg1_zero, is_der2_arg1_zero) + return $SCT.Dual(p_out, t_out) + end - ty = $SCT.tracer(dy) - is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$op, x, y) - is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$op, x, y) - t_out = $SCT.hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_der2_arg2_zero) - return $SCT.Dual(p_out, t_out) - end + @eval function $M.$fname( + x::Real, dy::D + ) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + y = $SCT.primal(dy) + p_out = $M.$fname(x, y) + + ty = $SCT.tracer(dy) + is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) + is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y) + t_out = $SCT.hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_der2_arg2_zero) + return $SCT.Dual(p_out, t_out) end end @@ -262,43 +266,40 @@ end end end -function overload_hessian_1_to_2(op) - M = parentmodule(op) +function overload_hessian_1_to_2(f) + is_der1_out1_zero_g = is_der1_out1_zero_global(f) + is_der2_out1_zero_g = is_der2_out1_zero_global(f) + is_der1_out2_zero_g = is_der1_out2_zero_global(f) + is_der2_out2_zero_g = is_der2_out2_zero_global(f) + SCT = SparseConnectivityTracer - @eval begin - ## HessianTracer - function $M.$op(t::$SCT.HessianTracer) - is_der1_out1_zero = $SCT.is_der1_out1_zero_global($M.$op) - is_der2_out1_zero = $SCT.is_der2_out1_zero_global($M.$op) - is_der1_out2_zero = $SCT.is_der1_out2_zero_global($M.$op) - is_der2_out2_zero = $SCT.is_der2_out2_zero_global($M.$op) - return $SCT.hessian_tracer_1_to_2( - t, - is_der1_out1_zero, - is_der2_out1_zero, - is_der1_out2_zero, - is_der2_out2_zero, - ) - end + M = nameofrootmodule(f) + fname = nameof(f) + + ## HessianTracer + @eval function $M.$fname(t::$SCT.HessianTracer) + return $SCT.hessian_tracer_1_to_2( + t, + $is_der1_out1_zero_g, + $is_der2_out1_zero_g, + $is_der1_out2_zero_g, + $is_der2_out2_zero_g, + ) + end - ## Dual - function $M.$op(d::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(d) - p_out1, p_out2 = $M.$op(x) - - is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$op, x) - is_der2_out1_zero = $SCT.is_der2_out1_zero_local($M.$op, x) - is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$op, x) - is_der2_out2_zero = $SCT.is_der2_out2_zero_local($M.$op, x) - t_out1, t_out2 = $SCT.hessian_tracer_1_to_2( - d, - is_der1_out1_zero, - is_der2_out1_zero, - is_der1_out2_zero, - is_der2_out2_zero, - ) - return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) - end + ## Dual + @eval function $M.$fname(d::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + p_out1, p_out2 = $M.$fname(x) + + is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$fname, x) + is_der2_out1_zero = $SCT.is_der2_out1_zero_local($M.$fname, x) + is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x) + is_der2_out2_zero = $SCT.is_der2_out2_zero_local($M.$fname, x) + t_out1, t_out2 = $SCT.hessian_tracer_1_to_2( + d, is_der1_out1_zero, is_der2_out1_zero, is_der1_out2_zero, is_der2_out2_zero + ) + return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) end end diff --git a/src/overloads/utils.jl b/src/overloads/utils.jl new file mode 100644 index 0000000..734da8c --- /dev/null +++ b/src/overloads/utils.jl @@ -0,0 +1,10 @@ +# symb(f) = Symbol(parentmodule(f), '.', nameof(f)) +function rootmodule(x) + parent = parentmodule(x) + if parent == x + return parent + else + return rootmodule(parent) + end +end +nameofrootmodule(x) = nameof(rootmodule(x)) From 23a74622ecea02949386abe2aca6727d0b3b776c Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 16 Aug 2024 16:10:36 +0200 Subject: [PATCH 12/20] Refactor overloads --- ext/SparseConnectivityTracerNNlibExt.jl | 11 +--- ...seConnectivityTracerSpecialFunctionsExt.jl | 11 ++-- src/operators.jl | 4 -- src/overloads/overload_all.jl | 64 +++++++++---------- 4 files changed, 39 insertions(+), 51 deletions(-) diff --git a/ext/SparseConnectivityTracerNNlibExt.jl b/ext/SparseConnectivityTracerNNlibExt.jl index 5761e47..c978b82 100644 --- a/ext/SparseConnectivityTracerNNlibExt.jl +++ b/ext/SparseConnectivityTracerNNlibExt.jl @@ -82,14 +82,9 @@ SCT.is_der1_zero_local(::typeof(softshrink), x) = x > -0.5 && x < 0.5 ops_1_to_1 = union(ops_1_to_1_s, ops_1_to_1_f) -## Lists +## Overload -SCT.list_operators_1_to_1(::Val{:NNlib}) = ops_1_to_1 -SCT.list_operators_2_to_1(::Val{:NNlib}) = () -SCT.list_operators_1_to_2(::Val{:NNlib}) = () - -## Overloads - -eval(SCT.overload_all(:NNlib)) +SCT.overload_gradient_1_to_1(ops_1_to_1) +SCT.overload_hessian_1_to_1(ops_1_to_1) end diff --git a/ext/SparseConnectivityTracerSpecialFunctionsExt.jl b/ext/SparseConnectivityTracerSpecialFunctionsExt.jl index 22f88c1..2494683 100644 --- a/ext/SparseConnectivityTracerSpecialFunctionsExt.jl +++ b/ext/SparseConnectivityTracerSpecialFunctionsExt.jl @@ -110,14 +110,11 @@ end ops_2_to_1 = ops_2_to_1_ssc -## Lists - -SCT.list_operators_1_to_1(::Val{:SpecialFunctions}) = ops_1_to_1 -SCT.list_operators_2_to_1(::Val{:SpecialFunctions}) = ops_2_to_1 -SCT.list_operators_1_to_2(::Val{:SpecialFunctions}) = () - ## Overloads -eval(SCT.overload_all(:SpecialFunctions)) +SCT.overload_gradient_1_to_1(ops_1_to_1) +SCT.overload_gradient_2_to_1(ops_2_to_1) +SCT.overload_hessian_1_to_1(ops_1_to_1) +SCT.overload_hessian_2_to_1(ops_2_to_1) end diff --git a/src/operators.jl b/src/operators.jl index 562482c..82d813a 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -584,7 +584,3 @@ ops_1_to_2 = union( ops_1_to_2_zz, ) #! format: on - -list_operators_1_to_1(::Val{:Base}) = ops_1_to_1 -list_operators_2_to_1(::Val{:Base}) = ops_2_to_1 -list_operators_1_to_2(::Val{:Base}) = ops_1_to_2 diff --git a/src/overloads/overload_all.jl b/src/overloads/overload_all.jl index 73fd8ab..f72161f 100644 --- a/src/overloads/overload_all.jl +++ b/src/overloads/overload_all.jl @@ -1,34 +1,34 @@ -# for overload in ( -# overload_gradient_1_to_1, -# overload_gradient_2_to_1, -# overload_gradient_1_to_2, -# overload_hessian_1_to_1, -# overload_hessian_2_to_1, -# overload_hessian_1_to_2, -# ) -# @eval function $overload(ops::Tuple) -# for op in ops -# $overload(op) -# end -# end -# end +for overload in ( + :overload_gradient_1_to_1, + :overload_gradient_2_to_1, + :overload_gradient_1_to_2, + :overload_hessian_1_to_1, + :overload_hessian_2_to_1, + :overload_hessian_1_to_2, +) + @eval function $overload(ops::Union{AbstractVector,Tuple}) + for op in ops + $overload(op) + end + end +end -# overload_gradient_1_to_1(ops_1_to_1) -# overload_gradient_2_to_1(ops_2_to_1) -# overload_gradient_1_to_2(ops_1_to_2) -# overload_hessian_1_to_1(ops_1_to_1) -# overload_hessian_2_to_1(ops_2_to_1) -# overload_hessian_1_to_2(ops_1_to_2) +overload_gradient_1_to_1(ops_1_to_1) +overload_gradient_2_to_1(ops_2_to_1) +overload_gradient_1_to_2(ops_1_to_2) +overload_hessian_1_to_1(ops_1_to_1) +overload_hessian_2_to_1(ops_2_to_1) +overload_hessian_1_to_2(ops_1_to_2) -for op in ops_1_to_1 - overload_gradient_1_to_1(op) - overload_hessian_1_to_1(op) -end -for op in ops_2_to_1 - overload_gradient_2_to_1(op) - overload_hessian_2_to_1(op) -end -for op in ops_1_to_2 - overload_gradient_1_to_2(op) - overload_hessian_1_to_2(op) -end +# for op in ops_1_to_1 +# overload_gradient_1_to_1(op) +# overload_hessian_1_to_1(op) +# end +# for op in ops_2_to_1 +# overload_gradient_2_to_1(op) +# overload_hessian_2_to_1(op) +# end +# for op in ops_1_to_2 +# overload_gradient_1_to_2(op) +# overload_hessian_1_to_2(op) +# end From 6ada1fcdb5914a4064e4ec2c236ac1b8426e973d Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 16 Aug 2024 16:33:24 +0200 Subject: [PATCH 13/20] Revert to quote-based approach --- ext/SparseConnectivityTracerNNlibExt.jl | 4 +- ...seConnectivityTracerSpecialFunctionsExt.jl | 8 +- src/SparseConnectivityTracer.jl | 1 - src/overloads/gradient_tracer.jl | 185 ++++++++--------- src/overloads/hessian_tracer.jl | 195 +++++++++--------- src/overloads/overload_all.jl | 32 +-- src/overloads/utils.jl | 10 - 7 files changed, 198 insertions(+), 237 deletions(-) delete mode 100644 src/overloads/utils.jl diff --git a/ext/SparseConnectivityTracerNNlibExt.jl b/ext/SparseConnectivityTracerNNlibExt.jl index c978b82..c7bff08 100644 --- a/ext/SparseConnectivityTracerNNlibExt.jl +++ b/ext/SparseConnectivityTracerNNlibExt.jl @@ -84,7 +84,7 @@ ops_1_to_1 = union(ops_1_to_1_s, ops_1_to_1_f) ## Overload -SCT.overload_gradient_1_to_1(ops_1_to_1) -SCT.overload_hessian_1_to_1(ops_1_to_1) +eval(SCT.overload_gradient_1_to_1(:NNlib, ops_1_to_1)) +eval(SCT.overload_hessian_1_to_1(:NNlib, ops_1_to_1)) end diff --git a/ext/SparseConnectivityTracerSpecialFunctionsExt.jl b/ext/SparseConnectivityTracerSpecialFunctionsExt.jl index 2494683..eaefda7 100644 --- a/ext/SparseConnectivityTracerSpecialFunctionsExt.jl +++ b/ext/SparseConnectivityTracerSpecialFunctionsExt.jl @@ -112,9 +112,9 @@ ops_2_to_1 = ops_2_to_1_ssc ## Overloads -SCT.overload_gradient_1_to_1(ops_1_to_1) -SCT.overload_gradient_2_to_1(ops_2_to_1) -SCT.overload_hessian_1_to_1(ops_1_to_1) -SCT.overload_hessian_2_to_1(ops_2_to_1) +eval(SCT.overload_gradient_1_to_1(:SpecialFunctions, ops_1_to_1)) +eval(SCT.overload_gradient_2_to_1(:SpecialFunctions, ops_2_to_1)) +eval(SCT.overload_hessian_1_to_1(:SpecialFunctions, ops_1_to_1)) +eval(SCT.overload_hessian_2_to_1(:SpecialFunctions, ops_2_to_1)) end diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 092f12f..d5747a6 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -25,7 +25,6 @@ include("tracers.jl") include("exceptions.jl") include("operators.jl") -include("overloads/utils.jl") include("overloads/conversion.jl") include("overloads/gradient_tracer.jl") include("overloads/hessian_tracer.jl") diff --git a/src/overloads/gradient_tracer.jl b/src/overloads/gradient_tracer.jl index fb9f2b8..2223363 100644 --- a/src/overloads/gradient_tracer.jl +++ b/src/overloads/gradient_tracer.jl @@ -28,55 +28,28 @@ function gradient_tracer_1_to_1_inner( end end -function overload_gradient_1_to_1(f) +function overload_gradient_1_to_1(M::Symbol, f) is_der1_zero_g = is_der1_zero_global(f) SCT = SparseConnectivityTracer - M = nameofrootmodule(f) fname = nameof(f) - ## GradientTracer - @eval function $M.$fname(t::$SCT.GradientTracer) - return $SCT.gradient_tracer_1_to_1(t, $is_der1_zero_g) - end - - ## Dual - @eval function $M.$fname(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(d) - p_out = $M.$fname(x) - - t = $SCT.tracer(d) - is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x) - t_out = $SCT.gradient_tracer_1_to_1(t, is_der1_zero) - return $SCT.Dual(p_out, t_out) - end -end - -## Drafty McDraftface -function foo(f) - M = nameofrootmodule(f) - fname = nameof(f) - is_der1_zero = true - - SCT = LinearAlgebra - - ## GradientTracer return quote - function $M.$fname(t::$SCT.Diagonal) - return $SCT.gradient_tracer_1_to_1(t, $is_der1_zero) + ## GradientTracer + function $M.$fname(t::$SCT.GradientTracer) + return $SCT.gradient_tracer_1_to_1(t, $is_der1_zero_g) end - end -end -function bar(f) - M = nameofrootmodule(f) - fname = nameof(f) - is_der1_zero = true - SCT = LinearAlgebra + ## Dual + function $M.$fname(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + p_out = $M.$fname(x) - ## GradientTracer - @eval function $M.$fname(t::$SCT.Diagonal) - return $SCT.gradient_tracer_1_to_1(t, $is_der1_zero) + t = $SCT.tracer(d) + is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x) + t_out = $SCT.gradient_tracer_1_to_1(t, is_der1_zero) + return $SCT.Dual(p_out, t_out) + end end end @@ -126,65 +99,68 @@ function gradient_tracer_2_to_1_inner( end end -function overload_gradient_2_to_1(f) +function overload_gradient_2_to_1(M::Symbol, f) is_der1_arg1_zero = is_der1_arg1_zero_global(f) is_der1_arg2_zero = is_der1_arg2_zero_global(f) SCT = SparseConnectivityTracer - M = nameofrootmodule(f) fname = nameof(f) ## GradientTracer - @eval function $M.$fname(tx::T, ty::T) where {T<:$SCT.GradientTracer} - return $SCT.gradient_tracer_2_to_1(tx, ty, $is_der1_arg1_zero, $is_der1_arg2_zero) - end + return quote + function $M.$fname(tx::T, ty::T) where {T<:$SCT.GradientTracer} + return $SCT.gradient_tracer_2_to_1( + tx, ty, $is_der1_arg1_zero, $is_der1_arg2_zero + ) + end - @eval function $M.$fname(tx::$SCT.GradientTracer, ::Real) - return $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero) - end + function $M.$fname(tx::$SCT.GradientTracer, ::Real) + return $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero) + end - @eval function $M.$fname(::Real, ty::$SCT.GradientTracer) - return $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero) - end + function $M.$fname(::Real, ty::$SCT.GradientTracer) + return $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero) + end - ## Dual - @eval function $M.$fname( - dx::D, dy::D - ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(dx) - y = $SCT.primal(dy) - p_out = $M.$fname(x, y) - - tx = $SCT.tracer(dx) - ty = $SCT.tracer(dy) - is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) - is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) - t_out = $SCT.gradient_tracer_2_to_1(tx, ty, is_der1_arg1_zero, is_der1_arg2_zero) - return $SCT.Dual(p_out, t_out) - end + ## Dual + function $M.$fname(dx::D, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + y = $SCT.primal(dy) + p_out = $M.$fname(x, y) + + tx = $SCT.tracer(dx) + ty = $SCT.tracer(dy) + is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) + is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) + t_out = $SCT.gradient_tracer_2_to_1( + tx, ty, is_der1_arg1_zero, is_der1_arg2_zero + ) + return $SCT.Dual(p_out, t_out) + end - @eval function $M.$fname( - dx::D, y::Real - ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(dx) - p_out = $M.$fname(x, y) + function $M.$fname( + dx::D, y::Real + ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + p_out = $M.$fname(x, y) - tx = $SCT.tracer(dx) - is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) - t_out = $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero) - return $SCT.Dual(p_out, t_out) - end + tx = $SCT.tracer(dx) + is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) + t_out = $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero) + return $SCT.Dual(p_out, t_out) + end - @eval function $M.$fname( - x::Real, dy::D - ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} - y = $SCT.primal(dy) - p_out = $M.$fname(x, y) + function $M.$fname( + x::Real, dy::D + ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + y = $SCT.primal(dy) + p_out = $M.$fname(x, y) - ty = $SCT.tracer(dy) - is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) - t_out = $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero) - return $SCT.Dual(p_out, t_out) + ty = $SCT.tracer(dy) + is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) + t_out = $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero) + return $SCT.Dual(p_out, t_out) + end end end @@ -202,31 +178,34 @@ end end end -function overload_gradient_1_to_2(f) +function overload_gradient_1_to_2(M::Symbol, f) is_der1_out1_zero_g = is_der1_out1_zero_global(f) is_der1_out2_zero_g = is_der1_out2_zero_global(f) SCT = SparseConnectivityTracer - M = nameofrootmodule(f) fname = nameof(f) - ## GradientTracer - @eval function $M.$fname(t::$SCT.GradientTracer) - return $SCT.gradient_tracer_1_to_2(t, $is_der1_out1_zero_g, $is_der1_out2_zero_g) - end - - ## Dual - @eval function $M.$fname(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(d) - p_out1, p_out2 = $M.$fname(x) + return quote + ## GradientTracer + function $M.$fname(t::$SCT.GradientTracer) + return $SCT.gradient_tracer_1_to_2( + t, $is_der1_out1_zero_g, $is_der1_out2_zero_g + ) + end - t = $SCT.tracer(d) - is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x) - is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$fname, x) - t_out1, t_out2 = $SCT.gradient_tracer_1_to_2( - t, is_der1_out1_zero, is_der1_out2_zero - ) - return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) # TODO: this was wrong, add test + ## Dual + function $M.$fname(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + p_out1, p_out2 = $M.$fname(x) + + t = $SCT.tracer(d) + is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x) + is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$fname, x) + t_out1, t_out2 = $SCT.gradient_tracer_1_to_2( + t, is_der1_out1_zero, is_der1_out2_zero + ) + return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) # TODO: this was wrong, add test + end end end diff --git a/src/overloads/hessian_tracer.jl b/src/overloads/hessian_tracer.jl index e8770f6..01237bc 100644 --- a/src/overloads/hessian_tracer.jl +++ b/src/overloads/hessian_tracer.jl @@ -55,15 +55,14 @@ function hessian_tracer_1_to_1_inner( return P(g_out, h_out) # return pattern end -function overload_hessian_1_to_1(f) +function overload_hessian_1_to_1(M::Symbol, f) is_der1_zero_g = is_der1_zero_global(f) is_der2_zero_g = is_der2_zero_global(f) SCT = SparseConnectivityTracer - M = nameofrootmodule(f) fname = nameof(f) - @eval begin + return quote ## HessianTracer function $M.$fname(t::$SCT.HessianTracer) return $SCT.hessian_tracer_1_to_1(t, $is_der1_zero_g, $is_der2_zero_g) @@ -164,7 +163,7 @@ function hessian_tracer_2_to_1_inner( return P(g_out, h_out) # return pattern end -function overload_hessian_2_to_1(f) +function overload_hessian_2_to_1(M::Symbol, f) is_der1_arg1_zero_g = is_der1_arg1_zero_global(f) is_der2_arg1_zero_g = is_der2_arg1_zero_global(f) is_der1_arg2_zero_g = is_der1_arg2_zero_global(f) @@ -172,79 +171,80 @@ function overload_hessian_2_to_1(f) is_der_cross_zero_g = is_der_cross_zero_global(f) SCT = SparseConnectivityTracer - M = nameofrootmodule(f) fname = nameof(f) - ## HessianTracer - @eval function $M.$fname(tx::T, ty::T) where {T<:$SCT.HessianTracer} - return $SCT.hessian_tracer_2_to_1( - tx, - ty, - $is_der1_arg1_zero_g, - $is_der2_arg1_zero_g, - $is_der1_arg2_zero_g, - $is_der2_arg2_zero_g, - $is_der_cross_zero_g, - ) - end + return quote + ## HessianTracer + function $M.$fname(tx::T, ty::T) where {T<:$SCT.HessianTracer} + return $SCT.hessian_tracer_2_to_1( + tx, + ty, + $is_der1_arg1_zero_g, + $is_der2_arg1_zero_g, + $is_der1_arg2_zero_g, + $is_der2_arg2_zero_g, + $is_der_cross_zero_g, + ) + end - @eval function $M.$fname(tx::$SCT.HessianTracer, y::Real) - return $SCT.hessian_tracer_1_to_1(tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g) - end + function $M.$fname(tx::$SCT.HessianTracer, y::Real) + return $SCT.hessian_tracer_1_to_1( + tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g + ) + end - @eval function $M.$fname(x::Real, ty::$SCT.HessianTracer) - return $SCT.hessian_tracer_1_to_1(ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g) - end + function $M.$fname(x::Real, ty::$SCT.HessianTracer) + return $SCT.hessian_tracer_1_to_1( + ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g + ) + end - ## Dual - @eval function $M.$fname(dx::D, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(dx) - y = $SCT.primal(dy) - p_out = $M.$fname(x, y) - - tx = $SCT.tracer(dx) - ty = $SCT.tracer(dy) - is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) - is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$fname, x, y) - is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) - is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y) - is_der_cross_zero = $SCT.is_der_cross_zero_local($M.$fname, x, y) - t_out = $SCT.hessian_tracer_2_to_1( - tx, - ty, - is_der1_arg1_zero, - is_der2_arg1_zero, - is_der1_arg2_zero, - is_der2_arg2_zero, - is_der_cross_zero, - ) - return $SCT.Dual(p_out, t_out) - end + ## Dual + function $M.$fname(dx::D, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + y = $SCT.primal(dy) + p_out = $M.$fname(x, y) + + tx = $SCT.tracer(dx) + ty = $SCT.tracer(dy) + is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) + is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$fname, x, y) + is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) + is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y) + is_der_cross_zero = $SCT.is_der_cross_zero_local($M.$fname, x, y) + t_out = $SCT.hessian_tracer_2_to_1( + tx, + ty, + is_der1_arg1_zero, + is_der2_arg1_zero, + is_der1_arg2_zero, + is_der2_arg2_zero, + is_der_cross_zero, + ) + return $SCT.Dual(p_out, t_out) + end - @eval function $M.$fname( - dx::D, y::Real - ) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(dx) - p_out = $M.$fname(x, y) - - tx = $SCT.tracer(dx) - is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) - is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$fname, x, y) - t_out = $SCT.hessian_tracer_1_to_1(tx, is_der1_arg1_zero, is_der2_arg1_zero) - return $SCT.Dual(p_out, t_out) - end + function $M.$fname(dx::D, y::Real) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + p_out = $M.$fname(x, y) + + tx = $SCT.tracer(dx) + is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) + is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$fname, x, y) + t_out = $SCT.hessian_tracer_1_to_1(tx, is_der1_arg1_zero, is_der2_arg1_zero) + return $SCT.Dual(p_out, t_out) + end + + function $M.$fname(x::Real, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + y = $SCT.primal(dy) + p_out = $M.$fname(x, y) - @eval function $M.$fname( - x::Real, dy::D - ) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} - y = $SCT.primal(dy) - p_out = $M.$fname(x, y) - - ty = $SCT.tracer(dy) - is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) - is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y) - t_out = $SCT.hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_der2_arg2_zero) - return $SCT.Dual(p_out, t_out) + ty = $SCT.tracer(dy) + is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) + is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y) + t_out = $SCT.hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_der2_arg2_zero) + return $SCT.Dual(p_out, t_out) + end end end @@ -266,40 +266,45 @@ end end end -function overload_hessian_1_to_2(f) +function overload_hessian_1_to_2(M::Symbol, f) is_der1_out1_zero_g = is_der1_out1_zero_global(f) is_der2_out1_zero_g = is_der2_out1_zero_global(f) is_der1_out2_zero_g = is_der1_out2_zero_global(f) is_der2_out2_zero_g = is_der2_out2_zero_global(f) SCT = SparseConnectivityTracer - M = nameofrootmodule(f) fname = nameof(f) - ## HessianTracer - @eval function $M.$fname(t::$SCT.HessianTracer) - return $SCT.hessian_tracer_1_to_2( - t, - $is_der1_out1_zero_g, - $is_der2_out1_zero_g, - $is_der1_out2_zero_g, - $is_der2_out2_zero_g, - ) - end + return quote + ## HessianTracer + function $M.$fname(t::$SCT.HessianTracer) + return $SCT.hessian_tracer_1_to_2( + t, + $is_der1_out1_zero_g, + $is_der2_out1_zero_g, + $is_der1_out2_zero_g, + $is_der2_out2_zero_g, + ) + end - ## Dual - @eval function $M.$fname(d::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(d) - p_out1, p_out2 = $M.$fname(x) - - is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$fname, x) - is_der2_out1_zero = $SCT.is_der2_out1_zero_local($M.$fname, x) - is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x) - is_der2_out2_zero = $SCT.is_der2_out2_zero_local($M.$fname, x) - t_out1, t_out2 = $SCT.hessian_tracer_1_to_2( - d, is_der1_out1_zero, is_der2_out1_zero, is_der1_out2_zero, is_der2_out2_zero - ) - return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) + ## Dual + function $M.$fname(d::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + p_out1, p_out2 = $M.$fname(x) + + is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$fname, x) + is_der2_out1_zero = $SCT.is_der2_out1_zero_local($M.$fname, x) + is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x) + is_der2_out2_zero = $SCT.is_der2_out2_zero_local($M.$fname, x) + t_out1, t_out2 = $SCT.hessian_tracer_1_to_2( + d, + is_der1_out1_zero, + is_der2_out1_zero, + is_der1_out2_zero, + is_der2_out2_zero, + ) + return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) + end end end diff --git a/src/overloads/overload_all.jl b/src/overloads/overload_all.jl index f72161f..9fc3eb6 100644 --- a/src/overloads/overload_all.jl +++ b/src/overloads/overload_all.jl @@ -6,29 +6,17 @@ for overload in ( :overload_hessian_2_to_1, :overload_hessian_1_to_2, ) - @eval function $overload(ops::Union{AbstractVector,Tuple}) - for op in ops - $overload(op) + @eval function $overload(M::Symbol, ops::Union{AbstractVector,Tuple}) + exprs = [$overload(M, op) for op in ops] + return quote + $(exprs...) end end end -overload_gradient_1_to_1(ops_1_to_1) -overload_gradient_2_to_1(ops_2_to_1) -overload_gradient_1_to_2(ops_1_to_2) -overload_hessian_1_to_1(ops_1_to_1) -overload_hessian_2_to_1(ops_2_to_1) -overload_hessian_1_to_2(ops_1_to_2) - -# for op in ops_1_to_1 -# overload_gradient_1_to_1(op) -# overload_hessian_1_to_1(op) -# end -# for op in ops_2_to_1 -# overload_gradient_2_to_1(op) -# overload_hessian_2_to_1(op) -# end -# for op in ops_1_to_2 -# overload_gradient_1_to_2(op) -# overload_hessian_1_to_2(op) -# end +eval(overload_gradient_1_to_1(:Base, ops_1_to_1)) +eval(overload_gradient_2_to_1(:Base, ops_2_to_1)) +eval(overload_gradient_1_to_2(:Base, ops_1_to_2)) +eval(overload_hessian_1_to_1(:Base, ops_1_to_1)) +eval(overload_hessian_2_to_1(:Base, ops_2_to_1)) +eval(overload_hessian_1_to_2(:Base, ops_1_to_2)) diff --git a/src/overloads/utils.jl b/src/overloads/utils.jl deleted file mode 100644 index 734da8c..0000000 --- a/src/overloads/utils.jl +++ /dev/null @@ -1,10 +0,0 @@ -# symb(f) = Symbol(parentmodule(f), '.', nameof(f)) -function rootmodule(x) - parent = parentmodule(x) - if parent == x - return parent - else - return rootmodule(parent) - end -end -nameofrootmodule(x) = nameof(rootmodule(x)) From 6e7268a6bc67813aa119e7cd1dc47fced0529661 Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 16 Aug 2024 17:03:12 +0200 Subject: [PATCH 14/20] Minor fixes --- src/SparseConnectivityTracer.jl | 2 +- src/overloads/gradient_tracer.jl | 16 ++++++---------- src/overloads/hessian_tracer.jl | 14 +++++--------- src/overloads/overload_all.jl | 4 +--- 4 files changed, 13 insertions(+), 23 deletions(-) diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index d5747a6..2de22ca 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -30,8 +30,8 @@ include("overloads/gradient_tracer.jl") include("overloads/hessian_tracer.jl") include("overloads/ifelse_global.jl") include("overloads/dual.jl") -include("overloads/arrays.jl") include("overloads/overload_all.jl") +include("overloads/arrays.jl") include("interface.jl") include("adtypes.jl") diff --git a/src/overloads/gradient_tracer.jl b/src/overloads/gradient_tracer.jl index 2223363..d244771 100644 --- a/src/overloads/gradient_tracer.jl +++ b/src/overloads/gradient_tracer.jl @@ -1,3 +1,5 @@ +SCT = SparseConnectivityTracer + ## 1-to-1 @noinline function gradient_tracer_1_to_1( @@ -29,10 +31,8 @@ function gradient_tracer_1_to_1_inner( end function overload_gradient_1_to_1(M::Symbol, f) - is_der1_zero_g = is_der1_zero_global(f) - - SCT = SparseConnectivityTracer fname = nameof(f) + is_der1_zero_g = is_der1_zero_global(f) return quote ## GradientTracer @@ -100,14 +100,12 @@ function gradient_tracer_2_to_1_inner( end function overload_gradient_2_to_1(M::Symbol, f) + fname = nameof(f) is_der1_arg1_zero = is_der1_arg1_zero_global(f) is_der1_arg2_zero = is_der1_arg2_zero_global(f) - SCT = SparseConnectivityTracer - fname = nameof(f) - - ## GradientTracer return quote + ## GradientTracer function $M.$fname(tx::T, ty::T) where {T<:$SCT.GradientTracer} return $SCT.gradient_tracer_2_to_1( tx, ty, $is_der1_arg1_zero, $is_der1_arg2_zero @@ -179,12 +177,10 @@ end end function overload_gradient_1_to_2(M::Symbol, f) + fname = nameof(f) is_der1_out1_zero_g = is_der1_out1_zero_global(f) is_der1_out2_zero_g = is_der1_out2_zero_global(f) - SCT = SparseConnectivityTracer - fname = nameof(f) - return quote ## GradientTracer function $M.$fname(t::$SCT.GradientTracer) diff --git a/src/overloads/hessian_tracer.jl b/src/overloads/hessian_tracer.jl index 01237bc..072223a 100644 --- a/src/overloads/hessian_tracer.jl +++ b/src/overloads/hessian_tracer.jl @@ -1,3 +1,5 @@ +SCT = SparseConnectivityTracer + ## 1-to-1 # 𝟙[∇γ] = 𝟙[∂φ]⋅𝟙[∇α] @@ -56,12 +58,10 @@ function hessian_tracer_1_to_1_inner( end function overload_hessian_1_to_1(M::Symbol, f) + fname = nameof(f) is_der1_zero_g = is_der1_zero_global(f) is_der2_zero_g = is_der2_zero_global(f) - SCT = SparseConnectivityTracer - fname = nameof(f) - return quote ## HessianTracer function $M.$fname(t::$SCT.HessianTracer) @@ -164,15 +164,13 @@ function hessian_tracer_2_to_1_inner( end function overload_hessian_2_to_1(M::Symbol, f) + fname = nameof(f) is_der1_arg1_zero_g = is_der1_arg1_zero_global(f) is_der2_arg1_zero_g = is_der2_arg1_zero_global(f) is_der1_arg2_zero_g = is_der1_arg2_zero_global(f) is_der2_arg2_zero_g = is_der2_arg2_zero_global(f) is_der_cross_zero_g = is_der_cross_zero_global(f) - SCT = SparseConnectivityTracer - fname = nameof(f) - return quote ## HessianTracer function $M.$fname(tx::T, ty::T) where {T<:$SCT.HessianTracer} @@ -267,14 +265,12 @@ end end function overload_hessian_1_to_2(M::Symbol, f) + fname = nameof(f) is_der1_out1_zero_g = is_der1_out1_zero_global(f) is_der2_out1_zero_g = is_der2_out1_zero_global(f) is_der1_out2_zero_g = is_der1_out2_zero_global(f) is_der2_out2_zero_g = is_der2_out2_zero_global(f) - SCT = SparseConnectivityTracer - fname = nameof(f) - return quote ## HessianTracer function $M.$fname(t::$SCT.HessianTracer) diff --git a/src/overloads/overload_all.jl b/src/overloads/overload_all.jl index 9fc3eb6..432580e 100644 --- a/src/overloads/overload_all.jl +++ b/src/overloads/overload_all.jl @@ -8,9 +8,7 @@ for overload in ( ) @eval function $overload(M::Symbol, ops::Union{AbstractVector,Tuple}) exprs = [$overload(M, op) for op in ops] - return quote - $(exprs...) - end + return Expr(:block, exprs...) end end From d875b0fd02a7c192266da26c77838c6b3dc455d2 Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 16 Aug 2024 17:36:31 +0200 Subject: [PATCH 15/20] Generate different overloads based on type of differentiability of an operator --- src/overloads/gradient_tracer.jl | 201 ++++++++++++++++++++----------- 1 file changed, 129 insertions(+), 72 deletions(-) diff --git a/src/overloads/gradient_tracer.jl b/src/overloads/gradient_tracer.jl index d244771..9dcead1 100644 --- a/src/overloads/gradient_tracer.jl +++ b/src/overloads/gradient_tracer.jl @@ -34,23 +34,33 @@ function overload_gradient_1_to_1(M::Symbol, f) fname = nameof(f) is_der1_zero_g = is_der1_zero_global(f) - return quote - ## GradientTracer + expr_gradienttracer = quote function $M.$fname(t::$SCT.GradientTracer) return $SCT.gradient_tracer_1_to_1(t, $is_der1_zero_g) end + end - ## Dual - function $M.$fname(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(d) - p_out = $M.$fname(x) - - t = $SCT.tracer(d) - is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x) - t_out = $SCT.gradient_tracer_1_to_1(t, is_der1_zero) - return $SCT.Dual(p_out, t_out) + expr_dual = if is_der1_zero_g + quote + function $M.$fname(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + return $M.$fname(x) + end + end + else + quote + function $M.$fname(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + p_out = $M.$fname(x) + + t = $SCT.tracer(d) + is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x) + t_out = $SCT.gradient_tracer_1_to_1(t, is_der1_zero) + return $SCT.Dual(p_out, t_out) + end end end + return Expr(:block, expr_gradienttracer, expr_dual) end ## 2-to-1 @@ -101,65 +111,103 @@ end function overload_gradient_2_to_1(M::Symbol, f) fname = nameof(f) - is_der1_arg1_zero = is_der1_arg1_zero_global(f) - is_der1_arg2_zero = is_der1_arg2_zero_global(f) + is_der1_arg1_zero_g = is_der1_arg1_zero_global(f) + is_der1_arg2_zero_g = is_der1_arg2_zero_global(f) - return quote - ## GradientTracer + ## GradientTracer + expr_gradienttracer = quote function $M.$fname(tx::T, ty::T) where {T<:$SCT.GradientTracer} return $SCT.gradient_tracer_2_to_1( - tx, ty, $is_der1_arg1_zero, $is_der1_arg2_zero + tx, ty, $is_der1_arg1_zero_g, $is_der1_arg2_zero_g ) end function $M.$fname(tx::$SCT.GradientTracer, ::Real) - return $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero) + return $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero_g) end function $M.$fname(::Real, ty::$SCT.GradientTracer) - return $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero) + return $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero_g) end + end - ## Dual - function $M.$fname(dx::D, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(dx) - y = $SCT.primal(dy) - p_out = $M.$fname(x, y) - - tx = $SCT.tracer(dx) - ty = $SCT.tracer(dy) - is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) - is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) - t_out = $SCT.gradient_tracer_2_to_1( - tx, ty, is_der1_arg1_zero, is_der1_arg2_zero - ) - return $SCT.Dual(p_out, t_out) + ## Dual + expr_dual_dual = if is_der1_arg1_zero_g && is_der1_arg2_zero_g + quote + function $M.$fname(dx::D, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + y = $SCT.primal(dy) + return $M.$fname(x, y) + end end - - function $M.$fname( - dx::D, y::Real - ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(dx) - p_out = $M.$fname(x, y) - - tx = $SCT.tracer(dx) - is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) - t_out = $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero) - return $SCT.Dual(p_out, t_out) + else + quote + function $M.$fname(dx::D, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + y = $SCT.primal(dy) + p_out = $M.$fname(x, y) + + tx = $SCT.tracer(dx) + ty = $SCT.tracer(dy) + is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) + is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) + t_out = $SCT.gradient_tracer_2_to_1( + tx, ty, is_der1_arg1_zero, is_der1_arg2_zero + ) + return $SCT.Dual(p_out, t_out) + end end - - function $M.$fname( - x::Real, dy::D - ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} - y = $SCT.primal(dy) - p_out = $M.$fname(x, y) - - ty = $SCT.tracer(dy) - is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) - t_out = $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero) - return $SCT.Dual(p_out, t_out) + end + expr_dual_real = if is_der1_arg1_zero_g + quote + function $M.$fname( + dx::D, y::Real + ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + return $M.$fname(x, y) + end + end + else + quote + function $M.$fname( + dx::D, y::Real + ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + p_out = $M.$fname(x, y) + + tx = $SCT.tracer(dx) + is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) + t_out = $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero) + return $SCT.Dual(p_out, t_out) + end end end + expr_real_dual = if is_der1_arg2_zero_g + quote + function $M.$fname( + x::Real, dy::D + ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + y = $SCT.primal(dy) + return $M.$fname(x, y) + end + end + else + quote + function $M.$fname( + x::Real, dy::D + ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + y = $SCT.primal(dy) + p_out = $M.$fname(x, y) + + ty = $SCT.tracer(dy) + is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) + t_out = $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero) + return $SCT.Dual(p_out, t_out) + end + end + end + + return Expr(:block, expr_gradienttracer, expr_dual_dual, expr_dual_real, expr_real_dual) end ## 1-to-2 @@ -181,31 +229,40 @@ function overload_gradient_1_to_2(M::Symbol, f) is_der1_out1_zero_g = is_der1_out1_zero_global(f) is_der1_out2_zero_g = is_der1_out2_zero_global(f) - return quote - ## GradientTracer + expr_gradienttracer = quote function $M.$fname(t::$SCT.GradientTracer) - return $SCT.gradient_tracer_1_to_2( - t, $is_der1_out1_zero_g, $is_der1_out2_zero_g - ) + return $SCT.gradient_tracer_1_to_2(t, $is_der1_out1_zero_g, $is_der1_out2_zero_g) end + end - ## Dual - function $M.$fname(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(d) - p_out1, p_out2 = $M.$fname(x) - - t = $SCT.tracer(d) - is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x) - is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$fname, x) - t_out1, t_out2 = $SCT.gradient_tracer_1_to_2( - t, is_der1_out1_zero, is_der1_out2_zero - ) - return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) # TODO: this was wrong, add test + expr_dual = if is_der1_out1_zero_g && is_der1_out2_zero_g + quote + function $M.$fname(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + return $M.$fname(x) + end + end + else + quote + function $M.$fname(d::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + p_out1, p_out2 = $M.$fname(x) + + t = $SCT.tracer(d) + is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x) + is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$fname, x) + t_out1, t_out2 = $SCT.gradient_tracer_1_to_2( + t, is_der1_out1_zero, is_der1_out2_zero + ) + return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) # TODO: this was wrong, add test + end end end + + return Expr(:block, expr_gradienttracer, expr_dual) end -## Special cases +## Special cases to avoid ambiguity errors ## Exponent (requires extra types) for S in (Integer, Rational, Irrational{:ℯ}) From cd21200b8517fdfa460b393f56ba555a17b28ccd Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 16 Aug 2024 17:39:14 +0200 Subject: [PATCH 16/20] Get rid of ambiguity errors on `isless` --- src/overloads/gradient_tracer.jl | 7 +++++-- src/overloads/hessian_tracer.jl | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/overloads/gradient_tracer.jl b/src/overloads/gradient_tracer.jl index 9dcead1..9ff4be1 100644 --- a/src/overloads/gradient_tracer.jl +++ b/src/overloads/gradient_tracer.jl @@ -262,9 +262,8 @@ function overload_gradient_1_to_2(M::Symbol, f) return Expr(:block, expr_gradienttracer, expr_dual) end -## Special cases to avoid ambiguity errors +## Special overloads to avoid ambiguity errors -## Exponent (requires extra types) for S in (Integer, Rational, Irrational{:ℯ}) Base.:^(t::T, ::S) where {T<:GradientTracer} = t Base.:^(::S, t::T) where {T<:GradientTracer} = t @@ -280,6 +279,10 @@ for S in (Integer, Rational, Irrational{:ℯ}) end end +function Base.isless(dx::D, y::AbstractFloat) where {P,T<:GradientTracer,D<:Dual{P,T}} + return isless(primal(dx), y) +end + ## Rounding Base.round(::T, ::RoundingMode; kwargs...) where {T<:GradientTracer} = myempty(T) function Base.round( diff --git a/src/overloads/hessian_tracer.jl b/src/overloads/hessian_tracer.jl index 072223a..41ddc95 100644 --- a/src/overloads/hessian_tracer.jl +++ b/src/overloads/hessian_tracer.jl @@ -304,9 +304,8 @@ function overload_hessian_1_to_2(M::Symbol, f) end end -## Special cases +## Special overloads to avoid ambiguity errors -## Exponent (requires extra types) for S in (Integer, Rational, Irrational{:ℯ}) Base.:^(t::T, ::S) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false) Base.:^(::S, t::T) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false) @@ -323,6 +322,10 @@ for S in (Integer, Rational, Irrational{:ℯ}) end end +function Base.isless(dx::D, y::AbstractFloat) where {P,T<:GradientTracer,D<:Dual{P,T}} + return isless(primal(dx), y) +end + ## Rounding Base.round(t::T, ::RoundingMode; kwargs...) where {T<:HessianTracer} = myempty(T) function Base.round( From f22c3332b2884cde55810857bf334eacb829520f Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 16 Aug 2024 17:53:45 +0200 Subject: [PATCH 17/20] Add conditional overloads to HessianTracer --- src/overloads/gradient_tracer.jl | 2 +- src/overloads/hessian_tracer.jl | 232 ++++++++++++++++++++----------- 2 files changed, 150 insertions(+), 84 deletions(-) diff --git a/src/overloads/gradient_tracer.jl b/src/overloads/gradient_tracer.jl index 9ff4be1..e707758 100644 --- a/src/overloads/gradient_tracer.jl +++ b/src/overloads/gradient_tracer.jl @@ -279,7 +279,7 @@ for S in (Integer, Rational, Irrational{:ℯ}) end end -function Base.isless(dx::D, y::AbstractFloat) where {P,T<:GradientTracer,D<:Dual{P,T}} +function Base.isless(dx::D, y::AbstractFloat) where {P<:Real,T<:GradientTracer,D<:Dual{P,T}} return isless(primal(dx), y) end diff --git a/src/overloads/hessian_tracer.jl b/src/overloads/hessian_tracer.jl index 41ddc95..6d4ed29 100644 --- a/src/overloads/hessian_tracer.jl +++ b/src/overloads/hessian_tracer.jl @@ -62,24 +62,36 @@ function overload_hessian_1_to_1(M::Symbol, f) is_der1_zero_g = is_der1_zero_global(f) is_der2_zero_g = is_der2_zero_global(f) - return quote + expr_hessiantracer = quote ## HessianTracer function $M.$fname(t::$SCT.HessianTracer) return $SCT.hessian_tracer_1_to_1(t, $is_der1_zero_g, $is_der2_zero_g) end + end - ## Dual - function $M.$fname(d::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(d) - p_out = $M.$fname(x) - - t = $SCT.tracer(d) - is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x) - is_der2_zero = $SCT.is_der2_zero_local($M.$fname, x) - t_out = $SCT.hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero) - return $SCT.Dual(p_out, t_out) + expr_dual = if is_der1_zero_g && is_der1_zero_g + quote + function $M.$fname(d::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + return $M.$fname(x) + end + end + else + quote + function $M.$fname(d::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + p_out = $M.$fname(x) + + t = $SCT.tracer(d) + is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x) + is_der2_zero = $SCT.is_der2_zero_local($M.$fname, x) + t_out = $SCT.hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero) + return $SCT.Dual(p_out, t_out) + end end end + + return Expr(:block, expr_hessiantracer, expr_dual) end ## 2-to-1 @@ -171,8 +183,8 @@ function overload_hessian_2_to_1(M::Symbol, f) is_der2_arg2_zero_g = is_der2_arg2_zero_global(f) is_der_cross_zero_g = is_der_cross_zero_global(f) - return quote - ## HessianTracer + ## HessianTracer + expr_hessiantracer = quote function $M.$fname(tx::T, ty::T) where {T<:$SCT.HessianTracer} return $SCT.hessian_tracer_2_to_1( tx, @@ -186,64 +198,103 @@ function overload_hessian_2_to_1(M::Symbol, f) end function $M.$fname(tx::$SCT.HessianTracer, y::Real) - return $SCT.hessian_tracer_1_to_1( - tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g - ) + return $SCT.hessian_tracer_1_to_1(tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g) end function $M.$fname(x::Real, ty::$SCT.HessianTracer) - return $SCT.hessian_tracer_1_to_1( - ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g - ) + return $SCT.hessian_tracer_1_to_1(ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g) end + end - ## Dual - function $M.$fname(dx::D, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(dx) - y = $SCT.primal(dy) - p_out = $M.$fname(x, y) - - tx = $SCT.tracer(dx) - ty = $SCT.tracer(dy) - is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) - is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$fname, x, y) - is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) - is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y) - is_der_cross_zero = $SCT.is_der_cross_zero_local($M.$fname, x, y) - t_out = $SCT.hessian_tracer_2_to_1( - tx, - ty, - is_der1_arg1_zero, - is_der2_arg1_zero, - is_der1_arg2_zero, - is_der2_arg2_zero, - is_der_cross_zero, - ) - return $SCT.Dual(p_out, t_out) + ## Dual + expr_dual_dual = + if is_der1_arg1_zero_g && + is_der2_arg1_zero_g && + is_der1_arg2_zero_g && + is_der2_arg2_zero_g && + is_der_cross_zero_g + quote + function $M.$fname( + dx::D, dy::D + ) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + y = $SCT.primal(dy) + return $M.$fname(x, y) + end + end + else + quote + function $M.$fname( + dx::D, dy::D + ) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + y = $SCT.primal(dy) + p_out = $M.$fname(x, y) + + tx = $SCT.tracer(dx) + ty = $SCT.tracer(dy) + is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) + is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$fname, x, y) + is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) + is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y) + is_der_cross_zero = $SCT.is_der_cross_zero_local($M.$fname, x, y) + t_out = $SCT.hessian_tracer_2_to_1( + tx, + ty, + is_der1_arg1_zero, + is_der2_arg1_zero, + is_der1_arg2_zero, + is_der2_arg2_zero, + is_der_cross_zero, + ) + return $SCT.Dual(p_out, t_out) + end + end end - - function $M.$fname(dx::D, y::Real) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(dx) - p_out = $M.$fname(x, y) - - tx = $SCT.tracer(dx) - is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) - is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$fname, x, y) - t_out = $SCT.hessian_tracer_1_to_1(tx, is_der1_arg1_zero, is_der2_arg1_zero) - return $SCT.Dual(p_out, t_out) + expr_dual_real = if is_der1_arg1_zero_g && is_der2_arg1_zero_g + quote + function $M.$fname(dx::D, y::Real) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + return $M.$fname(x, y) + end end - - function $M.$fname(x::Real, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} - y = $SCT.primal(dy) - p_out = $M.$fname(x, y) - - ty = $SCT.tracer(dy) - is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) - is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y) - t_out = $SCT.hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_der2_arg2_zero) - return $SCT.Dual(p_out, t_out) + else + quote + function $M.$fname(dx::D, y::Real) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(dx) + p_out = $M.$fname(x, y) + + tx = $SCT.tracer(dx) + is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) + is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$fname, x, y) + t_out = $SCT.hessian_tracer_1_to_1(tx, is_der1_arg1_zero, is_der2_arg1_zero) + return $SCT.Dual(p_out, t_out) + end + end + end + expr_real_dual = if is_der1_arg2_zero_g && is_der2_arg2_zero_g + quote + function $M.$fname(x::Real, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + y = $SCT.primal(dy) + return $M.$fname(x, y) + end + end + else + quote + function $M.$fname(x::Real, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + y = $SCT.primal(dy) + p_out = $M.$fname(x, y) + + ty = $SCT.tracer(dy) + is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) + is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y) + t_out = $SCT.hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_der2_arg2_zero) + return $SCT.Dual(p_out, t_out) + end end end + + return Expr(:block, expr_hessiantracer, expr_dual_dual, expr_dual_real, expr_real_dual) end ## 1-to-2 @@ -271,8 +322,7 @@ function overload_hessian_1_to_2(M::Symbol, f) is_der1_out2_zero_g = is_der1_out2_zero_global(f) is_der2_out2_zero_g = is_der2_out2_zero_global(f) - return quote - ## HessianTracer + expr_hessiantracer = quote function $M.$fname(t::$SCT.HessianTracer) return $SCT.hessian_tracer_1_to_2( t, @@ -282,26 +332,42 @@ function overload_hessian_1_to_2(M::Symbol, f) $is_der2_out2_zero_g, ) end + end - ## Dual - function $M.$fname(d::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(d) - p_out1, p_out2 = $M.$fname(x) - - is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$fname, x) - is_der2_out1_zero = $SCT.is_der2_out1_zero_local($M.$fname, x) - is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x) - is_der2_out2_zero = $SCT.is_der2_out2_zero_local($M.$fname, x) - t_out1, t_out2 = $SCT.hessian_tracer_1_to_2( - d, - is_der1_out1_zero, - is_der2_out1_zero, - is_der1_out2_zero, - is_der2_out2_zero, - ) - return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) + expr_dual = + if is_der1_out1_zero_g && + is_der2_out1_zero_g && + is_der1_out2_zero_g && + is_der2_out2_zero_g + quote + function $M.$fname(d::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + return $M.$fname(x) + end + end + else + quote + function $M.$fname(d::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + x = $SCT.primal(d) + p_out1, p_out2 = $M.$fname(x) + + is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$fname, x) + is_der2_out1_zero = $SCT.is_der2_out1_zero_local($M.$fname, x) + is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x) + is_der2_out2_zero = $SCT.is_der2_out2_zero_local($M.$fname, x) + t_out1, t_out2 = $SCT.hessian_tracer_1_to_2( + d, + is_der1_out1_zero, + is_der2_out1_zero, + is_der1_out2_zero, + is_der2_out2_zero, + ) + return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) + end + end end - end + + return Expr(:block, expr_hessiantracer, expr_dual) end ## Special overloads to avoid ambiguity errors @@ -322,7 +388,7 @@ for S in (Integer, Rational, Irrational{:ℯ}) end end -function Base.isless(dx::D, y::AbstractFloat) where {P,T<:GradientTracer,D<:Dual{P,T}} +function Base.isless(dx::D, y::AbstractFloat) where {P<:Real,T<:HessianTracer,D<:Dual{P,T}} return isless(primal(dx), y) end From e3f1ebe429fe177b85ef032c5382e467237e79ff Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 16 Aug 2024 18:17:31 +0200 Subject: [PATCH 18/20] Bring back operator lists for classification tests --- ext/SparseConnectivityTracerNNlibExt.jl | 6 +++++- ...seConnectivityTracerSpecialFunctionsExt.jl | 6 +++++- src/overloads/overload_all.jl | 6 ++++++ test/classification.jl | 21 +++++++++++-------- 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/ext/SparseConnectivityTracerNNlibExt.jl b/ext/SparseConnectivityTracerNNlibExt.jl index c7bff08..e545749 100644 --- a/ext/SparseConnectivityTracerNNlibExt.jl +++ b/ext/SparseConnectivityTracerNNlibExt.jl @@ -83,8 +83,12 @@ SCT.is_der1_zero_local(::typeof(softshrink), x) = x > -0.5 && x < 0.5 ops_1_to_1 = union(ops_1_to_1_s, ops_1_to_1_f) ## Overload - eval(SCT.overload_gradient_1_to_1(:NNlib, ops_1_to_1)) eval(SCT.overload_hessian_1_to_1(:NNlib, ops_1_to_1)) +## List operators for later testing +SCT.test_operators_1_to_1(::Val{:NNlib}) = ops_1_to_1 +SCT.test_operators_2_to_1(::Val{:NNlib}) = () +SCT.test_operators_1_to_2(::Val{:NNlib}) = () + end diff --git a/ext/SparseConnectivityTracerSpecialFunctionsExt.jl b/ext/SparseConnectivityTracerSpecialFunctionsExt.jl index eaefda7..f25d0d8 100644 --- a/ext/SparseConnectivityTracerSpecialFunctionsExt.jl +++ b/ext/SparseConnectivityTracerSpecialFunctionsExt.jl @@ -111,10 +111,14 @@ end ops_2_to_1 = ops_2_to_1_ssc ## Overloads - eval(SCT.overload_gradient_1_to_1(:SpecialFunctions, ops_1_to_1)) eval(SCT.overload_gradient_2_to_1(:SpecialFunctions, ops_2_to_1)) eval(SCT.overload_hessian_1_to_1(:SpecialFunctions, ops_1_to_1)) eval(SCT.overload_hessian_2_to_1(:SpecialFunctions, ops_2_to_1)) +## List operators for later testing +SCT.test_operators_1_to_1(::Val{:SpecialFunctions}) = ops_1_to_1 +SCT.test_operators_2_to_1(::Val{:SpecialFunctions}) = ops_2_to_1 +SCT.test_operators_1_to_2(::Val{:SpecialFunctions}) = () + end diff --git a/src/overloads/overload_all.jl b/src/overloads/overload_all.jl index 432580e..98e4dc8 100644 --- a/src/overloads/overload_all.jl +++ b/src/overloads/overload_all.jl @@ -12,9 +12,15 @@ for overload in ( end end +## Overload operators eval(overload_gradient_1_to_1(:Base, ops_1_to_1)) eval(overload_gradient_2_to_1(:Base, ops_2_to_1)) eval(overload_gradient_1_to_2(:Base, ops_1_to_2)) eval(overload_hessian_1_to_1(:Base, ops_1_to_1)) eval(overload_hessian_2_to_1(:Base, ops_2_to_1)) eval(overload_hessian_1_to_2(:Base, ops_1_to_2)) + +## List operators for later testing +test_operators_1_to_1(::Val{:Base}) = ops_1_to_1 +test_operators_2_to_1(::Val{:Base}) = ops_2_to_1 +test_operators_1_to_2(::Val{:Base}) = ops_1_to_2 diff --git a/test/classification.jl b/test/classification.jl index 440e657..c1c932e 100644 --- a/test/classification.jl +++ b/test/classification.jl @@ -1,10 +1,9 @@ -using SparseConnectivityTracer: - list_operators_1_to_1, +using SparseConnectivityTracer: # 1-to-1 is_der1_zero_global, is_der2_zero_global, is_der1_zero_local, - is_der2_zero_local, - list_operators_2_to_1, + is_der2_zero_local +using SparseConnectivityTracer: # 2-to-1 is_der1_arg1_zero_global, is_der2_arg1_zero_global, is_der1_arg2_zero_global, @@ -14,8 +13,8 @@ using SparseConnectivityTracer: is_der2_arg1_zero_local, is_der1_arg2_zero_local, is_der2_arg2_zero_local, - is_der_cross_zero_local, - list_operators_1_to_2, + is_der_cross_zero_local +using SparseConnectivityTracer: # 1-to-2 is_der1_out1_zero_global, is_der2_out1_zero_global, is_der1_out2_zero_global, @@ -24,6 +23,10 @@ using SparseConnectivityTracer: is_der1_out1_zero_local, is_der1_out2_zero_local, is_der2_out2_zero_local +using SparseConnectivityTracer: # testing + test_operators_2_to_1, + test_operators_1_to_1, + test_operators_1_to_2 using SpecialFunctions: SpecialFunctions using NNlib: NNlib using Test @@ -88,7 +91,7 @@ end @testset verbose = true "1-to-1" begin @testset "$m" for m in (Base, SpecialFunctions, NNlib) - @testset "$op" for op in list_operators_1_to_1(Val(Symbol(m))) + @testset "$op" for op in test_operators_1_to_1(Val(Symbol(m))) @test all( correct_classification_1_to_1(op, random_input(op); atol=DEFAULT_ATOL) for _ in 1:DEFAULT_TRIALS @@ -131,7 +134,7 @@ end @testset verbose = true "2-to-1" begin @testset "$m" for m in (Base, SpecialFunctions, NNlib) - @testset "$op" for op in list_operators_2_to_1(Val(Symbol(m))) + @testset "$op" for op in test_operators_2_to_1(Val(Symbol(m))) @test all( correct_classification_2_to_1( op, random_first_input(op), random_second_input(op); atol=DEFAULT_ATOL @@ -171,7 +174,7 @@ end @testset verbose = true "1-to-2" begin @testset "$m" for m in (Base, SpecialFunctions, NNlib) - @testset "$op" for op in list_operators_1_to_2(Val(Symbol(m))) + @testset "$op" for op in test_operators_1_to_2(Val(Symbol(m))) @test all( correct_classification_1_to_2(op, random_input(op); atol=DEFAULT_ATOL) for _ in 1:DEFAULT_TRIALS From 61773e1266c03fde26784fa9dacbc594871bf145 Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 16 Aug 2024 18:32:51 +0200 Subject: [PATCH 19/20] Fix for 1.6 --- Project.toml | 2 -- src/SparseConnectivityTracer.jl | 1 - test/test_gradient.jl | 1 + 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 15a62dd..2f48c80 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "0.6.2-DEV" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -25,7 +24,6 @@ SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions" [compat] ADTypes = "1" -Compat = "3,4" DocStringExtensions = "0.9" FillArrays = "1" LinearAlgebra = "<0.0.1, 1" diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 2de22ca..8a0e24b 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -1,7 +1,6 @@ module SparseConnectivityTracer using ADTypes: ADTypes, jacobian_sparsity, hessian_sparsity -using Compat: Returns using SparseArrays: SparseArrays using SparseArrays: sparse using Random: AbstractRNG, SamplerType diff --git a/test/test_gradient.jl b/test/test_gradient.jl index 9b9c11e..490be48 100644 --- a/test/test_gradient.jl +++ b/test/test_gradient.jl @@ -2,6 +2,7 @@ using SparseConnectivityTracer using SparseConnectivityTracer: GradientTracer, Dual, MissingPrimalError, trace_input using Test +using Compat: Returns using Random: rand, GLOBAL_RNG using LinearAlgebra: det, dot, logdet using SpecialFunctions: erf, beta From 005258c77f8a3e997ff724c3c1cc19d5734592d7 Mon Sep 17 00:00:00 2001 From: adrhill Date: Fri, 16 Aug 2024 18:44:33 +0200 Subject: [PATCH 20/20] Fix for Julia 1.6 --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 2156d4c..10dc3f0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,7 @@ Pkg.develop(; using SparseConnectivityTracer +using Compat: pkgversion using Test using JuliaFormatter using Aqua