Skip to content

Commit

Permalink
Use containertype to determine array type for array broadcast (JuliaL…
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasnoack authored and tkelman committed Dec 30, 2016
1 parent 7ba6ad6 commit ab984a5
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 13 deletions.
13 changes: 1 addition & 12 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,9 @@ export bitbroadcast, dotview
export broadcast_getindex, broadcast_setindex!

## Broadcasting utilities ##

broadcast_array_type() = Array
broadcast_array_type(A, As...) =
if is_nullable_array(A) || broadcast_array_type(As...) === Array{Nullable}
Array{Nullable}
else
Array
end

# fallbacks for some special cases
@inline broadcast(f, x::Number...) = f(x...)
@inline broadcast{N}(f, t::NTuple{N}, ts::Vararg{NTuple{N}}) = map(f, t, ts...)
@inline broadcast(f, As::AbstractArray...) =
broadcast_c(f, broadcast_array_type(As...), As...)

# special cases for "X .= ..." (broadcast!) assignments
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x)
Expand Down Expand Up @@ -313,7 +302,7 @@ ziptype{T}(::Type{T}, A) = typestuple(T, A)
ziptype{T}(::Type{T}, A, B) = (Base.@_pure_meta; Iterators.Zip2{typestuple(T, A), typestuple(T, B)})
@inline ziptype{T}(::Type{T}, A, B, C, D...) = Iterators.Zip{typestuple(T, A), ziptype(T, B, C, D...)}

_broadcast_type{S}(::Type{S}, f, T::Type, As...) = Base._return_type(S, typestuple(S, T, As...))
_broadcast_type{S}(::Type{S}, f, T::Type, As...) = Base._return_type(f, typestuple(S, T, As...))
_broadcast_type{T}(::Type{T}, f, A, Bs...) = Base._default_eltype(Base.Generator{ziptype(T, A, Bs...), ftype(f, A, Bs...)})

# broadcast methods that dispatch on the type of the final container
Expand Down
35 changes: 34 additions & 1 deletion test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ StrangeType18623(x,y) = (x,y)
let
f(A, n) = broadcast(x -> +(x, n), A)
@test @inferred(f([1.0], 1)) == [2.0]
g() = (a = 1; Base.Broadcast._broadcast_type(x -> x + a, 1.0))
g() = (a = 1; Base.Broadcast._broadcast_type(Any, x -> x + a, 1.0))
@test @inferred(g()) === Float64
end

Expand All @@ -376,3 +376,36 @@ end

# Check that broadcast!(f, A) populates A via independent calls to f (#12277, #19722).
@test let z = 1; A = broadcast!(() -> z += 1, zeros(2)); A[1] != A[2]; end

# broadcasting for custom AbstractArray
immutable Array19745{T,N} <: AbstractArray{T,N}
data::Array{T,N}
end
Base.getindex(A::Array19745, i::Integer...) = A.data[i...]
Base.size(A::Array19745) = size(A.data)

Base.Broadcast.containertype{T<:Array19745}(::Type{T}) = Array19745

Base.Broadcast.promote_containertype(::Type{Array19745}, ::Type{Array19745}) = Array19745
Base.Broadcast.promote_containertype(::Type{Array19745}, ::Type{Array}) = Array19745
Base.Broadcast.promote_containertype(::Type{Array19745}, ct) = Array19745
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{Array19745}) = Array19745
Base.Broadcast.promote_containertype(ct, ::Type{Array19745}) = Array19745

Base.Broadcast.broadcast_indices(::Type{Array19745}, A) = indices(A)
Base.Broadcast.broadcast_indices(::Type{Array19745}, A::Ref) = ()

getfield19745(x::Array19745) = x.data
getfield19745(x) = x

Base.Broadcast.broadcast_c(f, ::Type{Array19745}, A, Bs...) =
Array19745(Base.Broadcast.broadcast_c(f, Array, getfield19745(A), map(getfield19745, Bs)...))

@testset "broadcasting for custom AbstractArray" begin
a = randn(10)
aa = Array19745(a)
@test a .+ 1 == @inferred(aa .+ 1)
@test a .* a' == @inferred(aa .* aa')
@test isa(aa .+ 1, Array19745)
@test isa(aa .* aa', Array19745)
end

0 comments on commit ab984a5

Please sign in to comment.