Skip to content

Commit

Permalink
Avoid overwriting methods in 2-to-1 functions
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Oct 1, 2024
1 parent 9ce962f commit ce26ffd
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 60 deletions.
8 changes: 4 additions & 4 deletions src/overloads/ambiguities.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
## Special overloads to avoid ambiguity errors
eval(generate_code_2_to_1(:Base, ^, Integer))
eval(generate_code_2_to_1(:Base, ^, Rational))
eval(generate_code_2_to_1(:Base, ^, Irrational{:ℯ}))
eval(generate_code_2_to_1(:Base, isless, AbstractFloat))
eval(generate_code_2_to_1_typed(:Base, ^, Integer))
eval(generate_code_2_to_1_typed(:Base, ^, Rational))
eval(generate_code_2_to_1_typed(:Base, ^, Irrational{:ℯ}))
eval(generate_code_2_to_1_typed(:Base, isless, AbstractFloat))
56 changes: 32 additions & 24 deletions src/overloads/gradient_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function gradient_tracer_1_to_1_inner(
end
end

function generate_code_gradient_1_to_1(M::Symbol, f)
function generate_code_gradient_1_to_1(M::Symbol, f::Function)
fname = nameof(f)
is_der1_zero_g = is_der1_zero_global(f)

Expand Down Expand Up @@ -109,33 +109,19 @@ function gradient_tracer_2_to_1_inner(
end
end

function generate_code_gradient_2_to_1(
M::Symbol, # Symbol indicating Module of f, usually `:Base`
f::Function, # function to overload
Z::Type=Real, # external non-tracer-type to overload on
)
function generate_code_gradient_2_to_1(M::Symbol, f::Function)
fname = nameof(f)
is_der1_arg1_zero_g = is_der1_arg1_zero_global(f)
is_der1_arg2_zero_g = is_der1_arg2_zero_global(f)

## GradientTracer
expr_gradienttracer = quote
expr_tracer_tracer = quote
function $M.$fname(tx::T, ty::T) where {T<:$SCT.GradientTracer}
return $SCT.gradient_tracer_2_to_1(
tx, ty, $is_der1_arg1_zero_g, $is_der1_arg2_zero_g
)
end

function $M.$fname(tx::$SCT.GradientTracer, ::$Z)
return $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero_g)
end

function $M.$fname(::$Z, ty::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero_g)
end
end

## Dual
expr_dual_dual = if is_der1_arg1_zero_g && is_der1_arg2_zero_g
quote
function $M.$fname(dx::D, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
Expand All @@ -162,7 +148,32 @@ function generate_code_gradient_2_to_1(
end
end
end
expr_dual_nondual = if is_der1_arg1_zero_g

exprs_typed = generate_code_gradient_2_to_1_typed(M, f, Real)
return Expr(:block, expr_tracer_tracer, expr_dual_dual, exprs_typed)
end

function generate_code_gradient_2_to_1_typed(
M::Symbol, # Symbol indicating Module of f, usually `:Base`
f::Function, # function to overload
Z::Type, # external non-tracer-type to overload on
)
fname = nameof(f)
is_der1_arg1_zero_g = is_der1_arg1_zero_global(f)
is_der1_arg2_zero_g = is_der1_arg2_zero_global(f)

expr_tracer_type = quote
function $M.$fname(tx::$SCT.GradientTracer, ::$Z)
return $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero_g)
end
end
expr_type_tracer = quote
function $M.$fname(::$Z, ty::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero_g)
end
end

expr_dual_type = if is_der1_arg1_zero_g
quote
function $M.$fname(dx::D, y::$Z) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
x = $SCT.primal(dx)
Expand All @@ -182,7 +193,7 @@ function generate_code_gradient_2_to_1(
end
end
end
expr_nondual_dual = if is_der1_arg2_zero_g
expr_type_dual = if is_der1_arg2_zero_g
quote
function $M.$fname(x::$Z, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
y = $SCT.primal(dy)
Expand All @@ -202,10 +213,7 @@ function generate_code_gradient_2_to_1(
end
end
end

return Expr(
:block, expr_gradienttracer, expr_dual_dual, expr_dual_nondual, expr_nondual_dual
)
return Expr(:block, expr_tracer_type, expr_type_tracer, expr_dual_type, expr_type_dual)
end

## 1-to-2
Expand All @@ -222,7 +230,7 @@ end
end
end

function generate_code_gradient_1_to_2(M::Symbol, f)
function generate_code_gradient_1_to_2(M::Symbol, f::Function)
fname = nameof(f)
is_der1_out1_zero_g = is_der1_out1_zero_global(f)
is_der1_out2_zero_g = is_der1_out2_zero_global(f)
Expand Down
52 changes: 33 additions & 19 deletions src/overloads/hessian_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function hessian_tracer_1_to_1_inner(
return P(g_out, h_out) # return pattern
end

function generate_code_hessian_1_to_1(M::Symbol, f)
function generate_code_hessian_1_to_1(M::Symbol, f::Function)
fname = nameof(f)
is_der1_zero_g = is_der1_zero_global(f)
is_der2_zero_g = is_der2_zero_global(f)
Expand Down Expand Up @@ -187,8 +187,7 @@ function generate_code_hessian_2_to_1(
is_der2_arg2_zero_g = is_der2_arg2_zero_global(f)
is_der_cross_zero_g = is_der_cross_zero_global(f)

## HessianTracer
expr_hessiantracer = quote
expr_tracer_tracer = quote
function $M.$fname(tx::T, ty::T) where {T<:$SCT.HessianTracer}
return $SCT.hessian_tracer_2_to_1(
tx,
Expand All @@ -200,17 +199,8 @@ function generate_code_hessian_2_to_1(
$is_der_cross_zero_g,
)
end

function $M.$fname(tx::$SCT.HessianTracer, y::$Z)
return $SCT.hessian_tracer_1_to_1(tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g)
end

function $M.$fname(x::$Z, ty::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_1(ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g)
end
end

## Dual
expr_dual_dual =
if is_der1_arg1_zero_g &&
is_der2_arg1_zero_g &&
Expand Down Expand Up @@ -255,7 +245,34 @@ function generate_code_hessian_2_to_1(
end
end
end
expr_dual_nondual = if is_der1_arg1_zero_g && is_der2_arg1_zero_g

exprs_typed = generate_code_hessian_2_to_1_typed(M, f, Real)
return Expr(:block, expr_tracer_tracer, expr_dual_dual, exprs_typed)
end

function generate_code_hessian_2_to_1_typed(
M::Symbol, # Symbol indicating Module of f, usually `:Base`
f::Function, # function to overload
Z::Type, # external non-tracer-type to overload on
)
fname = nameof(f)
is_der1_arg1_zero_g = is_der1_arg1_zero_global(f)
is_der2_arg1_zero_g = is_der2_arg1_zero_global(f)
is_der1_arg2_zero_g = is_der1_arg2_zero_global(f)
is_der2_arg2_zero_g = is_der2_arg2_zero_global(f)

expr_tracer_type = quote
function $M.$fname(tx::$SCT.HessianTracer, y::$Z)
return $SCT.hessian_tracer_1_to_1(tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g)
end
end
expr_type_tracer = quote
function $M.$fname(x::$Z, ty::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_1(ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g)
end
end

expr_dual_type = if is_der1_arg1_zero_g && is_der2_arg1_zero_g
quote
function $M.$fname(dx::D, y::$Z) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
x = $SCT.primal(dx)
Expand All @@ -276,7 +293,7 @@ function generate_code_hessian_2_to_1(
end
end
end
expr_nondual_dual = if is_der1_arg2_zero_g && is_der2_arg2_zero_g
expr_type_dual = if is_der1_arg2_zero_g && is_der2_arg2_zero_g
quote
function $M.$fname(x::$Z, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
y = $SCT.primal(dy)
Expand All @@ -297,10 +314,7 @@ function generate_code_hessian_2_to_1(
end
end
end

return Expr(
:block, expr_hessiantracer, expr_dual_dual, expr_dual_nondual, expr_nondual_dual
)
return Expr(:block, expr_tracer_type, expr_type_tracer, expr_dual_type, expr_type_dual)
end

## 1-to-2
Expand All @@ -321,7 +335,7 @@ end
end
end

function generate_code_hessian_1_to_2(M::Symbol, f)
function generate_code_hessian_1_to_2(M::Symbol, f::Function)
fname = nameof(f)
is_der1_out1_zero_g = is_der1_out1_zero_global(f)
is_der2_out1_zero_g = is_der2_out1_zero_global(f)
Expand Down
18 changes: 5 additions & 13 deletions src/overloads/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,13 @@ for d in dims
end

# Overloads of 2-argument functions on arbitrary types
function generate_code_2_to_1(M::Symbol, f, Z::Type)
expr_g = generate_code_gradient_2_to_1(M, f, Z)
expr_h = generate_code_hessian_2_to_1(M, f, Z)
function generate_code_2_to_1_typed(M::Symbol, f, Z::Type)
expr_g = generate_code_gradient_2_to_1_typed(M, f, Z)
expr_h = generate_code_hessian_2_to_1_typed(M, f, Z)
return Expr(:block, expr_g, expr_h)
end
function generate_code_2_to_1(M::Symbol, ops::Union{AbstractVector,Tuple}, Z::Type)
exprs = [generate_code_2_to_1(M, op, Z) for op in ops]
return Expr(:block, exprs...)
end
function generate_code_gradient_2_to_1(M::Symbol, ops::Union{AbstractVector,Tuple}, Z::Type)
exprs = [generate_code_gradient_2_to_1(M, op, Z) for op in ops]
return Expr(:block, exprs...)
end
function generate_code_hessian_2_to_1(M::Symbol, ops::Union{AbstractVector,Tuple}, Z::Type)
exprs = [generate_code_hessian_2_to_1(M, op, Z) for op in ops]
function generate_code_2_to_1_typed(M::Symbol, ops::Union{AbstractVector,Tuple}, Z::Type)
exprs = [generate_code_2_to_1_typed(M, op, Z) for op in ops]
return Expr(:block, exprs...)
end

Expand Down

0 comments on commit ce26ffd

Please sign in to comment.