Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate 2-to-1 overloads on arbitrary types #197

Merged
merged 4 commits into from
Oct 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
@@ -27,12 +27,12 @@ include("operators.jl")
include("overloads/conversion.jl")
include("overloads/gradient_tracer.jl")
include("overloads/hessian_tracer.jl")
include("overloads/ambiguities.jl")
include("overloads/special_cases.jl")
include("overloads/ifelse_global.jl")
include("overloads/dual.jl")
include("overloads/arrays.jl")
include("overloads/utils.jl")
include("overloads/ambiguities.jl")

include("trace_functions.jl")
include("adtypes_interface.jl")
41 changes: 4 additions & 37 deletions src/overloads/ambiguities.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,5 @@
## Special overloads to avoid ambiguity errors
for S in (Integer, Rational, Irrational{:ℯ})
Base.:^(t::T, ::S) where {T<:GradientTracer} = t
Base.:^(::S, t::T) where {T<:GradientTracer} = t
Base.:^(t::T, ::S) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false)
Base.:^(::S, t::T) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false)

function Base.:^(d::D, y::S) where {P,T<:GradientTracer,D<:Dual{P,T}}
x = primal(d)
t = gradient_tracer_1_to_1(tracer(d), false)
return Dual(x^y, t)
end
function Base.:^(x::S, d::D) where {P,T<:GradientTracer,D<:Dual{P,T}}
y = primal(d)
t = gradient_tracer_1_to_1(tracer(d), false)
return Dual(x^y, t)
end

function Base.:^(d::D, y::S) where {P,T<:HessianTracer,D<:Dual{P,T}}
x = primal(d)
t = hessian_tracer_1_to_1(tracer(d), false, false)
return Dual(x^y, t)
end
function Base.:^(x::S, d::D) where {P,T<:HessianTracer,D<:Dual{P,T}}
y = primal(d)
t = hessian_tracer_1_to_1(tracer(d), false, false)
return Dual(x^y, t)
end
end

for TT in (GradientTracer, HessianTracer)
function Base.isless(dx::D, y::AbstractFloat) where {P<:Real,T<:TT,D<:Dual{P,T}}
return isless(primal(dx), y)
end
function Base.isless(x::AbstractFloat, dy::D) where {P<:Real,T<:TT,D<:Dual{P,T}}
return isless(x, primal(dy))
end
end
eval(generate_code_2_to_1(:Base, ^, Integer))
eval(generate_code_2_to_1(:Base, ^, Rational))
eval(generate_code_2_to_1(:Base, ^, Irrational{:ℯ}))
eval(generate_code_2_to_1(:Base, isless, AbstractFloat))
3 changes: 0 additions & 3 deletions src/overloads/dual.jl
Original file line number Diff line number Diff line change
@@ -18,6 +18,3 @@ for fn in (
throw(MissingPrimalError($fn, t))
end
end

# In some cases, more specialized methods are needed
Base.isless(dx::D, y::AbstractFloat) where {D<:Dual} = isless(primal(dx), y)
34 changes: 16 additions & 18 deletions src/overloads/gradient_tracer.jl
Original file line number Diff line number Diff line change
@@ -109,7 +109,11 @@ function gradient_tracer_2_to_1_inner(
end
end

function generate_code_gradient_2_to_1(M::Symbol, f)
function generate_code_gradient_2_to_1(
M::Symbol, # Symbol indicating Module of f, usually `:Base`
f::Function, # function to overload
Z::Type=Real, # external non-tracer-type to overload on
)
fname = nameof(f)
is_der1_arg1_zero_g = is_der1_arg1_zero_global(f)
is_der1_arg2_zero_g = is_der1_arg2_zero_global(f)
@@ -122,11 +126,11 @@ function generate_code_gradient_2_to_1(M::Symbol, f)
)
end

function $M.$fname(tx::$SCT.GradientTracer, ::Real)
function $M.$fname(tx::$SCT.GradientTracer, ::$Z)
return $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero_g)
end

function $M.$fname(::Real, ty::$SCT.GradientTracer)
function $M.$fname(::$Z, ty::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero_g)
end
end
@@ -158,20 +162,16 @@ function generate_code_gradient_2_to_1(M::Symbol, f)
end
end
end
expr_dual_real = if is_der1_arg1_zero_g
expr_dual_nondual = if is_der1_arg1_zero_g
quote
function $M.$fname(
dx::D, y::Real
) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
function $M.$fname(dx::D, y::$Z) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
x = $SCT.primal(dx)
return $M.$fname(x, y)
end
end
else
quote
function $M.$fname(
dx::D, y::Real
) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
function $M.$fname(dx::D, y::$Z) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
x = $SCT.primal(dx)
p_out = $M.$fname(x, y)

@@ -182,20 +182,16 @@ function generate_code_gradient_2_to_1(M::Symbol, f)
end
end
end
expr_real_dual = if is_der1_arg2_zero_g
expr_nondual_dual = if is_der1_arg2_zero_g
quote
function $M.$fname(
x::Real, dy::D
) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
function $M.$fname(x::$Z, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
y = $SCT.primal(dy)
return $M.$fname(x, y)
end
end
else
quote
function $M.$fname(
x::Real, dy::D
) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
function $M.$fname(x::$Z, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
y = $SCT.primal(dy)
p_out = $M.$fname(x, y)

@@ -207,7 +203,9 @@ function generate_code_gradient_2_to_1(M::Symbol, f)
end
end

return Expr(:block, expr_gradienttracer, expr_dual_dual, expr_dual_real, expr_real_dual)
return Expr(
:block, expr_gradienttracer, expr_dual_dual, expr_dual_nondual, expr_nondual_dual
)
end

## 1-to-2
26 changes: 16 additions & 10 deletions src/overloads/hessian_tracer.jl
Original file line number Diff line number Diff line change
@@ -175,7 +175,11 @@ function hessian_tracer_2_to_1_inner(
return P(g_out, h_out) # return pattern
end

function generate_code_hessian_2_to_1(M::Symbol, f)
function generate_code_hessian_2_to_1(
M::Symbol, # Symbol indicating Module of f, usually `:Base`
f::Function, # function to overload
Z::Type=Real, # external non-tracer-type to overload on
)
fname = nameof(f)
is_der1_arg1_zero_g = is_der1_arg1_zero_global(f)
is_der2_arg1_zero_g = is_der2_arg1_zero_global(f)
@@ -197,11 +201,11 @@ function generate_code_hessian_2_to_1(M::Symbol, f)
)
end

function $M.$fname(tx::$SCT.HessianTracer, y::Real)
function $M.$fname(tx::$SCT.HessianTracer, y::$Z)
return $SCT.hessian_tracer_1_to_1(tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g)
end

function $M.$fname(x::Real, ty::$SCT.HessianTracer)
function $M.$fname(x::$Z, ty::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_1(ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g)
end
end
@@ -251,16 +255,16 @@ function generate_code_hessian_2_to_1(M::Symbol, f)
end
end
end
expr_dual_real = if is_der1_arg1_zero_g && is_der2_arg1_zero_g
expr_dual_nondual = if is_der1_arg1_zero_g && is_der2_arg1_zero_g
quote
function $M.$fname(dx::D, y::Real) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
function $M.$fname(dx::D, y::$Z) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
x = $SCT.primal(dx)
return $M.$fname(x, y)
end
end
else
quote
function $M.$fname(dx::D, y::Real) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
function $M.$fname(dx::D, y::$Z) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
x = $SCT.primal(dx)
p_out = $M.$fname(x, y)

@@ -272,16 +276,16 @@ function generate_code_hessian_2_to_1(M::Symbol, f)
end
end
end
expr_real_dual = if is_der1_arg2_zero_g && is_der2_arg2_zero_g
expr_nondual_dual = if is_der1_arg2_zero_g && is_der2_arg2_zero_g
quote
function $M.$fname(x::Real, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
function $M.$fname(x::$Z, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
y = $SCT.primal(dy)
return $M.$fname(x, y)
end
end
else
quote
function $M.$fname(x::Real, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
function $M.$fname(x::$Z, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
y = $SCT.primal(dy)
p_out = $M.$fname(x, y)

@@ -294,7 +298,9 @@ function generate_code_hessian_2_to_1(M::Symbol, f)
end
end

return Expr(:block, expr_hessiantracer, expr_dual_dual, expr_dual_real, expr_real_dual)
return Expr(
:block, expr_hessiantracer, expr_dual_dual, expr_dual_nondual, expr_nondual_dual
)
end

## 1-to-2
19 changes: 19 additions & 0 deletions src/overloads/utils.jl
Original file line number Diff line number Diff line change
@@ -27,6 +27,25 @@ for d in dims
end
end

# Overloads of 2-argument functions on arbitrary types
function generate_code_2_to_1(M::Symbol, f, Z::Type)
expr_g = generate_code_gradient_2_to_1(M, f, Z)
expr_h = generate_code_hessian_2_to_1(M, f, Z)
return Expr(:block, expr_g, expr_h)
end
function generate_code_2_to_1(M::Symbol, ops::Union{AbstractVector,Tuple}, Z::Type)
exprs = [generate_code_2_to_1(M, op, Z) for op in ops]
return Expr(:block, exprs...)
end
function generate_code_gradient_2_to_1(M::Symbol, ops::Union{AbstractVector,Tuple}, Z::Type)
exprs = [generate_code_gradient_2_to_1(M, op, Z) for op in ops]
return Expr(:block, exprs...)
end
function generate_code_hessian_2_to_1(M::Symbol, ops::Union{AbstractVector,Tuple}, Z::Type)
exprs = [generate_code_hessian_2_to_1(M, op, Z) for op in ops]
return Expr(:block, exprs...)
end

## Overload operators
eval(generate_code_1_to_1(:Base, ops_1_to_1))
eval(generate_code_2_to_1(:Base, ops_2_to_1))