From 63560b013dc2cf26a8f05011bb52051692a41b56 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Mon, 1 Apr 2024 20:06:37 +0200 Subject: [PATCH] Initial functionality (#2) --- .JuliaFormatter.toml | 6 + .github/workflows/CI.yml | 8 +- Project.toml | 4 + README.md | 4 +- docs/make.jl | 22 +- docs/src/index.md | 18 +- src/SparseConnectivityTracer.jl | 12 +- src/connectivity.jl | 63 ++++ src/conversion.jl | 21 ++ src/operators.jl | 53 +++ src/tracer.jl | 97 ++++++ test/Manifest.toml | 350 +++++++++++++++++++- test/Project.toml | 5 + test/references/connectivity/NNlib/conv.txt | 1 + test/runtests.jl | 58 +++- 15 files changed, 694 insertions(+), 28 deletions(-) create mode 100644 .JuliaFormatter.toml create mode 100644 src/connectivity.jl create mode 100644 src/conversion.jl create mode 100644 src/operators.jl create mode 100644 src/tracer.jl create mode 100644 test/references/connectivity/NNlib/conv.txt diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 00000000..cc91afe1 --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1,6 @@ +style = "blue" +align_assignment = true +align_struct_field = true +align_conditional = true +align_pair_arrow = true +align_matrix = true diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 52540287..46c2ff53 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -23,7 +23,7 @@ jobs: fail-fast: false matrix: version: - - '1.6' + - '1' - 'nightly' os: - ubuntu-latest @@ -37,10 +37,14 @@ jobs: arch: ${{ matrix.arch }} - uses: julia-actions/cache@v1 - uses: julia-actions/julia-buildpkg@v1 + continue-on-error: ${{ matrix.version == 'nightly' }} - uses: julia-actions/julia-runtest@v1 + continue-on-error: ${{ matrix.version == 'nightly' }} - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 with: + token: ${{ secrets.CODECOV_TOKEN }} + slug: adrhill/SparseConnectivityTracer.jl files: lcov.info docs: name: Documentation diff --git a/Project.toml b/Project.toml index 2edc8c72..d8685032 100644 --- a/Project.toml +++ b/Project.toml @@ -3,5 +3,9 @@ uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" authors = ["Adrian Hill "] version = "1.0.0-DEV" +[deps] +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + [compat] julia = "1.6" diff --git a/README.md b/README.md index c7457845..90815142 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -# SparseConnectivityTracer +# SparseConnectivityTracer.jl -[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://adrhill.github.io/SparseConnectivityTracer.jl/stable/) + [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://adrhill.github.io/SparseConnectivityTracer.jl/dev/) [![Build Status](https://github.com/adrhill/SparseConnectivityTracer.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/adrhill/SparseConnectivityTracer.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/adrhill/SparseConnectivityTracer.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/adrhill/SparseConnectivityTracer.jl) diff --git a/docs/make.jl b/docs/make.jl index 096d8021..b21b6e94 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,23 +1,23 @@ using SparseConnectivityTracer using Documenter -DocMeta.setdocmeta!(SparseConnectivityTracer, :DocTestSetup, :(using SparseConnectivityTracer); recursive=true) +DocMeta.setdocmeta!( + SparseConnectivityTracer, + :DocTestSetup, + :(using SparseConnectivityTracer); + recursive=true, +) makedocs(; modules=[SparseConnectivityTracer], authors="Adrian Hill ", sitename="SparseConnectivityTracer.jl", format=Documenter.HTML(; - canonical="https://adrhill.github.io/SparseConnectivityTracer.jl", - edit_link="main", - assets=String[], + canonical = "https://adrhill.github.io/SparseConnectivityTracer.jl", + edit_link = "main", + assets = String[], ), - pages=[ - "Home" => "index.md", - ], + pages=["Home" => "index.md"], ) -deploydocs(; - repo="github.com/adrhill/SparseConnectivityTracer.jl", - devbranch="main", -) +deploydocs(; repo="github.com/adrhill/SparseConnectivityTracer.jl", devbranch="main") diff --git a/docs/src/index.md b/docs/src/index.md index 0dbe2afd..b9d2be31 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -9,6 +9,20 @@ Documentation for [SparseConnectivityTracer](https://github.com/adrhill/SparseCo ```@index ``` -```@autodocs -Modules = [SparseConnectivityTracer] +## API reference +SparseConnectivityTracer works by pushing a `Number` type called [`Tracer`](@ref) through generic functions: +```@docs +Tracer +tracer +``` + +The resulting connectivity matrix can be extracted using [`connectivity`](@ref): +```@docs +connectivity +``` + +or manually from individual [`Tracer`](@ref) outputs: +```@docs +inputs +sortedinputs ``` diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 02041487..43da4fa0 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -1,5 +1,13 @@ module SparseConnectivityTracer +import Random: rand, AbstractRNG, SamplerType +import SparseArrays: sparse -# Write your package code here. +include("tracer.jl") +include("conversion.jl") +include("operators.jl") +include("connectivity.jl") -end +export Tracer, tracer, inputs, sortedinputs +export connectivity + +end # module diff --git a/src/connectivity.jl b/src/connectivity.jl new file mode 100644 index 00000000..b5095c2c --- /dev/null +++ b/src/connectivity.jl @@ -0,0 +1,63 @@ +## Enumerate inputs +trace_input(x) = trace_input(x, 1) +trace_input(::Number, i) = tracer(i) +function trace_input(x::AbstractArray, i) + indices = (i - 1) .+ reshape(1:length(x), size(x)) + return tracer.(indices) +end + +## Construct connectivity matrix +""" + connectivity(f, x) + +Enumerates inputs `x` and primal outputs `y=f(x)` and returns sparse connectivity matrix `C` of size `(m, n)` +where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. +""" +function connectivity(f::Function, x) + xt = trace_input(x) + yt = f(xt) + return _connectivity(xt, yt) +end + +_connectivity(xt::Tracer, yt::Number) = _connectivity([xt], [yt]) +_connectivity(xt::Tracer, yt::AbstractArray{Number}) = _connectivity([xt], yt) +_connectivity(xt::AbstractArray{Tracer}, yt::Number) = _connectivity(xt, [yt]) +function _connectivity(xt::AbstractArray{Tracer}, yt::AbstractArray{<:Number}) + return connectivity_sparsematrixcsc(xt, yt) +end + +function connectivity_sparsematrixcsc( + xt::AbstractArray{Tracer}, yt::AbstractArray{<:Number} +) + # Construct connectivity matrix of size (ouput_dim, input_dim) + n, m = length(xt), length(yt) + I = UInt64[] + J = UInt64[] + V = Bool[] + for (i, y) in enumerate(yt) + if y isa Tracer + for j in inputs(y) + push!(I, i) + push!(J, j) + push!(V, true) + end + end + end + return sparse(I, J, V, m, n) +end + +function connectivity_bitmatrix(xt::AbstractArray{Tracer}, yt::AbstractArray{<:Number}) + # Construct connectivity matrix of size (ouput_dim, input_dim) + n, m = length(xt), length(yt) + C = BitArray(undef, m, n) + for i in axes(C, 1) + if yt[i] isa Tracer + for j in axes(C, 2) + C[i, j] = j ∈ yt[i].inputs + end + else + C[i, :] .= 0 + end + end + return C +end diff --git a/src/conversion.jl b/src/conversion.jl new file mode 100644 index 00000000..0cb3ed37 --- /dev/null +++ b/src/conversion.jl @@ -0,0 +1,21 @@ +## Type conversions +Base.promote_rule(::Type{Tracer}, ::Type{N}) where {N<:Number} = Tracer +Base.promote_rule(::Type{N}, ::Type{Tracer}) where {N<:Number} = Tracer + +Base.convert(::Type{Tracer}, x::Number) = tracer() +Base.convert(::Type{Tracer}, t::Tracer) = t +Base.convert(::Type{<:Number}, t::Tracer) = t + +## Array constructors +Base.zero(::Tracer) = tracer() +Base.zero(::Type{Tracer}) = tracer() +Base.one(::Tracer) = tracer() +Base.one(::Type{Tracer}) = tracer() + +Base.similar(a::Array{Tracer,1}) = zeros(Tracer, size(a, 1)) +Base.similar(a::Array{Tracer,2}) = zeros(Tracer, size(a, 1), size(a, 2)) +Base.similar(a::Array{T,1}, ::Type{Tracer}) where {T} = zeros(Tracer, size(a, 1)) +Base.similar(a::Array{T,2}, ::Type{Tracer}) where {T} = zeros(Tracer, size(a, 1), size(a, 2)) +Base.similar(::Array{Tracer}, m::Int) = zeros(Tracer, m) +Base.similar(::Array, ::Type{Tracer}, dims::Dims{N}) where {N} = zeros(Tracer, dims) +Base.similar(::Array{Tracer}, dims::Dims{N}) where {N} = zeros(Tracer, dims) diff --git a/src/operators.jl b/src/operators.jl new file mode 100644 index 00000000..d541a5a4 --- /dev/null +++ b/src/operators.jl @@ -0,0 +1,53 @@ +## Extent Base operators +for fn in (:+, :-, :*, :/) + @eval Base.$fn(a::Tracer, b::Tracer) = tracer(a, b) + for T in (:Number,) + @eval Base.$fn(t::Tracer, ::$T) = t + @eval Base.$fn(::$T, t::Tracer) = t + end +end + +Base.:^(a::Tracer, b::Tracer) = tracer(a, b) +for T in (:Number, :Integer, :Rational) + @eval Base.:^(t::Tracer, ::$T) = t + @eval Base.:^(::$T, t::Tracer) = t +end +Base.:^(t::Tracer, ::Irrational{:ℯ}) = t +Base.:^(::Irrational{:ℯ}, t::Tracer) = t + +## Two-argument functions +for fn in (:div, :fld, :cld) + @eval Base.$fn(a::Tracer, b::Tracer) = tracer(a, b) + @eval Base.$fn(t::Tracer, ::Number) = t + @eval Base.$fn(::Number, t::Tracer) = t +end + +## Single-argument functions + +#! format: off +scalar_operations = ( + :exp2, :deg2rad, :rad2deg, + :cos, :cosd, :cosh, :cospi, :cosc, + :sin, :sind, :sinh, :sinpi, :sinc, + :tan, :tand, :tanh, + :csc, :cscd, :csch, + :sec, :secd, :sech, + :cot, :cotd, :coth, + :acos, :acosd, :acosh, + :asin, :asind, :asinh, + :atan, :atand, :atanh, + :asec, :asech, + :acsc, :acsch, + :acot, :acoth, + :exp, :expm1, :exp10, + :frexp, :ldexp, + :abs, :abs2, :sqrt +) +#! format: on + +for fn in scalar_operations + @eval Base.$fn(t::Tracer) = t +end + +## Random numbers +rand(::AbstractRNG, ::SamplerType{Tracer}) = tracer() diff --git a/src/tracer.jl b/src/tracer.jl new file mode 100644 index 00000000..6340ec93 --- /dev/null +++ b/src/tracer.jl @@ -0,0 +1,97 @@ +""" + Tracer(indexset) <: Number + +Number type keeping track of input indices of previous computations. + +See also the convenience constructor [`tracer`](@ref). + +## Examples +```julia-repl +julia> x = tracer(1, 2, 3) +Tracer(1, 2, 3) + +julia> sin(x) +Tracer(1, 2, 3) + +julia> 2 * x^3 +Tracer(1, 2, 3) + +julia> 0 * x # Note: Tracer is strictly operator overloading... +Tracer(1, 2, 3) + +julia> zero(x) # ...this can be overloaded +Tracer() + +julia> y = tracer(3, 5) +Tracer(3, 5) + +julia> x + y +Tracer(1, 2, 3, 5) + +julia> x ^ y +Tracer(1, 2, 3, 5) + +julia> M = rand(Tracer, 3, 2) +3×2 Matrix{Tracer}: + Tracer() Tracer() + Tracer() Tracer() + Tracer() Tracer() + +julia> similar(M) +3×2 Matrix{Tracer}: + Tracer() Tracer() + Tracer() Tracer() + Tracer() Tracer() + +julia> M * [x, y] +3-element Vector{Tracer}: + Tracer(1, 2, 3, 5) + Tracer(1, 2, 3, 5) + Tracer(1, 2, 3, 5) +``` +""" +struct Tracer <: Number + inputs::Set{UInt64} # indices of connected, enumerated inputs +end + +# We have to be careful when defining constructors: +# Generic code expecting "regular" numbers `x` will sometimes convert them +# by calling `T(x)` (instead of `convert(T, x)`), where `T` can be `Tracer`. +# When this happens, we create a new empty tracer with no input connectivity. +Tracer(::Number) = tracer() +Tracer(t::Tracer) = t +# We therefore exclusively use the lower-case `tracer` for convenience constructors + +""" + tracer(index) + tracer(indices) + +Convenience constructor for [`Tracer`](@ref) from input indices. +""" +tracer() = Tracer(Set{UInt64}()) +tracer(a::Tracer, b::Tracer) = Tracer(union(a.inputs, b.inputs)) + +tracer(index::Integer) = Tracer(Set{UInt64}(index)) +tracer(inds::NTuple{N,<:Integer}) where {N} = Tracer(Set{UInt64}(inds)) +tracer(inds...) = tracer(inds) + +# Utilities for accessing input indices +""" + inputs(tracer) + +Return raw `UInt64` input indices of a [`Tracer`](@ref). +""" +inputs(t::Tracer) = collect(keys(t.inputs.dict)) + +""" + sortedinputs(tracer) + sortedinputs([T=Int], tracer) + +Return sorted input indices of a [`Tracer`](@ref). +""" +sortedinputs(t::Tracer) = sortedinputs(Int, t) +sortedinputs(T::Type, t::Tracer) = convert.(T, sort!(inputs(t))) + +function Base.show(io::IO, t::Tracer) + return Base.show_delim_array(io, sortedinputs(Int, t), "Tracer(", ',', ')', true) +end diff --git a/test/Manifest.toml b/test/Manifest.toml index 1c64683d..dd5ad7f5 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -2,7 +2,17 @@ julia_version = "1.10.2" manifest_format = "2.0" -project_hash = "1f45826613cafa0faebd6512f3b1b9e3efe7eaaf" +project_hash = "1d3f1a9039733a79302ff3b02a00cfd52278c0de" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "4.0.4" +weakdeps = ["StaticArrays"] + + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" [[deps.Aqua]] deps = ["Compat", "Pkg", "Test"] @@ -17,32 +27,118 @@ version = "1.1.1" [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +[[deps.Atomix]] +deps = ["UnsafeAtomics"] +git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" +uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" +version = "0.1.0" + [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +[[deps.CEnum]] +git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.5.0" + +[[deps.CSTParser]] +deps = ["Tokenize"] +git-tree-sha1 = "b544d62417a99d091c569b95109bc9d8c223e9e3" +uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" +version = "3.4.2" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.23.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" + [[deps.CodeTracking]] deps = ["InteractiveUtils", "UUIDs"] git-tree-sha1 = "c0216e792f518b39b22212127d4a84dc31e4e386" uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" version = "1.3.5" +[[deps.ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "eb7f0f8307f71fac7c606984ea5fb2817275d6e4" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.11.4" + +[[deps.ColorVectorSpace]] +deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] +git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" +uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" +version = "0.10.0" + + [deps.ColorVectorSpace.extensions] + SpecialFunctionsExt = "SpecialFunctions" + + [deps.ColorVectorSpace.weakdeps] + SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" + +[[deps.Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "fc08e5930ee9a4e03f84bfb5211cb54e7769758a" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.12.10" + +[[deps.CommonMark]] +deps = ["Crayons", "JSON", "PrecompileTools", "URIs"] +git-tree-sha1 = "532c4185d3c9037c0237546d817858b23cf9e071" +uuid = "a80b9123-70ca-4bc0-993e-6e3bcb318db6" +version = "0.8.12" + [[deps.Compat]] deps = ["TOML", "UUIDs"] git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" version = "4.14.0" +weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] CompatLinearAlgebraExt = "LinearAlgebra" - [deps.Compat.weakdeps] - Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" - LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.1.0+0" + +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "0f4b5d62a88d8f59003e43c25a8a90de9eb76317" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.18" [[deps.Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +[[deps.DeepDiffs]] +git-tree-sha1 = "9824894295b62a6a4ab6adf1c7bf337b3a9ca34c" +uuid = "ab62b9b5-e342-54a8-a765-a90f495de1a6" +version = "1.2.0" + +[[deps.Distances]] +deps = ["LinearAlgebra", "Statistics", "StatsAPI"] +git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" +uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +version = "0.10.11" +weakdeps = ["ChainRulesCore", "SparseArrays"] + + [deps.Distances.extensions] + DistancesChainRulesCoreExt = "ChainRulesCore" + DistancesSparseArraysExt = "SparseArrays" + [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -52,9 +148,44 @@ deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" +[[deps.FileIO]] +deps = ["Pkg", "Requires", "UUIDs"] +git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" +uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +version = "1.16.3" + [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" +[[deps.FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.4" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.6" + +[[deps.Glob]] +git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" +uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" +version = "1.3.1" + +[[deps.ImageBase]] +deps = ["ImageCore", "Reexport"] +git-tree-sha1 = "eb49b82c172811fd2c86759fa0553a2221feb909" +uuid = "c817782e-172a-44cc-b673-b171935fbb9e" +version = "0.1.7" + +[[deps.ImageCore]] +deps = ["ColorVectorSpace", "Colors", "FixedPointNumbers", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "PrecompileTools", "Reexport"] +git-tree-sha1 = "b2a7eaa169c13f5bcae8131a83bc30eff8f71be0" +uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" +version = "0.10.2" + [[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -65,12 +196,69 @@ git-tree-sha1 = "6ff76fc594051832ce91078686bc0d3def6d42c5" uuid = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" version = "0.8.29" +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.5.0" + +[[deps.JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.4" + +[[deps.JuliaFormatter]] +deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "Pkg", "PrecompileTools", "Tokenize"] +git-tree-sha1 = "e07d6fd7db543b11cd90ed764efec53f39851f09" +uuid = "98e50ef6-434e-11e9-1051-2b60c6c9e899" +version = "1.0.54" + [[deps.JuliaInterpreter]] deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"] git-tree-sha1 = "e9648d90370e2d0317f9518c9c6e0841db54a90b" uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" version = "0.9.31" +[[deps.KernelAbstractions]] +deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "ed7167240f40e62d97c1f5f7735dea6de3cc5c49" +uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +version = "0.9.18" + + [deps.KernelAbstractions.extensions] + EnzymeExt = "EnzymeCore" + + [deps.KernelAbstractions.weakdeps] + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + +[[deps.LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] +git-tree-sha1 = "ab01dde107f21aa76144d0771dccc08f152ccac7" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "6.6.2" + + [deps.LLVM.extensions] + BFloat16sExt = "BFloat16s" + + [deps.LLVM.weakdeps] + BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" + +[[deps.LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "88b916503aac4fb7f701bb625cd84ca5dd1677bc" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.29+0" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LazyModules]] +git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" +uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" +version = "0.3.1" + [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" @@ -98,6 +286,10 @@ version = "1.11.0+1" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -113,6 +305,11 @@ git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" version = "0.5.13" +[[deps.MappedArrays]] +git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" +uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" +version = "0.4.2" + [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" @@ -122,19 +319,72 @@ deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" version = "2.28.2+1" +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MosaicViews]] +deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] +git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" +uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" +version = "0.3.4" + [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2023.1.10" +[[deps.NNlib]] +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] +git-tree-sha1 = "1fa1a14766c60e66ab22e242d45c1857c83a3805" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.9.13" + + [deps.NNlib.extensions] + NNlibAMDGPUExt = "AMDGPU" + NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] + NNlibCUDAExt = "CUDA" + NNlibEnzymeCoreExt = "EnzymeCore" + + [deps.NNlib.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" +[[deps.OffsetArrays]] +git-tree-sha1 = "6a731f2b5c03157418a20c12195eb4b74c8f8621" +uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +version = "1.13.0" +weakdeps = ["Adapt"] + + [deps.OffsetArrays.extensions] + OffsetArraysAdaptExt = "Adapt" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.23+4" + [[deps.OrderedCollections]] git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" version = "1.6.3" +[[deps.PaddedViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" +uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" +version = "0.5.12" + +[[deps.Parsers]] +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.8.1" + [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -164,6 +414,17 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.ReferenceTests]] +deps = ["Colors", "DeepDiffs", "Distances", "FileIO", "ImageCore", "LazyModules", "Random", "SHA", "Test", "XTermColors"] +git-tree-sha1 = "ce6661add3d5a76c6e2e2c56f14b41a50577276c" +uuid = "324d217c-45ce-50fc-942e-d289b448e8cf" +version = "0.10.4" + [[deps.Requires]] deps = ["UUIDs"] git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" @@ -186,6 +447,49 @@ uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" + +[[deps.StackViews]] +deps = ["OffsetArrays"] +git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" +uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" +version = "0.1.1" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.9.3" +weakdeps = ["ChainRulesCore", "Statistics"] + + [deps.StaticArrays.extensions] + StaticArraysChainRulesCoreExt = "ChainRulesCore" + StaticArraysStatisticsExt = "Statistics" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.2" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.10.0" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.7.0" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.2.1+1" + [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" @@ -196,10 +500,26 @@ deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" version = "1.10.0" +[[deps.TensorCore]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" +uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" +version = "0.1.1" + [[deps.Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[deps.Tokenize]] +git-tree-sha1 = "5b5a892ba7704c0977013bd0f9c30f5d962181e0" +uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" +version = "0.5.28" + +[[deps.URIs]] +git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" +uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" +version = "1.5.1" + [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" @@ -207,11 +527,33 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +[[deps.UnsafeAtomics]] +git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" +uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" +version = "0.2.1" + +[[deps.UnsafeAtomicsLLVM]] +deps = ["LLVM", "UnsafeAtomics"] +git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" +uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" +version = "0.1.3" + +[[deps.XTermColors]] +deps = ["Crayons", "ImageBase", "OffsetArrays"] +git-tree-sha1 = "bc27b7622a51f570c57b80bd839d1c0d43605b38" +uuid = "c8c2cc18-de81-4e68-b407-38a3a0c0491f" +version = "0.2.1" + [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" version = "1.2.13+1" +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.8.0+1" + [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" diff --git a/test/Project.toml b/test/Project.toml index 737dfdaf..3c0b9875 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,9 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/references/connectivity/NNlib/conv.txt b/test/references/connectivity/NNlib/conv.txt new file mode 100644 index 00000000..aaf436e1 --- /dev/null +++ b/test/references/connectivity/NNlib/conv.txt @@ -0,0 +1 @@ +Bool[1 1 0 1 1 0 0 0 0 1 1 0 1 1 0 0 0 0; 0 1 1 0 1 1 0 0 0 0 1 1 0 1 1 0 0 0; 0 0 0 1 1 0 1 1 0 0 0 0 1 1 0 1 1 0; 0 0 0 0 1 1 0 1 1 0 0 0 0 1 1 0 1 1] \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 9eff7a85..a3c7bb0e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,14 +1,62 @@ using SparseConnectivityTracer +using SparseConnectivityTracer: trace_input + using Test +using ReferenceTests +using JuliaFormatter using Aqua using JET +using LinearAlgebra +using Random +using NNlib + @testset "SparseConnectivityTracer.jl" begin - @testset "Code quality (Aqua.jl)" begin - Aqua.test_all(SparseConnectivityTracer) + @testset "Code formatting" begin + @test JuliaFormatter.format( + SparseConnectivityTracer; verbose=false, overwrite=false + ) + end + @testset "Aqua.jl tests" begin + Aqua.test_all( + SparseConnectivityTracer; + ambiguities=false, + deps_compat=(ignore=[:Random, :SparseArrays],), + ) + end + @testset "JET tests" begin + JET.test_package(SparseConnectivityTracer; target_defined_modules=true) + end + + @testset "Connectivity" begin + x = rand(3) + xt = trace_input(x) + + # Matrix multiplication + A = rand(1, 3) + yt = only(A * xt) + @test sortedinputs(yt) == [1, 2, 3] + + @test connectivity(x -> only(A * x), x) ≈ [1 1 1] + + # Custom functions + f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])] + yt = f(xt) + @test sortedinputs(yt[1]) == [1] + @test sortedinputs(yt[2]) == [1, 2] + @test sortedinputs(yt[3]) == [3] + + @test connectivity(f, x) ≈ [1 0 0; 1 1 0; 0 0 1] + + @test connectivity(identity, rand()) ≈ [1;;] + @test connectivity(Returns(1), 1) ≈ [0;;] end - @testset "Code linting (JET.jl)" begin - JET.test_package(SparseConnectivityTracer; target_defined_modules = true) + @testset "Real-world tests" begin + @testset "NNlib" begin + x = rand(3, 3, 2, 1) # WHCN + w = rand(2, 2, 2, 1) # Conv((2, 2), 2 => 1) + C = connectivity(x -> NNlib.conv(x, w), x) + @test_reference "references/connectivity/NNlib/conv.txt" BitMatrix(C) + end end - # Write your tests here. end