From 59be7fd374d1d76c421c36cc2f6b84ea938a08d5 Mon Sep 17 00:00:00 2001 From: pabloferz Date: Fri, 17 Jun 2016 12:09:45 +0200 Subject: [PATCH] Treat non-indexable types as scalars in broadcast --- base/abstractarray.jl | 16 ++++++++++++---- base/broadcast.jl | 22 +++++++++++++++------- base/float.jl | 7 +++++++ base/number.jl | 4 ++++ base/parse.jl | 2 ++ base/promotion.jl | 5 +++-- test/broadcast.jl | 2 +- 7 files changed, 44 insertions(+), 14 deletions(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index e937e9ca4208a..9b0823f0f2f2a 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -151,7 +151,6 @@ immutable IndicesList <: IndicesBehavior end # indices like (:cat, :dog, indicesbehavior(A::AbstractArray) = indicesbehavior(typeof(A)) indicesbehavior{T<:AbstractArray}(::Type{T}) = IndicesStartAt1() -indicesbehavior(::Number) = IndicesStartAt1() abstract IndicesPerformance immutable IndicesFast1D <: IndicesPerformance end # indices(A, d) is fast @@ -412,8 +411,9 @@ end promote_indices(a::AbstractArray, b::AbstractArray) = _promote_indices(indicesbehavior(a), indicesbehavior(b), a, b) _promote_indices(::IndicesStartAt1, ::IndicesStartAt1, a, b) = a _promote_indices(::IndicesBehavior, ::IndicesBehavior, a, b) = throw(ArgumentError("types $(typeof(a)) and $(typeof(b)) do not have promote_indices defined")) -promote_indices(a::Number, b::AbstractArray) = b -promote_indices(a::AbstractArray, b::Number) = a +promote_indices(a, b::AbstractArray) = b +promote_indices(a::AbstractArray, b) = a +promote_indices(a, b) = a # Strip off the index-changing container---this assumes that `parent` # performs such an operation. TODO: since few things in Base need this, it @@ -1459,10 +1459,18 @@ end promote_eltype_op(::Any) = (@_pure_meta; Bottom) promote_eltype_op{T}(op, ::AbstractArray{T}) = (@_pure_meta; promote_op(op, T)) promote_eltype_op{T}(op, ::T ) = (@_pure_meta; promote_op(op, T)) +promote_eltype_op{T}(op, Ts::AbstractArray{DataType}, ::AbstractArray{T}) = typejoin((promote_op(op, S, T) for S in Ts)...) +promote_eltype_op{T}(op, Ts::AbstractArray{DataType}, ::T ) = typejoin((promote_op(op, S, T) for S in Ts)...) promote_eltype_op{R,S}(op, ::AbstractArray{R}, ::AbstractArray{S}) = (@_pure_meta; promote_op(op, R, S)) promote_eltype_op{R,S}(op, ::AbstractArray{R}, ::S) = (@_pure_meta; promote_op(op, R, S)) promote_eltype_op{R,S}(op, ::R, ::AbstractArray{S}) = (@_pure_meta; promote_op(op, R, S)) -promote_eltype_op(op, A, B, C, D...) = (@_pure_meta; promote_op(op, eltype(A), promote_eltype_op(op, B, C, D...))) +promote_eltype_op{R,S}(op, ::AbstractArray{R}, ::Type{S}) = (@_pure_meta; promote_op(op, R, S)) +promote_eltype_op{R,S}(op, ::Type{R}, ::AbstractArray{S}) = (@_pure_meta; promote_op(op, R, S)) +promote_eltype_op{R,S}(op, ::Type{R}, ::Type{S}) = (@_pure_meta; promote_op(op, R, S)) +promote_eltype_op{R,S}(op, ::Type{R}, ::S) = (@_pure_meta; promote_op(op, R, S)) +promote_eltype_op{R,S}(op, ::R, ::Type{S}) = (@_pure_meta; promote_op(op, R, S)) +promote_eltype_op{R,S}(op, ::R, ::S) = (@_pure_meta; promote_op(op, R, S)) +promote_eltype_op(op, A, B, C, D...) = promote_eltype_op(op, A, promote_eltype_op(op, B, C, D...)) ## 1 argument map!{F}(f::F, A::AbstractArray) = map!(f, A, A) diff --git a/base/broadcast.jl b/base/broadcast.jl index 0c0165fac300e..5eb245b7832a5 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -13,8 +13,9 @@ export broadcast_getindex, broadcast_setindex! ## Calculate the broadcast shape of the arguments, or error if incompatible # array inputs broadcast_shape() = () -broadcast_shape(A) = shape(A) -@inline broadcast_shape(A, B...) = broadcast_shape((), shape(A), map(shape, B)...) +broadcast_shape(A) = () +broadcast_shape(A::AbstractArray) = shape(A) +@inline broadcast_shape(A, B...) = broadcast_shape((), broadcast_shape(A), map(broadcast_shape, B)...) # shape inputs broadcast_shape(shape::Tuple) = shape @inline broadcast_shape(shape::Tuple, shape1::Tuple, shapes::Tuple...) = broadcast_shape(_bcs((), shape, shape1), shapes...) @@ -40,7 +41,7 @@ _bcsm(a::Number, b::Number) = a == b || b == 1 ## Check that all arguments are broadcast compatible with shape # comparing one input against a shape check_broadcast_shape(shp) = nothing -check_broadcast_shape(shp, A) = check_broadcast_shape(shp, shape(A)) +check_broadcast_shape(shp, A) = check_broadcast_shape(shp, broadcast_shape(A)) check_broadcast_shape(::Tuple{}, ::Tuple{}) = nothing check_broadcast_shape(shp, ::Tuple{}) = nothing check_broadcast_shape(::Tuple{}, Ashp::Tuple) = throw(DimensionMismatch("cannot broadcast array to have fewer dimensions")) @@ -63,8 +64,8 @@ end @inline _newindex(out, I) = out # can truncate if indexmap is shorter than I @inline _newindex(out, I, keep::Bool, indexmap...) = _newindex((out..., ifelse(keep, I[1], 1)), tail(I), indexmap...) -newindexer(sz, x::Number) = () -@inline newindexer(sz, A) = _newindexer(sz, size(A)) +newindexer(sz, x) = () +@inline newindexer(sz, A::AbstractArray) = _newindexer(sz, size(A)) @inline _newindexer(sz, szA::Tuple{}) = () @inline _newindexer(sz, szA) = (sz[1] == szA[1], _newindexer(tail(sz), tail(szA))...) @@ -79,6 +80,10 @@ const bitcache_size = 64 * bitcache_chunks # do not change this dumpbitcache(Bc::Vector{UInt64}, bind::Int, C::Vector{Bool}) = Base.copy_to_bitarray_chunks!(Bc, ((bind - 1) << 6) + 1, C, 1, min(bitcache_size, (length(Bc)-bind+1) << 6)) +# Since we can't make T[1] return T, use this inside `_broadcast!` +@inline _broadcast_getvals(A, I) = A +@inline _broadcast_getvals(A::AbstractArray, I) = A[I] + ## Broadcasting core # nargs encodes the number of As arguments (which matches the number # of indexmaps). The first two type parameters are to ensure specialization. @@ -92,7 +97,7 @@ dumpbitcache(Bc::Vector{UInt64}, bind::Int, C::Vector{Bool}) = # reverse-broadcast the indices @nexprs $nargs i->(I_i = newindex(I, imap_i)) # extract array values - @nexprs $nargs i->(@inbounds val_i = A_i[I_i]) + @nexprs $nargs i->(@inbounds val_i = _broadcast_getvals(A_i, I_i)) # call the function and store the result @inbounds B[I] = @ncall $nargs f val end @@ -140,7 +145,10 @@ end B end -@inline broadcast(f, As...) = broadcast!(f, allocate_for(Array{promote_eltype_op(f, As...)}, As, broadcast_shape(As...)), As...) +@inline _broadcast(::Type{Val{false}}, f, a...) = f(a...) +@inline _broadcast(::Type{Val{true}}, f, As...) = broadcast!(f, allocate_for(Array{promote_eltype_op(f, As...)}, As, broadcast_shape(As...)), As...) + +@inline broadcast(f, As...) = (b = any(isa(T,AbstractArray) for T in As); _broadcast(Val{b}, f, As...)) @inline bitbroadcast(f, As...) = broadcast!(f, allocate_for(BitArray, As, broadcast_shape(As...)), As...) diff --git a/base/float.jl b/base/float.jl index 6d507f2d6b34b..232e3fff2cf41 100644 --- a/base/float.jl +++ b/base/float.jl @@ -199,6 +199,13 @@ promote_rule(::Type{Float64}, ::Type{Float32}) = Float64 widen(::Type{Float16}) = Float32 widen(::Type{Float32}) = Float64 +promote_op{Op<:typeof(trunc),T<:Union{Float32,Float64}}(::Op, ::Type{Signed}, ::Type{T}) = Int +promote_op{Op<:typeof(trunc),T<:Union{Float32,Float64}}(::Op, ::Type{Unsigned}, ::Type{T}) = UInt +promote_op{Op<:typeof(trunc),R,S}(::Op, ::Type{R}, ::Type{S}) = R +for f in (ceil, floor, round) + @eval promote_op{Op<:$(typeof(f)),R,S}(::Op, ::Type{R}, ::Type{S}) = promote_op($trunc, R, S) +end + ## floating point arithmetic ## -(x::Float32) = box(Float32,neg_float(unbox(Float32,x))) -(x::Float64) = box(Float64,neg_float(unbox(Float64,x))) diff --git a/base/number.jl b/base/number.jl index ab33caf0ebd25..87c8089d340ad 100644 --- a/base/number.jl +++ b/base/number.jl @@ -64,3 +64,7 @@ one(x::Number) = oftype(x,1) one{T<:Number}(::Type{T}) = convert(T,1) factorial(x::Number) = gamma(x + 1) # fallback for x not Integer + +promote_op{T<:Number}(op, ::Type{T}) = typeof(op(one(T))) +promote_op{R,S<:Number}(op::Type{R}, ::Type{S}) = R # to handle ambiguities +promote_op{R<:Number,S<:Number}(op, ::Type{R}, ::Type{S}) = typeof(op(one(R), one(S))) diff --git a/base/parse.jl b/base/parse.jl index 1da9b742457e3..5387ee55d26e7 100644 --- a/base/parse.jl +++ b/base/parse.jl @@ -194,3 +194,5 @@ function parse(str::AbstractString; raise::Bool=true) end return ex end + +promote_op{Op<:typeof(parse),R,S}(::Op, ::Type{R}, ::Type{S}) = R diff --git a/base/promotion.jl b/base/promotion.jl index 7b6ebf0f406cb..b5c091ec6812a 100644 --- a/base/promotion.jl +++ b/base/promotion.jl @@ -222,9 +222,10 @@ minmax(x::Real, y::Real) = minmax(promote(x, y)...) # for the multiplication of two types, # promote_op{R<:MyType,S<:MyType}(::typeof(*), ::Type{R}, ::Type{S}) = MyType{multype(R,S)} promote_op(::Any) = (@_pure_meta; Bottom) -promote_op(::Any, T) = (@_pure_meta; T) +promote_op(::Any, T) = (@_pure_meta; Any) promote_op{T}(::Type{T}, ::Any) = (@_pure_meta; T) -promote_op{R,S}(::Any, ::Type{R}, ::Type{S}) = (@_pure_meta; promote_type(R, S)) +promote_op{R,S}(::Any, ::Type{R}, ::Type{S}) = (@_pure_meta; Any) +promote_op{Op<:typeof(convert),R,S}(::Op, ::Type{R}, ::Type{S}) = (@_pure_meta; R) promote_op(op, T, S, U, V...) = (@_pure_meta; promote_op(op, T, promote_op(op, S, U, V...))) ## catch-alls to prevent infinite recursion when definitions are missing ## diff --git a/test/broadcast.jl b/test/broadcast.jl index 2484a43d29701..b17716eaf8a9c 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -165,7 +165,7 @@ m = [1:2;]' @test @inferred([0,1.2].+reshape([0,-2],1,1,2)) == reshape([0 -2; 1.2 -0.8],2,1,2) rt = Base.return_types(.+, Tuple{Array{Float64, 3}, Array{Int, 1}}) @test length(rt) == 1 && rt[1] == Array{Float64, 3} -rt = Base.return_types(broadcast, Tuple{Function, Array{Float64, 3}, Array{Int, 1}}) +rt = Base.return_types(broadcast, Tuple{typeof(+), Array{Float64, 3}, Array{Int, 1}}) @test length(rt) == 1 && rt[1] == Array{Float64, 3} rt = Base.return_types(broadcast!, Tuple{Function, Array{Float64, 3}, Array{Float64, 3}, Array{Int, 1}}) @test length(rt) == 1 && rt[1] == Array{Float64, 3}