Skip to content

Commit

Permalink
Simplify broadcast's eltype promotion mechanism and make it handle mo…
Browse files Browse the repository at this point in the history
…re cases.

Re-simplify broadcast's eltype promotion mechanism as in JuliaLang#19421. With benefit of JuliaLang#19667, this simplified mechanism should handle additional cases (e.g. closures accepting more than two arguments). Also rename the mechanism more precisely (_broadcast_type -> _broadcast_eltype).
  • Loading branch information
Sacha0 committed Dec 25, 2016
1 parent 4f23cbb commit b78c9af
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 21 deletions.
18 changes: 5 additions & 13 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,22 +269,14 @@ end
@inline broadcast_elwise_op(f, As...) =
broadcast!(f, similar(Array{promote_eltype_op(f, As...)}, broadcast_indices(As...)), As...)

ftype(f, A) = typeof(f)
ftype(f, A...) = typeof(a -> f(a...))
ftype(T::Type, A...) = Type{T}
typestuple(a) = (Base.@_pure_meta; Tuple{eltype(a)})
typestuple(T::Type) = (Base.@_pure_meta; Tuple{Type{T}})
typestuple(a, b...) = (Base.@_pure_meta; Tuple{typestuple(a).types..., typestuple(b...).types...})
ziptype(A) = typestuple(A)
ziptype(A, B) = (Base.@_pure_meta; Iterators.Zip2{typestuple(A), typestuple(B)})
@inline ziptype(A, B, C, D...) = Iterators.Zip{typestuple(A), ziptype(B, C, D...)}

_broadcast_type(f, T::Type, As...) = Base._return_type(f, typestuple(T, As...))
_broadcast_type(f, A, Bs...) = Base._default_eltype(Base.Generator{ziptype(A, Bs...), ftype(f, A, Bs...)})
eltypestuple(a) = (Base.@_pure_meta; Tuple{eltype(a)})
eltypestuple(T::Type) = (Base.@_pure_meta; Tuple{Type{T}})
eltypestuple(a, b...) = (Base.@_pure_meta; Tuple{eltypestuple(a).types..., eltypestuple(b...).types...})
_broadcast_eltype(f, A, Bs...) = Base._return_type(f, eltypestuple(A, Bs...))

# broadcast methods that dispatch on the type of the final container
@inline function broadcast_c(f, ::Type{Array}, A, Bs...)
T = _broadcast_type(f, A, Bs...)
T = _broadcast_eltype(f, A, Bs...)
shape = broadcast_indices(A, Bs...)
iter = CartesianRange(shape)
if isleaftype(T)
Expand Down
6 changes: 2 additions & 4 deletions base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,15 +231,13 @@ end
promote_op(::Any...) = (@_pure_meta; Any)
function promote_op{S}(f, ::Type{S})
@_inline_meta
Z = Tuple{_default_type(S)}
T = _default_eltype(Generator{Z, typeof(f)})
T = _return_type(f, Tuple{_default_type(S)})
isleaftype(S) && return isleaftype(T) ? T : Any
return typejoin(S, T)
end
function promote_op{R,S}(f, ::Type{R}, ::Type{S})
@_inline_meta
Z = Iterators.Zip2{Tuple{_default_type(R)}, Tuple{_default_type(S)}}
T = _default_eltype(Generator{Z, typeof(a -> f(a...))})
T = _return_type(f, Tuple{_default_type(R), _default_type(S)})
isleaftype(R) && isleaftype(S) && return isleaftype(T) ? T : Any
return typejoin(R, S, T)
end
Expand Down
2 changes: 1 addition & 1 deletion base/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
rotl90, rotr90, round, scale!, setindex!, similar, size, transpose, tril,
triu, vec, permute!, map, map!

import Base.Broadcast: _broadcast_type, broadcast_indices
import Base.Broadcast: _broadcast_eltype, broadcast_indices

export AbstractSparseArray, AbstractSparseMatrix, AbstractSparseVector,
SparseMatrixCSC, SparseVector, blkdiag, dense, droptol!, dropzeros!, dropzeros,
Expand Down
4 changes: 2 additions & 2 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1413,7 +1413,7 @@ function map{Tf,N}(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = fofzeros == zero(fofzeros)
maxnnzC = fpreszeros ? min(length(A), _sumnnzs(A, Bs...)) : length(A)
entrytypeC = _broadcast_type(f, A, Bs...)
entrytypeC = _broadcast_eltype(f, A, Bs...)
indextypeC = _promote_indtype(A, Bs...)
Ccolptr = Vector{indextypeC}(A.n + 1)
Crowval = Vector{indextypeC}(maxnnzC)
Expand Down Expand Up @@ -1443,7 +1443,7 @@ function broadcast{Tf,N}(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = fofzeros == zero(fofzeros)
indextypeC = _promote_indtype(A, Bs...)
entrytypeC = _broadcast_type(f, A, Bs...)
entrytypeC = _broadcast_eltype(f, A, Bs...)
Cm, Cn = Base.to_shape(Base.Broadcast.broadcast_indices(A, Bs...))
maxnnzC = fpreszeros ? _checked_maxnnzbcres(Cm, Cn, A, Bs...) : (Cm * Cn)
Ccolptr = Vector{indextypeC}(Cn + 1)
Expand Down
7 changes: 6 additions & 1 deletion test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ StrangeType18623(x,y) = (x,y)
let
f(A, n) = broadcast(x -> +(x, n), A)
@test @inferred(f([1.0], 1)) == [2.0]
g() = (a = 1; Base.Broadcast._broadcast_type(x -> x + a, 1.0))
g() = (a = 1; Base.Broadcast._broadcast_eltype(x -> x + a, 1.0))
@test @inferred(g()) === Float64
end

Expand All @@ -373,3 +373,8 @@ end
@test (+).(Ref(1), Ref(2)) == fill(3)
@test (+).([[0,2], [1,3]], [1,-1]) == [[1,3], [0,2]]
@test (+).([[0,2], [1,3]], Ref{Vector{Int}}([1,-1])) == [[1,1], [2,2]]

# Test that broadcast's promotion mechanism handles closures accepting more than one argument
let f() = (a = 1; Base.Broadcast._broadcast_eltype((x, y) -> x + y + a, 1.0, 1.0))
@test @inferred(f()) == Float64
end

0 comments on commit b78c9af

Please sign in to comment.