Skip to content

Commit

Permalink
Add JacobianTracer (#22)
Browse files Browse the repository at this point in the history
Breaking changes: 

* renamed `Tracer` to `ConnectivityTracer`

* renamed `connectivity(f, x)` to `pattern(f, T, x)`, where `T` is a tracer type

* renamed `connectivity(f!, y, x)` to `pattern(f!, y, T, x)`, where `T` is a tracer type

* `trace_input(x)` has been replaced by `trace_input(T, x)`, where `T` is a tracer type
  • Loading branch information
adrhill authored Apr 23, 2024
1 parent 3830ca0 commit 475e576
Show file tree
Hide file tree
Showing 19 changed files with 421 additions and 328 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseConnectivityTracer"
uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
authors = ["Adrian Hill <[email protected]>"]
version = "0.1.0"
version = "0.2.0-DEV"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
54 changes: 24 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

Fast sparsity detection via operator-overloading.

Will soon include Jacobian sparsity detection ([#19](https://github.com/adrhill/SparseConnectivityTracer.jl/issues/19))
and Hessian sparsity detection ([#20](https://github.com/adrhill/SparseConnectivityTracer.jl/issues/20)).
Will soon include Hessian sparsity detection ([#20](https://github.com/adrhill/SparseConnectivityTracer.jl/issues/20)).

## Installation
To install this package, open the Julia REPL and run
Expand All @@ -26,7 +25,7 @@ julia> x = rand(3);
julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])];
julia> connectivity(f, x)
julia> pattern(f, JacobianTracer, x)
3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 4 stored entries:
1 ⋅ ⋅
1 1 ⋅
Expand All @@ -41,37 +40,32 @@ julia> x = rand(28, 28, 3, 1);
julia> layer = Conv((3, 3), 3 => 8);
julia> connectivity(layer, x)
julia> pattern(layer, JacobianTracer, x)
5408×2352 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 146016 stored entries:
⎡⠙⢶⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠘⢷⣄⠀⠀⠀⠀⠀⎤
⎢⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠙⢷⣄⠀⠀⠀⎥
⎢⠀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠉⠳⣦⡀⎥
⎢⠙⢷⣄⠀⠀⠀⠉⠻⣦⡀⠀⠀⠈⠙⢷⣄⠀⠀⠀⠈⠁⎥
⎢⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠉⠳⣦⡀⠀⠀⎥
⎢⠀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠈⠻⣦⡀⎥
⎢⠙⢷⣄⠀⠀⠀⠉⠻⣦⡀⠀⠀⠈⠙⠷⣤⡀⠀⠀⠈⠁⎥
⎢⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠈⠻⣦⡀⠀⠀⎥
⎢⠀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⢦⣄⠀⠀⠀⠈⠻⣦⡀⎥
⎢⠙⢷⣄⠀⠀⠀⠉⠻⣦⡀⠀⠀⠀⠉⠻⣦⡀⠀⠀⠈⠁⎥
⎢⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⢦⣀⠀⠀⠀⠈⠻⣦⡀⠀⠀⎥
⎢⠀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⎥
⎢⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⎥
⎢⠀⠀⠙⢷⣄⠀⠀⠀⠈⠙⢶⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⎥
⎢⣀⠀⠀⠀⠙⢷⣄⡀⠀⠀⠀⠙⢷⣄⡀⠀⠀⠈⠻⣦⡀⎥
⎢⠙⢷⣄⠀⠀⠀⠈⠛⢶⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⎥
⎢⠀⠀⠙⢷⣄⠀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⎥
⎢⣀⠀⠀⠀⠙⠳⣦⣀⠀⠀⠀⠙⢷⣄⡀⠀⠀⠈⠻⣦⡀⎥
⎢⠙⢷⣄⠀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⎥
⎢⠀⠀⠙⠷⣄⡀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⎥
⎢⣀⠀⠀⠀⠈⠻⣦⣀⠀⠀⠀⠙⢷⣄⡀⠀⠀⠈⠻⣦⡀⎥
⎢⠙⠷⣄⡀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⠀⠀⎥
⎢⠀⠀⠈⠻⣦⡀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⣦⡀⠀⠀⎥
⎣⠀⠀⠀⠀⠈⠻⣦⠀⠀⠀⠀⠙⢷⣄⠀⠀⠀⠈⠻⢦⡀⎦
⎡⠙⢦⡀⠀⠀⠘⢷⣄⠀⠀⠈⠻⣦⡀⠀⠀⠀⎤
⎢⠀⠀⠙⢷⣄⠀⠀⠙⠷⣄⠀⠀⠈⠻⣦⡀⠀⎥
⎢⢶⣄⠀⠀⠙⠳⣦⡀⠀⠈⠳⢦⡀⠀⠈⠛⠂⎥
⎢⠀⠙⢷⣄⠀⠀⠈⠻⣦⡀⠀⠀⠙⢦⣄⠀⠀⎥
⎢⣀⡀⠀⠉⠳⣄⡀⠀⠈⠻⣦⣀⠀⠀⠙⢷⡄⎥
⎢⠈⠻⣦⡀⠀⠈⠛⢦⡀⠀⠀⠙⢷⣄⠀⠀⠀⎥
⎢⠀⠀⠈⠻⣦⡀⠀⠀⠙⢷⣄⠀⠀⠙⠷⣄⠀⎥
⎢⠻⣦⡀⠀⠈⠙⢷⣄⠀⠀⠉⠻⣦⡀⠀⠈⠁⎥
⎢⠀⠀⠙⢦⣀⠀⠀⠙⢷⣄⠀⠀⠈⠻⣦⡀⠀⎥
⎢⢤⣄⠀⠀⠙⠳⣄⡀⠀⠉⠳⣤⡀⠀⠈⠛⠂⎥
⎢⠀⠙⢷⣄⠀⠀⠈⠻⣦⡀⠀⠈⠙⢦⡀⠀⠀⎥
⎢⣀⠀⠀⠙⢷⣄⡀⠀⠈⠻⣦⣀⠀⠀⠙⢷⡄⎥
⎢⠈⠳⣦⡀⠀⠈⠻⣦⡀⠀⠀⠙⢷⣄⠀⠀⠀⎥
⎢⠀⠀⠈⠻⣦⡀⠀⠀⠙⢦⣄⠀⠀⠙⢷⣄⠀⎥
⎢⠻⣦⡀⠀⠈⠙⢷⣄⠀⠀⠉⠳⣄⡀⠀⠉⠁⎥
⎢⠀⠈⠛⢦⡀⠀⠀⠙⢷⣄⠀⠀⠈⠻⣦⡀⠀⎥
⎢⢤⣄⠀⠀⠙⠶⣄⠀⠀⠙⠷⣤⡀⠀⠈⠻⠆⎥
⎢⠀⠙⢷⣄⠀⠀⠈⠳⣦⡀⠀⠈⠻⣦⡀⠀⠀⎥
⎣⠀⠀⠀⠙⢷⣄⠀⠀⠈⠻⣦⠀⠀⠀⠙⢦⡀⎦
```

SparseConnectivityTracer enumerates inputs `x` and primal outputs `y=f(x)` and returns a sparse connectivity matrix `C` of size $m \times n$, where `C[i, j]` is `true` if the compute graph connects the $j$-th entry in `x` to the $i$-th entry in `y`.
SparseConnectivityTracer enumerates inputs `x` and primal outputs `y = f(x)` and returns a sparse matrix `C` of size $m \times n$, where `C[i, j]` is `true` if the compute graph connects the $j$-th entry in `x` to the $i$-th entry in `y`.

For more detailled examples, take a look at the [API reference](https://adrianhill.de/SparseConnectivityTracer.jl/dev/api).
For more detailled examples, take a look at the [documentation](https://adrianhill.de/SparseConnectivityTracer.jl/dev).

## Related packages
* [SparseDiffTools.jl](https://github.com/JuliaDiff/SparseDiffTools.jl): automatic sparsity detection via Symbolics.jl and Cassette.jl
Expand Down
15 changes: 11 additions & 4 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,26 @@ CollapsedDocStrings = true

## Interface
```@docs
connectivity
pattern
TracerSparsityDetector
```

## Internals
SparseConnectivityTracer works by pushing a `Number` type called [`Tracer`](@ref) through generic functions:
SparseConnectivityTracer works by pushing `Number` types called tracers through generic functions.
Currently, two tracer types are provided:

```@docs
JacobianTracer
ConnectivityTracer
```

Utilities to create tracers:
```@docs
Tracer
tracer
trace_input
```

The following utilities can be used to extract input indices from [`Tracer`](@ref)s:
Utility to extract input indices from tracers:
```@docs
inputs
```
13 changes: 8 additions & 5 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@ using ADTypes: ADTypes
import Random: rand, AbstractRNG, SamplerType
import SparseArrays: sparse

include("tracer.jl")
abstract type AbstractTracer <: Number end

include("tracers.jl")
include("conversion.jl")
include("operators.jl")
include("overload_tracer.jl")
include("connectivity.jl")
include("overload_connectivity.jl")
include("overload_jacobian.jl")
include("pattern.jl")
include("adtypes.jl")

export Tracer
export JacobianTracer, ConnectivityTracer
export tracer, trace_input
export inputs
export connectivity
export pattern
export TracerSparsityDetector

end # module
5 changes: 3 additions & 2 deletions src/adtypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ julia> ADTypes.jacobian_sparsity(diff, rand(4), TracerSparsityDetector())
struct TracerSparsityDetector <: ADTypes.AbstractSparsityDetector end

function ADTypes.jacobian_sparsity(f, x, ::TracerSparsityDetector)
return connectivity(f, x)
return pattern(f, JacobianTracer, x)
end

function ADTypes.jacobian_sparsity(f!, y, x, ::TracerSparsityDetector)
return connectivity(f!, y, x)
return pattern(f!, y, JacobianTracer, x)
end

function ADTypes.hessian_sparsity(f, x, ::TracerSparsityDetector)
# TODO: return pattern(f, HessianTracer, x)
return error("Hessian sparsity is not yet implemented for `TracerSparsityDetector`.")
end
98 changes: 0 additions & 98 deletions src/connectivity.jl

This file was deleted.

38 changes: 20 additions & 18 deletions src/conversion.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
## Type conversions
Base.promote_rule(::Type{Tracer}, ::Type{N}) where {N<:Number} = Tracer
Base.promote_rule(::Type{N}, ::Type{Tracer}) where {N<:Number} = Tracer
for T in (:JacobianTracer, :ConnectivityTracer)
@eval Base.promote_rule(::Type{$T}, ::Type{N}) where {N<:Number} = $T
@eval Base.promote_rule(::Type{N}, ::Type{$T}) where {N<:Number} = $T

Base.big(::Type{Tracer}) = Tracer
Base.widen(::Type{Tracer}) = Tracer
Base.widen(t::Tracer) = t
@eval Base.big(::Type{$T}) = $T
@eval Base.widen(::Type{$T}) = $T
@eval Base.widen(t::$T) = t

Base.convert(::Type{Tracer}, x::Number) = EMPTY_TRACER
Base.convert(::Type{Tracer}, t::Tracer) = t
Base.convert(::Type{<:Number}, t::Tracer) = t
@eval Base.convert(::Type{$T}, x::Number) = empty($T)
@eval Base.convert(::Type{$T}, t::$T) = t
@eval Base.convert(::Type{<:Number}, t::$T) = t

## Array constructors
Base.zero(::Type{Tracer}) = EMPTY_TRACER
Base.one(::Type{Tracer}) = EMPTY_TRACER
## Array constructors
@eval Base.zero(::Type{$T}) = empty($T)
@eval Base.one(::Type{$T}) = empty($T)

Base.similar(a::Array{Tracer,1}) = zeros(Tracer, size(a, 1))
Base.similar(a::Array{Tracer,2}) = zeros(Tracer, size(a, 1), size(a, 2))
Base.similar(a::Array{T,1}, ::Type{Tracer}) where {T} = zeros(Tracer, size(a, 1))
Base.similar(a::Array{T,2}, ::Type{Tracer}) where {T} = zeros(Tracer, size(a, 1), size(a, 2))
Base.similar(::Array{Tracer}, m::Int) = zeros(Tracer, m)
Base.similar(::Array, ::Type{Tracer}, dims::Dims{N}) where {N} = zeros(Tracer, dims)
Base.similar(::Array{Tracer}, dims::Dims{N}) where {N} = zeros(Tracer, dims)
@eval Base.similar(a::Array{$T,1}) = zeros($T, size(a, 1))
@eval Base.similar(a::Array{$T,2}) = zeros($T, size(a, 1), size(a, 2))
@eval Base.similar(a::Array{A,1}, ::Type{$T}) where {A} = zeros($T, size(a, 1))
@eval Base.similar(a::Array{A,2}, ::Type{$T}) where {A} = zeros($T, size(a, 1), size(a, 2))
@eval Base.similar(::Array{$T}, m::Int) = zeros($T, m)
@eval Base.similar(::Array, ::Type{$T}, dims::Dims{N}) where {N} = zeros($T, dims)
@eval Base.similar(::Array{$T}, dims::Dims{N}) where {N} = zeros($T, dims)
end
32 changes: 32 additions & 0 deletions src/overload_connectivity.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
for fn in union(ops_1_to_1_s, ops_1_to_1_f, ops_1_to_1_z)
@eval Base.$fn(t::ConnectivityTracer) = t
end

for fn in ops_1_to_1_const
@eval Base.$fn(::ConnectivityTracer) = EMPTY_CONNECTIVITY_TRACER
end

for fn in ops_1_to_2
@eval Base.$fn(t::ConnectivityTracer) = (t, t)
end

for fn in ops_2_to_1
@eval Base.$fn(a::ConnectivityTracer, b::ConnectivityTracer) = uniontracer(a, b)
@eval Base.$fn(t::ConnectivityTracer, ::Number) = t
@eval Base.$fn(::Number, t::ConnectivityTracer) = t
end

# Extra types required for exponent
Base.:^(a::ConnectivityTracer, b::ConnectivityTracer) = uniontracer(a, b)
for T in (:Real, :Integer, :Rational)
@eval Base.:^(t::ConnectivityTracer, ::$T) = t
@eval Base.:^(::$T, t::ConnectivityTracer) = t
end
Base.:^(t::ConnectivityTracer, ::Irrational{:ℯ}) = t
Base.:^(::Irrational{:ℯ}, t::ConnectivityTracer) = t

## Rounding
Base.round(t::ConnectivityTracer, ::RoundingMode; kwargs...) = t

## Random numbers
rand(::AbstractRNG, ::SamplerType{ConnectivityTracer}) = EMPTY_CONNECTIVITY_TRACER
Loading

0 comments on commit 475e576

Please sign in to comment.