Skip to content

Commit

Permalink
Add constant empty tracer to avoid allocating sets (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill authored Apr 9, 2024
1 parent 2b29fbd commit 8b11d31
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 19 deletions.
4 changes: 2 additions & 2 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ include("conversion.jl")
include("operators.jl")
include("connectivity.jl")

export Tracer, tracer
export trace_input
export Tracer
export tracer, trace_input
export inputs, sortedinputs
export connectivity

Expand Down
10 changes: 5 additions & 5 deletions src/conversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ Base.big(::Type{Tracer}) = Tracer
Base.widen(::Type{Tracer}) = Tracer
Base.widen(t::Tracer) = t

Base.convert(::Type{Tracer}, x::Number) = tracer()
Base.convert(::Type{Tracer}, x::Number) = EMPTY_TRACER
Base.convert(::Type{Tracer}, t::Tracer) = t
Base.convert(::Type{<:Number}, t::Tracer) = t

## Array constructors
Base.zero(::Tracer) = tracer()
Base.zero(::Type{Tracer}) = tracer()
Base.one(::Tracer) = tracer()
Base.one(::Type{Tracer}) = tracer()
Base.zero(::Tracer) = EMPTY_TRACER
Base.zero(::Type{Tracer}) = EMPTY_TRACER
Base.one(::Tracer) = EMPTY_TRACER
Base.one(::Type{Tracer}) = EMPTY_TRACER

Base.similar(a::Array{Tracer,1}) = zeros(Tracer, size(a, 1))
Base.similar(a::Array{Tracer,2}) = zeros(Tracer, size(a, 1), size(a, 2))
Expand Down
11 changes: 5 additions & 6 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,12 @@ ops_1_to_1 = (
# exponentials
:exp, :exp2, :exp10, :expm1,
:log, :log2, :log10, :log1p,
:abs, :abs2,
# roots
:sqrt, :cbrt,
# absolute values
:abs, :abs2,
# rounding
:floor, :ceil, :trunc,
:round, :floor, :ceil, :trunc,
# other
:inv, :signbit, :hypot, :sign, :mod2pi
)
Expand All @@ -65,13 +64,13 @@ for fn in ops_1_to_2
end

for fn in ops_2_to_1
@eval Base.$fn(a::Tracer, b::Tracer) = tracer(a, b)
@eval Base.$fn(a::Tracer, b::Tracer) = uniontracer(a, b)
@eval Base.$fn(t::Tracer, ::Number) = t
@eval Base.$fn(::Number, t::Tracer) = t
end

# Extra types required for exponent
Base.:^(a::Tracer, b::Tracer) = tracer(a, b)
Base.:^(a::Tracer, b::Tracer) = uniontracer(a, b)
for T in (:Real, :Integer, :Rational)
@eval Base.:^(t::Tracer, ::$T) = t
@eval Base.:^(::$T, t::Tracer) = t
Expand All @@ -81,11 +80,11 @@ Base.:^(::Irrational{:ℯ}, t::Tracer) = t

## Precision operators create empty Tracer
for fn in (:eps, :nextfloat, :floatmin, :floatmax, :maxintfloat, :typemax)
@eval Base.$fn(::Tracer) = tracer()
@eval Base.$fn(::Tracer) = EMPTY_TRACER
end

## Rounding
Base.round(t::Tracer, ::RoundingMode; kwargs...) = t

## Random numbers
rand(::AbstractRNG, ::SamplerType{Tracer}) = tracer()
rand(::AbstractRNG, ::SamplerType{Tracer}) = EMPTY_TRACER
12 changes: 6 additions & 6 deletions src/tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,24 @@ struct Tracer <: Number
inputs::Set{UInt64} # indices of connected, enumerated inputs
end

const EMPTY_TRACER = Tracer(Set{UInt64}())

# We have to be careful when defining constructors:
# Generic code expecting "regular" numbers `x` will sometimes convert them
# by calling `T(x)` (instead of `convert(T, x)`), where `T` can be `Tracer`.
# When this happens, we create a new empty tracer with no input connectivity.
Tracer(::Number) = tracer()
Tracer(::Number) = EMPTY_TRACER
Tracer(t::Tracer) = t
# We therefore exclusively use the lower-case `tracer` for convenience constructors

uniontracer(a::Tracer, b::Tracer) = Tracer(union(a.inputs, b.inputs))

"""
tracer(index)
tracer(indices)
Convenience constructor for [`Tracer`](@ref) from input indices.
"""
tracer() = Tracer(Set{UInt64}())
tracer(a::Tracer, b::Tracer) = Tracer(union(a.inputs, b.inputs))

tracer(index::Integer) = Tracer(Set{UInt64}(index))
tracer(index::Integer) = Tracer(Set{UInt64}(index))
tracer(inds::NTuple{N,<:Integer}) where {N} = Tracer(Set{UInt64}(inds))
tracer(inds...) = tracer(inds)

Expand Down

0 comments on commit 8b11d31

Please sign in to comment.