Skip to content

Commit

Permalink
Generalize to IterableStatePairs
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobnissen committed Oct 28, 2023
1 parent 8a81cb3 commit 8d2699c
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 48 deletions.
36 changes: 35 additions & 1 deletion base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using .Base:
SizeUnknown, HasLength, HasShape, IsInfinite, EltypeUnknown, HasEltype, OneTo,
@propagate_inbounds, @isdefined, @boundscheck, @inbounds, Generator,
AbstractRange, AbstractUnitRange, UnitRange, LinearIndices, TupleOrBottom,
(:), |, +, -, *, !==, !, ==, !=, <=, <, >, >=, missing,
(:), |, +, -, *, !==, !, ==, !=, <=, <, >, >=, =>, missing,
any, _counttuple, eachindex, ntuple, zero, prod, reduce, in, firstindex, lastindex,
tail, fieldtypes, min, max, minimum, zero, oneunit, promote, promote_shape
using Core: @doc
Expand Down Expand Up @@ -1573,4 +1573,38 @@ only(x::NamedTuple) = throw(
ArgumentError("NamedTuple contains $(length(x)) elements, must contain exactly 1 element")
)

"""
IterableStatePairs(x)
This internal type is returned by [`pairs`](@ref), when the key is the same as
the state of `iterate`. This allows the iterator to determine the key => value
pairs by only calling iterate on the values.
"""
struct IterableStatePairs{T}
x::T
end

IteratorSize(::Type{<:IterableStatePairs{T}}) where T = IteratorSize(T)
length(x::IterableStatePairs) = length(x.x)

function iterate(x::IterableStatePairs, state=first(keys(x.x)))
it = iterate(x.x, state)
it === nothing && return nothing
(state => first(it), last(it))
end

reverse(x::IterableStatePairs) = IterableStatePairs(Iterators.reverse(x.x))
reverse(x::IterableStatePairs{<:Iterators.Reverse}) = IterableStatePairs(x.x.itr)

function iterate(x::IterableStatePairs{<:Iterators.Reverse}, state=last(keys(x.x.itr)))
it = iterate(x.x, state)
it === nothing && return nothing
(state => first(it), last(it))
end

# According to the docs of iterate(::AbstractString), the iteration state must
# be the same as the keys, so this is a valid optimization (see #51631)
pairs(s::AbstractString) = IterableStatePairs(s)

end
2 changes: 1 addition & 1 deletion base/strings/string.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ is_valid_continuation(c) = c & 0xc0 == 0x80
b = @inbounds codeunit(s, i)
u = UInt32(b) << 24
between(b, 0x80, 0xf7) || return reinterpret(Char, u), i+1
return iterate_continued(s, i, u)
return @noinline iterate_continued(s, i, u)
end

# duck-type s so that external UTF-8 string packages like StringViews can hook in
Expand Down
35 changes: 0 additions & 35 deletions base/strings/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1102,38 +1102,3 @@ function Base.rest(s::AbstractString, st...)
end
return String(take!(io))
end

"""
StringPairs{T}(x::AbstractString)
This internal type is an iterator over (key => value) pairs of strings.
"""
struct StringPairs{T <: AbstractString}
x::T
end

StringPairs(x) = StringPairs{typeof(x)}(x)
IteratorSize(::Type{StringPairs{T}}) where T = IteratorSize(T)
length(x::StringPairs) = length(x.x)
pairs(x::AbstractString) = StringPairs(x)

# Generic fallback
function iterate(x::StringPairs, i=firstindex(x.x))
i > ncodeunits(x.x) && return nothing
(i => x.x[i], nextind(x.x, i))
end

# In this method, exploit that string iteration's state is the index
function iterate(
x::StringPairs{<:Union{String, SubString{String}}},
state::Int=firstindex(x.x)
)
(char, i) = @something iterate(x.x, state) return nothing
(state => char, i)
end

# At this moment, Reverse{<:AbstractString} is inefficient, so this simple
# implementation is not easily optimised
function iterate(x::Iterators.Reverse{<:StringPairs}, i=lastindex(x.itr.x))
i < firstindex(x.itr.x) ? nothing : (i => x.itr.x[i], prevind(x.itr.x, i))
end
14 changes: 14 additions & 0 deletions test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1009,3 +1009,17 @@ end
@testset "collect partition substring" begin
@test collect(Iterators.partition(lstrip("01111", '0'), 2)) == ["11", "11"]
end

@testset "IterableStringPairs" begin
for s in ["", "a", "abcde", "γ", "∋γa"]
for T in (String, SubString, GenericString)
sT = T(s)
p = pairs(sT)
@test collect(p) == [k=>v for (k,v) in zip(keys(sT), sT)]
rv = Iterators.reverse(p)
@test collect(rv) == reverse([k=>v for (k,v) in zip(keys(sT), sT)])
rrv = Iterators.reverse(rv)
@test collect(rrv) == collect(p)
end
end
end
11 changes: 0 additions & 11 deletions test/strings/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -724,14 +724,3 @@ end
@test endswith(A, split(B, ' ')[end])
@test endswith(A, 'g')
end

@testset "pairs" begin
for s in ["", "a", "abcde", "γ", "∋γa"]
for T in (String, SubString, GenericString)
sT = T(s)
@test collect(pairs(sT)) == [k=>v for (k,v) in zip(keys(sT), sT)]
rv = Iterators.reverse(pairs(sT))
@test collect(rv) == reverse([k=>v for (k,v) in zip(keys(sT), sT)])
end
end
end

0 comments on commit 8d2699c

Please sign in to comment.