Skip to content

Commit

Permalink
Make broadcasting closures inferable again
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz committed Dec 17, 2016
1 parent 363ecad commit f6e65af
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 11 deletions.
19 changes: 14 additions & 5 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module Broadcast

using Base.Cartesian
using Base: @pure, promote_eltype_op, _promote_op, linearindices, tail, OneTo, to_shape,
using Base: promote_eltype_op, linearindices, tail, OneTo, to_shape,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache
import Base: .+, .-, .*, ./, .\, .//, .==, .<, .!=, .<=, , .%, .<<, .>>, .^
import Base: broadcast
Expand Down Expand Up @@ -257,13 +257,22 @@ end
@inline broadcast_elwise_op(f, As...) =
broadcast!(f, similar(Array{promote_eltype_op(f, As...)}, broadcast_indices(As...)), As...)

@pure typestuple(a) = Tuple{eltype(a)}
@pure typestuple(T::Type) = Tuple{Type{T}}
@pure typestuple(a, b...) = Tuple{typestuple(a).types..., typestuple(b...).types...}
ftype(f, A) = typeof(f)
ftype(f, A...) = typeof(a -> f(a...))
ftype(T::Type, A...) = Type{T}
typestuple(a) = (Base.@_pure_meta; Tuple{eltype(a)})
typestuple(T::Type) = (Base.@_pure_meta; Tuple{Type{T}})
typestuple(a, b...) = (Base.@_pure_meta; Tuple{typestuple(a).types..., typestuple(b...).types...})
ziptype(A) = typestuple(A)
ziptype(A, B) = (Base.@_pure_meta; Iterators.Zip2{typestuple(A), typestuple(B)})
@inline ziptype(A, B, C, D...) = Iterators.Zip{typestuple(A), ziptype(B, C, D...)}

_broadcast_type(f, T::Type, As...) = Base._return_type(f, typestuple(T, As...))
_broadcast_type(f, A, Bs...) = Base._default_eltype(Base.Generator{ziptype(A, Bs...), ftype(f, A, Bs...)})

# broadcast methods that dispatch on the type of the final container
@inline function broadcast_c(f, ::Type{Array}, A, Bs...)
T = _promote_op(f, typestuple(A, Bs...))
T = _broadcast_type(f, A, Bs...)
shape = broadcast_indices(A, Bs...)
iter = CartesianRange(shape)
if isleaftype(T)
Expand Down
10 changes: 6 additions & 4 deletions base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,21 +223,23 @@ minmax(x::Real, y::Real) = minmax(promote(x, y)...)
_default_type(T::Type) = (@_pure_meta; T)

if isdefined(Core, :Inference)
_promote_op(f::ANY, t::ANY) = Core.Inference.return_type(f, t)
_return_type(f::ANY, t::ANY) = Core.Inference.return_type(f, t)
else
_promote_op(f::ANY, t::ANY) = Any
_return_type(f::ANY, t::ANY) = Any
end

promote_op(::Any...) = (@_pure_meta; Any)
function promote_op{S}(f, ::Type{S})
@_inline_meta
T = _promote_op(f, Tuple{_default_type(S)})
Z = Tuple{_default_type(S)}
T = _default_eltype(Generator{Z, typeof(f)})
isleaftype(S) && return isleaftype(T) ? T : Any
return typejoin(S, T)
end
function promote_op{R,S}(f, ::Type{R}, ::Type{S})
@_inline_meta
T = _promote_op(f, Tuple{_default_type(R), _default_type(S)})
Z = Iterators.Zip2{Tuple{_default_type(R)}, Tuple{_default_type(S)}}
T = _default_eltype(Generator{Z, typeof(a -> f(a...))})
isleaftype(R) && isleaftype(S) && return isleaftype(T) ? T : Any
return typejoin(R, S, T)
end
Expand Down
2 changes: 1 addition & 1 deletion base/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
rotl90, rotr90, round, scale!, setindex!, similar, size, transpose, tril,
triu, vec, permute!, map, map!

import Base.Broadcast: broadcast_indices
import Base.Broadcast: _broadcast_type, broadcast_indices

export AbstractSparseArray, AbstractSparseMatrix, AbstractSparseVector,
SparseMatrixCSC, SparseVector, blkdiag, dense, droptol!, dropzeros!, dropzeros,
Expand Down
1 change: 0 additions & 1 deletion base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1464,7 +1464,6 @@ _maxnnzfrom(Cm, Cn, A) = nnz(A) * div(Cm, A.m) * div(Cn, A.n)
@inline _maxnnzfrom_each(Cm, Cn, As) = (_maxnnzfrom(Cm, Cn, first(As)), _maxnnzfrom_each(Cm, Cn, tail(As))...)
@inline _unchecked_maxnnzbcres(Cm, Cn, As) = min(Cm * Cn, sum(_maxnnzfrom_each(Cm, Cn, As)))
@inline _checked_maxnnzbcres(Cm, Cn, As...) = Cm != 0 && Cn != 0 ? _unchecked_maxnnzbcres(Cm, Cn, As) : 0
_broadcast_type(f, As...) = Base._promote_op(f, Base.Broadcast.typestuple(As...))

# _map_zeropres!/_map_notzeropres! specialized for a single sparse matrix
"Stores only the nonzero entries of `map(f, Matrix(A))` in `C`."
Expand Down
8 changes: 8 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,11 @@ StrangeType18623(x,y) = (x,y)

# 19419
@test @inferred(broadcast(round, Int, [1])) == [1]

# https://discourse.julialang.org/t/towards-broadcast-over-combinations-of-sparse-matrices-and-scalars/910
let
f(A, n) = broadcast(x -> +(x, n), A)
@test @inferred(f([1.0], 1)) == [2.0]
g() = (a = 1; Base.Broadcast._broadcast_type(x -> x + a, 1.0))
@test @inferred(g()) === Float64
end

0 comments on commit f6e65af

Please sign in to comment.