From 4c076c80af9d9c8439cfa20e2efd5c884d88b64d Mon Sep 17 00:00:00 2001 From: matthias314 <56549971+matthias314@users.noreply.github.com> Date: Mon, 28 Oct 2024 18:02:41 -0400 Subject: [PATCH] improved `eltype` for `flatten` with tuple argument (#55946) We have always had ``` julia> t = (Int16[1,2], Int32[3,4]); eltype(Iterators.flatten(t)) Any ``` With this PR, the result is `Signed` (`promote_typejoin` applied to the element types of the tuple elements). The same applies to `NamedTuple`: ``` julia> nt = (a = [1,2], b = (3,4)); eltype(Iterators.flatten(nt)) Any # old Int64 # new ``` --- base/iterators.jl | 3 ++- test/iterators.jl | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/base/iterators.jl b/base/iterators.jl index 8bd30991319b6..1a0d42ed7447f 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -30,7 +30,7 @@ end import .Base: first, last, isempty, length, size, axes, ndims, - eltype, IteratorSize, IteratorEltype, + eltype, IteratorSize, IteratorEltype, promote_typejoin, haskey, keys, values, pairs, getindex, setindex!, get, iterate, popfirst!, isdone, peek, intersect @@ -1213,6 +1213,7 @@ julia> [(x,y) for x in 0:1 for y in 'a':'c'] # collects generators involving It flatten(itr) = Flatten(itr) eltype(::Type{Flatten{I}}) where {I} = eltype(eltype(I)) +eltype(::Type{Flatten{I}}) where {I<:Union{Tuple,NamedTuple}} = promote_typejoin(map(eltype, fieldtypes(I))...) eltype(::Type{Flatten{Tuple{}}}) = eltype(Tuple{}) IteratorEltype(::Type{Flatten{I}}) where {I} = _flatteneltype(I, IteratorEltype(I)) IteratorEltype(::Type{Flatten{Tuple{}}}) = IteratorEltype(Tuple{}) diff --git a/test/iterators.jl b/test/iterators.jl index 0df4d9afd371a..d1e7525c43465 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -513,6 +513,8 @@ end @test collect(flatten(Any[flatten(Any[1:2, 6:5]), flatten(Any[6:7, 8:9])])) == Any[1,2,6,7,8,9] @test collect(flatten(Any[2:1])) == Any[] @test eltype(flatten(UnitRange{Int8}[1:2, 3:4])) == Int8 +@test eltype(flatten(([1, 2], [3.0, 4.0]))) == Real +@test eltype(flatten((a = [1, 2], b = Int8[3, 4]))) == Signed @test length(flatten(zip(1:3, 4:6))) == 6 @test length(flatten(1:6)) == 6 @test collect(flatten(Any[])) == Any[]