Skip to content

Commit

Permalink
Improve HessianTracer performance (#45)
Browse files Browse the repository at this point in the history
* Sort typed `SortedVector` by default

* Couple `HessianTracer` index to set `eltype`

* Add `empty` constructors for common `SortedVector`s

* Fix NNlib benchmark
  • Loading branch information
adrhill authored May 6, 2024
1 parent bf32fee commit f20d411
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 71 deletions.
2 changes: 1 addition & 1 deletion src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using ADTypes: ADTypes
import SparseArrays: sparse
import Random: rand, AbstractRNG, SamplerType

include("sortedvector.jl")
include("tracers.jl")
include("conversion.jl")
include("operators.jl")
Expand All @@ -12,7 +13,6 @@ include("overload_jacobian.jl")
include("overload_hessian.jl")
include("pattern.jl")
include("adtypes.jl")
include("sortedvector.jl")

export ConnectivityTracer, connectivity_pattern
export JacobianTracer, jacobian_pattern
Expand Down
4 changes: 2 additions & 2 deletions src/pattern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ function hessian_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S}
end

function hessian_pattern_to_mat(
xt::AbstractArray{HessianTracer{S}}, yt::HessianTracer{S}
) where {S}
xt::AbstractArray{HessianTracer{S,T}}, yt::HessianTracer{S,T}
) where {S,T}
# Allocate Hessian matrix
n = length(xt)
I = UInt64[] # row indices
Expand Down
18 changes: 7 additions & 11 deletions src/sortedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ A wrapper for sorted vectors, designed for fast unions.
# Constructor
SortedVector(data::AbstractVector; already_sorted=false)
SortedVector(data::AbstractVector; sorted=false)
# Example
```jldoctest
x = SortedVector([3, 4, 2])
x = SortedVector([1, 3, 5]; already_sorted=true)
x = SortedVector([1, 3, 5]; sorted=true)
z = union(x, y)
# output
Expand All @@ -22,8 +22,9 @@ SortedVector([1, 2, 3, 4, 5])
struct SortedVector{T<:Number} <: AbstractVector{T}
data::Vector{T}

function SortedVector{T}(data::AbstractVector{T}) where {T}
return new{T}(convert(Vector{T}, data))
function SortedVector{T}(data::AbstractVector{T}; sorted=false) where {T}
sorted_data = ifelse(sorted, data, sort(data))
return new{T}(convert(Vector{T}, sorted_data))
end

function SortedVector{T}(x::Number) where {T}
Expand All @@ -35,13 +36,8 @@ struct SortedVector{T<:Number} <: AbstractVector{T}
end
end

function SortedVector(data::AbstractVector{T}; already_sorted=false) where {T}
sorted_data = ifelse(already_sorted, data, sort(data))
return SortedVector{T}(sorted_data)
end

function Base.convert(::Type{SortedVector{T}}, v::Vector{T}) where {T}
return SortedVector(v; already_sorted=false)
return SortedVector{T}(v; sorted=false)
end

Base.eltype(::SortedVector{T}) where {T} = T
Expand Down Expand Up @@ -82,5 +78,5 @@ function Base.union(v1::SortedVector{T}, v2::SortedVector{T}) where {T}
result_index += 1
end
resize!(result, result_index - 1)
return SortedVector(result; already_sorted=true)
return SortedVector{T}(result; sorted=true)
end
125 changes: 78 additions & 47 deletions src/tracers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,23 @@ end
empty(::Type{ConnectivityTracer{S}}) where {S} = ConnectivityTracer(S())

# Performance can be gained by not re-allocating empty tracers
const EMPTY_CONNECTIVITY_TRACER_BITSET = ConnectivityTracer(BitSet())
const EMPTY_CONNECTIVITY_TRACER_SET_UINT8 = ConnectivityTracer(Set{UInt8}())
const EMPTY_CONNECTIVITY_TRACER_SET_UINT16 = ConnectivityTracer(Set{UInt16}())
const EMPTY_CONNECTIVITY_TRACER_SET_UINT32 = ConnectivityTracer(Set{UInt32}())
const EMPTY_CONNECTIVITY_TRACER_SET_UINT64 = ConnectivityTracer(Set{UInt64}())

empty(::Type{ConnectivityTracer{BitSet}}) = EMPTY_CONNECTIVITY_TRACER_BITSET
empty(::Type{ConnectivityTracer{Set{UInt8}}}) = EMPTY_CONNECTIVITY_TRACER_SET_UINT8
empty(::Type{ConnectivityTracer{Set{UInt16}}}) = EMPTY_CONNECTIVITY_TRACER_SET_UINT16
empty(::Type{ConnectivityTracer{Set{UInt32}}}) = EMPTY_CONNECTIVITY_TRACER_SET_UINT32
empty(::Type{ConnectivityTracer{Set{UInt64}}}) = EMPTY_CONNECTIVITY_TRACER_SET_UINT64
const EMPTY_CONNECTIVITY_TRACER_BITSET = ConnectivityTracer(BitSet())
const EMPTY_CONNECTIVITY_TRACER_SET_U8 = ConnectivityTracer(Set{UInt8}())
const EMPTY_CONNECTIVITY_TRACER_SET_U16 = ConnectivityTracer(Set{UInt16}())
const EMPTY_CONNECTIVITY_TRACER_SET_U32 = ConnectivityTracer(Set{UInt32}())
const EMPTY_CONNECTIVITY_TRACER_SET_U64 = ConnectivityTracer(Set{UInt64}())
const EMPTY_CONNECTIVITY_TRACER_SV_U16 = ConnectivityTracer(SortedVector{UInt16}())
const EMPTY_CONNECTIVITY_TRACER_SV_U32 = ConnectivityTracer(SortedVector{UInt32}())
const EMPTY_CONNECTIVITY_TRACER_SV_U64 = ConnectivityTracer(SortedVector{UInt64}())

empty(::Type{ConnectivityTracer{BitSet}}) = EMPTY_CONNECTIVITY_TRACER_BITSET
empty(::Type{ConnectivityTracer{Set{UInt8}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U8
empty(::Type{ConnectivityTracer{Set{UInt16}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U16
empty(::Type{ConnectivityTracer{Set{UInt32}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U32
empty(::Type{ConnectivityTracer{Set{UInt64}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U64
empty(::Type{ConnectivityTracer{SortedVector{UInt16}}}) = EMPTY_CONNECTIVITY_TRACER_SV_U16
empty(::Type{ConnectivityTracer{SortedVector{UInt32}}}) = EMPTY_CONNECTIVITY_TRACER_SV_U32
empty(::Type{ConnectivityTracer{SortedVector{UInt64}}}) = EMPTY_CONNECTIVITY_TRACER_SV_U64

# We have to be careful when defining constructors:
# Generic code expecting "regular" numbers `x` will sometimes convert them
Expand Down Expand Up @@ -88,17 +94,23 @@ end
empty(::Type{JacobianTracer{S}}) where {S} = JacobianTracer(S())

# Performance can be gained by not re-allocating empty tracers
const EMPTY_JACOBIAN_TRACER_BITSET = JacobianTracer(BitSet())
const EMPTY_JACOBIAN_TRACER_SET_UINT8 = JacobianTracer(Set{UInt8}())
const EMPTY_JACOBIAN_TRACER_SET_UINT16 = JacobianTracer(Set{UInt16}())
const EMPTY_JACOBIAN_TRACER_SET_UINT32 = JacobianTracer(Set{UInt32}())
const EMPTY_JACOBIAN_TRACER_SET_UINT64 = JacobianTracer(Set{UInt64}())

empty(::Type{JacobianTracer{BitSet}}) = EMPTY_JACOBIAN_TRACER_BITSET
empty(::Type{JacobianTracer{Set{UInt8}}}) = EMPTY_JACOBIAN_TRACER_SET_UINT8
empty(::Type{JacobianTracer{Set{UInt16}}}) = EMPTY_JACOBIAN_TRACER_SET_UINT16
empty(::Type{JacobianTracer{Set{UInt32}}}) = EMPTY_JACOBIAN_TRACER_SET_UINT32
empty(::Type{JacobianTracer{Set{UInt64}}}) = EMPTY_JACOBIAN_TRACER_SET_UINT64
const EMPTY_JACOBIAN_TRACER_BITSET = JacobianTracer(BitSet())
const EMPTY_JACOBIAN_TRACER_SET_U8 = JacobianTracer(Set{UInt8}())
const EMPTY_JACOBIAN_TRACER_SET_U16 = JacobianTracer(Set{UInt16}())
const EMPTY_JACOBIAN_TRACER_SET_U32 = JacobianTracer(Set{UInt32}())
const EMPTY_JACOBIAN_TRACER_SET_U64 = JacobianTracer(Set{UInt64}())
const EMPTY_JACOBIAN_TRACER_SV_U16 = JacobianTracer(SortedVector{UInt16}())
const EMPTY_JACOBIAN_TRACER_SV_U32 = JacobianTracer(SortedVector{UInt32}())
const EMPTY_JACOBIAN_TRACER_SV_U64 = JacobianTracer(SortedVector{UInt64}())

empty(::Type{JacobianTracer{BitSet}}) = EMPTY_JACOBIAN_TRACER_BITSET
empty(::Type{JacobianTracer{Set{UInt8}}}) = EMPTY_JACOBIAN_TRACER_SET_U8
empty(::Type{JacobianTracer{Set{UInt16}}}) = EMPTY_JACOBIAN_TRACER_SET_U16
empty(::Type{JacobianTracer{Set{UInt32}}}) = EMPTY_JACOBIAN_TRACER_SET_U32
empty(::Type{JacobianTracer{Set{UInt64}}}) = EMPTY_JACOBIAN_TRACER_SET_U64
empty(::Type{JacobianTracer{SortedVector{UInt16}}}) = EMPTY_JACOBIAN_TRACER_SV_U16
empty(::Type{JacobianTracer{SortedVector{UInt32}}}) = EMPTY_JACOBIAN_TRACER_SV_U32
empty(::Type{JacobianTracer{SortedVector{UInt64}}}) = EMPTY_JACOBIAN_TRACER_SV_U64

JacobianTracer{S}(::Number) where {S} = empty(JacobianTracer{S})
JacobianTracer(t::JacobianTracer) = t
Expand All @@ -120,8 +132,8 @@ $SET_TYPE_MESSAGE
For a higher-level interface, refer to [`hessian_pattern`](@ref).
"""
struct HessianTracer{S} <: AbstractTracer
inputs::Dict{UInt64,S}
struct HessianTracer{S,I<:Integer} <: AbstractTracer
inputs::Dict{I,S}
end
function Base.show(io::IO, t::HessianTracer{S}) where {S}
println(io, "HessianTracer{", S, "}(")
Expand All @@ -133,31 +145,45 @@ function Base.show(io::IO, t::HessianTracer{S}) where {S}
return print(io, ")")
end

function empty(::Type{HessianTracer{S}}) where {S}
return HessianTracer(Dict{UInt64,S}())
function empty(::Type{HessianTracer{S,I}}) where {S,I}
return HessianTracer(Dict{I,S}())
end

# Performance can be gained by not re-allocating empty tracers
const EMPTY_HESSIAN_TRACER_BITSET = HessianTracer(Dict{UInt64,BitSet}())
const EMPTY_HESSIAN_TRACER_SET_UINT8 = HessianTracer(Dict{UInt64,Set{UInt8}}())
const EMPTY_HESSIAN_TRACER_SET_UINT16 = HessianTracer(Dict{UInt64,Set{UInt16}}())
const EMPTY_HESSIAN_TRACER_SET_UINT32 = HessianTracer(Dict{UInt64,Set{UInt32}}())
const EMPTY_HESSIAN_TRACER_SET_UINT64 = HessianTracer(Dict{UInt64,Set{UInt64}}())

empty(::Type{HessianTracer{BitSet}}) = EMPTY_HESSIAN_TRACER_BITSET
empty(::Type{HessianTracer{Set{UInt8}}}) = EMPTY_HESSIAN_TRACER_SET_UINT8
empty(::Type{HessianTracer{Set{UInt16}}}) = EMPTY_HESSIAN_TRACER_SET_UINT16
empty(::Type{HessianTracer{Set{UInt32}}}) = EMPTY_HESSIAN_TRACER_SET_UINT32
empty(::Type{HessianTracer{Set{UInt64}}}) = EMPTY_HESSIAN_TRACER_SET_UINT64

HessianTracer{S}(::Number) where {S} = empty(HessianTracer{S})
const EMPTY_HESSIAN_TRACER_BITSET = HessianTracer(Dict{Int,BitSet}())
const EMPTY_HESSIAN_TRACER_SET_U8 = HessianTracer(Dict{UInt8,Set{UInt8}}())
const EMPTY_HESSIAN_TRACER_SET_U16 = HessianTracer(Dict{UInt16,Set{UInt16}}())
const EMPTY_HESSIAN_TRACER_SET_U32 = HessianTracer(Dict{UInt32,Set{UInt32}}())
const EMPTY_HESSIAN_TRACER_SET_U64 = HessianTracer(Dict{UInt64,Set{UInt64}}())
const EMPTY_HESSIAN_TRACER_SV_U16 = HessianTracer(Dict{UInt16,SortedVector{UInt16}}())
const EMPTY_HESSIAN_TRACER_SV_U32 = HessianTracer(Dict{UInt32,SortedVector{UInt32}}())
const EMPTY_HESSIAN_TRACER_SV_U64 = HessianTracer(Dict{UInt64,SortedVector{UInt64}}())

empty(::Type{HessianTracer{BitSet,Int}}) = EMPTY_HESSIAN_TRACER_BITSET
empty(::Type{HessianTracer{Set{UInt8},UInt8}}) = EMPTY_HESSIAN_TRACER_SET_U8
empty(::Type{HessianTracer{Set{UInt16},UInt16}}) = EMPTY_HESSIAN_TRACER_SET_U16
empty(::Type{HessianTracer{Set{UInt32},UInt32}}) = EMPTY_HESSIAN_TRACER_SET_U32
empty(::Type{HessianTracer{Set{UInt64},UInt64}}) = EMPTY_HESSIAN_TRACER_SET_U64
empty(::Type{HessianTracer{SortedVector{UInt16},UInt16}}) = EMPTY_HESSIAN_TRACER_SV_U16
empty(::Type{HessianTracer{SortedVector{UInt32},UInt32}}) = EMPTY_HESSIAN_TRACER_SV_U32
empty(::Type{HessianTracer{SortedVector{UInt64},UInt64}}) = EMPTY_HESSIAN_TRACER_SV_U64

HessianTracer{S,I}(::Number) where {S,I} = empty(HessianTracer{S,I})
HessianTracer(t::HessianTracer) = t

function keys2set(::Type{S}, d::Dict{I}) where {I<:Integer,S<:AbstractSet{<:I}}
return S(keys(d))
end
function keys2set(::Type{S}, d::Dict{I}) where {I<:Integer,S<:SortedVector{I}}
return S(collect(keys(d)); sorted=false)
end

# Turn first-order interactions into second-order interactions
function promote_order(t::HessianTracer)
function promote_order(t::HessianTracer{S}) where {S}
d = deepcopy(t.inputs)
s = keys2set(S, d)
for (k, v) in pairs(d)
d[k] = union(v, keys(d)) # works by not being clever with symmetry
d[k] = union(v, s) # ignores symmetry
end
return HessianTracer(d)
end
Expand All @@ -168,15 +194,18 @@ function additive_merge(a::HessianTracer, b::HessianTracer)
end

# Merge first- and second-order terms in a "distributive" fashion
function distributive_merge(a::HessianTracer, b::HessianTracer)
function distributive_merge(a::HessianTracer{S}, b::HessianTracer{S}) where {S}
da = deepcopy(a.inputs)
db = deepcopy(b.inputs)
# add second-order interaction term, works by not being clever with symmetry
sa = keys2set(S, da)
sb = keys2set(S, db)

# add second-order interaction term by ignoring symmetry
for (ka, va) in pairs(da)
da[ka] = union(va, keys(db))
da[ka] = union(va, sb)
end
for (kb, vb) in pairs(db)
db[kb] = union(vb, keys(da))
db[kb] = union(vb, sa)
end
return HessianTracer(merge(da, db))
end
Expand Down Expand Up @@ -204,7 +233,8 @@ function tracer(::Type{ConnectivityTracer{S}}, index::Integer) where {S}
return ConnectivityTracer(S(index))
end
function tracer(::Type{HessianTracer{S}}, index::Integer) where {S}
return HessianTracer(Dict{UInt64,S}(index => S()))
I = eltype(S)
return HessianTracer{S,I}(Dict{I,S}(index => S()))
end

function tracer(::Type{JacobianTracer{S}}, inds::NTuple{N,<:Integer}) where {N,S}
Expand All @@ -214,5 +244,6 @@ function tracer(::Type{ConnectivityTracer{S}}, inds::NTuple{N,<:Integer}) where
return ConnectivityTracer{S}(S(inds))
end
function tracer(::Type{HessianTracer{S}}, inds::NTuple{N,<:Integer}) where {N,S}
return HessianTracer{S}(Dict{UInt64,S}(i => S() for i in inds))
I = eltype(S)
return HessianTracer{S,I}(Dict{I,S}(i => S() for i in inds))
end
19 changes: 11 additions & 8 deletions test/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using NNlib: conv

include("brusselator_definition.jl")

function benchmark_brusselator(N::Integer, method=:tracer)
function benchmark_brusselator(N::Integer, method=:tracer_bitset)
dims = (N, N, 2)
A = 1.0
B = 1.0
Expand All @@ -28,9 +28,9 @@ function benchmark_brusselator(N::Integer, method=:tracer)
end
end

function benchmark_conv(method=:tracer)
x = rand(28, 28, 3, 1) # WHCN image
w = rand(5, 5, 3, 16) # corresponds to Conv((5, 5), 3 => 16)
function benchmark_conv(N, method=:tracer_bitset)
x = rand(N, N, 3, 1) # WHCN image
w = rand(5, 5, 3, 2) # corresponds to Conv((5, 5), 3 => 2)
f(x) = conv(x, w)

if method == :tracer_bitset
Expand All @@ -52,7 +52,10 @@ for N in (6, 24, 100)
end

## Run conv benchmarks
@info "Benchmarking NNlib.conv with tracer..."
# Symbolics fails on this example
b = benchmark_conv(:tracer)
display(b)
for N in (28, 224)
for method in (:tracer_bitset, :tracer_sortedvector) # Symbolics fails on this example
@info "Benchmarking NNlib.conv on image of size ($N, $N, 3) with with $method..."
b = benchmark_conv(N, method)
display(b)
end
end
4 changes: 2 additions & 2 deletions test/sortedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ using Test
k2 in (0, 10, 100, 1000)

for _ in 1:100
x = SortedVector(rand(T(1):T(1000), k1); already_sorted=false)
y = SortedVector(sort(rand(T(1):T(1000), k2)); already_sorted=true)
x = SortedVector{T}(rand(T(1):T(1000), k1); sorted=false)
y = SortedVector{T}(sort(rand(T(1):T(1000), k2)); sorted=true)
z = union(x, y)
@test eltype(z) == T
@test issorted(z.data)
Expand Down

0 comments on commit f20d411

Please sign in to comment.