diff --git a/README.md b/README.md index 8a0117a1..439d00d6 100644 --- a/README.md +++ b/README.md @@ -6,17 +6,20 @@ [![Coverage](https://codecov.io/gh/adrhill/SparseConnectivityTracer.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/adrhill/SparseConnectivityTracer.jl) [![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) -Fast sparsity detection via operator-overloading. - -Will soon include Hessian sparsity detection ([#20](https://github.com/adrhill/SparseConnectivityTracer.jl/issues/20)). +Fast Jacobian and Hessian sparsity detection via operator-overloading. ## Installation To install this package, open the Julia REPL and run + ```julia-repl julia> ]add SparseConnectivityTracer ``` ## Examples +### Jacobian + +For functions `y = f(x)` and `f!(y, x)`, the sparsity pattern of the Jacobian of $f$ can be obtained +by computing a single forward-pass through `f`: ```julia-repl julia> using SparseConnectivityTracer @@ -63,10 +66,37 @@ julia> pattern(layer, JacobianTracer, x) ⎣⠀⠀⠀⠙⢷⣄⠀⠀⠈⠻⣦⠀⠀⠀⠙⢦⡀⎦ ``` -SparseConnectivityTracer enumerates inputs `x` and primal outputs `y = f(x)` and returns a sparse matrix `C` of size $m \times n$, where `C[i, j]` is `true` if the compute graph connects the $j$-th entry in `x` to the $i$-th entry in `y`. +### Hessian + +For scalar functions `y = f(x)`, the sparsity pattern of the Hessian of $f$ can be obtained +by computing a single forward-pass through `f`: + +```julia-repl +julia> x = rand(5); + +julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + 1*x[5]; + +julia> pattern(f, HessianTracer, x) +5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 3 stored entries: + ⋅ ⋅ ⋅ ⋅ ⋅ + ⋅ ⋅ 1 ⋅ ⋅ + ⋅ 1 ⋅ ⋅ ⋅ + ⋅ ⋅ ⋅ 1 ⋅ + ⋅ ⋅ ⋅ ⋅ ⋅ + +julia> g(x) = f(x) + x[2]^x[5]; + +julia> pattern(g, HessianTracer, x) +5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 7 stored entries: + ⋅ ⋅ ⋅ ⋅ ⋅ + ⋅ 1 1 ⋅ 1 + ⋅ 1 ⋅ ⋅ ⋅ + ⋅ ⋅ ⋅ 1 ⋅ + ⋅ 1 ⋅ ⋅ 1 +``` For more detailled examples, take a look at the [documentation](https://adrianhill.de/SparseConnectivityTracer.jl/dev). ## Related packages * [SparseDiffTools.jl](https://github.com/JuliaDiff/SparseDiffTools.jl): automatic sparsity detection via Symbolics.jl and Cassette.jl -* [SparsityTracing.jl](https://github.com/PALEOtoolkit/SparsityTracing.jl): automatic Jacobian sparsity detection using an algorithm based on SparsLinC by Bischof et al. (1996) \ No newline at end of file +* [SparsityTracing.jl](https://github.com/PALEOtoolkit/SparsityTracing.jl): automatic Jacobian sparsity detection using an algorithm based on SparsLinC by Bischof et al. (1996) diff --git a/docs/src/api.md b/docs/src/api.md index d3e6c2a4..7640ebc0 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -19,8 +19,9 @@ SparseConnectivityTracer works by pushing `Number` types called tracers through Currently, two tracer types are provided: ```@docs -JacobianTracer ConnectivityTracer +JacobianTracer +HessianTracer ``` Utilities to create tracers: diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 039d9896..a7c13f1d 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -1,8 +1,8 @@ module SparseConnectivityTracer using ADTypes: ADTypes -import Random: rand, AbstractRNG, SamplerType import SparseArrays: sparse +import Random: rand, AbstractRNG, SamplerType abstract type AbstractTracer <: Number end @@ -11,10 +11,11 @@ include("conversion.jl") include("operators.jl") include("overload_connectivity.jl") include("overload_jacobian.jl") +include("overload_hessian.jl") include("pattern.jl") include("adtypes.jl") -export JacobianTracer, ConnectivityTracer +export JacobianTracer, ConnectivityTracer, HessianTracer export tracer, trace_input export inputs export pattern diff --git a/src/adtypes.jl b/src/adtypes.jl index 25a39ea4..cfb78bf9 100644 --- a/src/adtypes.jl +++ b/src/adtypes.jl @@ -26,6 +26,5 @@ function ADTypes.jacobian_sparsity(f!, y, x, ::TracerSparsityDetector) end function ADTypes.hessian_sparsity(f, x, ::TracerSparsityDetector) - # TODO: return pattern(f, HessianTracer, x) - return error("Hessian sparsity is not yet implemented for `TracerSparsityDetector`.") + return pattern(f, HessianTracer, x) end diff --git a/src/conversion.jl b/src/conversion.jl index fcdc5b0d..2856879d 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -1,5 +1,5 @@ ## Type conversions -for T in (:JacobianTracer, :ConnectivityTracer) +for T in (:JacobianTracer, :ConnectivityTracer, :HessianTracer) @eval Base.promote_rule(::Type{$T}, ::Type{N}) where {N<:Number} = $T @eval Base.promote_rule(::Type{N}, ::Type{$T}) where {N<:Number} = $T @@ -11,10 +11,13 @@ for T in (:JacobianTracer, :ConnectivityTracer) @eval Base.convert(::Type{$T}, t::$T) = t @eval Base.convert(::Type{<:Number}, t::$T) = t - ## Array constructors - @eval Base.zero(::Type{$T}) = empty($T) - @eval Base.one(::Type{$T}) = empty($T) + ## Constants + @eval Base.zero(::Type{$T}) = empty($T) + @eval Base.one(::Type{$T}) = empty($T) + @eval Base.typemin(::Type{$T}) = empty($T) + @eval Base.typemax(::Type{$T}) = empty($T) + ## Array constructors @eval Base.similar(a::Array{$T,1}) = zeros($T, size(a, 1)) @eval Base.similar(a::Array{$T,2}) = zeros($T, size(a, 1), size(a, 2)) @eval Base.similar(a::Array{A,1}, ::Type{$T}) where {A} = zeros($T, size(a, 1)) diff --git a/src/operators.jl b/src/operators.jl index d4377825..88089676 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -47,6 +47,7 @@ ops_1_to_1_s = ( # ∂²f/∂x² == 0 ops_1_to_1_f = ( :+, :-, + :identity, :abs, :hypot, :deg2rad, :rad2deg, :mod2pi, :prevfloat, :nextfloat, @@ -67,8 +68,8 @@ ops_1_to_1_z = ( ops_1_to_1_const = ( :zero, :one, :eps, - :typemax, - # :floatmin, :floatmax, :maxintfloat, + :typemin, :typemax, + :floatmin, :floatmax, :maxintfloat, ) ops_1_to_1 = union( @@ -89,7 +90,7 @@ ops_1_to_1 = union( # ∂²f/∂y² != 0 # ∂²f/∂x∂y != 0 ops_2_to_1_ssc = ( - :hypot, + :^, :hypot ) # ops_2_to_1_ssz: diff --git a/src/overload_connectivity.jl b/src/overload_connectivity.jl index 2bd667b3..406a8c75 100644 --- a/src/overload_connectivity.jl +++ b/src/overload_connectivity.jl @@ -17,7 +17,6 @@ for fn in ops_2_to_1 end # Extra types required for exponent -Base.:^(a::ConnectivityTracer, b::ConnectivityTracer) = uniontracer(a, b) for T in (:Real, :Integer, :Rational) @eval Base.:^(t::ConnectivityTracer, ::$T) = t @eval Base.:^(::$T, t::ConnectivityTracer) = t diff --git a/src/overload_hessian.jl b/src/overload_hessian.jl new file mode 100644 index 00000000..443f658a --- /dev/null +++ b/src/overload_hessian.jl @@ -0,0 +1,128 @@ +## 1-to-1 +for fn in ops_1_to_1_s + @eval Base.$fn(t::HessianTracer) = promote_order(t) +end +for fn in ops_1_to_1_f + @eval Base.$fn(t::HessianTracer) = t +end + +for fn in union(ops_1_to_1_z, ops_1_to_1_const) + @eval Base.$fn(::HessianTracer) = EMPTY_HESSIAN_TRACER +end + +## 2-to-1 +# Including second-order only +for fn in ops_2_to_1_ssc + @eval function Base.$fn(a::HessianTracer, b::HessianTracer) + a = promote_order(a) + b = promote_order(b) + return distributive_merge(a, b) + end + @eval Base.$fn(t::HessianTracer, ::Number) = promote_order(t) + @eval Base.$fn(::Number, t::HessianTracer) = promote_order(t) +end + +for fn in ops_2_to_1_ssz + @eval function Base.$fn(a::HessianTracer, b::HessianTracer) + a = promote_order(a) + b = promote_order(b) + return additive_merge(a, b) + end + @eval Base.$fn(t::HessianTracer, ::Number) = promote_order(t) + @eval Base.$fn(::Number, t::HessianTracer) = promote_order(t) +end + +# Including second- and first-order +for fn in ops_2_to_1_sfc + @eval function Base.$fn(a::HessianTracer, b::HessianTracer) + a = promote_order(a) + return distributive_merge(a, b) + end + @eval Base.$fn(t::HessianTracer, ::Number) = promote_order(t) + @eval Base.$fn(::Number, t::HessianTracer) = t +end + +for fn in ops_2_to_1_sfz + @eval function Base.$fn(a::HessianTracer, b::HessianTracer) + a = promote_order(a) + return additive_merge(a, b) + end + @eval Base.$fn(t::HessianTracer, ::Number) = promote_order(t) + @eval Base.$fn(::Number, t::HessianTracer) = t +end + +for fn in ops_2_to_1_fsc + @eval function Base.$fn(a::HessianTracer, b::HessianTracer) + b = promote_order(b) + return distributive_merge(a, b) + end + @eval Base.$fn(t::HessianTracer, ::Number) = t + @eval Base.$fn(::Number, t::HessianTracer) = promote_order(t) +end + +for fn in ops_2_to_1_fsz + @eval function Base.$fn(a::HessianTracer, b::HessianTracer) + b = promote_order(b) + return additive_merge(a, b) + end + @eval Base.$fn(t::HessianTracer, ::Number) = t + @eval Base.$fn(::Number, t::HessianTracer) = promote_order(t) +end + +# Including first-order only +for fn in ops_2_to_1_ffc + @eval Base.$fn(a::HessianTracer, b::HessianTracer) = distributive_merge(a, b) + @eval Base.$fn(t::HessianTracer, ::Number) = t + @eval Base.$fn(::Number, t::HessianTracer) = t +end + +for fn in ops_2_to_1_ffz + @eval Base.$fn(a::HessianTracer, b::HessianTracer) = additive_merge(a, b) + @eval Base.$fn(t::HessianTracer, ::Number) = t + @eval Base.$fn(::Number, t::HessianTracer) = t +end + +# Including zero-order +for fn in ops_2_to_1_szz + @eval Base.$fn(t::HessianTracer, ::HessianTracer) = promote_order(t) + @eval Base.$fn(t::HessianTracer, ::Number) = promote_order(t) + @eval Base.$fn(::Number, t::HessianTracer) = EMPTY_HESSIAN_TRACER +end + +for fn in ops_2_to_1_zsz + @eval Base.$fn(::HessianTracer, t::HessianTracer) = promote_order(t) + @eval Base.$fn(::HessianTracer, ::Number) = EMPTY_HESSIAN_TRACER + @eval Base.$fn(::Number, t::HessianTracer) = promote_order(t) +end + +for fn in ops_2_to_1_fzz + @eval Base.$fn(t::HessianTracer, ::HessianTracer) = t + @eval Base.$fn(t::HessianTracer, ::Number) = t + @eval Base.$fn(::Number, t::HessianTracer) = EMPTY_HESSIAN_TRACER +end + +for fn in ops_2_to_1_zfz + @eval Base.$fn(::HessianTracer, t::HessianTracer) = t + @eval Base.$fn(::HessianTracer, ::Number) = EMPTY_HESSIAN_TRACER + @eval Base.$fn(::Number, t::HessianTracer) = t +end + +for fn in ops_2_to_1_zzz + @eval Base.$fn(::HessianTracer, t::HessianTracer) = EMPTY_HESSIAN_TRACER + @eval Base.$fn(::HessianTracer, ::Number) = EMPTY_HESSIAN_TRACER + @eval Base.$fn(::Number, t::HessianTracer) = EMPTY_HESSIAN_TRACER +end + +# Extra types required for exponent +for T in (:Real, :Integer, :Rational) + @eval Base.:^(t::HessianTracer, ::$T) = promote_order(t) + @eval Base.:^(::$T, t::HessianTracer) = promote_order(t) +end +Base.:^(t::HessianTracer, ::Irrational{:ℯ}) = promote_order(t) +Base.:^(::Irrational{:ℯ}, t::HessianTracer) = promote_order(t) + +## Rounding +Base.round(t::HessianTracer, ::RoundingMode; kwargs...) = EMPTY_HESSIAN_TRACER + +## Random numbers +rand(::AbstractRNG, ::SamplerType{HessianTracer}) = EMPTY_HESSIAN_TRACER diff --git a/src/overload_jacobian.jl b/src/overload_jacobian.jl index 6b4d9b15..8fad8b31 100644 --- a/src/overload_jacobian.jl +++ b/src/overload_jacobian.jl @@ -53,7 +53,6 @@ for fn in ops_1_to_2_zz end # Extra types required for exponent -Base.:^(a::JacobianTracer, b::JacobianTracer) = uniontracer(a, b) for T in (:Real, :Integer, :Rational) @eval Base.:^(t::JacobianTracer, ::$T) = t @eval Base.:^(::$T, t::JacobianTracer) = t diff --git a/src/pattern.jl b/src/pattern.jl index cde0afbe..84124dcc 100644 --- a/src/pattern.jl +++ b/src/pattern.jl @@ -36,16 +36,22 @@ end ## Construct sparsity pattern matrix """ + pattern(f, ConnectivityTracer, x) + +Enumerates inputs `x` and primal outputs `y = f(x)` and returns sparse matrix `C` of size `(m, n)` +where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. + pattern(f, JacobianTracer, x) Computes the sparsity pattern of the Jacobian of `y = f(x)`. - pattern(f, ConnectivityTracer, x) + pattern(f, HessianTracer, x) -Enumerates inputs `x` and primal outputs `y = f(x)` and returns sparse matrix `C` of size `(m, n)` -where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. +Computes the sparsity pattern of the Hessian of a scalar function `y = f(x)`. + +## Examples +### First order -## Example ```jldoctest julia> x = rand(3); @@ -57,6 +63,32 @@ julia> pattern(f, ConnectivityTracer, x) 1 1 ⋅ ⋅ ⋅ 1 ``` + +### Second order + +```jldoctest +julia> x = rand(5); + +julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + 1*x[5]; + +julia> pattern(f, HessianTracer, x) +5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 3 stored entries: + ⋅ ⋅ ⋅ ⋅ ⋅ + ⋅ ⋅ 1 ⋅ ⋅ + ⋅ 1 ⋅ ⋅ ⋅ + ⋅ ⋅ ⋅ 1 ⋅ + ⋅ ⋅ ⋅ ⋅ ⋅ + +julia> g(x) = f(x) + x[2]^x[5]; + +julia> pattern(g, HessianTracer, x) +5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 7 stored entries: + ⋅ ⋅ ⋅ ⋅ ⋅ + ⋅ 1 1 ⋅ 1 + ⋅ 1 ⋅ ⋅ ⋅ + ⋅ ⋅ ⋅ 1 ⋅ + ⋅ 1 ⋅ ⋅ 1 +``` """ function pattern(f, ::Type{T}, x) where {T<:AbstractTracer} xt = trace_input(T, x) @@ -89,15 +121,15 @@ function _pattern(xt::AbstractArray{<:AbstractTracer}, yt::AbstractArray{<:Numbe end function _pattern_to_sparsemat( - xt::AbstractArray{<:AbstractTracer}, yt::AbstractArray{<:Number} -) + xt::AbstractArray{T}, yt::AbstractArray{<:Number} +) where {T<:AbstractTracer} # Construct matrix of size (ouput_dim, input_dim) n, m = length(xt), length(yt) - I = UInt64[] - J = UInt64[] - V = Bool[] + I = UInt64[] # row indices + J = UInt64[] # column indices + V = Bool[] # values for (i, y) in enumerate(yt) - if y isa AbstractTracer + if y isa T for j in inputs(y) push!(I, i) push!(J, j) @@ -107,3 +139,26 @@ function _pattern_to_sparsemat( end return sparse(I, J, V, m, n) end + +function _pattern_to_sparsemat( + xt::AbstractArray{HessianTracer}, yt::AbstractArray{HessianTracer} +) + length(yt) != 1 && error("pattern(f, HessianTracer, x) expects scalar output y=f(x).") + y = only(yt) + + # Allocate Hessian matrix + n = length(xt) + I = UInt64[] # row indices + J = UInt64[] # column indices + V = Bool[] # values + + for i in keys(y.inputs) + for j in y.inputs[i] + push!(I, i) + push!(J, j) + push!(V, true) + end + end + H = sparse(I, J, V, n, n) + return H +end diff --git a/src/tracers.jl b/src/tracers.jl index c932e492..b4ddf43d 100644 --- a/src/tracers.jl +++ b/src/tracers.jl @@ -18,7 +18,9 @@ function Base.show(io::IO, t::ConnectivityTracer) return Base.show_delim_array(io, inputs(t), "ConnectivityTracer(", ',', ')', true) end -const EMPTY_CONNECTIVITY_TRACER = ConnectivityTracer(BitSet()) +const EMPTY_CONNECTIVITY_TRACER = ConnectivityTracer(BitSet()) +empty(::ConnectivityTracer) = EMPTY_CONNECTIVITY_TRACER +empty(::Type{ConnectivityTracer}) = EMPTY_CONNECTIVITY_TRACER # We have to be careful when defining constructors: # Generic code expecting "regular" numbers `x` will sometimes convert them @@ -27,6 +29,11 @@ const EMPTY_CONNECTIVITY_TRACER = ConnectivityTracer(BitSet()) ConnectivityTracer(::Number) = EMPTY_CONNECTIVITY_TRACER ConnectivityTracer(t::ConnectivityTracer) = t +## Unions of tracers +function uniontracer(a::ConnectivityTracer, b::ConnectivityTracer) + return ConnectivityTracer(union(a.inputs, b.inputs)) +end + #==========# # Jacobian # #==========# @@ -40,18 +47,95 @@ See also the convenience constructor [`tracer`](@ref). For a higher-level interface, refer to [`pattern`](@ref). """ struct JacobianTracer <: AbstractTracer - inputs::BitSet # indices of connected, enumerated inputs + inputs::BitSet end function Base.show(io::IO, t::JacobianTracer) return Base.show_delim_array(io, inputs(t), "JacobianTracer(", ',', ')', true) end -const EMPTY_JACOBIAN_TRACER = JacobianTracer(BitSet()) +const EMPTY_JACOBIAN_TRACER = JacobianTracer(BitSet()) +empty(::JacobianTracer) = EMPTY_JACOBIAN_TRACER +empty(::Type{JacobianTracer}) = EMPTY_JACOBIAN_TRACER JacobianTracer(::Number) = EMPTY_JACOBIAN_TRACER JacobianTracer(t::JacobianTracer) = t +## Unions of tracers +function uniontracer(a::JacobianTracer, b::JacobianTracer) + return JacobianTracer(union(a.inputs, b.inputs)) +end + +#=========# +# Hessian # +#=========# +const HessianDict = Dict{UInt64,BitSet} +""" + HessianTracer(indexset) <: Number + +Number type keeping track of input indices of previous computations with non-zero first and second derivatives. + +See also the convenience constructor [`tracer`](@ref). +For a higher-level interface, refer to [`pattern`](@ref). +""" +struct HessianTracer <: AbstractTracer + inputs::HessianDict +end +function Base.show(io::IO, t::HessianTracer) + println(io, "HessianTracer(") + for key in keys(t.inputs) + print(io, " ", key, " => ") + Base.show_delim_array(io, collect(t.inputs[key]), "(", ',', ')', true) + println(io, ",") + end + return println(io, ")") +end + +const EMPTY_HESSIAN_TRACER = HessianTracer(HessianDict()) +empty(::HessianTracer) = EMPTY_HESSIAN_TRACER +empty(::Type{HessianTracer}) = EMPTY_HESSIAN_TRACER + +HessianTracer(::Number) = empty(HessianTracer) +HessianTracer(t::HessianTracer) = t + +# Turn first-order interactions into second-order interactions +function promote_order(t::HessianTracer) + d = deepcopy(t.inputs) + for k in keys(d) + union!(d[k], k) + end + return HessianTracer(d) +end + +# Merge first- and second-order terms in an "additive" fashion +function additive_merge(a::HessianTracer, b::HessianTracer) + da = deepcopy(a.inputs) + db = b.inputs + for k in keys(db) + if haskey(da, k) + union!(da[k], db[k]) + else + push!(da, k => db[k]) + end + end + return HessianTracer(da) +end + +# Merge first- and second-order terms in a "distributive" fashion +function distributive_merge(a::HessianTracer, b::HessianTracer) + da = deepcopy(a.inputs) + db = deepcopy(b.inputs) + for ka in keys(da) + for kb in keys(db) + # add second-order interaction term + union!(da[ka], kb) + union!(db[kb], ka) + end + end + merge!(da, db) + return HessianTracer(da) +end + #===========# # Utilities # #===========# @@ -77,21 +161,6 @@ julia> inputs(t) inputs(t::ConnectivityTracer) = collect(t.inputs) inputs(t::JacobianTracer) = collect(t.inputs) -## Unions of tracers -function uniontracer(a::ConnectivityTracer, b::ConnectivityTracer) - return ConnectivityTracer(union(a.inputs, b.inputs)) -end - -function uniontracer(a::JacobianTracer, b::JacobianTracer) - return JacobianTracer(union(a.inputs, b.inputs)) -end - -## Get empty tracer -empty(::JacobianTracer) = EMPTY_JACOBIAN_TRACER -empty(::Type{JacobianTracer}) = EMPTY_JACOBIAN_TRACER -empty(::ConnectivityTracer) = EMPTY_CONNECTIVITY_TRACER -empty(::Type{ConnectivityTracer}) = EMPTY_CONNECTIVITY_TRACER - """ tracer(JacobianTracer, index) tracer(JacobianTracer, indices) @@ -102,6 +171,9 @@ Convenience constructor for [`JacobianTracer`](@ref) [`ConnectivityTracer`](@ref """ tracer(::Type{JacobianTracer}, index::Integer) = JacobianTracer(BitSet(index)) tracer(::Type{ConnectivityTracer}, index::Integer) = ConnectivityTracer(BitSet(index)) +function tracer(::Type{HessianTracer}, index::Integer) + return HessianTracer(Dict{UInt64,BitSet}(index => BitSet())) +end function tracer(::Type{JacobianTracer}, inds::NTuple{N,<:Integer}) where {N} return JacobianTracer(BitSet(inds)) @@ -109,5 +181,8 @@ end function tracer(::Type{ConnectivityTracer}, inds::NTuple{N,<:Integer}) where {N} return ConnectivityTracer(BitSet(inds)) end +function tracer(::Type{HessianTracer}, inds::NTuple{N,<:Integer}) where {N} + return HessianTracer(Dict{UInt64,BitSet}(i => BitSet() for i in inds)) +end tracer(::Type{T}, inds...) where {T<:AbstractTracer} = tracer(T, inds) diff --git a/test/adtypes.jl b/test/adtypes.jl index f7cd0ad4..81628282 100644 --- a/test/adtypes.jl +++ b/test/adtypes.jl @@ -13,4 +13,17 @@ J2 = ADTypes.jacobian_sparsity((y, x) -> y .= diff(x), y, x, sd) @test J1 isa SparseMatrixCSC @test J2 isa SparseMatrixCSC @test nnz(J1) == nnz(J2) == 18 -@test_throws ErrorException ADTypes.hessian_sparsity(sum, x, sd) + +H1 = ADTypes.hessian_sparsity(x -> sum(diff(x)), x, sd) +@test H1 ≈ zeros(10, 10) + +x = rand(5) +f(x) = x[1] + x[2] * x[3] + 1 / x[4] + 1 * x[5] +H2 = ADTypes.hessian_sparsity(f, x, sd) +@test H2 ≈ [ + 0 0 0 0 0 + 0 0 1 0 0 + 0 1 0 0 0 + 0 0 0 1 0 + 0 0 0 0 0 +] diff --git a/test/runtests.jl b/test/runtests.jl index 3217f9d2..154bedbd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -43,7 +43,7 @@ DocMeta.setdocmeta!( @testset "Classification of operators by diff'ability" begin include("test_differentiability.jl") end - @testset "Connectivity" begin + @testset "First order" begin x = rand(3) xt = trace_input(ConnectivityTracer, x) @@ -75,6 +75,34 @@ DocMeta.setdocmeta!( @test pattern(g, ConnectivityTracer, x) ≈ [1 1; 1 1; 1 1] @test pattern(g, JacobianTracer, x) ≈ [1 1; 0 0; 1 0] end + @testset "Second order" begin + @test pattern(identity, HessianTracer, rand()) ≈ [0;;] + @test pattern(sqrt, HessianTracer, rand()) ≈ [1;;] + + @test pattern(x -> 1 * x, HessianTracer, rand()) ≈ [0;;] + @test pattern(x -> x * 1, HessianTracer, rand()) ≈ [0;;] + + x = rand(5) + f(x) = x[1] + x[2] * x[3] + 1 / x[4] + 1 * x[5] + H = pattern(f, HessianTracer, x) + @test H ≈ [ + 0 0 0 0 0 + 0 0 1 0 0 + 0 1 0 0 0 + 0 0 0 1 0 + 0 0 0 0 0 + ] + + g(x) = f(x) + x[2]^x[5] + H = pattern(g, HessianTracer, x) + @test H ≈ [ + 0 0 0 0 0 + 0 1 1 0 1 + 0 1 0 0 0 + 0 0 0 1 0 + 0 1 0 0 1 + ] + end @testset "Real-world tests" begin @testset "NNlib" begin x = rand(3, 3, 2, 1) # WHCN diff --git a/test/test_differentiability.jl b/test/test_differentiability.jl index fad92e48..1c79dae8 100644 --- a/test/test_differentiability.jl +++ b/test/test_differentiability.jl @@ -97,13 +97,8 @@ const TEST_1_TO_1 = ( ("Second order", ops_1_to_1_s, second_order), ("First order", ops_1_to_1_f, first_order), ("Zero order", ops_1_to_1_z, zero_order), - ("Constant", ops_1_to_1_const, zero_order), ) @testset verbose = true "1-to-1" begin - @testset "All operators covered" begin - all_ops = union([ops for (name, ops, ref_order) in TEST_1_TO_1]...) - @test Set(all_ops) == Set(ops_1_to_1) - end for (name, ops, ref_order) in TEST_1_TO_1 @testset "$name" begin for op in ops