From d31a1398046c82b16180d87c49c11a46b623a163 Mon Sep 17 00:00:00 2001 From: adrhill Date: Tue, 1 Oct 2024 15:45:06 +0200 Subject: [PATCH 1/3] Allow 2-to-1 overloads on arbitrary types --- src/overloads/gradient_tracer.jl | 34 +++++++++++++++----------------- src/overloads/hessian_tracer.jl | 26 ++++++++++++++---------- 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/src/overloads/gradient_tracer.jl b/src/overloads/gradient_tracer.jl index 40ae1d1b..be401b38 100644 --- a/src/overloads/gradient_tracer.jl +++ b/src/overloads/gradient_tracer.jl @@ -109,7 +109,11 @@ function gradient_tracer_2_to_1_inner( end end -function generate_code_gradient_2_to_1(M::Symbol, f) +function generate_code_gradient_2_to_1( + M::Symbol, # Symbol indicating Module of f, usually `:Base` + f::Function, # function to overload + Z::Type=Real, # external non-tracer-type to overload on +) fname = nameof(f) is_der1_arg1_zero_g = is_der1_arg1_zero_global(f) is_der1_arg2_zero_g = is_der1_arg2_zero_global(f) @@ -122,11 +126,11 @@ function generate_code_gradient_2_to_1(M::Symbol, f) ) end - function $M.$fname(tx::$SCT.GradientTracer, ::Real) + function $M.$fname(tx::$SCT.GradientTracer, ::$Z) return $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero_g) end - function $M.$fname(::Real, ty::$SCT.GradientTracer) + function $M.$fname(::$Z, ty::$SCT.GradientTracer) return $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero_g) end end @@ -158,20 +162,16 @@ function generate_code_gradient_2_to_1(M::Symbol, f) end end end - expr_dual_real = if is_der1_arg1_zero_g + expr_dual_nondual = if is_der1_arg1_zero_g quote - function $M.$fname( - dx::D, y::Real - ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + function $M.$fname(dx::D, y::$Z) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} x = $SCT.primal(dx) return $M.$fname(x, y) end end else quote - function $M.$fname( - dx::D, y::Real - ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + function $M.$fname(dx::D, y::$Z) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} x = $SCT.primal(dx) p_out = $M.$fname(x, y) @@ -182,20 +182,16 @@ function generate_code_gradient_2_to_1(M::Symbol, f) end end end - expr_real_dual = if is_der1_arg2_zero_g + expr_nondual_dual = if is_der1_arg2_zero_g quote - function $M.$fname( - x::Real, dy::D - ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + function $M.$fname(x::$Z, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} y = $SCT.primal(dy) return $M.$fname(x, y) end end else quote - function $M.$fname( - x::Real, dy::D - ) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} + function $M.$fname(x::$Z, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}} y = $SCT.primal(dy) p_out = $M.$fname(x, y) @@ -207,7 +203,9 @@ function generate_code_gradient_2_to_1(M::Symbol, f) end end - return Expr(:block, expr_gradienttracer, expr_dual_dual, expr_dual_real, expr_real_dual) + return Expr( + :block, expr_gradienttracer, expr_dual_dual, expr_dual_nondual, expr_nondual_dual + ) end ## 1-to-2 diff --git a/src/overloads/hessian_tracer.jl b/src/overloads/hessian_tracer.jl index 79e61968..4f07ae03 100644 --- a/src/overloads/hessian_tracer.jl +++ b/src/overloads/hessian_tracer.jl @@ -175,7 +175,11 @@ function hessian_tracer_2_to_1_inner( return P(g_out, h_out) # return pattern end -function generate_code_hessian_2_to_1(M::Symbol, f) +function generate_code_hessian_2_to_1( + M::Symbol, # Symbol indicating Module of f, usually `:Base` + f::Function, # function to overload + Z::Type=Real, # external non-tracer-type to overload on +) fname = nameof(f) is_der1_arg1_zero_g = is_der1_arg1_zero_global(f) is_der2_arg1_zero_g = is_der2_arg1_zero_global(f) @@ -197,11 +201,11 @@ function generate_code_hessian_2_to_1(M::Symbol, f) ) end - function $M.$fname(tx::$SCT.HessianTracer, y::Real) + function $M.$fname(tx::$SCT.HessianTracer, y::$Z) return $SCT.hessian_tracer_1_to_1(tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g) end - function $M.$fname(x::Real, ty::$SCT.HessianTracer) + function $M.$fname(x::$Z, ty::$SCT.HessianTracer) return $SCT.hessian_tracer_1_to_1(ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g) end end @@ -251,16 +255,16 @@ function generate_code_hessian_2_to_1(M::Symbol, f) end end end - expr_dual_real = if is_der1_arg1_zero_g && is_der2_arg1_zero_g + expr_dual_nondual = if is_der1_arg1_zero_g && is_der2_arg1_zero_g quote - function $M.$fname(dx::D, y::Real) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + function $M.$fname(dx::D, y::$Z) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} x = $SCT.primal(dx) return $M.$fname(x, y) end end else quote - function $M.$fname(dx::D, y::Real) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + function $M.$fname(dx::D, y::$Z) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} x = $SCT.primal(dx) p_out = $M.$fname(x, y) @@ -272,16 +276,16 @@ function generate_code_hessian_2_to_1(M::Symbol, f) end end end - expr_real_dual = if is_der1_arg2_zero_g && is_der2_arg2_zero_g + expr_nondual_dual = if is_der1_arg2_zero_g && is_der2_arg2_zero_g quote - function $M.$fname(x::Real, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + function $M.$fname(x::$Z, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} y = $SCT.primal(dy) return $M.$fname(x, y) end end else quote - function $M.$fname(x::Real, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} + function $M.$fname(x::$Z, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}} y = $SCT.primal(dy) p_out = $M.$fname(x, y) @@ -294,7 +298,9 @@ function generate_code_hessian_2_to_1(M::Symbol, f) end end - return Expr(:block, expr_hessiantracer, expr_dual_dual, expr_dual_real, expr_real_dual) + return Expr( + :block, expr_hessiantracer, expr_dual_dual, expr_dual_nondual, expr_nondual_dual + ) end ## 1-to-2 From 94fecaeed700df4dc1393d85f40ecaa5fd6206cc Mon Sep 17 00:00:00 2001 From: adrhill Date: Tue, 1 Oct 2024 16:21:41 +0200 Subject: [PATCH 2/3] Update utils to be able to pass type --- src/overloads/utils.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/overloads/utils.jl b/src/overloads/utils.jl index e9118e27..9d20a71d 100644 --- a/src/overloads/utils.jl +++ b/src/overloads/utils.jl @@ -27,6 +27,25 @@ for d in dims end end +# Overloads of 2-argument functions on arbitrary types +function generate_code_2_to_1(M::Symbol, f, Z::Type) + expr_g = generate_code_gradient_2_to_1(M, f, Z) + expr_h = generate_code_hessian_2_to_1(M, f, Z) + return Expr(:block, expr_g, expr_h) +end +function generate_code_2_to_1(M::Symbol, ops::Union{AbstractVector,Tuple}, Z::Type) + exprs = [generate_code_2_to_1(M, op, Z) for op in ops] + return Expr(:block, exprs...) +end +function generate_code_gradient_2_to_1(M::Symbol, ops::Union{AbstractVector,Tuple}, Z::Type) + exprs = [generate_code_gradient_2_to_1(M, op, Z) for op in ops] + return Expr(:block, exprs...) +end +function generate_code_hessian_2_to_1(M::Symbol, ops::Union{AbstractVector,Tuple}, Z::Type) + exprs = [generate_code_hessian_2_to_1(M, op, Z) for op in ops] + return Expr(:block, exprs...) +end + ## Overload operators eval(generate_code_1_to_1(:Base, ops_1_to_1)) eval(generate_code_2_to_1(:Base, ops_2_to_1)) From eecc5d6a1f8682f2b013e54a04715e5e2181fdba Mon Sep 17 00:00:00 2001 From: adrhill Date: Tue, 1 Oct 2024 16:21:51 +0200 Subject: [PATCH 3/3] Refactor ambiguities --- src/SparseConnectivityTracer.jl | 2 +- src/overloads/ambiguities.jl | 41 ++++----------------------------- src/overloads/dual.jl | 3 --- 3 files changed, 5 insertions(+), 41 deletions(-) diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 221fa1ef..8a44b8ba 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -27,12 +27,12 @@ include("operators.jl") include("overloads/conversion.jl") include("overloads/gradient_tracer.jl") include("overloads/hessian_tracer.jl") -include("overloads/ambiguities.jl") include("overloads/special_cases.jl") include("overloads/ifelse_global.jl") include("overloads/dual.jl") include("overloads/arrays.jl") include("overloads/utils.jl") +include("overloads/ambiguities.jl") include("trace_functions.jl") include("adtypes_interface.jl") diff --git a/src/overloads/ambiguities.jl b/src/overloads/ambiguities.jl index 23c5cdab..2e0ab3d5 100644 --- a/src/overloads/ambiguities.jl +++ b/src/overloads/ambiguities.jl @@ -1,38 +1,5 @@ ## Special overloads to avoid ambiguity errors -for S in (Integer, Rational, Irrational{:ℯ}) - Base.:^(t::T, ::S) where {T<:GradientTracer} = t - Base.:^(::S, t::T) where {T<:GradientTracer} = t - Base.:^(t::T, ::S) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false) - Base.:^(::S, t::T) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false) - - function Base.:^(d::D, y::S) where {P,T<:GradientTracer,D<:Dual{P,T}} - x = primal(d) - t = gradient_tracer_1_to_1(tracer(d), false) - return Dual(x^y, t) - end - function Base.:^(x::S, d::D) where {P,T<:GradientTracer,D<:Dual{P,T}} - y = primal(d) - t = gradient_tracer_1_to_1(tracer(d), false) - return Dual(x^y, t) - end - - function Base.:^(d::D, y::S) where {P,T<:HessianTracer,D<:Dual{P,T}} - x = primal(d) - t = hessian_tracer_1_to_1(tracer(d), false, false) - return Dual(x^y, t) - end - function Base.:^(x::S, d::D) where {P,T<:HessianTracer,D<:Dual{P,T}} - y = primal(d) - t = hessian_tracer_1_to_1(tracer(d), false, false) - return Dual(x^y, t) - end -end - -for TT in (GradientTracer, HessianTracer) - function Base.isless(dx::D, y::AbstractFloat) where {P<:Real,T<:TT,D<:Dual{P,T}} - return isless(primal(dx), y) - end - function Base.isless(x::AbstractFloat, dy::D) where {P<:Real,T<:TT,D<:Dual{P,T}} - return isless(x, primal(dy)) - end -end +eval(generate_code_2_to_1(:Base, ^, Integer)) +eval(generate_code_2_to_1(:Base, ^, Rational)) +eval(generate_code_2_to_1(:Base, ^, Irrational{:ℯ})) +eval(generate_code_2_to_1(:Base, isless, AbstractFloat)) diff --git a/src/overloads/dual.jl b/src/overloads/dual.jl index 4c504bd5..461e2456 100644 --- a/src/overloads/dual.jl +++ b/src/overloads/dual.jl @@ -18,6 +18,3 @@ for fn in ( throw(MissingPrimalError($fn, t)) end end - -# In some cases, more specialized methods are needed -Base.isless(dx::D, y::AbstractFloat) where {D<:Dual} = isless(primal(dx), y)