Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial functionality #2

Merged
merged 16 commits into from
Apr 1, 2024
6 changes: 6 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
style = "blue"
align_assignment = true
align_struct_field = true
align_conditional = true
align_pair_arrow = true
align_matrix = true
8 changes: 6 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1'
- 'nightly'
os:
- ubuntu-latest
Expand All @@ -37,10 +37,14 @@ jobs:
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
continue-on-error: ${{ matrix.version == 'nightly' }}
- uses: julia-actions/julia-runtest@v1
continue-on-error: ${{ matrix.version == 'nightly' }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v3
- uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: adrhill/SparseConnectivityTracer.jl
files: lcov.info
docs:
name: Documentation
Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,9 @@ uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
authors = ["Adrian Hill <[email protected]>"]
version = "1.0.0-DEV"

[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
julia = "1.6"
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SparseConnectivityTracer
# SparseConnectivityTracer.jl

[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://adrhill.github.io/SparseConnectivityTracer.jl/stable/)
<!-- [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://adrhill.github.io/SparseConnectivityTracer.jl/stable/) -->
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://adrhill.github.io/SparseConnectivityTracer.jl/dev/)
[![Build Status](https://github.com/adrhill/SparseConnectivityTracer.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/adrhill/SparseConnectivityTracer.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/adrhill/SparseConnectivityTracer.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/adrhill/SparseConnectivityTracer.jl)
Expand Down
22 changes: 11 additions & 11 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
using SparseConnectivityTracer
using Documenter

DocMeta.setdocmeta!(SparseConnectivityTracer, :DocTestSetup, :(using SparseConnectivityTracer); recursive=true)
DocMeta.setdocmeta!(
SparseConnectivityTracer,
:DocTestSetup,
:(using SparseConnectivityTracer);
recursive=true,
)

makedocs(;
modules=[SparseConnectivityTracer],
authors="Adrian Hill <[email protected]>",
sitename="SparseConnectivityTracer.jl",
format=Documenter.HTML(;
canonical="https://adrhill.github.io/SparseConnectivityTracer.jl",
edit_link="main",
assets=String[],
canonical = "https://adrhill.github.io/SparseConnectivityTracer.jl",
edit_link = "main",
assets = String[],
),
pages=[
"Home" => "index.md",
],
pages=["Home" => "index.md"],
)

deploydocs(;
repo="github.com/adrhill/SparseConnectivityTracer.jl",
devbranch="main",
)
deploydocs(; repo="github.com/adrhill/SparseConnectivityTracer.jl", devbranch="main")
18 changes: 16 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@ Documentation for [SparseConnectivityTracer](https://github.com/adrhill/SparseCo
```@index
```

```@autodocs
Modules = [SparseConnectivityTracer]
## API reference
SparseConnectivityTracer works by pushing a `Number` type called [`Tracer`](@ref) through generic functions:
```@docs
Tracer
tracer
```

The resulting connectivity matrix can be extracted using [`connectivity`](@ref):
```@docs
connectivity
```

or manually from individual [`Tracer`](@ref) outputs:
```@docs
inputs
sortedinputs
```
12 changes: 10 additions & 2 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
module SparseConnectivityTracer
import Random: rand, AbstractRNG, SamplerType
import SparseArrays: sparse

# Write your package code here.
include("tracer.jl")
include("conversion.jl")
include("operators.jl")
include("connectivity.jl")

end
export Tracer, tracer, inputs, sortedinputs
export connectivity

end # module
63 changes: 63 additions & 0 deletions src/connectivity.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
## Enumerate inputs
trace_input(x) = trace_input(x, 1)
trace_input(::Number, i) = tracer(i)
function trace_input(x::AbstractArray, i)
indices = (i - 1) .+ reshape(1:length(x), size(x))
return tracer.(indices)
end

## Construct connectivity matrix
"""
connectivity(f, x)

Enumerates inputs `x` and primal outputs `y=f(x)` and returns sparse connectivity matrix `C` of size `(m, n)`
where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`.
"""
function connectivity(f::Function, x)
xt = trace_input(x)
yt = f(xt)
return _connectivity(xt, yt)
end

_connectivity(xt::Tracer, yt::Number) = _connectivity([xt], [yt])
_connectivity(xt::Tracer, yt::AbstractArray{Number}) = _connectivity([xt], yt)
_connectivity(xt::AbstractArray{Tracer}, yt::Number) = _connectivity(xt, [yt])
function _connectivity(xt::AbstractArray{Tracer}, yt::AbstractArray{<:Number})
return connectivity_sparsematrixcsc(xt, yt)
end

function connectivity_sparsematrixcsc(
xt::AbstractArray{Tracer}, yt::AbstractArray{<:Number}
)
# Construct connectivity matrix of size (ouput_dim, input_dim)
n, m = length(xt), length(yt)
I = UInt64[]
J = UInt64[]
V = Bool[]
for (i, y) in enumerate(yt)
if y isa Tracer
for j in inputs(y)
push!(I, i)
push!(J, j)
push!(V, true)
end
end
end
return sparse(I, J, V, m, n)
end

function connectivity_bitmatrix(xt::AbstractArray{Tracer}, yt::AbstractArray{<:Number})
# Construct connectivity matrix of size (ouput_dim, input_dim)
n, m = length(xt), length(yt)
C = BitArray(undef, m, n)
for i in axes(C, 1)
if yt[i] isa Tracer
for j in axes(C, 2)
C[i, j] = j ∈ yt[i].inputs
end
else
C[i, :] .= 0
end
end
return C
end
21 changes: 21 additions & 0 deletions src/conversion.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
## Type conversions
Base.promote_rule(::Type{Tracer}, ::Type{N}) where {N<:Number} = Tracer
Base.promote_rule(::Type{N}, ::Type{Tracer}) where {N<:Number} = Tracer

Base.convert(::Type{Tracer}, x::Number) = tracer()
Base.convert(::Type{Tracer}, t::Tracer) = t
Base.convert(::Type{<:Number}, t::Tracer) = t

## Array constructors
Base.zero(::Tracer) = tracer()
Base.zero(::Type{Tracer}) = tracer()
Base.one(::Tracer) = tracer()
Base.one(::Type{Tracer}) = tracer()

Base.similar(a::Array{Tracer,1}) = zeros(Tracer, size(a, 1))
Base.similar(a::Array{Tracer,2}) = zeros(Tracer, size(a, 1), size(a, 2))
Base.similar(a::Array{T,1}, ::Type{Tracer}) where {T} = zeros(Tracer, size(a, 1))
Base.similar(a::Array{T,2}, ::Type{Tracer}) where {T} = zeros(Tracer, size(a, 1), size(a, 2))
Base.similar(::Array{Tracer}, m::Int) = zeros(Tracer, m)
Base.similar(::Array, ::Type{Tracer}, dims::Dims{N}) where {N} = zeros(Tracer, dims)
Base.similar(::Array{Tracer}, dims::Dims{N}) where {N} = zeros(Tracer, dims)
53 changes: 53 additions & 0 deletions src/operators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
## Extent Base operators
for fn in (:+, :-, :*, :/)
@eval Base.$fn(a::Tracer, b::Tracer) = tracer(a, b)
for T in (:Number,)
@eval Base.$fn(t::Tracer, ::$T) = t
@eval Base.$fn(::$T, t::Tracer) = t
end
end

Base.:^(a::Tracer, b::Tracer) = tracer(a, b)
for T in (:Number, :Integer, :Rational)
@eval Base.:^(t::Tracer, ::$T) = t
@eval Base.:^(::$T, t::Tracer) = t
end
Base.:^(t::Tracer, ::Irrational{:ℯ}) = t
Base.:^(::Irrational{:ℯ}, t::Tracer) = t

## Two-argument functions
for fn in (:div, :fld, :cld)
@eval Base.$fn(a::Tracer, b::Tracer) = tracer(a, b)
@eval Base.$fn(t::Tracer, ::Number) = t
@eval Base.$fn(::Number, t::Tracer) = t
end

## Single-argument functions

#! format: off
scalar_operations = (
:exp2, :deg2rad, :rad2deg,
:cos, :cosd, :cosh, :cospi, :cosc,
:sin, :sind, :sinh, :sinpi, :sinc,
:tan, :tand, :tanh,
:csc, :cscd, :csch,
:sec, :secd, :sech,
:cot, :cotd, :coth,
:acos, :acosd, :acosh,
:asin, :asind, :asinh,
:atan, :atand, :atanh,
:asec, :asech,
:acsc, :acsch,
:acot, :acoth,
:exp, :expm1, :exp10,
:frexp, :ldexp,
:abs, :abs2, :sqrt
)
#! format: on

for fn in scalar_operations
@eval Base.$fn(t::Tracer) = t
end

## Random numbers
rand(::AbstractRNG, ::SamplerType{Tracer}) = tracer()
97 changes: 97 additions & 0 deletions src/tracer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
Tracer(indexset) <: Number

Number type keeping track of input indices of previous computations.

See also the convenience constructor [`tracer`](@ref).

## Examples
```julia-repl
julia> x = tracer(1, 2, 3)
Tracer(1, 2, 3)

julia> sin(x)
Tracer(1, 2, 3)

julia> 2 * x^3
Tracer(1, 2, 3)

julia> 0 * x # Note: Tracer is strictly operator overloading...
Tracer(1, 2, 3)

julia> zero(x) # ...this can be overloaded
Tracer()

julia> y = tracer(3, 5)
Tracer(3, 5)

julia> x + y
Tracer(1, 2, 3, 5)

julia> x ^ y
Tracer(1, 2, 3, 5)

julia> M = rand(Tracer, 3, 2)
3×2 Matrix{Tracer}:
Tracer() Tracer()
Tracer() Tracer()
Tracer() Tracer()

julia> similar(M)
3×2 Matrix{Tracer}:
Tracer() Tracer()
Tracer() Tracer()
Tracer() Tracer()

julia> M * [x, y]
3-element Vector{Tracer}:
Tracer(1, 2, 3, 5)
Tracer(1, 2, 3, 5)
Tracer(1, 2, 3, 5)
```
"""
struct Tracer <: Number
inputs::Set{UInt64} # indices of connected, enumerated inputs
end

# We have to be careful when defining constructors:
# Generic code expecting "regular" numbers `x` will sometimes convert them
# by calling `T(x)` (instead of `convert(T, x)`), where `T` can be `Tracer`.
# When this happens, we create a new empty tracer with no input connectivity.
Tracer(::Number) = tracer()
Tracer(t::Tracer) = t
# We therefore exclusively use the lower-case `tracer` for convenience constructors

"""
tracer(index)
tracer(indices)

Convenience constructor for [`Tracer`](@ref) from input indices.
"""
tracer() = Tracer(Set{UInt64}())
tracer(a::Tracer, b::Tracer) = Tracer(union(a.inputs, b.inputs))

tracer(index::Integer) = Tracer(Set{UInt64}(index))
tracer(inds::NTuple{N,<:Integer}) where {N} = Tracer(Set{UInt64}(inds))
tracer(inds...) = tracer(inds)

# Utilities for accessing input indices
"""
inputs(tracer)

Return raw `UInt64` input indices of a [`Tracer`](@ref).
"""
inputs(t::Tracer) = collect(keys(t.inputs.dict))

"""
sortedinputs(tracer)
sortedinputs([T=Int], tracer)

Return sorted input indices of a [`Tracer`](@ref).
"""
sortedinputs(t::Tracer) = sortedinputs(Int, t)
sortedinputs(T::Type, t::Tracer) = convert.(T, sort!(inputs(t)))

function Base.show(io::IO, t::Tracer)
return Base.show_delim_array(io, sortedinputs(Int, t), "Tracer(", ',', ')', true)
end
Loading
Loading