Skip to content

Commit

Permalink
show(io, float) optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
green-nsk committed Jul 2, 2021
1 parent 4e6eb22 commit 9dbc3f9
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 24 deletions.
44 changes: 25 additions & 19 deletions base/ryu/Ryu.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Ryu

import .Base: significand_bits, significand_mask, exponent_bits, exponent_mask, exponent_bias, exponent_max, uinttype
import .Base: with_scratch

include("utils.jl")
include("shortest.jl")
Expand Down Expand Up @@ -45,9 +46,10 @@ function writeshortest(x::T,
decchar::UInt8=UInt8('.'),
typed::Bool=false,
compact::Bool=false) where {T <: Base.IEEEFloat}
buf = Base.StringVector(neededdigits(T))
pos = writeshortest(buf, 1, x, plus, space, hash, precision, expchar, padexp, decchar, typed, compact)
return String(resize!(buf, pos - 1))
with_scratch(neededdigits(T)) do buf
pos = writeshortest(buf, 1, x, plus, space, hash, precision, expchar, padexp, decchar, typed, compact)
return String(@inbounds view(buf, 1:pos - 1))
end
end

"""
Expand All @@ -73,9 +75,10 @@ function writefixed(x::T,
hash::Bool=false,
decchar::UInt8=UInt8('.'),
trimtrailingzeros::Bool=false) where {T <: Base.IEEEFloat}
buf = Base.StringVector(precision + neededdigits(T))
pos = writefixed(buf, 1, x, precision, plus, space, hash, decchar, trimtrailingzeros)
return String(resize!(buf, pos - 1))
with_scratch(precision + neededdigits(T)) do buf
pos = writefixed(buf, 1, x, precision, plus, space, hash, decchar, trimtrailingzeros)
return String(@inbounds view(buf, 1:pos - 1))
end
end

"""
Expand Down Expand Up @@ -103,26 +106,29 @@ function writeexp(x::T,
expchar::UInt8=UInt8('e'),
decchar::UInt8=UInt8('.'),
trimtrailingzeros::Bool=false) where {T <: Base.IEEEFloat}
buf = Base.StringVector(precision + neededdigits(T))
pos = writeexp(buf, 1, x, precision, plus, space, hash, expchar, decchar, trimtrailingzeros)
return String(resize!(buf, pos - 1))
with_scratch(precision + neededdigits(T)) do buf
pos = writeexp(buf, 1, x, precision, plus, space, hash, expchar, decchar, trimtrailingzeros)
return String(@inbounds view(buf, 1:pos - 1))
end
end

function Base.show(io::IO, x::T, forceuntyped::Bool=false, fromprint::Bool=false) where {T <: Base.IEEEFloat}
compact = get(io, :compact, false)::Bool
buf = Base.StringVector(neededdigits(T))
typed = !forceuntyped && !compact && get(io, :typeinfo, Any) != typeof(x)
pos = writeshortest(buf, 1, x, false, false, true, -1,
(x isa Float32 && !fromprint) ? UInt8('f') : UInt8('e'), false, UInt8('.'), typed, compact)
write(io, resize!(buf, pos - 1))
return
with_scratch(neededdigits(T)) do buf
typed = !forceuntyped && !compact && get(io, :typeinfo, Any) != typeof(x)
pos = writeshortest(buf, 1, x, false, false, true, -1,
(x isa Float32 && !fromprint) ? UInt8('f') : UInt8('e'), false, UInt8('.'), typed, compact)
write(io, @inbounds view(buf, 1:pos - 1))
return
end
end

function Base.string(x::T) where {T <: Base.IEEEFloat}
buf = Base.StringVector(neededdigits(T))
pos = writeshortest(buf, 1, x, false, false, true, -1,
UInt8('e'), false, UInt8('.'), false, false)
return String(resize!(buf, pos - 1))
with_scratch(neededdigits(T)) do buf
pos = writeshortest(buf, 1, x, false, false, true, -1,
UInt8('e'), false, UInt8('.'), false, false)
return String(@inbounds view(buf, 1:pos - 1))
end
end

Base.print(io::IO, x::Union{Float16, Float32}) = show(io, x, true, true)
Expand Down
2 changes: 1 addition & 1 deletion base/ryu/shortest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ integer. If a `maxsignif` argument is provided, then `b < maxsignif`.
return b, e10
end

function writeshortest(buf::Vector{UInt8}, pos, x::T,
function writeshortest(buf, pos, x::T,
plus=false, space=false, hash=true,
precision=-1, expchar=UInt8('e'), padexp=false, decchar=UInt8('.'),
typed=false, compact=false) where {T}
Expand Down
38 changes: 34 additions & 4 deletions base/scratch.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const PRINT_SCRATCH_LENGTH = 256
# NOTE: ryu/Ryu.jl: neededdigits(::Type{Float64}) = 309 + 17
const PRINT_SCRATCH_LENGTH = 309 + 17
struct ScratchBuf
data::NTuple{PRINT_SCRATCH_LENGTH,UInt8}

Expand All @@ -11,16 +12,45 @@ struct Scratch
end

function with_scratch(f, n)
buf = Ref(ScratchBuf(undef))
GC.@preserve buf f(Scratch(pointer_from_objref(buf), unsafe_trunc(UInt64, n)))
if n <= PRINT_SCRATCH_LENGTH
buf = Ref(ScratchBuf(undef))
GC.@preserve buf f(Scratch(pointer_from_objref(buf), unsafe_trunc(UInt64, n)))
else
tls = task_local_storage()
buf = get!(tls, :PRINT_SCRATCH) do
Vector{UInt8}(undef, n)
end::Vector{UInt8}
resize!(buf, n)
GC.@preserve buf f(Scratch(pointer(buf), length(buf)))
end
end

@propagate_inbounds function setindex!(a::Scratch, v, i)
@boundscheck (i <= a.length || throw(BoundsError(a, i)))
@boundscheck (1 <= i <= a.length || throw(BoundsError(a, i)))
unsafe_store!(a.p, convert(UInt8, v), i)
a
end

@propagate_inbounds function getindex(a::Scratch, i)
@boundscheck (1 <= i <= a.length || throw(BoundsError(a, i)))
unsafe_load(a.p, i)
end

@propagate_inbounds function view(a::Scratch, index::UnitRange)
@boundscheck ((index.start > 0 && index.stop <= a.length) || throw(BoundsError))
index.start > index.stop && return Scratch(Ptr{UInt8}(), 0)
Scratch(a.p + index.start - 1, unsafe_trunc(UInt64, index.stop - index.start + 1))
end

pointer(a::Scratch) = a.p
pointer(a::Scratch, pos) = a.p + pos - 1
length(a::Scratch) = a.length
@inline function unsafe_copyto!(dst::Scratch, dst_pos, src::Vector{UInt8}, src_pos, N)
GC.@preserve dst, src, begin
unsafe_copyto!(pointer(dst, dst_pos), pointer(src, src_pos), N)
end
end

write(io::IO, a::Scratch) = unsafe_write(io, a.p, a.length)

String(s::Scratch) = unsafe_string(s.p, s.length)

0 comments on commit 9dbc3f9

Please sign in to comment.