Skip to content
This repository has been archived by the owner on Sep 1, 2020. It is now read-only.

length for chain (fixes #84) #85

Merged
merged 1 commit into from
Oct 30, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/Iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ if VERSION < v"0.5.0-dev+3305"
error("Do not call this on older versions")
end
else
import Base: iteratorsize, SizeUnknown, IsInfinite, HasLength
import Base: iteratorsize, SizeUnknown, IsInfinite,
HasLength, HasShape
end


Expand Down Expand Up @@ -105,17 +106,32 @@ done(it::RepeatCallForever, state) = false


# Concatenate the output of n iterators

immutable Chain{T<:Tuple}
xss::T
end

iteratorsize{T<:Chain}(::Type{T}) = SizeUnknown()
if VERSION >= v"0.5.0-dev+3305"
iteratorsize{T}(::Type{Chain{T}}) = _chain_is(T)

eltype{T}(::Type{Chain{T}}) = typejoin([eltype(t) for t in T.parameters]...)
@generated function _chain_is{T}(t::Type{T})
for itype in T.types
if iteratorsize(itype) == IsInfinite()
return :(IsInfinite())
elseif iteratorsize(itype) == SizeUnknown()
return :(SizeUnknown())
end
end
return :(HasLength())
end
end

chain(xss...) = Chain(xss)

length(it::Chain{Tuple{}}) = 0
length(it::Chain) = sum(length, it.xss)

eltype{T}(::Type{Chain{T}}) = typejoin([eltype(t) for t in T.parameters]...)

function start(it::Chain)
i = 1
xs_state = nothing
Expand Down
52 changes: 51 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
using Iterators, Base.Test

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bring back this whitespace.

if VERSION >= v"0.5.0-dev+3305"
import Base: IsInfinite, SizeUnknown, HasLength, iteratorsize, HasShape
end

# count
# -----

Expand Down Expand Up @@ -116,6 +120,53 @@ ch2 = chain(1:0, 1:2:5, 0.2:0.1:1.6)

@test eltype(ch2) == typejoin(Int, Float64)
@test collect(ch2) == [1:2:5; 0.2:0.1:1.6]
@test length(ch2) == length(collect(ch2))
if VERSION >= v"0.5.0-dev+3305"
@test iteratorsize(ch2) == HasLength()
end

ch3 = chain(1:10, 1:10, 1:10)
@test length(ch3) == 30
if VERSION >= v"0.5.0-dev+3305"
@test iteratorsize(ch3) == HasLength()
end

r = countfrom(1)
ch4 = chain(1:10, countfrom(1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deserves an eltype test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@test eltype(ch4) == Int
@test_throws MethodError length(ch4)
if VERSION >= v"0.5.0-dev+3305"
@assert iteratorsize(r) == IsInfinite()
@test iteratorsize(ch4) == IsInfinite()
end

ch5 = chain()
@test length(ch5) == 0
if VERSION >= v"0.5.0-dev+3305"
@test iteratorsize(ch5) == HasLength()
end

c = chain(ch1, ch2, ch3)
@test length(c) == length(ch1) + length(ch2) + length(ch3)
@test collect(c) == [collect(ch1); collect(ch2); collect(ch3)]

r = rand(2,2)
c = chain(r, r)
@test length(c) == 8
@test collect(c) == [vec(r); vec(r)]
if VERSION >= v"0.5.0-dev+3305"
@test iteratorsize(r) == HasShape()
@test iteratorsize(c) == HasLength()
end

if VERSION >= v"0.5.0-dev+3305"
r = distinct(collect(1:10))
@test iteratorsize(r) == SizeUnknown() #lazy filtering
c = chain(1:10, r)
@test_throws MethodError length(c)
@test length(collect(c)) == 20
@test iteratorsize(c) == SizeUnknown()
end

# product
# -------
Expand Down Expand Up @@ -557,4 +608,3 @@ end
@test_chain [1,2,3] Any[] ['w', 'x', 'y', 'z']
@test_chain [1,2,3] Union{}[] ['w', 'x', 'y', 'z']
@test_chain [1,2,3] 4 [('w',3), ('x',2), ('y',1), ('z',0)]