Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HessianTracer #24

Merged
merged 10 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@
[![Coverage](https://codecov.io/gh/adrhill/SparseConnectivityTracer.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/adrhill/SparseConnectivityTracer.jl)
[![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

Fast sparsity detection via operator-overloading.

Will soon include Hessian sparsity detection ([#20](https://github.com/adrhill/SparseConnectivityTracer.jl/issues/20)).
Fast Jacobian and Hessian sparsity detection via operator-overloading.

## Installation
To install this package, open the Julia REPL and run

```julia-repl
julia> ]add SparseConnectivityTracer
```

## Examples
### Jacobian

For functions `y = f(x)` and `f!(y, x)`, the sparsity pattern of the Jacobian of $f$ can be obtained
by computing a single forward-pass through `f`:

```julia-repl
julia> using SparseConnectivityTracer
Expand Down Expand Up @@ -63,10 +66,37 @@ julia> pattern(layer, JacobianTracer, x)
⎣⠀⠀⠀⠙⢷⣄⠀⠀⠈⠻⣦⠀⠀⠀⠙⢦⡀⎦
```

SparseConnectivityTracer enumerates inputs `x` and primal outputs `y = f(x)` and returns a sparse matrix `C` of size $m \times n$, where `C[i, j]` is `true` if the compute graph connects the $j$-th entry in `x` to the $i$-th entry in `y`.
### Hessian

For scalar functions `y = f(x)`, the sparsity pattern of the Hessian of $f$ can be obtained
by computing a single forward-pass through `f`:

```julia-repl
julia> x = rand(5);

julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + 1*x[5];

julia> pattern(f, HessianTracer, x)
5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 3 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ 1 ⋅ ⋅
⋅ 1 ⋅ ⋅ ⋅
⋅ ⋅ ⋅ 1 ⋅
⋅ ⋅ ⋅ ⋅ ⋅

julia> g(x) = f(x) + x[2]^x[5];

julia> pattern(g, HessianTracer, x)
5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 7 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
⋅ 1 1 ⋅ 1
⋅ 1 ⋅ ⋅ ⋅
⋅ ⋅ ⋅ 1 ⋅
⋅ 1 ⋅ ⋅ 1
```

For more detailled examples, take a look at the [documentation](https://adrianhill.de/SparseConnectivityTracer.jl/dev).

## Related packages
* [SparseDiffTools.jl](https://github.com/JuliaDiff/SparseDiffTools.jl): automatic sparsity detection via Symbolics.jl and Cassette.jl
* [SparsityTracing.jl](https://github.com/PALEOtoolkit/SparsityTracing.jl): automatic Jacobian sparsity detection using an algorithm based on SparsLinC by Bischof et al. (1996)
* [SparsityTracing.jl](https://github.com/PALEOtoolkit/SparsityTracing.jl): automatic Jacobian sparsity detection using an algorithm based on SparsLinC by Bischof et al. (1996)
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ SparseConnectivityTracer works by pushing `Number` types called tracers through
Currently, two tracer types are provided:

```@docs
JacobianTracer
ConnectivityTracer
JacobianTracer
HessianTracer
```

Utilities to create tracers:
Expand Down
5 changes: 3 additions & 2 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module SparseConnectivityTracer

using ADTypes: ADTypes
import Random: rand, AbstractRNG, SamplerType
import SparseArrays: sparse
import Random: rand, AbstractRNG, SamplerType

abstract type AbstractTracer <: Number end

Expand All @@ -11,10 +11,11 @@ include("conversion.jl")
include("operators.jl")
include("overload_connectivity.jl")
include("overload_jacobian.jl")
include("overload_hessian.jl")
include("pattern.jl")
include("adtypes.jl")

export JacobianTracer, ConnectivityTracer
export JacobianTracer, ConnectivityTracer, HessianTracer
export tracer, trace_input
export inputs
export pattern
Expand Down
3 changes: 1 addition & 2 deletions src/adtypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,5 @@ function ADTypes.jacobian_sparsity(f!, y, x, ::TracerSparsityDetector)
end

function ADTypes.hessian_sparsity(f, x, ::TracerSparsityDetector)
# TODO: return pattern(f, HessianTracer, x)
return error("Hessian sparsity is not yet implemented for `TracerSparsityDetector`.")
return pattern(f, HessianTracer, x)
end
11 changes: 7 additions & 4 deletions src/conversion.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
## Type conversions
for T in (:JacobianTracer, :ConnectivityTracer)
for T in (:JacobianTracer, :ConnectivityTracer, :HessianTracer)
@eval Base.promote_rule(::Type{$T}, ::Type{N}) where {N<:Number} = $T
@eval Base.promote_rule(::Type{N}, ::Type{$T}) where {N<:Number} = $T

Expand All @@ -11,10 +11,13 @@ for T in (:JacobianTracer, :ConnectivityTracer)
@eval Base.convert(::Type{$T}, t::$T) = t
@eval Base.convert(::Type{<:Number}, t::$T) = t

## Array constructors
@eval Base.zero(::Type{$T}) = empty($T)
@eval Base.one(::Type{$T}) = empty($T)
## Constants
@eval Base.zero(::Type{$T}) = empty($T)
@eval Base.one(::Type{$T}) = empty($T)
@eval Base.typemin(::Type{$T}) = empty($T)
@eval Base.typemax(::Type{$T}) = empty($T)

## Array constructors
@eval Base.similar(a::Array{$T,1}) = zeros($T, size(a, 1))
@eval Base.similar(a::Array{$T,2}) = zeros($T, size(a, 1), size(a, 2))
@eval Base.similar(a::Array{A,1}, ::Type{$T}) where {A} = zeros($T, size(a, 1))
Expand Down
7 changes: 4 additions & 3 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ ops_1_to_1_s = (
# ∂²f/∂x² == 0
ops_1_to_1_f = (
:+, :-,
:identity,
:abs, :hypot,
:deg2rad, :rad2deg,
:mod2pi, :prevfloat, :nextfloat,
Expand All @@ -67,8 +68,8 @@ ops_1_to_1_z = (
ops_1_to_1_const = (
:zero, :one,
:eps,
:typemax,
# :floatmin, :floatmax, :maxintfloat,
:typemin, :typemax,
:floatmin, :floatmax, :maxintfloat,
)

ops_1_to_1 = union(
Expand All @@ -89,7 +90,7 @@ ops_1_to_1 = union(
# ∂²f/∂y² != 0
# ∂²f/∂x∂y != 0
ops_2_to_1_ssc = (
:hypot,
:^, :hypot
)

# ops_2_to_1_ssz:
Expand Down
1 change: 0 additions & 1 deletion src/overload_connectivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ for fn in ops_2_to_1
end

# Extra types required for exponent
Base.:^(a::ConnectivityTracer, b::ConnectivityTracer) = uniontracer(a, b)
for T in (:Real, :Integer, :Rational)
@eval Base.:^(t::ConnectivityTracer, ::$T) = t
@eval Base.:^(::$T, t::ConnectivityTracer) = t
Expand Down
128 changes: 128 additions & 0 deletions src/overload_hessian.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
## 1-to-1
for fn in ops_1_to_1_s
@eval Base.$fn(t::HessianTracer) = promote_order(t)
end
for fn in ops_1_to_1_f
@eval Base.$fn(t::HessianTracer) = t
end

for fn in union(ops_1_to_1_z, ops_1_to_1_const)
@eval Base.$fn(::HessianTracer) = EMPTY_HESSIAN_TRACER
end

## 2-to-1
# Including second-order only
for fn in ops_2_to_1_ssc
@eval function Base.$fn(a::HessianTracer, b::HessianTracer)
a = promote_order(a)
b = promote_order(b)
return distributive_merge(a, b)
end
@eval Base.$fn(t::HessianTracer, ::Number) = promote_order(t)
@eval Base.$fn(::Number, t::HessianTracer) = promote_order(t)
end

for fn in ops_2_to_1_ssz
@eval function Base.$fn(a::HessianTracer, b::HessianTracer)
a = promote_order(a)
b = promote_order(b)
return additive_merge(a, b)
end
@eval Base.$fn(t::HessianTracer, ::Number) = promote_order(t)
@eval Base.$fn(::Number, t::HessianTracer) = promote_order(t)
end

# Including second- and first-order
for fn in ops_2_to_1_sfc
@eval function Base.$fn(a::HessianTracer, b::HessianTracer)
a = promote_order(a)
return distributive_merge(a, b)
end
@eval Base.$fn(t::HessianTracer, ::Number) = promote_order(t)
@eval Base.$fn(::Number, t::HessianTracer) = t
end

for fn in ops_2_to_1_sfz
@eval function Base.$fn(a::HessianTracer, b::HessianTracer)
a = promote_order(a)
return additive_merge(a, b)
end
@eval Base.$fn(t::HessianTracer, ::Number) = promote_order(t)
@eval Base.$fn(::Number, t::HessianTracer) = t
end

for fn in ops_2_to_1_fsc
@eval function Base.$fn(a::HessianTracer, b::HessianTracer)
b = promote_order(b)
return distributive_merge(a, b)
end
@eval Base.$fn(t::HessianTracer, ::Number) = t
@eval Base.$fn(::Number, t::HessianTracer) = promote_order(t)
end

for fn in ops_2_to_1_fsz
@eval function Base.$fn(a::HessianTracer, b::HessianTracer)
b = promote_order(b)
return additive_merge(a, b)
end
@eval Base.$fn(t::HessianTracer, ::Number) = t
@eval Base.$fn(::Number, t::HessianTracer) = promote_order(t)
end

# Including first-order only
for fn in ops_2_to_1_ffc
@eval Base.$fn(a::HessianTracer, b::HessianTracer) = distributive_merge(a, b)
@eval Base.$fn(t::HessianTracer, ::Number) = t
@eval Base.$fn(::Number, t::HessianTracer) = t
end

for fn in ops_2_to_1_ffz
@eval Base.$fn(a::HessianTracer, b::HessianTracer) = additive_merge(a, b)
@eval Base.$fn(t::HessianTracer, ::Number) = t
@eval Base.$fn(::Number, t::HessianTracer) = t
end

# Including zero-order
for fn in ops_2_to_1_szz
@eval Base.$fn(t::HessianTracer, ::HessianTracer) = promote_order(t)
@eval Base.$fn(t::HessianTracer, ::Number) = promote_order(t)
@eval Base.$fn(::Number, t::HessianTracer) = EMPTY_HESSIAN_TRACER
end

for fn in ops_2_to_1_zsz
@eval Base.$fn(::HessianTracer, t::HessianTracer) = promote_order(t)
@eval Base.$fn(::HessianTracer, ::Number) = EMPTY_HESSIAN_TRACER
@eval Base.$fn(::Number, t::HessianTracer) = promote_order(t)
end

for fn in ops_2_to_1_fzz
@eval Base.$fn(t::HessianTracer, ::HessianTracer) = t
@eval Base.$fn(t::HessianTracer, ::Number) = t
@eval Base.$fn(::Number, t::HessianTracer) = EMPTY_HESSIAN_TRACER
end

for fn in ops_2_to_1_zfz
@eval Base.$fn(::HessianTracer, t::HessianTracer) = t
@eval Base.$fn(::HessianTracer, ::Number) = EMPTY_HESSIAN_TRACER
@eval Base.$fn(::Number, t::HessianTracer) = t
end

for fn in ops_2_to_1_zzz
@eval Base.$fn(::HessianTracer, t::HessianTracer) = EMPTY_HESSIAN_TRACER
@eval Base.$fn(::HessianTracer, ::Number) = EMPTY_HESSIAN_TRACER
@eval Base.$fn(::Number, t::HessianTracer) = EMPTY_HESSIAN_TRACER
end

# Extra types required for exponent
for T in (:Real, :Integer, :Rational)
@eval Base.:^(t::HessianTracer, ::$T) = promote_order(t)
@eval Base.:^(::$T, t::HessianTracer) = promote_order(t)
end
Base.:^(t::HessianTracer, ::Irrational{:ℯ}) = promote_order(t)
Base.:^(::Irrational{:ℯ}, t::HessianTracer) = promote_order(t)

## Rounding
Base.round(t::HessianTracer, ::RoundingMode; kwargs...) = EMPTY_HESSIAN_TRACER

## Random numbers
rand(::AbstractRNG, ::SamplerType{HessianTracer}) = EMPTY_HESSIAN_TRACER
1 change: 0 additions & 1 deletion src/overload_jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ for fn in ops_1_to_2_zz
end

# Extra types required for exponent
Base.:^(a::JacobianTracer, b::JacobianTracer) = uniontracer(a, b)
for T in (:Real, :Integer, :Rational)
@eval Base.:^(t::JacobianTracer, ::$T) = t
@eval Base.:^(::$T, t::JacobianTracer) = t
Expand Down
Loading
Loading