Skip to content

Commit

Permalink
Merge pull request #3100 from toivoh/staged_bsxfun
Browse files Browse the repository at this point in the history
Staged bsxfun and other broadcast operations
  • Loading branch information
StefanKarpinski committed May 21, 2013
2 parents 2639665 + ca90e4b commit 9718d26
Show file tree
Hide file tree
Showing 8 changed files with 336 additions and 17 deletions.
245 changes: 245 additions & 0 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
module Broadcast

using ..Meta.quot
import Base.(.+), Base.(.-), Base.(.*), Base.(./)
export broadcast, broadcast!, broadcast_function, broadcast!_function
export broadcast_getindex, broadcast_setindex!


## Broadcasting utilities ##

# Calculate the broadcast shape of the arguments, or error if incompatible
broadcast_shape() = ()
function broadcast_shape(As::AbstractArray...)
nd = ndims(As[1])
for i = 2:length(As)
nd = max(nd, ndims(As[i]))
end
bshape = ones(Int, nd)
for A in As
for d = 1:ndims(A)
n = size(A, d)
if n != 1
if bshape[d] == 1
bshape[d] = n
elseif bshape[d] != n
error("arrays cannot be broadcast to a common size")
end
end
end
end
return tuple(bshape...)
end

# Check that all arguments are broadcast compatible with shape
function check_broadcast_shape(shape::Dims, As::AbstractArray...)
for A in As
if ndims(A) > length(shape)
error("cannot broadcast array to have fewer dimensions")
end
for k in 1:ndims(A)
n, nA = shape[k], size(A, k)
if n != nA != 1
error("array cannot be broadcast to match destination")
end
end
end
end

# Calculate strides as will be used by the generated inner loops
function calc_loop_strides(shape::Dims, As::AbstractArray...)
# squeeze out singleton dimensions in shape
dims = Array(Int, 0)
loopshape = Array(Int, 0)
nd = length(shape)
sizehint(dims, nd)
sizehint(loopshape, nd)
for i = 1:nd
s = shape[i]
if s != 1
push!(dims, i)
push!(loopshape, s)
end
end
nd = length(loopshape)

strides = [(size(A, d) > 1 ? stride(A, d) : 0) for A in As, d in dims]
# convert from regular strides to loop strides
for k=(nd-1):-1:1, a=1:length(As)
strides[a, k+1] -= strides[a, k]*loopshape[k]
end

tuple(loopshape...), strides
end

function broadcast_args(shape::Dims, As::(Array...))
loopshape, strides = calc_loop_strides(shape, As...)
(loopshape, As, ones(Int, length(As)), strides)
end
function broadcast_args(shape::Dims, As::(StridedArray...))
loopshape, strides = calc_loop_strides(shape, As...)
nA = length(As)
offs = Array(Int, nA)
baseAs = Array(Array, nA)
for (k, A) in enumerate(As)
offs[k],baseAs[k] = isa(A,SubArray) ? (A.first_index,A.parent) : (1,A)
end
(loopshape, tuple(baseAs...), offs, strides)
end


## Generation of inner loop instances ##

function code_inner_loop(fname::Symbol, extra_args::Vector, initial,
innermost::Function, narrays::Int, nd::Int)
Asyms = [gensym("A$a") for a=1:narrays]
indsyms = [gensym("k$a") for a=1:narrays]
axissyms = [gensym("i$d") for d=1:nd]
sizesyms = [gensym("n$d") for d=1:nd]
stridesyms = [gensym("s$(a)_$d") for a=1:narrays, d=1:nd]

loop = innermost([:($arr[$ind]) for (arr, ind) in zip(Asyms, indsyms)]...)
for (d, (axis, n)) in enumerate(zip(axissyms, sizesyms))
loop = :(
for $axis=1:$n
$loop
$([:($ind += $(stridesyms[a, d]))
for (a, ind) in enumerate(indsyms)]...)
end
)
end

@gensym shape arrays offsets strides
quote
function $fname($shape::NTuple{$nd, Int},
$arrays::NTuple{$narrays, StridedArray},
$offsets::Vector{Int},
$strides::Matrix{Int}, $(extra_args...))
@assert size($strides) == ($narrays, $nd)
($(sizesyms...),) = $shape
$([:(if $n==0; return; end) for n in sizesyms]...)
($(Asyms...), ) = $arrays
($(stridesyms...),) = $strides
($(indsyms...), ) = $offsets
$initial
$loop
end
end
end


## Generation of inner loop staged functions ##

function code_inner(fname::Symbol, extra_args::Vector, initial,
innermost::Function)
quote
function $fname(shape::(Int...), arrays::(StridedArray...),
offsets::Vector{Int}, strides::Matrix{Int},
$(extra_args...))
f = eval(code_inner_loop($(quot(fname)), $(quot(extra_args)),
$(quot(initial)), $(quot(innermost)),
length(arrays), length(shape)))
f(shape, arrays, offsets, strides, $(extra_args...))
end
end
end

code_foreach_inner(fname::Symbol, extra_args::Vector, innermost::Function) =
code_inner(fname, extra_args, quote end, innermost)

function code_map!_inner(fname::Symbol, dest, extra_args::Vector,
innermost::Function)
@gensym k
code_inner(fname, {dest, extra_args...}, :($k=1),
(els...)->quote
$dest[$k] = $(innermost(:($dest[$k]), els...))
$k += 1
end)
end


## (Generation of) complete broadcast functions ##

function code_broadcast(fname::Symbol, op)
inner! = gensym("$(fname)_inner!")
innerdef = code_map!_inner(inner!, :(result::Array), [],
(dest, els...) -> :( $op($(els...)) ))
quote
$innerdef
$fname() = $op()
function $fname(As::StridedArray...)
shape = broadcast_shape(As...)
result = Array(promote_type([eltype(A) for A in As]...), shape)
$inner!(broadcast_args(shape, As)..., result)
result
end
end
end

function code_broadcast!(fname::Symbol, op)
inner! = gensym("$(fname)!_inner!")
innerdef = code_foreach_inner(inner!, [],
(dest, els...) -> :( $dest=$op($(els...)) ))
quote
$innerdef
function $fname(dest::StridedArray, As::StridedArray...)
shape = size(dest)
check_broadcast_shape(shape, As...)
$inner!(broadcast_args(shape, tuple(dest, As...))...)
dest
end
end
end

eval(code_map!_inner(:broadcast_getindex_inner!,
:(result::Array), [:(A::AbstractArray)],
(dest, inds...) -> :( A[$(inds...)] )))
function broadcast_getindex(A::AbstractArray,
ind1::StridedArray{Int},
inds::StridedArray{Int}...)
inds = tuple(ind1, inds...)
shape = broadcast_shape(inds...)
result = Array(eltype(A), shape)
broadcast_getindex_inner!(broadcast_args(shape, inds)..., result, A)
result
end

eval(code_foreach_inner(:broadcast_setindex!_inner!, [:(A::AbstractArray)],
(x, inds...)->:( A[$(inds...)] = $x )))
function broadcast_setindex!(A::AbstractArray, X::StridedArray,
ind1::StridedArray{Int},
inds::StridedArray{Int}...)
Xinds = tuple(X, ind1, inds...)
shape = broadcast_shape(Xinds...)
broadcast_setindex!_inner!(broadcast_args(shape, Xinds)..., A)
Xinds[1]
end


## actual functions for broadcast and broadcast! ##

for (fname, op) in {(:.+, +), (:.-, -), (:.*, *), (:./, /)}
eval(code_broadcast(fname, quot(op)))
end

broadcastfuns = (Function=>Function)[]
function broadcast_function(op::Function)
(haskey(broadcastfuns, op) ? broadcastfuns[op] :
(broadcastfuns[op] = eval(code_broadcast(gensym("broadcast_$(op)"),
quot(op)))))
end
broadcast(op::Function) = op()
broadcast(op::Function, As::StridedArray...) = broadcast_function(op)(As...)

broadcast!funs = (Function=>Function)[]
function broadcast!_function(op::Function)
(haskey(broadcast!funs, op) ? broadcast!funs[op] :
(broadcast!funs[op] = eval(code_broadcast!(gensym("broadcast!_$(op)"),
quot(op)))))
end
function broadcast!(op::Function, dest::StridedArray, As::StridedArray...)
broadcast!_function(op)(dest, As...)
end


end # module
7 changes: 6 additions & 1 deletion base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,12 @@ export
# arrays
mapslices,
reducedim,
bsxfun,
broadcast,
broadcast!,
broadcast_function,
broadcast!_function,
broadcast_getindex,
broadcast_setindex!,
cartesianmap,
cat,
cell,
Expand Down
2 changes: 2 additions & 0 deletions base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ push!(I18n.CALLBACKS, Help.clear_cache)
include("sparse.jl")
include("linalg.jl")
importall .LinAlg
include("broadcast.jl")
importall .Broadcast

# signal processing
include("fftw.jl")
Expand Down
14 changes: 8 additions & 6 deletions doc/manual/arrays.rst
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,13 @@ vector to the size of the matrix::
0.848333 1.66714 1.3262
1.26743 1.77988 1.13859

This is wasteful when dimensions get large, so Julia offers the
MATLAB-inspired ``bsxfun``, which expands singleton dimensions in
This is wasteful when dimensions get large, so Julia offers
``broadcast``, which expands singleton dimensions in
array arguments to match the corresponding dimension in the other
array without using extra memory, and applies the given binary
function::
array without using extra memory, and applies the given
function elementwise::

julia> bsxfun(+, a, A)
julia> broadcast(+, a, A)
2x3 Float64 Array:
0.848333 1.66714 1.3262
1.26743 1.77988 1.13859
Expand All @@ -344,11 +344,13 @@ function::
1x2 Float64 Array:
0.629799 0.754948

julia> bsxfun(+, a, b)
julia> broadcast(+, a, b)
2x2 Float64 Array:
1.31849 1.44364
1.56107 1.68622

Elementwise operators such as ``.+`` and ``.*`` perform broadcasting if necessary. There is also a ``broadcast!`` function to specify an explicit destination, and ``broadcast_getindex`` and ``broadcast_setindex!`` that broadcast the indices before indexing.

Implementation
--------------

Expand Down
36 changes: 28 additions & 8 deletions doc/stdlib/base.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2233,28 +2233,48 @@ Mathematical operators and functions

All mathematical operations and functions are supported for arrays

.. function:: bsxfun(fn, A, B[, C...])
.. function:: broadcast(f, As...)

Apply binary function ``fn`` to two or more arrays, with singleton dimensions expanded.
Broadcasts the arrays ``As`` to a common size by expanding singleton dimensions, and returns an array of the results ``f(as...)`` for each position.

.. function:: broadcast!(f, dest, As...)

Like ``broadcast``, but store the result in the ``dest`` array.

.. function:: broadcast_function(f)

Returns a function ``broadcast_f`` such that ``broadcast_function(f)(As...) === broadcast(f, As...)``. Most useful in the form ``const broadcast_f = broadcast_function(f)``.

.. function:: broadcast!_function(f)

Like ``broadcast_function``, but for ``broadcast!``.

Indexing, Assignment, and Concatenation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. function:: getindex(A, ind)
.. function:: getindex(A, inds...)

Returns a subset of array ``A`` as specified by ``ind``, which may be an ``Int``, a ``Range``, or a ``Vector``.
Returns a subset of array ``A`` as specified by ``inds``, where each ``ind`` may be an ``Int``, a ``Range``, or a ``Vector``.

.. function:: sub(A, ind)
.. function:: sub(A, inds...)

Returns a SubArray, which stores the input ``A`` and ``ind`` rather than computing the result immediately. Calling ``getindex`` on a SubArray computes the indices on the fly.
Returns a SubArray, which stores the input ``A`` and ``inds`` rather than computing the result immediately. Calling ``getindex`` on a SubArray computes the indices on the fly.

.. function:: slicedim(A, d, i)

Return all the data of ``A`` where the index for dimension ``d`` equals ``i``. Equivalent to ``A[:,:,...,i,:,:,...]`` where ``i`` is in position ``d``.

.. function:: setindex!(A, X, ind)
.. function:: setindex!(A, X, inds...)

Store values from array ``X`` within some subset of ``A`` as specified by ``inds``.

.. function:: broadcast_getindex(A, inds...)

Broadcasts the ``inds`` arrays to a common size like ``broadcast``, and returns an array of the results ``A[ks...]``, where ``ks`` goes over the positions in the broadcast.

.. function:: broadcast_setindex!(A, X, inds...)

Store values from array ``X`` within some subset of ``A`` as specified by ``ind``.
Broadcasts the ``X`` and ``inds`` arrays to a common size and stores the value from each position in ``X`` at the indices given by the same positions in ``inds``.

.. function:: cat(dim, A...)

Expand Down
2 changes: 1 addition & 1 deletion test/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ TESTS = all core keywordargs numbers strings unicode corelib hashing \
remote iostring arrayops linalg blas fft dsp sparse bitarray \
random math functional bigint sorting statistics spawn parallel \
arpack bigfloat file git pkg pkg2 suitesparse complex version \
pollfd mpfr
pollfd mpfr broadcast

$(TESTS) ::
$(QUIET_JULIA) $(call spawn,$(JULIA_EXECUTABLE)) ./runtests.jl $@
Expand Down
Loading

0 comments on commit 9718d26

Please sign in to comment.