Skip to content

Commit

Permalink
Refactor high-level API (#32)
Browse files Browse the repository at this point in the history
* Add `*_pattern` API

* Export fewer internals
  • Loading branch information
adrhill authored May 2, 2024
1 parent b98e796 commit 8a9110a
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 183 deletions.
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ julia> x = rand(3);
julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])];
julia> pattern(f, JacobianTracer{BitSet}, x)
julia> jacobian_pattern(f, x)
3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 4 stored entries:
1 ⋅ ⋅
1 1 ⋅
Expand All @@ -46,7 +46,7 @@ julia> x = rand(28, 28, 3, 1);
julia> layer = Conv((3, 3), 3 => 8);
julia> pattern(layer, JacobianTracer{BitSet}, x)
julia> jacobian_pattern(layer, x)
5408×2352 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 146016 stored entries:
⎡⠙⢦⡀⠀⠀⠘⢷⣄⠀⠀⠈⠻⣦⡀⠀⠀⠀⎤
⎢⠀⠀⠙⢷⣄⠀⠀⠙⠷⣄⠀⠀⠈⠻⣦⡀⠀⎥
Expand All @@ -69,6 +69,9 @@ julia> pattern(layer, JacobianTracer{BitSet}, x)
⎣⠀⠀⠀⠙⢷⣄⠀⠀⠈⠻⣦⠀⠀⠀⠙⢦⡀⎦
```

The type of index set `T<:AbstractSet{<:Integer}` that is internally used to keep track of connectivity can be specified via `jacobian_pattern(f, x, T)`, defaulting to `BitSet`.
For high-dimensional functions, `Set{UInt64}` can be more efficient .

### Hessian

For scalar functions `y = f(x)`, the sparsity pattern of the Hessian of $f$ can be obtained
Expand All @@ -79,7 +82,7 @@ julia> x = rand(5);
julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + 1*x[5];
julia> pattern(f, HessianTracer{BitSet}, x)
julia> hessian_pattern(f, x)
5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 3 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ 1 ⋅ ⋅
Expand All @@ -89,7 +92,7 @@ julia> pattern(f, HessianTracer{BitSet}, x)
julia> g(x) = f(x) + x[2]^x[5];
julia> pattern(g, HessianTracer{BitSet}, x)
julia> hessian_pattern(g, x)
5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 7 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
⋅ 1 1 ⋅ 1
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ makedocs(;
assets = String[],
),
pages=["Home" => "index.md", "API Reference" => "api.md"],
warnonly=[:missing_docs],
)

deploydocs(; repo="github.com/adrhill/SparseConnectivityTracer.jl", devbranch="main")
19 changes: 6 additions & 13 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,20 @@ CollapsedDocStrings = true

## Interface
```@docs
pattern
connectivity_pattern
jacobian_pattern
hessian_pattern
```
```@docs
TracerSparsityDetector
```

## Internals
SparseConnectivityTracer works by pushing `Number` types called tracers through generic functions.
Currently, two tracer types are provided:
Currently, three tracer types are provided:

```@docs
ConnectivityTracer
JacobianTracer
HessianTracer
```

Utilities to create tracers:
```@docs
tracer
trace_input
```

Utility to extract input indices from tracers:
```@docs
inputs
```
10 changes: 4 additions & 6 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ using ADTypes: ADTypes
import SparseArrays: sparse
import Random: rand, AbstractRNG, SamplerType

abstract type AbstractTracer <: Number end

include("tracers.jl")
include("conversion.jl")
include("operators.jl")
Expand All @@ -15,10 +13,10 @@ include("overload_hessian.jl")
include("pattern.jl")
include("adtypes.jl")

export JacobianTracer, ConnectivityTracer, HessianTracer
export tracer, trace_input
export inputs
export pattern
export ConnectivityTracer, connectivity_pattern
export JacobianTracer, jacobian_pattern
export HessianTracer, hessian_pattern

export TracerSparsityDetector

end # module
6 changes: 3 additions & 3 deletions src/adtypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ TracerSparsityDetector() = TracerSparsityDetector(BitSet)
function ADTypes.jacobian_sparsity(
f, x, ::TracerSparsityDetector{S}
) where {S<:AbstractIndexSet}
return pattern(f, JacobianTracer{S}, x)
return jacobian_pattern(f, x, S)
end

function ADTypes.jacobian_sparsity(
f!, y, x, ::TracerSparsityDetector{S}
) where {S<:AbstractIndexSet}
return pattern(f!, y, JacobianTracer{S}, x)
return jacobian_pattern(f!, y, x, S)
end

function ADTypes.hessian_sparsity(
f, x, ::TracerSparsityDetector{S}
) where {S<:AbstractIndexSet}
return pattern(f, HessianTracer{S}, x)
return hessian_pattern(f, x, S)
end
132 changes: 78 additions & 54 deletions src/pattern.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,13 @@
## Enumerate inputs
const DEFAULT_SET_TYPE = BitSet

## Enumerate inputs
"""
trace_input(T, x)
trace_input(T, x)
Enumerates input indices and constructs the specified type `T` of tracer.
Supports [`ConnectivityTracer`](@ref), [`JacobianTracer`](@ref) and [`HessianTracer`](@ref).
## Example
```jldoctest
julia> x = rand(3);
julia> trace_input(ConnectivityTracer{BitSet}, x)
3-element Vector{ConnectivityTracer{BitSet}}:
ConnectivityTracer{BitSet}(1,)
ConnectivityTracer{BitSet}(2,)
ConnectivityTracer{BitSet}(3,)
julia> trace_input(JacobianTracer{BitSet}, x)
3-element Vector{JacobianTracer{BitSet}}:
JacobianTracer{BitSet}(1,)
JacobianTracer{BitSet}(2,)
JacobianTracer{BitSet}(3,)
julia> trace_input(HessianTracer{BitSet}, x)
3-element Vector{HessianTracer{BitSet}}:
HessianTracer{BitSet}(
1 => (),
)
HessianTracer{BitSet}(
2 => (),
)
HessianTracer{BitSet}(
3 => (),
)
```
"""
trace_input(::Type{T}, x) where {T<:AbstractTracer} = trace_input(T, x, 1)
trace_input(::Type{T}, ::Number, i) where {T<:AbstractTracer} = tracer(T, i)
Expand All @@ -46,42 +18,100 @@ end

## Construct sparsity pattern matrix
"""
pattern(f, ConnectivityTracer{S}, x) where {S<:AbstractSet{<:Integer}}
connectivity_pattern(f, x)
connectivity_pattern(f, x, T)
Enumerates inputs `x` and primal outputs `y = f(x)` and returns sparse matrix `C` of size `(m, n)`
where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`.
pattern(f, JacobianTracer{S}, x) where {S<:AbstractSet{<:Integer}}
The type of index set `T<:AbstractSet{<:Integer}` can be specified as an optional argument and defaults to `BitSet`.
Computes the sparsity pattern of the Jacobian of `y = f(x)`.
## Example
pattern(f, HessianTracer{S}, x) where {S<:AbstractSet{<:Integer}}
```jldoctest
julia> x = rand(3);
Computes the sparsity pattern of the Hessian of a scalar function `y = f(x)`.
julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sign(x[3])];
julia> connectivity_pattern(f, x)
3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 4 stored entries:
1 ⋅ ⋅
1 1 ⋅
⋅ ⋅ 1
```
"""
connectivity_pattern(f, x, settype::Type{S}=DEFAULT_SET_TYPE) where {S<:AbstractIndexSet} =
pattern(f, ConnectivityTracer{S}, x)

"""
connectivity_pattern(f!, y, x)
connectivity_pattern(f!, y, x, T)
Enumerates inputs `x` and primal outputs `y` after `f!(y, x)` and returns sparse matrix `C` of size `(m, n)`
where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`.
The type of index set `T<:AbstractSet{<:Integer}` can be specified as an optional argument and defaults to `BitSet`.
"""
function connectivity_pattern(
f!, y, x, ::Type{S}=DEFAULT_SET_TYPE
) where {S<:AbstractIndexSet}
return pattern(f!, y, ConnectivityTracer{S}, x)
end

"""
jacobian_pattern(f, x)
jacobian_pattern(f, x, T)
Compute the sparsity pattern of the Jacobian of `y = f(x)`.
## Examples
### First order
The type of index set `T<:AbstractSet{<:Integer}` can be specified as an optional argument and defaults to `BitSet`.
## Example
```jldoctest
julia> x = rand(3);
julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])];
julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sign(x[3])];
julia> pattern(f, ConnectivityTracer{BitSet}, x)
3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 4 stored entries:
julia> jacobian_pattern(f, x)
3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 3 stored entries:
1 ⋅ ⋅
1 1 ⋅
⋅ ⋅ 1
⋅ ⋅
```
"""
function jacobian_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S<:AbstractIndexSet}
return pattern(f, JacobianTracer{S}, x)
end

### Second order
"""
jacobian_pattern(f!, y, x)
jacobian_pattern(f!, y, x, T)
Compute the sparsity pattern of the Jacobian of `f!(y, x)`.
The type of index set `T<:AbstractSet{<:Integer}` can be specified as an optional argument and defaults to `BitSet`.
"""
function jacobian_pattern(f!, y, x, ::Type{S}=DEFAULT_SET_TYPE) where {S<:AbstractIndexSet}
return pattern(f!, y, JacobianTracer{S}, x)
end

"""
hessian_pattern(f, x)
hessian_pattern(f, x, T)
Computes the sparsity pattern of the Hessian of a scalar function `y = f(x)`.
The type of index set `T<:AbstractSet{<:Integer}` can be specified as an optional argument and defaults to `BitSet`.
## Example
```jldoctest
julia> x = rand(5);
julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + 1*x[5];
julia> pattern(f, HessianTracer{BitSet}, x)
julia> hessian_pattern(f, x)
5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 3 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ 1 ⋅ ⋅
Expand All @@ -91,7 +121,7 @@ julia> pattern(f, HessianTracer{BitSet}, x)
julia> g(x) = f(x) + x[2]^x[5];
julia> pattern(g, HessianTracer{BitSet}, x)
julia> hessian_pattern(g, x)
5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 7 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
⋅ 1 1 ⋅ 1
Expand All @@ -100,22 +130,16 @@ julia> pattern(g, HessianTracer{BitSet}, x)
⋅ 1 ⋅ ⋅ 1
```
"""
function hessian_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S<:AbstractIndexSet}
return pattern(f, HessianTracer{S}, x)
end

function pattern(f, ::Type{T}, x) where {T<:AbstractTracer}
xt = trace_input(T, x)
yt = f(xt)
return _pattern(xt, yt)
end

"""
pattern(f!, y, JacobianTracer{S}, x) where {S<:AbstractSet{<:Integer}}
Computes the sparsity pattern of the Jacobian of `f!(y, x)`.
pattern(f!, y, ConnectivityTracer{S}, x) where {S<:AbstractSet{<:Integer}}
Enumerates inputs `x` and primal outputs `y` after `f!(y, x)` and returns sparse matrix `C` of size `(m, n)`
where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`.
"""
function pattern(f!, y, ::Type{T}, x) where {T<:AbstractTracer}
xt = trace_input(T, x)
yt = similar(y, T)
Expand Down
Loading

0 comments on commit 8a9110a

Please sign in to comment.