From 8b11d3117a3270939fde165dc9fbd9efccf176d7 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Tue, 9 Apr 2024 16:42:18 +0200 Subject: [PATCH] Add constant empty tracer to avoid allocating sets (#7) --- src/SparseConnectivityTracer.jl | 4 ++-- src/conversion.jl | 10 +++++----- src/operators.jl | 11 +++++------ src/tracer.jl | 12 ++++++------ 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index b9de0f2e..5a08f71c 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -7,8 +7,8 @@ include("conversion.jl") include("operators.jl") include("connectivity.jl") -export Tracer, tracer -export trace_input +export Tracer +export tracer, trace_input export inputs, sortedinputs export connectivity diff --git a/src/conversion.jl b/src/conversion.jl index eba952bb..b7ba66be 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -6,15 +6,15 @@ Base.big(::Type{Tracer}) = Tracer Base.widen(::Type{Tracer}) = Tracer Base.widen(t::Tracer) = t -Base.convert(::Type{Tracer}, x::Number) = tracer() +Base.convert(::Type{Tracer}, x::Number) = EMPTY_TRACER Base.convert(::Type{Tracer}, t::Tracer) = t Base.convert(::Type{<:Number}, t::Tracer) = t ## Array constructors -Base.zero(::Tracer) = tracer() -Base.zero(::Type{Tracer}) = tracer() -Base.one(::Tracer) = tracer() -Base.one(::Type{Tracer}) = tracer() +Base.zero(::Tracer) = EMPTY_TRACER +Base.zero(::Type{Tracer}) = EMPTY_TRACER +Base.one(::Tracer) = EMPTY_TRACER +Base.one(::Type{Tracer}) = EMPTY_TRACER Base.similar(a::Array{Tracer,1}) = zeros(Tracer, size(a, 1)) Base.similar(a::Array{Tracer,2}) = zeros(Tracer, size(a, 1), size(a, 2)) diff --git a/src/operators.jl b/src/operators.jl index b05a872c..cde0a784 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -35,13 +35,12 @@ ops_1_to_1 = ( # exponentials :exp, :exp2, :exp10, :expm1, :log, :log2, :log10, :log1p, - :abs, :abs2, # roots :sqrt, :cbrt, # absolute values :abs, :abs2, # rounding - :floor, :ceil, :trunc, + :round, :floor, :ceil, :trunc, # other :inv, :signbit, :hypot, :sign, :mod2pi ) @@ -65,13 +64,13 @@ for fn in ops_1_to_2 end for fn in ops_2_to_1 - @eval Base.$fn(a::Tracer, b::Tracer) = tracer(a, b) + @eval Base.$fn(a::Tracer, b::Tracer) = uniontracer(a, b) @eval Base.$fn(t::Tracer, ::Number) = t @eval Base.$fn(::Number, t::Tracer) = t end # Extra types required for exponent -Base.:^(a::Tracer, b::Tracer) = tracer(a, b) +Base.:^(a::Tracer, b::Tracer) = uniontracer(a, b) for T in (:Real, :Integer, :Rational) @eval Base.:^(t::Tracer, ::$T) = t @eval Base.:^(::$T, t::Tracer) = t @@ -81,11 +80,11 @@ Base.:^(::Irrational{:ℯ}, t::Tracer) = t ## Precision operators create empty Tracer for fn in (:eps, :nextfloat, :floatmin, :floatmax, :maxintfloat, :typemax) - @eval Base.$fn(::Tracer) = tracer() + @eval Base.$fn(::Tracer) = EMPTY_TRACER end ## Rounding Base.round(t::Tracer, ::RoundingMode; kwargs...) = t ## Random numbers -rand(::AbstractRNG, ::SamplerType{Tracer}) = tracer() +rand(::AbstractRNG, ::SamplerType{Tracer}) = EMPTY_TRACER diff --git a/src/tracer.jl b/src/tracer.jl index 88855ae9..1c8f1a8c 100644 --- a/src/tracer.jl +++ b/src/tracer.jl @@ -77,13 +77,16 @@ struct Tracer <: Number inputs::Set{UInt64} # indices of connected, enumerated inputs end +const EMPTY_TRACER = Tracer(Set{UInt64}()) + # We have to be careful when defining constructors: # Generic code expecting "regular" numbers `x` will sometimes convert them # by calling `T(x)` (instead of `convert(T, x)`), where `T` can be `Tracer`. # When this happens, we create a new empty tracer with no input connectivity. -Tracer(::Number) = tracer() +Tracer(::Number) = EMPTY_TRACER Tracer(t::Tracer) = t -# We therefore exclusively use the lower-case `tracer` for convenience constructors + +uniontracer(a::Tracer, b::Tracer) = Tracer(union(a.inputs, b.inputs)) """ tracer(index) @@ -91,10 +94,7 @@ Tracer(t::Tracer) = t Convenience constructor for [`Tracer`](@ref) from input indices. """ -tracer() = Tracer(Set{UInt64}()) -tracer(a::Tracer, b::Tracer) = Tracer(union(a.inputs, b.inputs)) - -tracer(index::Integer) = Tracer(Set{UInt64}(index)) +tracer(index::Integer) = Tracer(Set{UInt64}(index)) tracer(inds::NTuple{N,<:Integer}) where {N} = Tracer(Set{UInt64}(inds)) tracer(inds...) = tracer(inds)