From a6392c509b15fcf29c0263c1dbf446790a733d5c Mon Sep 17 00:00:00 2001 From: denizyuret Date: Wed, 9 Jan 2019 18:30:45 -0500 Subject: [PATCH] fix #30643, correctly propagate iterator traits through Stateful (#30644) --- base/iterators.jl | 5 ++--- test/iterators.jl | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/base/iterators.jl b/base/iterators.jl index 449e8dbb0663d..118cb9b61c5de 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -1094,10 +1094,9 @@ end @inline peek(s::Stateful, sentinel=nothing) = s.nextvalstate !== nothing ? s.nextvalstate[1] : sentinel @inline iterate(s::Stateful, state=nothing) = s.nextvalstate === nothing ? nothing : (popfirst!(s), nothing) -IteratorSize(::Type{Stateful{VS,T}} where VS) where {T} = - isa(IteratorSize(T), SizeUnknown) ? SizeUnknown() : HasLength() +IteratorSize(::Type{Stateful{T,VS}}) where {T,VS} = IteratorSize(T) isa HasShape ? HasLength() : IteratorSize(T) eltype(::Type{Stateful{T, VS}} where VS) where {T} = eltype(T) -IteratorEltype(::Type{Stateful{VS,T}} where VS) where {T} = IteratorEltype(T) +IteratorEltype(::Type{Stateful{T,VS}}) where {T,VS} = IteratorEltype(T) length(s::Stateful) = length(s.itr) - s.taken end diff --git a/test/iterators.jl b/test/iterators.jl index be5ad0b9398e5..ac33b611f3ed4 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -549,3 +549,27 @@ end @test ps isa Iterators.Pairs @test collect(ps) == [1 => :a, 2 => :b] end + +@testset "Stateful fix #30643" begin + @test Base.IteratorSize(1:10) isa Base.HasShape + a = Iterators.Stateful(1:10) + @test Base.IteratorSize(a) isa Base.HasLength + @test length(a) == 10 + @test length(collect(a)) == 10 + @test length(a) == 0 + b = Iterators.Stateful(Iterators.take(1:10,3)) + @test Base.IteratorSize(b) isa Base.HasLength + @test length(b) == 3 + @test length(collect(b)) == 3 + @test length(b) == 0 + c = Iterators.Stateful(Iterators.countfrom(1)) + @test Base.IteratorSize(c) isa Base.IsInfinite + @test length(Iterators.take(c,3)) == 3 + @test length(collect(Iterators.take(c,3))) == 3 + d = Iterators.Stateful(Iterators.filter(isodd,1:10)) + @test Base.IteratorSize(d) isa Base.SizeUnknown + @test length(collect(Iterators.take(d,3))) == 3 + @test length(collect(d)) == 2 + @test length(collect(d)) == 0 +end +