Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Change iteratorsize trait of product(itr1, itr2) #16437

Merged
merged 5 commits into from
May 29, 2016
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions base/iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,49 @@ done(it::Repeated, state) = false

repeated(x, n::Int) = take(repeated(x), n)

# product

# Product -- cartesian product of iterators

abstract AbstractProdIterator

length(p::AbstractProdIterator) = prod(size(p))
size(p::AbstractProdIterator) = _prod_size(p.a, p.b, iteratorsize(p.a), iteratorsize(p.b))
ndims(p::AbstractProdIterator) = length(size(p))

# generic methods to handle size of Prod* types
_prod_size(a, ::HasShape) = size(a)
_prod_size(a, ::HasLength) = (length(a), )
_prod_size(a, A) =
throw(ArgumentError("Cannot compute size for object of type $(typeof(a))"))
_prod_size(a, b, ::HasLength, ::HasLength) = (length(a), length(b))
_prod_size(a, b, ::HasLength, ::HasShape) = (length(a), size(b)...)
_prod_size(a, b, ::HasShape, ::HasLength) = (size(a)..., length(b))
_prod_size(a, b, ::HasShape, ::HasShape) = (size(a)..., size(b)...)
_prod_size(a, b, A, ::Union{HasShape, HasLength}) =
throw(ArgumentError("Cannot compute size for object of type $(typeof(a))"))
_prod_size(a, b, ::Union{HasShape, HasLength}, B) =
throw(ArgumentError("Cannot compute size for object of type $(typeof(b))"))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is ambiguous

# one iterator
immutable Prod1{I} <: AbstractProdIterator
a::I
end
product(a) = Prod1(a)

eltype{I}(::Type{Prod1{I}}) = Tuple{eltype(I)}
size(p::Prod1) = _prod_size(p.a, iteratorsize(p.a))

@inline start(p::Prod1) = start(p.a)
@inline function next(p::Prod1, st)
n, st = next(p.a, st)
(n, ), st
end
@inline done(p::Prod1, st) = done(p.a, st)

iteratoreltype{I}(::Type{Prod1{I}}) = iteratoreltype(I)
iteratorsize{I}(::Type{Prod1{I}}) = iteratorsize(I)

# two iterators
immutable Prod2{I1, I2} <: AbstractProdIterator
a::I1
b::I2
Expand All @@ -323,11 +362,11 @@ changes the fastest. Example:
(1,5)
(2,5)
"""
product(a) = Zip1(a)
product(a, b) = Prod2(a, b)

eltype{I1,I2}(::Type{Prod2{I1,I2}}) = Tuple{eltype(I1), eltype(I2)}

iteratoreltype{I1,I2}(::Type{Prod2{I1,I2}}) = and_iteratoreltype(iteratoreltype(I1),iteratoreltype(I2))
length(p::AbstractProdIterator) = length(p.a)*length(p.b)
iteratorsize{I1,I2}(::Type{Prod2{I1,I2}}) = prod_iteratorsize(iteratorsize(I1),iteratorsize(I2))

function start(p::AbstractProdIterator)
Expand Down Expand Up @@ -355,13 +394,15 @@ end
@inline next(p::Prod2, st) = prod_next(p, st)
@inline done(p::AbstractProdIterator, st) = st[4]

# n iterators
immutable Prod{I1, I2<:AbstractProdIterator} <: AbstractProdIterator
a::I1
b::I2
end

product(a, b, c...) = Prod(a, product(b, c...))

eltype{I1,I2}(::Type{Prod{I1,I2}}) = tuple_type_cons(eltype(I1), eltype(I2))

iteratoreltype{I1,I2}(::Type{Prod{I1,I2}}) = and_iteratoreltype(iteratoreltype(I1),iteratoreltype(I2))
iteratorsize{I1,I2}(::Type{Prod{I1,I2}}) = prod_iteratorsize(iteratorsize(I1),iteratorsize(I2))

Expand All @@ -370,8 +411,10 @@ iteratorsize{I1,I2}(::Type{Prod{I1,I2}}) = prod_iteratorsize(iteratorsize(I1),it
((x[1][1],x[1][2]...), x[2])
end

prod_iteratorsize(::Union{HasLength,HasShape}, ::Union{HasLength,HasShape}) = HasLength()
prod_iteratorsize(a, ::IsInfinite) = IsInfinite() # products can have an infinite last iterator (which moves slowest)
prod_iteratorsize(::Union{HasLength,HasShape}, ::Union{HasLength,HasShape}) = HasShape()
# products can have an infinite iterator
prod_iteratorsize(a, ::IsInfinite) = IsInfinite()
prod_iteratorsize(::IsInfinite, b) = IsInfinite()
prod_iteratorsize(a, b) = SizeUnknown()
Copy link
Contributor

@mschauer mschauer May 27, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one too, but this should be removed, because a or b might be zero and 0*infty=0 in terms of taking products.

Copy link
Contributor Author

@gasagna gasagna May 28, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that would be a special case. How would you handle the majority of cases where you have (finite nonzero size) * (infinite size) ?


_size(p::Prod2) = (length(p.a), length(p.b))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed anymore?

Expand Down
179 changes: 171 additions & 8 deletions test/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,177 @@ end
# product
# -------

@test isempty(Base.product(1:2,1:0))
@test isempty(Base.product(1:2,1:0,1:10))
@test isempty(Base.product(1:2,1:10,1:0))
@test isempty(Base.product(1:0,1:2,1:10))
@test collect(Base.product(1:2,3:4)) == [(1,3),(2,3),(1,4),(2,4)]
@test isempty(collect(Base.product(1:0,1:2)))
@test length(Base.product(1:2,1:10,4:6)) == 60
@test Base.iteratorsize(Base.product(1:2, countfrom(1))) == Base.IsInfinite()
# empty?
for itr in [Base.product(1:0),
Base.product(1:2, 1:0),
Base.product(1:0, 1:2),
Base.product(1:0, 1:1, 1:2),
Base.product(1:1, 1:0, 1:2),
Base.product(1:1, 1:2 ,1:0)]
@test isempty(itr)
@test isempty(collect(itr))
end

# collect a product - first iterators runs faster
@test collect(Base.product(1:2)) == [(i,) for i=1:2]
@test collect(Base.product(1:2, 3:4)) == [(i, j) for i=1:2, j=3:4]
@test collect(Base.product(1:2, 3:4, 5:6)) == [(i, j, k) for i=1:2, j=3:4, k=5:6]

# iteration order
let
expected = [(1,3,5), (2,3,5), (1,4,5), (2,4,5), (1,3,6), (2,3,6), (1,4,6), (2,4,6)]
actual = Base.product(1:2, 3:4, 5:6)
for (exp, act) in zip(expected, actual)
@test exp == act
end
end

# collect multidimensional array
let
a, b = 1:3, [4 6;
5 7]
p = Base.product(a, b)
@test size(p) == (3, 2, 2)
@test length(p) == 12
@test ndims(p) == 3
@test eltype(p) == NTuple{2, Int}
cp = collect(p)
for i = 1:3
@test cp[i, :, :] == [(i, 4) (i, 6);
(i, 5) (i, 7)]
end
end

# with 1D inputs
let
a, b, c = 1:2, 1.0:10.0, Int32(1):Int32(0)

# length
@test length(Base.product(a)) == 2
@test length(Base.product(a, b)) == 20
@test length(Base.product(a, b, c)) == 0

# size
@test size(Base.product(a)) == (2, )
@test size(Base.product(a, b)) == (2, 10)
@test size(Base.product(a, b, c)) == (2, 10, 0)

# eltype
@test eltype(Base.product(a)) == Tuple{Int}
@test eltype(Base.product(a, b)) == Tuple{Int, Float64}
@test eltype(Base.product(a, b, c)) == Tuple{Int, Float64, Int32}

# ndims
@test ndims(Base.product(a)) == 1
@test ndims(Base.product(a, b)) == 2
@test ndims(Base.product(a, b, c)) == 3
end

# with multidimensional inputs
let
a, b, c = randn(4, 4), randn(3, 3, 3), randn(2, 2, 2, 2)
args = Any[(a,),
(a, a),
(a, b),
(a, a, a),
(a, b, c)]
sizes = Any[(4, 4),
(4, 4, 4, 4),
(4, 4, 3, 3, 3),
(4, 4, 4, 4, 4, 4),
(4, 4, 3, 3, 3, 2, 2, 2, 2)]
for (method, fun) in zip([size, ndims, length], [x->x, length, prod])
for i in 1:length(args)
@test method(Base.product(args[i]...)) == method(collect(Base.product(args[i]...))) == fun(sizes[i])
end
end
end

# more tests on product with iterators of various type
let
iters = (1:2,
rand(2, 2, 2),
take(1:4, 2),
Base.product(1:2, 1:3),
Base.product(rand(2, 2), rand(1, 1, 1))
)
for method in [size, length, ndims, eltype]
for i = 1:length(iters)
args = iters[i]
@test method(Base.product(args...)) == method(collect(Base.product(args...)))
for j = 1:length(iters)
args = iters[i], iters[j]
@test method(Base.product(args...)) == method(collect(Base.product(args...)))
for k = 1:length(iters)
args = iters[i], iters[j], iters[k]
@test method(Base.product(args...)) == method(collect(Base.product(args...)))
end
end
end
end
end

# product of finite length and infinite length iterators
let
a = 1:2
b = countfrom(1)
ab = Base.product(a, b)
ba = Base.product(b, a)
abexp = [(1, 1), (2, 1), (1, 2), (2, 2), (1, 3), (2, 3)]
baexp = [(1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1)]
for (expected, actual) in zip([abexp, baexp], [ab, ba])
for (i, el) in enumerate(actual)
@test el == expected[i]
i == length(expected) && break
end
@test_throws ArgumentError length(actual)
@test_throws ArgumentError size(actual)
@test_throws ArgumentError ndims(actual)
end

# size infinite or unknown raises an error
for itr in Any[countfrom(1), Filter(i->0, 1:10)]
@test_throws ArgumentError length(Base.product(itr))
@test_throws ArgumentError size(Base.product(itr))
@test_throws ArgumentError ndims(Base.product(itr))
end
end

# iteratorsize trait business
let f1 = Filter(i->i>0, 1:10)
@test Base.iteratorsize(Base.product(f1)) == Base.SizeUnknown()
@test Base.iteratorsize(Base.product(1:2, f1)) == Base.SizeUnknown()
@test Base.iteratorsize(Base.product(f1, 1:2)) == Base.SizeUnknown()
@test Base.iteratorsize(Base.product(f1, f1)) == Base.SizeUnknown()
@test Base.iteratorsize(Base.product(f1, countfrom(1))) == Base.IsInfinite()
@test Base.iteratorsize(Base.product(countfrom(1), f1)) == Base.IsInfinite()
end
@test Base.iteratorsize(Base.product(1:2, countfrom(1))) == Base.IsInfinite()
@test Base.iteratorsize(Base.product(countfrom(1), 1:2)) == Base.IsInfinite()
@test Base.iteratorsize(Base.product(1:2)) == Base.HasShape()
@test Base.iteratorsize(Base.product(1:2, 1:2)) == Base.HasShape()
@test Base.iteratorsize(Base.product(take(1:2, 1), take(1:2, 1))) == Base.HasShape()
@test Base.iteratorsize(Base.product(take(1:2, 2))) == Base.HasLength()
@test Base.iteratorsize(Base.product([1 2; 3 4])) == Base.HasShape()

# iteratoreltype trait business
let f1 = Filter(i->i>0, 1:10)
@test Base.iteratoreltype(Base.product(f1)) == Base.HasEltype() # FIXME? eltype(f1) is Any
@test Base.iteratoreltype(Base.product(1:2, f1)) == Base.HasEltype() # FIXME? eltype(f1) is Any
@test Base.iteratoreltype(Base.product(f1, 1:2)) == Base.HasEltype() # FIXME? eltype(f1) is Any
@test Base.iteratoreltype(Base.product(f1, f1)) == Base.HasEltype() # FIXME? eltype(f1) is Any
@test Base.iteratoreltype(Base.product(f1, countfrom(1))) == Base.HasEltype() # FIXME? eltype(f1) is Any
@test Base.iteratoreltype(Base.product(countfrom(1), f1)) == Base.HasEltype() # FIXME? eltype(f1) is Any
end
@test Base.iteratoreltype(Base.product(1:2, countfrom(1))) == Base.HasEltype()
@test Base.iteratoreltype(Base.product(countfrom(1), 1:2)) == Base.HasEltype()
@test Base.iteratoreltype(Base.product(1:2)) == Base.HasEltype()
@test Base.iteratoreltype(Base.product(1:2, 1:2)) == Base.HasEltype()
@test Base.iteratoreltype(Base.product(take(1:2, 1), take(1:2, 1))) == Base.HasEltype()
@test Base.iteratoreltype(Base.product(take(1:2, 2))) == Base.HasEltype()
@test Base.iteratoreltype(Base.product([1 2; 3 4])) == Base.HasEltype()



# flatten
# -------
Expand Down