Skip to content

Commit

Permalink
Allow more generic hessian set types (#103)
Browse files Browse the repository at this point in the history
* Add two variants of `connectivity_tracer_i_to_j`, `gradient_tracer_i_to_j` and `hessian_tracer_i_to_j`: one for `AbstractTracer` and one for `AbstractSet`

* Use `gradient_tracer_i_to_j` inside `hessian_tracer_i_to_j` for separation of concerns

* Allow more generic eltype in the various custom sets, so that they can hold `Tuple{Int,Int}`

* Start implementing clever `union!`, `product` and `union_product!`
  • Loading branch information
gdalle authored May 31, 2024
1 parent d3313e5 commit ee6c15d
Show file tree
Hide file tree
Showing 16 changed files with 307 additions and 177 deletions.
60 changes: 42 additions & 18 deletions src/overload_connectivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,18 @@
function connectivity_tracer_1_to_1(
t::T, is_influence_zero::Bool
) where {T<:ConnectivityTracer}
s = inputs(t)
s_out = connectivity_tracer_1_to_1(s, is_influence_zero)
return T(s_out)
end

function connectivity_tracer_1_to_1(
s::S, is_influence_zero::Bool
) where {S<:AbstractSet{<:Integer}}
if is_influence_zero
return myempty(T)
return myempty(S)
else
return t
return s
end
end

Expand Down Expand Up @@ -38,18 +46,24 @@ end
function connectivity_tracer_2_to_1(
tx::T, ty::T, is_influence_arg1_zero::Bool, is_influence_arg2_zero::Bool
) where {T<:ConnectivityTracer}
if is_influence_arg1_zero
if is_influence_arg2_zero
return myempty(T)
else
return ty
end
else # x -> f ≠ 0
if is_influence_arg2_zero
return tx
else
return T(inputs(tx) inputs(ty))
end
sx, sy = inputs(tx), inputs(ty)
s_out = connectivity_tracer_2_to_1(
sx, sy, is_influence_arg1_zero, is_influence_arg2_zero
)
return T(s_out)
end

function connectivity_tracer_2_to_1(
sx::S, sy::S, is_influence_arg1_zero::Bool, is_influence_arg2_zero::Bool
) where {S<:AbstractSet{<:Integer}}
if is_influence_arg1_zero && is_influence_arg2_zero
return myempty(S)
elseif !is_influence_arg1_zero && is_influence_arg2_zero
return sx
elseif is_influence_arg1_zero && !is_influence_arg2_zero
return sy
else
return clever_union(sx, sy)
end
end

Expand Down Expand Up @@ -124,9 +138,19 @@ end
function connectivity_tracer_1_to_2(
t::T, is_influence_out1_zero::Bool, is_influence_out2_zero::Bool
) where {T<:ConnectivityTracer}
t1 = connectivity_tracer_1_to_1(t, is_influence_out1_zero)
t2 = connectivity_tracer_1_to_1(t, is_influence_out2_zero)
return (t1, t2)
s = inputs(t)
(s_out1, s_out2) = connectivity_tracer_1_to_2(
s, is_influence_out1_zero, is_influence_out2_zero
)
return (T(s_out1), T(s_out2))
end

function connectivity_tracer_1_to_2(
s::S, is_influence_out1_zero::Bool, is_influence_out2_zero::Bool
) where {S<:AbstractSet{<:Integer}}
s1 = connectivity_tracer_1_to_1(s, is_influence_out1_zero)
s2 = connectivity_tracer_1_to_1(s, is_influence_out2_zero)
return (s1, s2)
end

function overload_connectivity_1_to_2(M, op)
Expand All @@ -149,7 +173,7 @@ function overload_connectivity_1_to_2_dual(M, op)
x = $SCT.primal(d)
p1_out, p2_out = $M.$op(x)
t1_out, t2_out = $SCT.connectivity_tracer_1_to_2(
t,
$SCT.tracer(d), # TODO: add test, this was buggy
$SCT.is_influence_out1_zero_local($M.$op, x),
$SCT.is_influence_out2_zero_local($M.$op, x),
)
Expand Down
55 changes: 38 additions & 17 deletions src/overload_gradient.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
## 1-to-1

function gradient_tracer_1_to_1(t::T, is_firstder_zero::Bool) where {T<:GradientTracer}
s = gradient(t)
s_out = gradient_tracer_1_to_1(s, is_firstder_zero)
return T(s_out)
end

function gradient_tracer_1_to_1(
s::S, is_firstder_zero::Bool
) where {S<:AbstractSet{<:Integer}}
if is_firstder_zero
return myempty(T)
return myempty(S)
else
return t
return s
end
end

Expand Down Expand Up @@ -36,18 +44,22 @@ end
function gradient_tracer_2_to_1(
tx::T, ty::T, is_firstder_arg1_zero::Bool, is_firstder_arg2_zero::Bool
) where {T<:GradientTracer}
if is_firstder_arg1_zero
if is_firstder_arg2_zero
return myempty(T)
else
return ty
end
else # ∂f∂x ≠ 0
if is_firstder_arg2_zero
return tx
else
return T(gradient(tx) gradient(ty))
end
sx, sy = gradient(tx), gradient(ty)
s_out = gradient_tracer_2_to_1(sx, sy, is_firstder_arg1_zero, is_firstder_arg2_zero)
return T(s_out)
end

function gradient_tracer_2_to_1(
sx::S, sy::S, is_firstder_arg1_zero::Bool, is_firstder_arg2_zero::Bool
) where {S<:AbstractSet{<:Integer}}
if is_firstder_arg1_zero && is_firstder_arg2_zero
return myempty(S)
elseif !is_firstder_arg1_zero && is_firstder_arg2_zero
return sx
elseif is_firstder_arg1_zero && !is_firstder_arg2_zero
return sy
else
return clever_union(sx, sy)
end
end

Expand Down Expand Up @@ -76,6 +88,7 @@ function overload_gradient_2_to_1(M, op)
end
end
end

function overload_gradient_2_to_1_dual(M, op)
SCT = SparseConnectivityTracer
return quote
Expand Down Expand Up @@ -117,9 +130,17 @@ end
function gradient_tracer_1_to_2(
t::T, is_firstder_out1_zero::Bool, is_firstder_out2_zero::Bool
) where {T<:GradientTracer}
t1 = gradient_tracer_1_to_1(t, is_firstder_out1_zero)
t2 = gradient_tracer_1_to_1(t, is_firstder_out2_zero)
return (t1, t2)
s = gradient(t)
s_out1, s_out2 = gradient_tracer_1_to_2(s, is_firstder_out1_zero, is_firstder_out2_zero)
return (T(s_out1), T(s_out2))
end

function gradient_tracer_1_to_2(
s::S, is_firstder_out1_zero::Bool, is_firstder_out2_zero::Bool
) where {S<:AbstractSet{<:Integer}}
s_out1 = gradient_tracer_1_to_1(s, is_firstder_out1_zero)
s_out2 = gradient_tracer_1_to_1(s, is_firstder_out2_zero)
return (s_out1, s_out2)
end

function overload_gradient_1_to_2(M, op)
Expand Down
138 changes: 93 additions & 45 deletions src/overload_hessian.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
## 1-to-1

function hessian_tracer_1_to_1(
t::T, is_firstder_zero::Bool, is_seconder_zero::Bool
) where {G,H,T<:HessianTracer{G,H}}
if is_seconder_zero
if is_firstder_zero
return myempty(T)
else
return t
end
t::T, is_firstder_zero::Bool, is_secondder_zero::Bool
) where {T<:HessianTracer}
sg, sh = gradient(t), hessian(t)
sg_out, sh_out = hessian_tracer_1_to_1(sg, sh, is_firstder_zero, is_secondder_zero)
return T(sg_out, sh_out)
end

function hessian_tracer_1_to_1(
sg::SG, sh::SH, is_firstder_zero::Bool, is_secondder_zero::Bool
) where {I,SG<:AbstractSet{<:I},SH<:AbstractSet{<:Tuple{I,I}}}
sg_out = gradient_tracer_1_to_1(sg, is_firstder_zero)
sh_out = if is_firstder_zero && is_secondder_zero
myempty(SH)
elseif !is_firstder_zero && is_secondder_zero
sh
elseif is_firstder_zero && !is_secondder_zero
product(sg, sg)
else
if is_firstder_zero
return T(myempty(G), gradient(t) × gradient(t))
else
return T(gradient(t), hessian(t) (gradient(t) × gradient(t)))
end
union_product(sh, sg, sg)
end
return (sg_out, sh_out)
end

function overload_hessian_1_to_1(M, op)
Expand Down Expand Up @@ -50,34 +56,50 @@ end
## 2-to-1

function hessian_tracer_2_to_1(
a::T,
b::T,
tx::T,
ty::T,
is_firstder_arg1_zero::Bool,
is_seconder_arg1_zero::Bool,
is_secondder_arg1_zero::Bool,
is_firstder_arg2_zero::Bool,
is_seconder_arg2_zero::Bool,
is_secondder_arg2_zero::Bool,
is_crossder_zero::Bool,
) where {G,H,T<:HessianTracer{G,H}}
grad = myempty(G)
hess = myempty(H)
if !is_firstder_arg1_zero
grad = union(grad, gradient(a)) # TODO: use union!
union!(hess, hessian(a))
end
if !is_firstder_arg2_zero
grad = union(grad, gradient(b)) # TODO: use union!
union!(hess, hessian(b))
end
if !is_seconder_arg1_zero
union!(hess, gradient(a) × gradient(a))
end
if !is_seconder_arg2_zero
union!(hess, gradient(b) × gradient(b))
end
if !is_crossder_zero
union!(hess, (gradient(a) × gradient(b)) (gradient(b) × gradient(a)))
end
return T(grad, hess)
) where {T<:HessianTracer}
sgx, shx = gradient(tx), hessian(tx)
sgy, shy = gradient(ty), hessian(ty)
sg_out, sh_out = hessian_tracer_2_to_1(
sgx,
shx,
sgy,
shy,
is_firstder_arg1_zero,
is_secondder_arg1_zero,
is_firstder_arg2_zero,
is_secondder_arg2_zero,
is_crossder_zero,
)
return T(sg_out, sh_out)
end

function hessian_tracer_2_to_1(
sgx::SG,
shx::SH,
sgy::SG,
shy::SH,
is_firstder_arg1_zero::Bool,
is_secondder_arg1_zero::Bool,
is_firstder_arg2_zero::Bool,
is_secondder_arg2_zero::Bool,
is_crossder_zero::Bool,
) where {I,SG<:AbstractSet{I},SH<:AbstractSet{<:Tuple{I,I}}}
sg_out = gradient_tracer_2_to_1(sgx, sgy, is_firstder_arg1_zero, is_firstder_arg2_zero)
sh_out = myempty(SH)
!is_firstder_arg1_zero && union!(sh_out, shx) # hessian alpha
!is_firstder_arg2_zero && union!(sh_out, shy) # hessian beta
!is_secondder_arg1_zero && union_product!(sh_out, sgx, sgx) # product alpha
!is_secondder_arg2_zero && union_product!(sh_out, sgy, sgy) # product beta
!is_crossder_zero && union_product!(sh_out, sgx, sgy) # cross product 1
!is_crossder_zero && union_product!(sh_out, sgy, sgx) # cross product 2
return (sg_out, sh_out)
end

function overload_hessian_2_to_1(M, op)
Expand Down Expand Up @@ -165,9 +187,33 @@ function hessian_tracer_1_to_2(
is_firstder_out2_zero::Bool,
is_seconder_out2_zero::Bool,
) where {T<:HessianTracer}
t1 = hessian_tracer_1_to_1(t, is_firstder_out1_zero, is_seconder_out1_zero)
t2 = hessian_tracer_1_to_1(t, is_firstder_out2_zero, is_seconder_out2_zero)
return (t1, t2)
sg, sh = gradient(t), hessian(t)
(sg_out1, sh_out1), (sg_out2, sh_out2) = hessian_tracer_1_to_2(
sg,
sh,
is_firstder_out1_zero,
is_seconder_out1_zero,
is_firstder_out2_zero,
is_seconder_out2_zero,
)
return (T(sg_out1, sh_out1), T(sg_out2, sh_out2))
end

function hessian_tracer_1_to_2(
sg::SG,
sh::SH,
is_firstder_out1_zero::Bool,
is_secondder_out1_zero::Bool,
is_firstder_out2_zero::Bool,
is_secondder_out2_zero::Bool,
) where {I,SG<:AbstractSet{I},SH<:AbstractSet{<:Tuple{I,I}}}
sg_out1, sh_out1 = hessian_tracer_1_to_1(
sg, sh, is_firstder_out1_zero, is_secondder_out1_zero
)
sg_out2, sh_out2 = hessian_tracer_1_to_1(
sg, sh, is_firstder_out2_zero, is_secondder_out2_zero
)
return ((sg_out1, sh_out1), (sg_out2, sh_out2))
end

function overload_hessian_1_to_2(M, op)
Expand Down Expand Up @@ -208,20 +254,22 @@ end
## Exponent (requires extra types)
for S in (Integer, Rational, Irrational{:ℯ})
function Base.:^(tx::T, y::S) where {T<:HessianTracer}
return T(gradient(tx), hessian(tx) (gradient(tx) × gradient(tx)))
return T(gradient(tx), union_product(hessian(tx), gradient(tx), gradient(tx)))
end
function Base.:^(x::S, ty::T) where {T<:HessianTracer}
return T(gradient(ty), hessian(ty) (gradient(ty) × gradient(ty)))
return T(gradient(ty), union_product(hessian(ty), gradient(ty), gradient(ty)))
end

function Base.:^(dx::D, y::S) where {P,T<:HessianTracer,D<:Dual{P,T}}
return Dual(
primal(dx)^y, T(gradient(dx), hessian(dx) (gradient(dx) × gradient(dx)))
primal(dx)^y,
T(gradient(dx), union_product(hessian(dx), gradient(dx), gradient(dx))),
)
end
function Base.:^(x::S, dy::D) where {P,T<:HessianTracer,D<:Dual{P,T}}
return Dual(
x^primal(dy), T(gradient(dy), hessian(dy) (gradient(dy) × gradient(dy)))
x^primal(dy),
T(gradient(dy), union_product(hessian(dy), gradient(dy), gradient(dy))),
)
end
end
Expand Down
10 changes: 7 additions & 3 deletions src/overload_ifelse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
end

## output union on scalar outputs
output_union(tx::C, ty::C) where {C<:ConnectivityTracer} = C(inputs(tx) inputs(ty))
output_union(tx::G, ty::G) where {G<:GradientTracer} = G(gradient(tx) gradient(ty))
function output_union(tx::C, ty::C) where {C<:ConnectivityTracer}
return C(union(inputs(tx), inputs(ty)))
end
function output_union(tx::G, ty::G) where {G<:GradientTracer}
return G(union(gradient(tx), gradient(ty)))
end
function output_union(tx::H, ty::H) where {H<:HessianTracer}
return H(gradient(tx) gradient(ty), hessian(tx) hessian(ty))
return H(union(gradient(tx), gradient(ty)), union(hessian(tx), hessian(ty)))
end

output_union(tx::AbstractTracer, y) = tx
Expand Down
Loading

0 comments on commit ee6c15d

Please sign in to comment.