-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
694 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
style = "blue" | ||
align_assignment = true | ||
align_struct_field = true | ||
align_conditional = true | ||
align_pair_arrow = true | ||
align_matrix = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,5 +3,9 @@ uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" | |
authors = ["Adrian Hill <[email protected]>"] | ||
version = "1.0.0-DEV" | ||
|
||
[deps] | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" | ||
|
||
[compat] | ||
julia = "1.6" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,23 @@ | ||
using SparseConnectivityTracer | ||
using Documenter | ||
|
||
DocMeta.setdocmeta!(SparseConnectivityTracer, :DocTestSetup, :(using SparseConnectivityTracer); recursive=true) | ||
DocMeta.setdocmeta!( | ||
SparseConnectivityTracer, | ||
:DocTestSetup, | ||
:(using SparseConnectivityTracer); | ||
recursive=true, | ||
) | ||
|
||
makedocs(; | ||
modules=[SparseConnectivityTracer], | ||
authors="Adrian Hill <[email protected]>", | ||
sitename="SparseConnectivityTracer.jl", | ||
format=Documenter.HTML(; | ||
canonical="https://adrhill.github.io/SparseConnectivityTracer.jl", | ||
edit_link="main", | ||
assets=String[], | ||
canonical = "https://adrhill.github.io/SparseConnectivityTracer.jl", | ||
edit_link = "main", | ||
assets = String[], | ||
), | ||
pages=[ | ||
"Home" => "index.md", | ||
], | ||
pages=["Home" => "index.md"], | ||
) | ||
|
||
deploydocs(; | ||
repo="github.com/adrhill/SparseConnectivityTracer.jl", | ||
devbranch="main", | ||
) | ||
deploydocs(; repo="github.com/adrhill/SparseConnectivityTracer.jl", devbranch="main") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,13 @@ | ||
module SparseConnectivityTracer | ||
import Random: rand, AbstractRNG, SamplerType | ||
import SparseArrays: sparse | ||
|
||
# Write your package code here. | ||
include("tracer.jl") | ||
include("conversion.jl") | ||
include("operators.jl") | ||
include("connectivity.jl") | ||
|
||
end | ||
export Tracer, tracer, inputs, sortedinputs | ||
export connectivity | ||
|
||
end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
## Enumerate inputs | ||
trace_input(x) = trace_input(x, 1) | ||
trace_input(::Number, i) = tracer(i) | ||
function trace_input(x::AbstractArray, i) | ||
indices = (i - 1) .+ reshape(1:length(x), size(x)) | ||
return tracer.(indices) | ||
end | ||
|
||
## Construct connectivity matrix | ||
""" | ||
connectivity(f, x) | ||
Enumerates inputs `x` and primal outputs `y=f(x)` and returns sparse connectivity matrix `C` of size `(m, n)` | ||
where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. | ||
""" | ||
function connectivity(f::Function, x) | ||
xt = trace_input(x) | ||
yt = f(xt) | ||
return _connectivity(xt, yt) | ||
end | ||
|
||
_connectivity(xt::Tracer, yt::Number) = _connectivity([xt], [yt]) | ||
_connectivity(xt::Tracer, yt::AbstractArray{Number}) = _connectivity([xt], yt) | ||
_connectivity(xt::AbstractArray{Tracer}, yt::Number) = _connectivity(xt, [yt]) | ||
function _connectivity(xt::AbstractArray{Tracer}, yt::AbstractArray{<:Number}) | ||
return connectivity_sparsematrixcsc(xt, yt) | ||
end | ||
|
||
function connectivity_sparsematrixcsc( | ||
xt::AbstractArray{Tracer}, yt::AbstractArray{<:Number} | ||
) | ||
# Construct connectivity matrix of size (ouput_dim, input_dim) | ||
n, m = length(xt), length(yt) | ||
I = UInt64[] | ||
J = UInt64[] | ||
V = Bool[] | ||
for (i, y) in enumerate(yt) | ||
if y isa Tracer | ||
for j in inputs(y) | ||
push!(I, i) | ||
push!(J, j) | ||
push!(V, true) | ||
end | ||
end | ||
end | ||
return sparse(I, J, V, m, n) | ||
end | ||
|
||
function connectivity_bitmatrix(xt::AbstractArray{Tracer}, yt::AbstractArray{<:Number}) | ||
# Construct connectivity matrix of size (ouput_dim, input_dim) | ||
n, m = length(xt), length(yt) | ||
C = BitArray(undef, m, n) | ||
for i in axes(C, 1) | ||
if yt[i] isa Tracer | ||
for j in axes(C, 2) | ||
C[i, j] = j ∈ yt[i].inputs | ||
end | ||
else | ||
C[i, :] .= 0 | ||
end | ||
end | ||
return C | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
## Type conversions | ||
Base.promote_rule(::Type{Tracer}, ::Type{N}) where {N<:Number} = Tracer | ||
Base.promote_rule(::Type{N}, ::Type{Tracer}) where {N<:Number} = Tracer | ||
|
||
Base.convert(::Type{Tracer}, x::Number) = tracer() | ||
Base.convert(::Type{Tracer}, t::Tracer) = t | ||
Base.convert(::Type{<:Number}, t::Tracer) = t | ||
|
||
## Array constructors | ||
Base.zero(::Tracer) = tracer() | ||
Base.zero(::Type{Tracer}) = tracer() | ||
Base.one(::Tracer) = tracer() | ||
Base.one(::Type{Tracer}) = tracer() | ||
|
||
Base.similar(a::Array{Tracer,1}) = zeros(Tracer, size(a, 1)) | ||
Base.similar(a::Array{Tracer,2}) = zeros(Tracer, size(a, 1), size(a, 2)) | ||
Base.similar(a::Array{T,1}, ::Type{Tracer}) where {T} = zeros(Tracer, size(a, 1)) | ||
Base.similar(a::Array{T,2}, ::Type{Tracer}) where {T} = zeros(Tracer, size(a, 1), size(a, 2)) | ||
Base.similar(::Array{Tracer}, m::Int) = zeros(Tracer, m) | ||
Base.similar(::Array, ::Type{Tracer}, dims::Dims{N}) where {N} = zeros(Tracer, dims) | ||
Base.similar(::Array{Tracer}, dims::Dims{N}) where {N} = zeros(Tracer, dims) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
## Extent Base operators | ||
for fn in (:+, :-, :*, :/) | ||
@eval Base.$fn(a::Tracer, b::Tracer) = tracer(a, b) | ||
for T in (:Number,) | ||
@eval Base.$fn(t::Tracer, ::$T) = t | ||
@eval Base.$fn(::$T, t::Tracer) = t | ||
end | ||
end | ||
|
||
Base.:^(a::Tracer, b::Tracer) = tracer(a, b) | ||
for T in (:Number, :Integer, :Rational) | ||
@eval Base.:^(t::Tracer, ::$T) = t | ||
@eval Base.:^(::$T, t::Tracer) = t | ||
end | ||
Base.:^(t::Tracer, ::Irrational{:ℯ}) = t | ||
Base.:^(::Irrational{:ℯ}, t::Tracer) = t | ||
|
||
## Two-argument functions | ||
for fn in (:div, :fld, :cld) | ||
@eval Base.$fn(a::Tracer, b::Tracer) = tracer(a, b) | ||
@eval Base.$fn(t::Tracer, ::Number) = t | ||
@eval Base.$fn(::Number, t::Tracer) = t | ||
end | ||
|
||
## Single-argument functions | ||
|
||
#! format: off | ||
scalar_operations = ( | ||
:exp2, :deg2rad, :rad2deg, | ||
:cos, :cosd, :cosh, :cospi, :cosc, | ||
:sin, :sind, :sinh, :sinpi, :sinc, | ||
:tan, :tand, :tanh, | ||
:csc, :cscd, :csch, | ||
:sec, :secd, :sech, | ||
:cot, :cotd, :coth, | ||
:acos, :acosd, :acosh, | ||
:asin, :asind, :asinh, | ||
:atan, :atand, :atanh, | ||
:asec, :asech, | ||
:acsc, :acsch, | ||
:acot, :acoth, | ||
:exp, :expm1, :exp10, | ||
:frexp, :ldexp, | ||
:abs, :abs2, :sqrt | ||
) | ||
#! format: on | ||
|
||
for fn in scalar_operations | ||
@eval Base.$fn(t::Tracer) = t | ||
end | ||
|
||
## Random numbers | ||
rand(::AbstractRNG, ::SamplerType{Tracer}) = tracer() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
""" | ||
Tracer(indexset) <: Number | ||
Number type keeping track of input indices of previous computations. | ||
See also the convenience constructor [`tracer`](@ref). | ||
## Examples | ||
```julia-repl | ||
julia> x = tracer(1, 2, 3) | ||
Tracer(1, 2, 3) | ||
julia> sin(x) | ||
Tracer(1, 2, 3) | ||
julia> 2 * x^3 | ||
Tracer(1, 2, 3) | ||
julia> 0 * x # Note: Tracer is strictly operator overloading... | ||
Tracer(1, 2, 3) | ||
julia> zero(x) # ...this can be overloaded | ||
Tracer() | ||
julia> y = tracer(3, 5) | ||
Tracer(3, 5) | ||
julia> x + y | ||
Tracer(1, 2, 3, 5) | ||
julia> x ^ y | ||
Tracer(1, 2, 3, 5) | ||
julia> M = rand(Tracer, 3, 2) | ||
3×2 Matrix{Tracer}: | ||
Tracer() Tracer() | ||
Tracer() Tracer() | ||
Tracer() Tracer() | ||
julia> similar(M) | ||
3×2 Matrix{Tracer}: | ||
Tracer() Tracer() | ||
Tracer() Tracer() | ||
Tracer() Tracer() | ||
julia> M * [x, y] | ||
3-element Vector{Tracer}: | ||
Tracer(1, 2, 3, 5) | ||
Tracer(1, 2, 3, 5) | ||
Tracer(1, 2, 3, 5) | ||
``` | ||
""" | ||
struct Tracer <: Number | ||
inputs::Set{UInt64} # indices of connected, enumerated inputs | ||
end | ||
|
||
# We have to be careful when defining constructors: | ||
# Generic code expecting "regular" numbers `x` will sometimes convert them | ||
# by calling `T(x)` (instead of `convert(T, x)`), where `T` can be `Tracer`. | ||
# When this happens, we create a new empty tracer with no input connectivity. | ||
Tracer(::Number) = tracer() | ||
Tracer(t::Tracer) = t | ||
# We therefore exclusively use the lower-case `tracer` for convenience constructors | ||
|
||
""" | ||
tracer(index) | ||
tracer(indices) | ||
Convenience constructor for [`Tracer`](@ref) from input indices. | ||
""" | ||
tracer() = Tracer(Set{UInt64}()) | ||
tracer(a::Tracer, b::Tracer) = Tracer(union(a.inputs, b.inputs)) | ||
|
||
tracer(index::Integer) = Tracer(Set{UInt64}(index)) | ||
tracer(inds::NTuple{N,<:Integer}) where {N} = Tracer(Set{UInt64}(inds)) | ||
tracer(inds...) = tracer(inds) | ||
|
||
# Utilities for accessing input indices | ||
""" | ||
inputs(tracer) | ||
Return raw `UInt64` input indices of a [`Tracer`](@ref). | ||
""" | ||
inputs(t::Tracer) = collect(keys(t.inputs.dict)) | ||
|
||
""" | ||
sortedinputs(tracer) | ||
sortedinputs([T=Int], tracer) | ||
Return sorted input indices of a [`Tracer`](@ref). | ||
""" | ||
sortedinputs(t::Tracer) = sortedinputs(Int, t) | ||
sortedinputs(T::Type, t::Tracer) = convert.(T, sort!(inputs(t))) | ||
|
||
function Base.show(io::IO, t::Tracer) | ||
return Base.show_delim_array(io, sortedinputs(Int, t), "Tracer(", ',', ')', true) | ||
end |
Oops, something went wrong.