Skip to content

Commit

Permalink
Add SortedVector set type (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored May 3, 2024
1 parent 3001626 commit 79cf7b9
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 36 deletions.
6 changes: 6 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,9 @@ ConnectivityTracer
JacobianTracer
HessianTracer
```

We also define a custom alternative to sets that can deliver faster `union`:

```@docs
SparseConnectivityTracer.SortedVector
```
1 change: 1 addition & 0 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ include("overload_jacobian.jl")
include("overload_hessian.jl")
include("pattern.jl")
include("adtypes.jl")
include("sortedvector.jl")

export ConnectivityTracer, connectivity_pattern
export JacobianTracer, jacobian_pattern
Expand Down
86 changes: 86 additions & 0 deletions src/sortedvector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
SortedVector
A wrapper for sorted vectors, designed for fast unions.
# Constructor
SortedVector(data::AbstractVector; already_sorted=false)
# Example
```jldoctest
x = SortedVector([3, 4, 2])
x = SortedVector([1, 3, 5]; already_sorted=true)
z = union(x, y)
# output
SortedVector([1, 2, 3, 4, 5])
````
"""
struct SortedVector{T<:Number} <: AbstractVector{T}
data::Vector{T}

function SortedVector{T}(data::AbstractVector{T}) where {T}
return new{T}(convert(Vector{T}, data))
end

function SortedVector{T}(x::Number) where {T}
return new{T}([convert(T, x)])
end

function SortedVector{T}() where {T}
return new{T}(T[])
end
end

function SortedVector(data::AbstractVector{T}; already_sorted=false) where {T}
sorted_data = ifelse(already_sorted, data, sort(data))
return SortedVector{T}(sorted_data)
end

function Base.convert(::Type{SortedVector{T}}, v::Vector{T}) where {T}
return SortedVector(v; already_sorted=false)
end

Base.eltype(::SortedVector{T}) where {T} = T
Base.size(v::SortedVector) = size(v.data)
Base.getindex(v::SortedVector, i) = v.data[i]
Base.IndexStyle(::Type{SortedVector{T}}) where {T} = IndexStyle(Vector{T})
Base.show(io::IO, v::SortedVector) = print(io, "SortedVector($(v.data))")

function Base.union(v1::SortedVector{T}, v2::SortedVector{T}) where {T}
left, right = v1.data, v2.data
result = similar(left, length(left) + length(right))
left_index, right_index, result_index = 1, 1, 1
# common part of left and right
@inbounds while (
left_index in eachindex(left) &&
right_index in eachindex(right) &&
result_index in eachindex(result)
)
left_item = left[left_index]
right_item = right[right_index]
left_smaller = left_item <= right_item
right_smaller = right_item <= left_item
result_item = ifelse(left_smaller, left_item, right_item)
result[result_index] = result_item
result_index += 1
left_index = ifelse(left_smaller, left_index + 1, left_index)
right_index = ifelse(right_smaller, right_index + 1, right_index)
end
# either left or right has reached its end at this point
@inbounds while left_index in eachindex(left) && result_index in eachindex(result)
result[result_index] = left[left_index]
left_index += 1
result_index += 1
end
@inbounds while right_index in eachindex(right) && result_index in eachindex(result)
result[result_index] = right[right_index]
right_index += 1
result_index += 1
end
resize!(result, result_index - 1)
return SortedVector(result; already_sorted=true)
end
14 changes: 9 additions & 5 deletions src/tracers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const SET_TYPE_MESSAGE = """
The provided index set type `S` has to satisfy the following conditions:
- it is an iterable with `<:Integer` element type
- it implements methods `union`, `union!` and `push!`
- it implements `union`
Subtypes of `AbstractSet{<:Integer}` are a natural choice, like `BitSet` or `Set{UInt64}`.
"""
Expand All @@ -30,7 +30,9 @@ struct ConnectivityTracer{S} <: AbstractTracer
end

function Base.show(io::IO, t::ConnectivityTracer{S}) where {S}
return Base.show_delim_array(io, inputs(t), "ConnectivityTracer{$S}(", ',', ')', true)
return Base.show_delim_array(
io, convert.(Int, inputs(t)), "ConnectivityTracer{$S}(", ',', ')', true
)
end

empty(::Type{ConnectivityTracer{S}}) where {S} = ConnectivityTracer(S())
Expand Down Expand Up @@ -78,7 +80,9 @@ struct JacobianTracer{S} <: AbstractTracer
end

function Base.show(io::IO, t::JacobianTracer{S}) where {S}
return Base.show_delim_array(io, inputs(t), "JacobianTracer{$S}(", ',', ')', true)
return Base.show_delim_array(
io, convert.(Int, inputs(t)), "JacobianTracer{$S}(", ',', ')', true
)
end

empty(::Type{JacobianTracer{S}}) where {S} = JacobianTracer(S())
Expand Down Expand Up @@ -122,8 +126,8 @@ end
function Base.show(io::IO, t::HessianTracer{S}) where {S}
println(io, "HessianTracer{", S, "}(")
for key in keys(t.inputs)
print(io, " ", key, " => ")
Base.show_delim_array(io, collect(t.inputs[key]), "(", ',', ')', true)
print(io, " ", Int(key), " => ")
Base.show_delim_array(io, convert.(Int, t.inputs[key]), "(", ',', ')', true)
println(io, ",")
end
return print(io, ")")
Expand Down
17 changes: 11 additions & 6 deletions test/benchmark.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using BenchmarkTools
using SparseConnectivityTracer
using SparseConnectivityTracer: SortedVector
using Symbolics: Symbolics
using NNlib: conv

Expand All @@ -18,8 +19,10 @@ function benchmark_brusselator(N::Integer, method=:tracer)
du = similar(u)
f!(du, u) = brusselator_2d_loop(du, u, p, nothing)

if method == :tracer
return @benchmark pattern($f!, $du, $u)
if method == :tracer_bitset
return @benchmark jacobian_pattern($f!, $du, $u, BitSet)
elseif method == :tracer_sortedvector
return @benchmark jacobian_pattern($f!, $du, $u, SortedVector{UInt64})
elseif method == :symbolics
return @benchmark Symbolics.jacobian_sparsity($f!, $du, $u)
end
Expand All @@ -30,16 +33,18 @@ function benchmark_conv(method=:tracer)
w = rand(5, 5, 3, 16) # corresponds to Conv((5, 5), 3 => 16)
f(x) = conv(x, w)

if method == :tracer
return @benchmark pattern($f, $x)
if method == :tracer_bitset
return @benchmark jacobian_pattern($f, $x, BitSet)
elseif method == :tracer_sortedvector
return @benchmark jacobian_pattern($f, $x, SortedVector{UInt64})
elseif method == :symbolics
return @benchmark Symbolics.jacobian_sparsity($f, $x)
end
end

## Run Brusselator benchmarks
for N in (6, 24)
for method in (:tracer, :symbolics)
for N in (6, 24, 100)
for method in (:tracer_bitset, :tracer_sortedvector, :symbolics)
@info "Benchmarking Brusselator of size $N with $method..."
b = benchmark_brusselator(N, method)
display(b)
Expand Down
2 changes: 1 addition & 1 deletion test/references/show/ConnectivityTracer_Set{UInt64}.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ConnectivityTracer{Set{UInt64}}(0x0000000000000002,)
ConnectivityTracer{Set{UInt64}}(2,)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ConnectivityTracer{SortedVector{UInt64}}(2,)
3 changes: 3 additions & 0 deletions test/references/show/HessianTracer_SortedVector{UInt64}.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
HessianTracer{SortedVector{UInt64}}(
2 => (),
)
2 changes: 1 addition & 1 deletion test/references/show/JacobianTracer_Set{UInt64}.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
JacobianTracer{Set{UInt64}}(0x0000000000000002,)
JacobianTracer{Set{UInt64}}(2,)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
JacobianTracer{SortedVector{UInt64}}(2,)
29 changes: 6 additions & 23 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ DocMeta.setdocmeta!(
@testset "Doctests" begin
Documenter.doctest(SparseConnectivityTracer)
end
@testset "SortedVector" begin
include("sortedvector.jl")
end
@testset "Classification of operators by diff'ability" begin
include("test_differentiability.jl")
end
@testset "First order" begin
for S in (BitSet, Set{UInt64})
for S in (BitSet, Set{UInt64}, SortedVector{UInt64})
@testset "Set type $S" begin
CT = ConnectivityTracer{S}
JT = JacobianTracer{S}
Expand Down Expand Up @@ -92,19 +95,6 @@ DocMeta.setdocmeta!(
@test jacobian_pattern(x ->^x, 1, S) [1;;]
@test jacobian_pattern(x -> round(x, RoundNearestTiesUp), 1, S) [0;;]

@test rand(CT) == empty(CT)
@test rand(JT) == empty(JT)

t = tracer(CT, 2)
@test ConnectivityTracer(t) == t
@test empty(t) == empty(CT)
@test CT(1) == empty(CT)

t = tracer(JT, 2)
@test JacobianTracer(t) == t
@test empty(t) == empty(JT)
@test JT(1) == empty(JT)

# Base.show
@test_reference "references/show/ConnectivityTracer_$S.txt" repr(
"text/plain", tracer(CT, 2)
Expand All @@ -116,7 +106,7 @@ DocMeta.setdocmeta!(
end
end
@testset "Second order" begin
for S in (BitSet, Set{UInt64})
for S in (BitSet, Set{UInt64}, SortedVector{UInt64})
@testset "Set type $S" begin
HT = HessianTracer{S}

Expand All @@ -136,13 +126,6 @@ DocMeta.setdocmeta!(
@test hessian_pattern(x ->^x, 1) [1;;]
@test hessian_pattern(x -> round(x, RoundNearestTiesUp), 1) [0;;]

@test rand(HT) == empty(HT)

t = tracer(HT, 2)
@test HessianTracer(t) == t
@test empty(t) == empty(HT)
@test HT(1) == empty(HT)

H = hessian_pattern(x -> x[1] / x[2] + x[3] / 1 + 1 / x[4], rand(4), S)
@test H [
0 1 0 0
Expand Down Expand Up @@ -258,7 +241,7 @@ DocMeta.setdocmeta!(
end
@testset "Real-world tests" begin
include("brusselator.jl")
for S in (BitSet, Set{UInt64})
for S in (BitSet, Set{UInt64}, SortedVector{UInt64})
@testset "Set type $S" begin
@testset "Brusselator" begin
N = 6
Expand Down
29 changes: 29 additions & 0 deletions test/sortedvector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using ADTypes
using SparseArrays
using SparseConnectivityTracer
using SparseConnectivityTracer: SortedVector
using Test

@testset "Correctness" begin
@testset "$T - ($k1, $k2)" for T in (Int32, Int64),
k1 in (0, 10, 100, 1000),
k2 in (0, 10, 100, 1000)

for _ in 1:100
x = SortedVector(rand(T(1):T(1000), k1); already_sorted=false)
y = SortedVector(sort(rand(T(1):T(1000), k2)); already_sorted=true)
z = union(x, y)
@test eltype(z) == T
@test issorted(z.data)
@test Set(z.data) == union(Set(x.data), Set(y.data))
if k1 > 0 && k2 > 0
@test z[1] == min(x[1], y[1])
@test z[end] == max(x[end], y[end])
end
end
end
end;

sd = TracerSparsityDetector(SortedVector{UInt})
@test ADTypes.jacobian_sparsity(diff, rand(10), sd) isa SparseMatrixCSC
@test ADTypes.hessian_sparsity(x -> sum(abs2, diff(x)), rand(10), sd) isa SparseMatrixCSC

0 comments on commit 79cf7b9

Please sign in to comment.