Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve HessianTracer performance #45

Merged
merged 5 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using ADTypes: ADTypes
import SparseArrays: sparse
import Random: rand, AbstractRNG, SamplerType

include("sortedvector.jl")
include("tracers.jl")
include("conversion.jl")
include("operators.jl")
Expand All @@ -12,7 +13,6 @@ include("overload_jacobian.jl")
include("overload_hessian.jl")
include("pattern.jl")
include("adtypes.jl")
include("sortedvector.jl")

export ConnectivityTracer, connectivity_pattern
export JacobianTracer, jacobian_pattern
Expand Down
4 changes: 2 additions & 2 deletions src/pattern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ function hessian_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S}
end

function hessian_pattern_to_mat(
xt::AbstractArray{HessianTracer{S}}, yt::HessianTracer{S}
) where {S}
xt::AbstractArray{HessianTracer{S,T}}, yt::HessianTracer{S,T}
) where {S,T}
gdalle marked this conversation as resolved.
Show resolved Hide resolved
# Allocate Hessian matrix
n = length(xt)
I = UInt64[] # row indices
Expand Down
18 changes: 7 additions & 11 deletions src/sortedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ A wrapper for sorted vectors, designed for fast unions.

# Constructor

SortedVector(data::AbstractVector; already_sorted=false)
SortedVector(data::AbstractVector; sorted=false)
gdalle marked this conversation as resolved.
Show resolved Hide resolved

# Example

```jldoctest
x = SortedVector([3, 4, 2])
x = SortedVector([1, 3, 5]; already_sorted=true)
x = SortedVector([1, 3, 5]; sorted=true)
z = union(x, y)

# output
Expand All @@ -22,8 +22,9 @@ SortedVector([1, 2, 3, 4, 5])
struct SortedVector{T<:Number} <: AbstractVector{T}
data::Vector{T}

function SortedVector{T}(data::AbstractVector{T}) where {T}
return new{T}(convert(Vector{T}, data))
function SortedVector{T}(data::AbstractVector{T}; sorted=false) where {T}
sorted_data = ifelse(sorted, data, sort(data))
return new{T}(convert(Vector{T}, sorted_data))
end
gdalle marked this conversation as resolved.
Show resolved Hide resolved

function SortedVector{T}(x::Number) where {T}
Expand All @@ -35,13 +36,8 @@ struct SortedVector{T<:Number} <: AbstractVector{T}
end
end

function SortedVector(data::AbstractVector{T}; already_sorted=false) where {T}
sorted_data = ifelse(already_sorted, data, sort(data))
return SortedVector{T}(sorted_data)
end

function Base.convert(::Type{SortedVector{T}}, v::Vector{T}) where {T}
return SortedVector(v; already_sorted=false)
return SortedVector{T}(v; sorted=false)
end

Base.eltype(::SortedVector{T}) where {T} = T
Expand Down Expand Up @@ -82,5 +78,5 @@ function Base.union(v1::SortedVector{T}, v2::SortedVector{T}) where {T}
result_index += 1
end
resize!(result, result_index - 1)
return SortedVector(result; already_sorted=true)
return SortedVector{T}(result; sorted=true)
end
125 changes: 78 additions & 47 deletions src/tracers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,23 @@ end
empty(::Type{ConnectivityTracer{S}}) where {S} = ConnectivityTracer(S())

# Performance can be gained by not re-allocating empty tracers
const EMPTY_CONNECTIVITY_TRACER_BITSET = ConnectivityTracer(BitSet())
const EMPTY_CONNECTIVITY_TRACER_SET_UINT8 = ConnectivityTracer(Set{UInt8}())
const EMPTY_CONNECTIVITY_TRACER_SET_UINT16 = ConnectivityTracer(Set{UInt16}())
const EMPTY_CONNECTIVITY_TRACER_SET_UINT32 = ConnectivityTracer(Set{UInt32}())
const EMPTY_CONNECTIVITY_TRACER_SET_UINT64 = ConnectivityTracer(Set{UInt64}())

empty(::Type{ConnectivityTracer{BitSet}}) = EMPTY_CONNECTIVITY_TRACER_BITSET
empty(::Type{ConnectivityTracer{Set{UInt8}}}) = EMPTY_CONNECTIVITY_TRACER_SET_UINT8
empty(::Type{ConnectivityTracer{Set{UInt16}}}) = EMPTY_CONNECTIVITY_TRACER_SET_UINT16
empty(::Type{ConnectivityTracer{Set{UInt32}}}) = EMPTY_CONNECTIVITY_TRACER_SET_UINT32
empty(::Type{ConnectivityTracer{Set{UInt64}}}) = EMPTY_CONNECTIVITY_TRACER_SET_UINT64
const EMPTY_CONNECTIVITY_TRACER_BITSET = ConnectivityTracer(BitSet())
const EMPTY_CONNECTIVITY_TRACER_SET_U8 = ConnectivityTracer(Set{UInt8}())
const EMPTY_CONNECTIVITY_TRACER_SET_U16 = ConnectivityTracer(Set{UInt16}())
const EMPTY_CONNECTIVITY_TRACER_SET_U32 = ConnectivityTracer(Set{UInt32}())
const EMPTY_CONNECTIVITY_TRACER_SET_U64 = ConnectivityTracer(Set{UInt64}())
const EMPTY_CONNECTIVITY_TRACER_SV_U16 = ConnectivityTracer(SortedVector{UInt16}())
const EMPTY_CONNECTIVITY_TRACER_SV_U32 = ConnectivityTracer(SortedVector{UInt32}())
const EMPTY_CONNECTIVITY_TRACER_SV_U64 = ConnectivityTracer(SortedVector{UInt64}())

empty(::Type{ConnectivityTracer{BitSet}}) = EMPTY_CONNECTIVITY_TRACER_BITSET
empty(::Type{ConnectivityTracer{Set{UInt8}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U8
empty(::Type{ConnectivityTracer{Set{UInt16}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U16
empty(::Type{ConnectivityTracer{Set{UInt32}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U32
empty(::Type{ConnectivityTracer{Set{UInt64}}}) = EMPTY_CONNECTIVITY_TRACER_SET_U64
empty(::Type{ConnectivityTracer{SortedVector{UInt16}}}) = EMPTY_CONNECTIVITY_TRACER_SV_U16
empty(::Type{ConnectivityTracer{SortedVector{UInt32}}}) = EMPTY_CONNECTIVITY_TRACER_SV_U32
empty(::Type{ConnectivityTracer{SortedVector{UInt64}}}) = EMPTY_CONNECTIVITY_TRACER_SV_U64

# We have to be careful when defining constructors:
# Generic code expecting "regular" numbers `x` will sometimes convert them
Expand Down Expand Up @@ -88,17 +94,23 @@ end
empty(::Type{JacobianTracer{S}}) where {S} = JacobianTracer(S())

# Performance can be gained by not re-allocating empty tracers
const EMPTY_JACOBIAN_TRACER_BITSET = JacobianTracer(BitSet())
const EMPTY_JACOBIAN_TRACER_SET_UINT8 = JacobianTracer(Set{UInt8}())
const EMPTY_JACOBIAN_TRACER_SET_UINT16 = JacobianTracer(Set{UInt16}())
const EMPTY_JACOBIAN_TRACER_SET_UINT32 = JacobianTracer(Set{UInt32}())
const EMPTY_JACOBIAN_TRACER_SET_UINT64 = JacobianTracer(Set{UInt64}())

empty(::Type{JacobianTracer{BitSet}}) = EMPTY_JACOBIAN_TRACER_BITSET
empty(::Type{JacobianTracer{Set{UInt8}}}) = EMPTY_JACOBIAN_TRACER_SET_UINT8
empty(::Type{JacobianTracer{Set{UInt16}}}) = EMPTY_JACOBIAN_TRACER_SET_UINT16
empty(::Type{JacobianTracer{Set{UInt32}}}) = EMPTY_JACOBIAN_TRACER_SET_UINT32
empty(::Type{JacobianTracer{Set{UInt64}}}) = EMPTY_JACOBIAN_TRACER_SET_UINT64
const EMPTY_JACOBIAN_TRACER_BITSET = JacobianTracer(BitSet())
const EMPTY_JACOBIAN_TRACER_SET_U8 = JacobianTracer(Set{UInt8}())
const EMPTY_JACOBIAN_TRACER_SET_U16 = JacobianTracer(Set{UInt16}())
const EMPTY_JACOBIAN_TRACER_SET_U32 = JacobianTracer(Set{UInt32}())
const EMPTY_JACOBIAN_TRACER_SET_U64 = JacobianTracer(Set{UInt64}())
const EMPTY_JACOBIAN_TRACER_SV_U16 = JacobianTracer(SortedVector{UInt16}())
const EMPTY_JACOBIAN_TRACER_SV_U32 = JacobianTracer(SortedVector{UInt32}())
const EMPTY_JACOBIAN_TRACER_SV_U64 = JacobianTracer(SortedVector{UInt64}())

empty(::Type{JacobianTracer{BitSet}}) = EMPTY_JACOBIAN_TRACER_BITSET
empty(::Type{JacobianTracer{Set{UInt8}}}) = EMPTY_JACOBIAN_TRACER_SET_U8
empty(::Type{JacobianTracer{Set{UInt16}}}) = EMPTY_JACOBIAN_TRACER_SET_U16
empty(::Type{JacobianTracer{Set{UInt32}}}) = EMPTY_JACOBIAN_TRACER_SET_U32
empty(::Type{JacobianTracer{Set{UInt64}}}) = EMPTY_JACOBIAN_TRACER_SET_U64
empty(::Type{JacobianTracer{SortedVector{UInt16}}}) = EMPTY_JACOBIAN_TRACER_SV_U16
empty(::Type{JacobianTracer{SortedVector{UInt32}}}) = EMPTY_JACOBIAN_TRACER_SV_U32
empty(::Type{JacobianTracer{SortedVector{UInt64}}}) = EMPTY_JACOBIAN_TRACER_SV_U64

JacobianTracer{S}(::Number) where {S} = empty(JacobianTracer{S})
JacobianTracer(t::JacobianTracer) = t
Expand All @@ -120,8 +132,8 @@ $SET_TYPE_MESSAGE

For a higher-level interface, refer to [`hessian_pattern`](@ref).
"""
struct HessianTracer{S} <: AbstractTracer
inputs::Dict{UInt64,S}
struct HessianTracer{S,I<:Integer} <: AbstractTracer
gdalle marked this conversation as resolved.
Show resolved Hide resolved
gdalle marked this conversation as resolved.
Show resolved Hide resolved
inputs::Dict{I,S}
gdalle marked this conversation as resolved.
Show resolved Hide resolved
end
function Base.show(io::IO, t::HessianTracer{S}) where {S}
println(io, "HessianTracer{", S, "}(")
Expand All @@ -133,31 +145,45 @@ function Base.show(io::IO, t::HessianTracer{S}) where {S}
return print(io, ")")
end

function empty(::Type{HessianTracer{S}}) where {S}
return HessianTracer(Dict{UInt64,S}())
function empty(::Type{HessianTracer{S,I}}) where {S,I}
return HessianTracer(Dict{I,S}())
end

# Performance can be gained by not re-allocating empty tracers
const EMPTY_HESSIAN_TRACER_BITSET = HessianTracer(Dict{UInt64,BitSet}())
const EMPTY_HESSIAN_TRACER_SET_UINT8 = HessianTracer(Dict{UInt64,Set{UInt8}}())
const EMPTY_HESSIAN_TRACER_SET_UINT16 = HessianTracer(Dict{UInt64,Set{UInt16}}())
const EMPTY_HESSIAN_TRACER_SET_UINT32 = HessianTracer(Dict{UInt64,Set{UInt32}}())
const EMPTY_HESSIAN_TRACER_SET_UINT64 = HessianTracer(Dict{UInt64,Set{UInt64}}())

empty(::Type{HessianTracer{BitSet}}) = EMPTY_HESSIAN_TRACER_BITSET
empty(::Type{HessianTracer{Set{UInt8}}}) = EMPTY_HESSIAN_TRACER_SET_UINT8
empty(::Type{HessianTracer{Set{UInt16}}}) = EMPTY_HESSIAN_TRACER_SET_UINT16
empty(::Type{HessianTracer{Set{UInt32}}}) = EMPTY_HESSIAN_TRACER_SET_UINT32
empty(::Type{HessianTracer{Set{UInt64}}}) = EMPTY_HESSIAN_TRACER_SET_UINT64

HessianTracer{S}(::Number) where {S} = empty(HessianTracer{S})
const EMPTY_HESSIAN_TRACER_BITSET = HessianTracer(Dict{Int,BitSet}())
const EMPTY_HESSIAN_TRACER_SET_U8 = HessianTracer(Dict{UInt8,Set{UInt8}}())
const EMPTY_HESSIAN_TRACER_SET_U16 = HessianTracer(Dict{UInt16,Set{UInt16}}())
const EMPTY_HESSIAN_TRACER_SET_U32 = HessianTracer(Dict{UInt32,Set{UInt32}}())
const EMPTY_HESSIAN_TRACER_SET_U64 = HessianTracer(Dict{UInt64,Set{UInt64}}())
const EMPTY_HESSIAN_TRACER_SV_U16 = HessianTracer(Dict{UInt16,SortedVector{UInt16}}())
const EMPTY_HESSIAN_TRACER_SV_U32 = HessianTracer(Dict{UInt32,SortedVector{UInt32}}())
const EMPTY_HESSIAN_TRACER_SV_U64 = HessianTracer(Dict{UInt64,SortedVector{UInt64}}())

empty(::Type{HessianTracer{BitSet,Int}}) = EMPTY_HESSIAN_TRACER_BITSET
empty(::Type{HessianTracer{Set{UInt8},UInt8}}) = EMPTY_HESSIAN_TRACER_SET_U8
empty(::Type{HessianTracer{Set{UInt16},UInt16}}) = EMPTY_HESSIAN_TRACER_SET_U16
empty(::Type{HessianTracer{Set{UInt32},UInt32}}) = EMPTY_HESSIAN_TRACER_SET_U32
empty(::Type{HessianTracer{Set{UInt64},UInt64}}) = EMPTY_HESSIAN_TRACER_SET_U64
empty(::Type{HessianTracer{SortedVector{UInt16},UInt16}}) = EMPTY_HESSIAN_TRACER_SV_U16
empty(::Type{HessianTracer{SortedVector{UInt32},UInt32}}) = EMPTY_HESSIAN_TRACER_SV_U32
empty(::Type{HessianTracer{SortedVector{UInt64},UInt64}}) = EMPTY_HESSIAN_TRACER_SV_U64

HessianTracer{S,I}(::Number) where {S,I} = empty(HessianTracer{S,I})
HessianTracer(t::HessianTracer) = t

function keys2set(::Type{S}, d::Dict{I}) where {I<:Integer,S<:AbstractSet{<:I}}
return S(keys(d))
end
function keys2set(::Type{S}, d::Dict{I}) where {I<:Integer,S<:SortedVector{I}}
return S(collect(keys(d)); sorted=false)
end
gdalle marked this conversation as resolved.
Show resolved Hide resolved

# Turn first-order interactions into second-order interactions
function promote_order(t::HessianTracer)
function promote_order(t::HessianTracer{S}) where {S}
d = deepcopy(t.inputs)
s = keys2set(S, d)
for (k, v) in pairs(d)
d[k] = union(v, keys(d)) # works by not being clever with symmetry
d[k] = union(v, s) # ignores symmetry
end
return HessianTracer(d)
end
Expand All @@ -168,15 +194,18 @@ function additive_merge(a::HessianTracer, b::HessianTracer)
end

# Merge first- and second-order terms in a "distributive" fashion
function distributive_merge(a::HessianTracer, b::HessianTracer)
function distributive_merge(a::HessianTracer{S}, b::HessianTracer{S}) where {S}
da = deepcopy(a.inputs)
db = deepcopy(b.inputs)
# add second-order interaction term, works by not being clever with symmetry
sa = keys2set(S, da)
sb = keys2set(S, db)

# add second-order interaction term by ignoring symmetry
for (ka, va) in pairs(da)
da[ka] = union(va, keys(db))
da[ka] = union(va, sb)
end
for (kb, vb) in pairs(db)
db[kb] = union(vb, keys(da))
db[kb] = union(vb, sa)
gdalle marked this conversation as resolved.
Show resolved Hide resolved
end
return HessianTracer(merge(da, db))
end
Expand Down Expand Up @@ -204,7 +233,8 @@ function tracer(::Type{ConnectivityTracer{S}}, index::Integer) where {S}
return ConnectivityTracer(S(index))
end
function tracer(::Type{HessianTracer{S}}, index::Integer) where {S}
return HessianTracer(Dict{UInt64,S}(index => S()))
I = eltype(S)
return HessianTracer{S,I}(Dict{I,S}(index => S()))
end

function tracer(::Type{JacobianTracer{S}}, inds::NTuple{N,<:Integer}) where {N,S}
Expand All @@ -214,5 +244,6 @@ function tracer(::Type{ConnectivityTracer{S}}, inds::NTuple{N,<:Integer}) where
return ConnectivityTracer{S}(S(inds))
end
function tracer(::Type{HessianTracer{S}}, inds::NTuple{N,<:Integer}) where {N,S}
return HessianTracer{S}(Dict{UInt64,S}(i => S() for i in inds))
I = eltype(S)
return HessianTracer{S,I}(Dict{I,S}(i => S() for i in inds))
end
19 changes: 11 additions & 8 deletions test/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using NNlib: conv

include("brusselator.jl")

function benchmark_brusselator(N::Integer, method=:tracer)
function benchmark_brusselator(N::Integer, method=:tracer_bitset)
gdalle marked this conversation as resolved.
Show resolved Hide resolved
dims = (N, N, 2)
A = 1.0
B = 1.0
Expand All @@ -28,9 +28,9 @@ function benchmark_brusselator(N::Integer, method=:tracer)
end
end

function benchmark_conv(method=:tracer)
x = rand(28, 28, 3, 1) # WHCN image
w = rand(5, 5, 3, 16) # corresponds to Conv((5, 5), 3 => 16)
function benchmark_conv(N, method=:tracer_bitset)
x = rand(N, N, 3, 1) # WHCN image
w = rand(5, 5, 3, 2) # corresponds to Conv((5, 5), 3 => 2)
f(x) = conv(x, w)

if method == :tracer_bitset
Expand All @@ -52,7 +52,10 @@ for N in (6, 24, 100)
end

## Run conv benchmarks
@info "Benchmarking NNlib.conv with tracer..."
# Symbolics fails on this example
b = benchmark_conv(:tracer)
display(b)
for N in (28, 224)
for method in (:tracer_bitset, :tracer_sortedvector) # Symbolics fails on this example
@info "Benchmarking NNlib.conv on image of size ($N, $N, 3) with with $method..."
b = benchmark_conv(N, method)
display(b)
end
end
4 changes: 2 additions & 2 deletions test/sortedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ using Test
k2 in (0, 10, 100, 1000)

for _ in 1:100
x = SortedVector(rand(T(1):T(1000), k1); already_sorted=false)
y = SortedVector(sort(rand(T(1):T(1000), k2)); already_sorted=true)
x = SortedVector{T}(rand(T(1):T(1000), k1); sorted=false)
y = SortedVector{T}(sort(rand(T(1):T(1000), k2)); sorted=true)
z = union(x, y)
@test eltype(z) == T
@test issorted(z.data)
Expand Down
Loading