Skip to content

Commit

Permalink
Support Hessians on non-tracer outputs (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored May 23, 2024
1 parent bfc8ee3 commit 5598602
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/pattern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,13 @@ function hessian_pattern_to_mat(
) where {P1,P2,T<:HessianTracer,D1<:Dual{P1,T},D2<:Dual{P2,T}}
return hessian_pattern_to_mat(tracer.(xt), tracer(yt))
end

function hessian_pattern_to_mat(xt::AbstractArray{T}, yt::Number) where {T<:HessianTracer}
return hessian_pattern_to_mat(xt, empty(T))
end

function hessian_pattern_to_mat(
xt::AbstractArray{D1}, yt::Number
) where {P1,T<:HessianTracer,D1<:Dual{P1,T}}
return hessian_pattern_to_mat(tracer.(xt), empty(T))
end
2 changes: 2 additions & 0 deletions test/test_connectivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ NNLIB_ACTIVATIONS = union(NNLIB_ACTIVATIONS_S, NNLIB_ACTIVATIONS_F)
@test connectivity_pattern(x -> x^ℯ, 1, G) [1;;]
@test connectivity_pattern(x ->^x, 1, G) [1;;]
@test connectivity_pattern(x -> round(x, RoundNearestTiesUp), 1, G) [1;;]
@test connectivity_pattern(x -> 0, 1, G) [0;;]

# SpecialFunctions extension
@test connectivity_pattern(x -> erf(x[1]), rand(2), G) == [1 0]
Expand All @@ -87,5 +88,6 @@ end
@test local_connectivity_pattern(
x -> ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]), [1 3 2 4], G
) == [0 0 1 1]
@test local_connectivity_pattern(x -> 0, 1, G) [0;;]
end
end
2 changes: 2 additions & 0 deletions test/test_gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ NNLIB_ACTIVATIONS = union(NNLIB_ACTIVATIONS_S, NNLIB_ACTIVATIONS_F)
@test jacobian_sparsity(x -> x^ℯ, 1, method) [1;;]
@test jacobian_sparsity(x ->^x, 1, method) [1;;]
@test jacobian_sparsity(x -> round(x, RoundNearestTiesUp), 1, method) [0;;]
@test jacobian_sparsity(x -> 0, 1, method) [0;;]

# Linear Algebra
@test jacobian_sparsity(x -> dot(x[1:2], x[4:5]), rand(5), method) == [1 1 0 1 1]
Expand Down Expand Up @@ -128,6 +129,7 @@ end
@test jacobian_sparsity(x -> x^ℯ, 1, method) [1;;]
@test jacobian_sparsity(x ->^x, 1, method) [1;;]
@test jacobian_sparsity(x -> round(x, RoundNearestTiesUp), 1, method) [0;;]
@test jacobian_sparsity(x -> 0, 1, method) [0;;]

# Linear algebra
@test jacobian_sparsity(logdet, [1.0 -1.0; 2.0 2.0], method) == [1 1 1 1] # (#68)
Expand Down
2 changes: 2 additions & 0 deletions test/test_hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ const SECOND_ORDER_SET_TYPES = (
@test hessian_sparsity(x -> x^ℯ, 1, method) [1;;]
@test hessian_sparsity(x ->^x, 1, method) [1;;]
@test hessian_sparsity(x -> round(x, RoundNearestTiesUp), 1, method) [0;;]
@test hessian_sparsity(x -> 0, 1, method) [0;;]

h = hessian_sparsity(x -> x[1] / x[2] + x[3] / 1 + 1 / x[4], rand(4), method)
@test h == [
Expand Down Expand Up @@ -211,5 +212,6 @@ end
@test hessian_sparsity(x -> (2//3)^x, 1, method) [1;;]
@test hessian_sparsity(x -> x^ℯ, 1, method) [1;;]
@test hessian_sparsity(x ->^x, 1, method) [1;;]
@test hessian_sparsity(x -> 0, 1, method) [0;;]
end
end

0 comments on commit 5598602

Please sign in to comment.