Skip to content

Commit

Permalink
Initial functionality (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill authored Apr 1, 2024
1 parent bd6b9ec commit 63560b0
Show file tree
Hide file tree
Showing 15 changed files with 694 additions and 28 deletions.
6 changes: 6 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
style = "blue"
align_assignment = true
align_struct_field = true
align_conditional = true
align_pair_arrow = true
align_matrix = true
8 changes: 6 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1'
- 'nightly'
os:
- ubuntu-latest
Expand All @@ -37,10 +37,14 @@ jobs:
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
continue-on-error: ${{ matrix.version == 'nightly' }}
- uses: julia-actions/julia-runtest@v1
continue-on-error: ${{ matrix.version == 'nightly' }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v3
- uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: adrhill/SparseConnectivityTracer.jl
files: lcov.info
docs:
name: Documentation
Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,9 @@ uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
authors = ["Adrian Hill <[email protected]>"]
version = "1.0.0-DEV"

[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
julia = "1.6"
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SparseConnectivityTracer
# SparseConnectivityTracer.jl

[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://adrhill.github.io/SparseConnectivityTracer.jl/stable/)
<!-- [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://adrhill.github.io/SparseConnectivityTracer.jl/stable/) -->
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://adrhill.github.io/SparseConnectivityTracer.jl/dev/)
[![Build Status](https://github.com/adrhill/SparseConnectivityTracer.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/adrhill/SparseConnectivityTracer.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/adrhill/SparseConnectivityTracer.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/adrhill/SparseConnectivityTracer.jl)
Expand Down
22 changes: 11 additions & 11 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
using SparseConnectivityTracer
using Documenter

DocMeta.setdocmeta!(SparseConnectivityTracer, :DocTestSetup, :(using SparseConnectivityTracer); recursive=true)
DocMeta.setdocmeta!(
SparseConnectivityTracer,
:DocTestSetup,
:(using SparseConnectivityTracer);
recursive=true,
)

makedocs(;
modules=[SparseConnectivityTracer],
authors="Adrian Hill <[email protected]>",
sitename="SparseConnectivityTracer.jl",
format=Documenter.HTML(;
canonical="https://adrhill.github.io/SparseConnectivityTracer.jl",
edit_link="main",
assets=String[],
canonical = "https://adrhill.github.io/SparseConnectivityTracer.jl",
edit_link = "main",
assets = String[],
),
pages=[
"Home" => "index.md",
],
pages=["Home" => "index.md"],
)

deploydocs(;
repo="github.com/adrhill/SparseConnectivityTracer.jl",
devbranch="main",
)
deploydocs(; repo="github.com/adrhill/SparseConnectivityTracer.jl", devbranch="main")
18 changes: 16 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@ Documentation for [SparseConnectivityTracer](https://github.com/adrhill/SparseCo
```@index
```

```@autodocs
Modules = [SparseConnectivityTracer]
## API reference
SparseConnectivityTracer works by pushing a `Number` type called [`Tracer`](@ref) through generic functions:
```@docs
Tracer
tracer
```

The resulting connectivity matrix can be extracted using [`connectivity`](@ref):
```@docs
connectivity
```

or manually from individual [`Tracer`](@ref) outputs:
```@docs
inputs
sortedinputs
```
12 changes: 10 additions & 2 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
module SparseConnectivityTracer
import Random: rand, AbstractRNG, SamplerType
import SparseArrays: sparse

# Write your package code here.
include("tracer.jl")
include("conversion.jl")
include("operators.jl")
include("connectivity.jl")

end
export Tracer, tracer, inputs, sortedinputs
export connectivity

end # module
63 changes: 63 additions & 0 deletions src/connectivity.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
## Enumerate inputs
trace_input(x) = trace_input(x, 1)
trace_input(::Number, i) = tracer(i)
function trace_input(x::AbstractArray, i)
indices = (i - 1) .+ reshape(1:length(x), size(x))
return tracer.(indices)
end

## Construct connectivity matrix
"""
connectivity(f, x)
Enumerates inputs `x` and primal outputs `y=f(x)` and returns sparse connectivity matrix `C` of size `(m, n)`
where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`.
"""
function connectivity(f::Function, x)
xt = trace_input(x)
yt = f(xt)
return _connectivity(xt, yt)
end

_connectivity(xt::Tracer, yt::Number) = _connectivity([xt], [yt])
_connectivity(xt::Tracer, yt::AbstractArray{Number}) = _connectivity([xt], yt)
_connectivity(xt::AbstractArray{Tracer}, yt::Number) = _connectivity(xt, [yt])
function _connectivity(xt::AbstractArray{Tracer}, yt::AbstractArray{<:Number})
return connectivity_sparsematrixcsc(xt, yt)
end

function connectivity_sparsematrixcsc(
xt::AbstractArray{Tracer}, yt::AbstractArray{<:Number}
)
# Construct connectivity matrix of size (ouput_dim, input_dim)
n, m = length(xt), length(yt)
I = UInt64[]
J = UInt64[]
V = Bool[]
for (i, y) in enumerate(yt)
if y isa Tracer
for j in inputs(y)
push!(I, i)
push!(J, j)
push!(V, true)
end
end
end
return sparse(I, J, V, m, n)
end

function connectivity_bitmatrix(xt::AbstractArray{Tracer}, yt::AbstractArray{<:Number})
# Construct connectivity matrix of size (ouput_dim, input_dim)
n, m = length(xt), length(yt)
C = BitArray(undef, m, n)
for i in axes(C, 1)
if yt[i] isa Tracer
for j in axes(C, 2)
C[i, j] = j yt[i].inputs
end
else
C[i, :] .= 0
end
end
return C
end
21 changes: 21 additions & 0 deletions src/conversion.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
## Type conversions
Base.promote_rule(::Type{Tracer}, ::Type{N}) where {N<:Number} = Tracer
Base.promote_rule(::Type{N}, ::Type{Tracer}) where {N<:Number} = Tracer

Base.convert(::Type{Tracer}, x::Number) = tracer()
Base.convert(::Type{Tracer}, t::Tracer) = t
Base.convert(::Type{<:Number}, t::Tracer) = t

## Array constructors
Base.zero(::Tracer) = tracer()
Base.zero(::Type{Tracer}) = tracer()
Base.one(::Tracer) = tracer()
Base.one(::Type{Tracer}) = tracer()

Base.similar(a::Array{Tracer,1}) = zeros(Tracer, size(a, 1))
Base.similar(a::Array{Tracer,2}) = zeros(Tracer, size(a, 1), size(a, 2))
Base.similar(a::Array{T,1}, ::Type{Tracer}) where {T} = zeros(Tracer, size(a, 1))
Base.similar(a::Array{T,2}, ::Type{Tracer}) where {T} = zeros(Tracer, size(a, 1), size(a, 2))
Base.similar(::Array{Tracer}, m::Int) = zeros(Tracer, m)
Base.similar(::Array, ::Type{Tracer}, dims::Dims{N}) where {N} = zeros(Tracer, dims)
Base.similar(::Array{Tracer}, dims::Dims{N}) where {N} = zeros(Tracer, dims)
53 changes: 53 additions & 0 deletions src/operators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
## Extent Base operators
for fn in (:+, :-, :*, :/)
@eval Base.$fn(a::Tracer, b::Tracer) = tracer(a, b)
for T in (:Number,)
@eval Base.$fn(t::Tracer, ::$T) = t
@eval Base.$fn(::$T, t::Tracer) = t
end
end

Base.:^(a::Tracer, b::Tracer) = tracer(a, b)
for T in (:Number, :Integer, :Rational)
@eval Base.:^(t::Tracer, ::$T) = t
@eval Base.:^(::$T, t::Tracer) = t
end
Base.:^(t::Tracer, ::Irrational{:ℯ}) = t
Base.:^(::Irrational{:ℯ}, t::Tracer) = t

## Two-argument functions
for fn in (:div, :fld, :cld)
@eval Base.$fn(a::Tracer, b::Tracer) = tracer(a, b)
@eval Base.$fn(t::Tracer, ::Number) = t
@eval Base.$fn(::Number, t::Tracer) = t
end

## Single-argument functions

#! format: off
scalar_operations = (
:exp2, :deg2rad, :rad2deg,
:cos, :cosd, :cosh, :cospi, :cosc,
:sin, :sind, :sinh, :sinpi, :sinc,
:tan, :tand, :tanh,
:csc, :cscd, :csch,
:sec, :secd, :sech,
:cot, :cotd, :coth,
:acos, :acosd, :acosh,
:asin, :asind, :asinh,
:atan, :atand, :atanh,
:asec, :asech,
:acsc, :acsch,
:acot, :acoth,
:exp, :expm1, :exp10,
:frexp, :ldexp,
:abs, :abs2, :sqrt
)
#! format: on

for fn in scalar_operations
@eval Base.$fn(t::Tracer) = t
end

## Random numbers
rand(::AbstractRNG, ::SamplerType{Tracer}) = tracer()
97 changes: 97 additions & 0 deletions src/tracer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
Tracer(indexset) <: Number
Number type keeping track of input indices of previous computations.
See also the convenience constructor [`tracer`](@ref).
## Examples
```julia-repl
julia> x = tracer(1, 2, 3)
Tracer(1, 2, 3)
julia> sin(x)
Tracer(1, 2, 3)
julia> 2 * x^3
Tracer(1, 2, 3)
julia> 0 * x # Note: Tracer is strictly operator overloading...
Tracer(1, 2, 3)
julia> zero(x) # ...this can be overloaded
Tracer()
julia> y = tracer(3, 5)
Tracer(3, 5)
julia> x + y
Tracer(1, 2, 3, 5)
julia> x ^ y
Tracer(1, 2, 3, 5)
julia> M = rand(Tracer, 3, 2)
3×2 Matrix{Tracer}:
Tracer() Tracer()
Tracer() Tracer()
Tracer() Tracer()
julia> similar(M)
3×2 Matrix{Tracer}:
Tracer() Tracer()
Tracer() Tracer()
Tracer() Tracer()
julia> M * [x, y]
3-element Vector{Tracer}:
Tracer(1, 2, 3, 5)
Tracer(1, 2, 3, 5)
Tracer(1, 2, 3, 5)
```
"""
struct Tracer <: Number
inputs::Set{UInt64} # indices of connected, enumerated inputs
end

# We have to be careful when defining constructors:
# Generic code expecting "regular" numbers `x` will sometimes convert them
# by calling `T(x)` (instead of `convert(T, x)`), where `T` can be `Tracer`.
# When this happens, we create a new empty tracer with no input connectivity.
Tracer(::Number) = tracer()
Tracer(t::Tracer) = t
# We therefore exclusively use the lower-case `tracer` for convenience constructors

"""
tracer(index)
tracer(indices)
Convenience constructor for [`Tracer`](@ref) from input indices.
"""
tracer() = Tracer(Set{UInt64}())
tracer(a::Tracer, b::Tracer) = Tracer(union(a.inputs, b.inputs))

tracer(index::Integer) = Tracer(Set{UInt64}(index))
tracer(inds::NTuple{N,<:Integer}) where {N} = Tracer(Set{UInt64}(inds))
tracer(inds...) = tracer(inds)

# Utilities for accessing input indices
"""
inputs(tracer)
Return raw `UInt64` input indices of a [`Tracer`](@ref).
"""
inputs(t::Tracer) = collect(keys(t.inputs.dict))

"""
sortedinputs(tracer)
sortedinputs([T=Int], tracer)
Return sorted input indices of a [`Tracer`](@ref).
"""
sortedinputs(t::Tracer) = sortedinputs(Int, t)
sortedinputs(T::Type, t::Tracer) = convert.(T, sort!(inputs(t)))

function Base.show(io::IO, t::Tracer)
return Base.show_delim_array(io, sortedinputs(Int, t), "Tracer(", ',', ')', true)
end
Loading

0 comments on commit 63560b0

Please sign in to comment.