Skip to content

Commit

Permalink
[ci skip]
Browse files Browse the repository at this point in the history
Address comments.

Reimplement and generalize all-scalar optimization.

Fix allocation tests for sparse broadcast!.
  • Loading branch information
tkoolen committed Dec 15, 2017
1 parent 6ebf703 commit 35638c7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 16 deletions.
37 changes: 26 additions & 11 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -442,21 +442,36 @@ Note that `dest` is only used to store the result, and does not supply
arguments to `f` unless it is also listed in the `As`,
as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
"""
broadcast!(f, dest, As...) = broadcast!(f, dest, combine_styles(As...), As...)
broadcast!(f, dest, ::BroadcastStyle, As...) = broadcast!(f, dest, nothing, As...)
@inline function broadcast!(f, C, ::Void, A, Bs::Vararg{Any,N}) where N
if isa(f, typeof(identity)) && N == 0
if isa(A, Number)
return fill!(C, A)
elseif isa(C, AbstractArray) && isa(A, AbstractArray) && Base.axes(C) == Base.axes(A)
return copy!(C, A)
@inline broadcast!(f, dest, As...) = broadcast!(f, dest, combine_styles(As...), As...)
@inline broadcast!(f, dest, ::BroadcastStyle, As...) = broadcast!(f, dest, nothing, As...)

# Default behavior (separated out so that it can be called by users who want to extend broadcast!).
@inline function broadcast!(f, dest, ::Void, As::Vararg{Any, N}) where N
if f isa typeof(identity) && N == 1
A = As[1]
if A isa AbstractArray && Base.axes(dest) == Base.axes(A)
return copy!(dest, A)
end
end
return _broadcast!(f, C, A, Bs...)
return _broadcast!(f, dest, As...)
end

# This indirection allows size-dependent implementations (e.g., see the copying `identity`
# specialization above)
# Optimization for the all-Scalar case.
@inline function broadcast!(f, dest, ::Scalar, As::Vararg{Any, N}) where N
if dest isa AbstractArray
if f isa typeof(identity) && N == 1
return fill!(dest, As[1])
else
@inbounds for I in eachindex(dest)
dest[I] = f(As...)
end
return dest
end
end
return _broadcast!(f, dest, As...)
end

# This indirection allows size-dependent implementations.
@inline function _broadcast!(f, C, A, Bs::Vararg{Any,N}) where N
shape = broadcast_indices(C)
@boundscheck check_broadcast_indices(shape, A, Bs...)
Expand Down
15 changes: 10 additions & 5 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ end
broadcast(f::Tf, A::SparseVector) where {Tf} = _noshapecheck_map(f, A)
broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A)

function broadcast!(f::Tf, C::SparseVecOrMat, ::Void) where Tf
@inline function broadcast!(f::Tf, C::SparseVecOrMat, ::BroadcastStyle) where Tf
isempty(C) && return _finishempty!(C)
fofnoargs = f()
if _iszero(fofnoargs) # f() is zero, so empty C
Expand All @@ -107,11 +107,16 @@ function broadcast!(f::Tf, C::SparseVecOrMat, ::Void) where Tf
end
return C
end
function broadcast!(f, dest::SparseVecOrMat, ::Void, A, Bs::Vararg{Any,N}) where N
if isa(f, typeof(identity)) && N == 0 && isa(A, Number)
return fill!(dest, A)
@inline function broadcast!(f::Tf, dest::SparseVecOrMat, style::BroadcastStyle, As::Vararg{Any,N}) where {Tf,N}
if f isa typeof(identity) && N == 1
A = As[1]
if A isa Number
return fill!(dest, A)
elseif A isa AbstractArray && Base.axes(dest) == Base.axes(A)
return copy!(dest, A)
end
end
return spbroadcast_args!(f, dest, Broadcast.combine_styles(A, Bs...), A, Bs...)
return spbroadcast_args!(f, dest, style, As...)
end

# the following three similar defs are necessary for type stability in the mixed vector/matrix case
Expand Down

0 comments on commit 35638c7

Please sign in to comment.