diff --git a/base/broadcast.jl b/base/broadcast.jl new file mode 100644 index 0000000000000..caa5a2829ea31 --- /dev/null +++ b/base/broadcast.jl @@ -0,0 +1,245 @@ +module Broadcast + +using ..Meta.quot +import Base.(.+), Base.(.-), Base.(.*), Base.(./) +export broadcast, broadcast!, broadcast_function, broadcast!_function +export broadcast_getindex, broadcast_setindex! + + +## Broadcasting utilities ## + +# Calculate the broadcast shape of the arguments, or error if incompatible +broadcast_shape() = () +function broadcast_shape(As::AbstractArray...) + nd = ndims(As[1]) + for i = 2:length(As) + nd = max(nd, ndims(As[i])) + end + bshape = ones(Int, nd) + for A in As + for d = 1:ndims(A) + n = size(A, d) + if n != 1 + if bshape[d] == 1 + bshape[d] = n + elseif bshape[d] != n + error("arrays cannot be broadcast to a common size") + end + end + end + end + return tuple(bshape...) +end + +# Check that all arguments are broadcast compatible with shape +function check_broadcast_shape(shape::Dims, As::AbstractArray...) + for A in As + if ndims(A) > length(shape) + error("cannot broadcast array to have fewer dimensions") + end + for k in 1:ndims(A) + n, nA = shape[k], size(A, k) + if n != nA != 1 + error("array cannot be broadcast to match destination") + end + end + end +end + +# Calculate strides as will be used by the generated inner loops +function calc_loop_strides(shape::Dims, As::AbstractArray...) + # squeeze out singleton dimensions in shape + dims = Array(Int, 0) + loopshape = Array(Int, 0) + nd = length(shape) + sizehint(dims, nd) + sizehint(loopshape, nd) + for i = 1:nd + s = shape[i] + if s != 1 + push!(dims, i) + push!(loopshape, s) + end + end + nd = length(loopshape) + + strides = [(size(A, d) > 1 ? stride(A, d) : 0) for A in As, d in dims] + # convert from regular strides to loop strides + for k=(nd-1):-1:1, a=1:length(As) + strides[a, k+1] -= strides[a, k]*loopshape[k] + end + + tuple(loopshape...), strides +end + +function broadcast_args(shape::Dims, As::(Array...)) + loopshape, strides = calc_loop_strides(shape, As...) + (loopshape, As, ones(Int, length(As)), strides) +end +function broadcast_args(shape::Dims, As::(StridedArray...)) + loopshape, strides = calc_loop_strides(shape, As...) + nA = length(As) + offs = Array(Int, nA) + baseAs = Array(Array, nA) + for (k, A) in enumerate(As) + offs[k],baseAs[k] = isa(A,SubArray) ? (A.first_index,A.parent) : (1,A) + end + (loopshape, tuple(baseAs...), offs, strides) +end + + +## Generation of inner loop instances ## + +function code_inner_loop(fname::Symbol, extra_args::Vector, initial, + innermost::Function, narrays::Int, nd::Int) + Asyms = [gensym("A$a") for a=1:narrays] + indsyms = [gensym("k$a") for a=1:narrays] + axissyms = [gensym("i$d") for d=1:nd] + sizesyms = [gensym("n$d") for d=1:nd] + stridesyms = [gensym("s$(a)_$d") for a=1:narrays, d=1:nd] + + loop = innermost([:($arr[$ind]) for (arr, ind) in zip(Asyms, indsyms)]...) + for (d, (axis, n)) in enumerate(zip(axissyms, sizesyms)) + loop = :( + for $axis=1:$n + $loop + $([:($ind += $(stridesyms[a, d])) + for (a, ind) in enumerate(indsyms)]...) + end + ) + end + + @gensym shape arrays offsets strides + quote + function $fname($shape::NTuple{$nd, Int}, + $arrays::NTuple{$narrays, StridedArray}, + $offsets::Vector{Int}, + $strides::Matrix{Int}, $(extra_args...)) + @assert size($strides) == ($narrays, $nd) + ($(sizesyms...),) = $shape + $([:(if $n==0; return; end) for n in sizesyms]...) + ($(Asyms...), ) = $arrays + ($(stridesyms...),) = $strides + ($(indsyms...), ) = $offsets + $initial + $loop + end + end +end + + +## Generation of inner loop staged functions ## + +function code_inner(fname::Symbol, extra_args::Vector, initial, + innermost::Function) + quote + function $fname(shape::(Int...), arrays::(StridedArray...), + offsets::Vector{Int}, strides::Matrix{Int}, + $(extra_args...)) + f = eval(code_inner_loop($(quot(fname)), $(quot(extra_args)), + $(quot(initial)), $(quot(innermost)), + length(arrays), length(shape))) + f(shape, arrays, offsets, strides, $(extra_args...)) + end + end +end + +code_foreach_inner(fname::Symbol, extra_args::Vector, innermost::Function) = + code_inner(fname, extra_args, quote end, innermost) + +function code_map!_inner(fname::Symbol, dest, extra_args::Vector, + innermost::Function) + @gensym k + code_inner(fname, {dest, extra_args...}, :($k=1), + (els...)->quote + $dest[$k] = $(innermost(:($dest[$k]), els...)) + $k += 1 + end) +end + + +## (Generation of) complete broadcast functions ## + +function code_broadcast(fname::Symbol, op) + inner! = gensym("$(fname)_inner!") + innerdef = code_map!_inner(inner!, :(result::Array), [], + (dest, els...) -> :( $op($(els...)) )) + quote + $innerdef + $fname() = $op() + function $fname(As::StridedArray...) + shape = broadcast_shape(As...) + result = Array(promote_type([eltype(A) for A in As]...), shape) + $inner!(broadcast_args(shape, As)..., result) + result + end + end +end + +function code_broadcast!(fname::Symbol, op) + inner! = gensym("$(fname)!_inner!") + innerdef = code_foreach_inner(inner!, [], + (dest, els...) -> :( $dest=$op($(els...)) )) + quote + $innerdef + function $fname(dest::StridedArray, As::StridedArray...) + shape = size(dest) + check_broadcast_shape(shape, As...) + $inner!(broadcast_args(shape, tuple(dest, As...))...) + dest + end + end +end + +eval(code_map!_inner(:broadcast_getindex_inner!, + :(result::Array), [:(A::AbstractArray)], + (dest, inds...) -> :( A[$(inds...)] ))) +function broadcast_getindex(A::AbstractArray, + ind1::StridedArray{Int}, + inds::StridedArray{Int}...) + inds = tuple(ind1, inds...) + shape = broadcast_shape(inds...) + result = Array(eltype(A), shape) + broadcast_getindex_inner!(broadcast_args(shape, inds)..., result, A) + result +end + +eval(code_foreach_inner(:broadcast_setindex!_inner!, [:(A::AbstractArray)], + (x, inds...)->:( A[$(inds...)] = $x ))) +function broadcast_setindex!(A::AbstractArray, X::StridedArray, + ind1::StridedArray{Int}, + inds::StridedArray{Int}...) + Xinds = tuple(X, ind1, inds...) + shape = broadcast_shape(Xinds...) + broadcast_setindex!_inner!(broadcast_args(shape, Xinds)..., A) + Xinds[1] +end + + +## actual functions for broadcast and broadcast! ## + +for (fname, op) in {(:.+, +), (:.-, -), (:.*, *), (:./, /)} + eval(code_broadcast(fname, quot(op))) +end + +broadcastfuns = (Function=>Function)[] +function broadcast_function(op::Function) + (haskey(broadcastfuns, op) ? broadcastfuns[op] : + (broadcastfuns[op] = eval(code_broadcast(gensym("broadcast_$(op)"), + quot(op))))) +end +broadcast(op::Function) = op() +broadcast(op::Function, As::StridedArray...) = broadcast_function(op)(As...) + +broadcast!funs = (Function=>Function)[] +function broadcast!_function(op::Function) + (haskey(broadcast!funs, op) ? broadcast!funs[op] : + (broadcast!funs[op] = eval(code_broadcast!(gensym("broadcast!_$(op)"), + quot(op))))) +end +function broadcast!(op::Function, dest::StridedArray, As::StridedArray...) + broadcast!_function(op)(dest, As...) +end + + +end # module diff --git a/base/exports.jl b/base/exports.jl index 8fcbe26c75f9a..5ca1524cdcab6 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -469,7 +469,12 @@ export # arrays mapslices, reducedim, - bsxfun, + broadcast, + broadcast!, + broadcast_function, + broadcast!_function, + broadcast_getindex, + broadcast_setindex!, cartesianmap, cat, cell, diff --git a/base/sysimg.jl b/base/sysimg.jl index 3d93508ea2d8c..871ccb0c87916 100644 --- a/base/sysimg.jl +++ b/base/sysimg.jl @@ -164,6 +164,8 @@ push!(I18n.CALLBACKS, Help.clear_cache) include("sparse.jl") include("linalg.jl") importall .LinAlg +include("broadcast.jl") +importall .Broadcast # signal processing include("fftw.jl") diff --git a/doc/manual/arrays.rst b/doc/manual/arrays.rst index 878764695c6e7..f4a6da9a2c8ab 100644 --- a/doc/manual/arrays.rst +++ b/doc/manual/arrays.rst @@ -329,13 +329,13 @@ vector to the size of the matrix:: 0.848333 1.66714 1.3262 1.26743 1.77988 1.13859 -This is wasteful when dimensions get large, so Julia offers the -MATLAB-inspired ``bsxfun``, which expands singleton dimensions in +This is wasteful when dimensions get large, so Julia offers +``broadcast``, which expands singleton dimensions in array arguments to match the corresponding dimension in the other -array without using extra memory, and applies the given binary -function:: +array without using extra memory, and applies the given +function elementwise:: - julia> bsxfun(+, a, A) + julia> broadcast(+, a, A) 2x3 Float64 Array: 0.848333 1.66714 1.3262 1.26743 1.77988 1.13859 @@ -344,11 +344,13 @@ function:: 1x2 Float64 Array: 0.629799 0.754948 - julia> bsxfun(+, a, b) + julia> broadcast(+, a, b) 2x2 Float64 Array: 1.31849 1.44364 1.56107 1.68622 +Elementwise operators such as ``.+`` and ``.*`` perform broadcasting if necessary. There is also a ``broadcast!`` function to specify an explicit destination, and ``broadcast_getindex`` and ``broadcast_setindex!`` that broadcast the indices before indexing. + Implementation -------------- diff --git a/doc/stdlib/base.rst b/doc/stdlib/base.rst index 797e30d75fb6a..b58d1c9f7ac8b 100644 --- a/doc/stdlib/base.rst +++ b/doc/stdlib/base.rst @@ -2225,28 +2225,48 @@ Mathematical operators and functions All mathematical operations and functions are supported for arrays -.. function:: bsxfun(fn, A, B[, C...]) +.. function:: broadcast(f, As...) - Apply binary function ``fn`` to two or more arrays, with singleton dimensions expanded. + Broadcasts the arrays ``As`` to a common size by expanding singleton dimensions, and returns an array of the results ``f(as...)`` for each position. + +.. function:: broadcast!(f, dest, As...) + + Like ``broadcast``, but store the result in the ``dest`` array. + +.. function:: broadcast_function(f) + + Returns a function ``broadcast_f`` such that ``broadcast_function(f)(As...) === broadcast(f, As...)``. Most useful in the form ``const broadcast_f = broadcast_function(f)``. + +.. function:: broadcast!_function(f) + + Like ``broadcast_function``, but for ``broadcast!``. Indexing, Assignment, and Concatenation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. function:: getindex(A, ind) +.. function:: getindex(A, inds...) - Returns a subset of array ``A`` as specified by ``ind``, which may be an ``Int``, a ``Range``, or a ``Vector``. + Returns a subset of array ``A`` as specified by ``inds``, where each ``ind`` may be an ``Int``, a ``Range``, or a ``Vector``. -.. function:: sub(A, ind) +.. function:: sub(A, inds...) - Returns a SubArray, which stores the input ``A`` and ``ind`` rather than computing the result immediately. Calling ``getindex`` on a SubArray computes the indices on the fly. + Returns a SubArray, which stores the input ``A`` and ``inds`` rather than computing the result immediately. Calling ``getindex`` on a SubArray computes the indices on the fly. .. function:: slicedim(A, d, i) Return all the data of ``A`` where the index for dimension ``d`` equals ``i``. Equivalent to ``A[:,:,...,i,:,:,...]`` where ``i`` is in position ``d``. -.. function:: setindex!(A, X, ind) +.. function:: setindex!(A, X, inds...) + + Store values from array ``X`` within some subset of ``A`` as specified by ``inds``. + +.. function:: broadcast_getindex(A, inds...) + + Broadcasts the ``inds`` arrays to a common size like ``broadcast``, and returns an array of the results ``A[ks...]``, where ``ks`` goes over the positions in the broadcast. + +.. function:: broadcast_setindex!(A, X, inds...) - Store values from array ``X`` within some subset of ``A`` as specified by ``ind``. + Broadcasts the ``X`` and ``inds`` arrays to a common size and stores the value from each position in ``X`` at the indices given by the same positions in ``inds``. .. function:: cat(dim, A...) diff --git a/test/Makefile b/test/Makefile index f581618225228..af19f57f4a4b1 100644 --- a/test/Makefile +++ b/test/Makefile @@ -5,7 +5,7 @@ TESTS = all core keywordargs numbers strings unicode corelib hashing \ remote iostring arrayops linalg blas fft dsp sparse bitarray \ random math functional bigint sorting statistics spawn parallel \ arpack bigfloat file git pkg pkg2 suitesparse complex version \ - pollfd mpfr + pollfd mpfr broadcast $(TESTS) :: $(QUIET_JULIA) $(call spawn,$(JULIA_EXECUTABLE)) ./runtests.jl $@ diff --git a/test/broadcast.jl b/test/broadcast.jl new file mode 100644 index 0000000000000..b0706806579d9 --- /dev/null +++ b/test/broadcast.jl @@ -0,0 +1,45 @@ +function as_sub(x::Vector) + y = Array(eltype(x), length(x)*2) + y = sub(y, 2:2:length(y)) + y[:] = x[:] + y +end +function as_sub(x::Matrix) + y = Array(eltype(x), tuple(([size(x)...]*2)...)) + y = sub(y, 2:2:size(y,1), 2:2:size(y,2)) + for j=1:size(x,2) + for i=1:size(x,1) + y[i,j] = x[i,j] + end + end + y +end + +for arr in (identity, as_sub) + @test broadcast(+, arr(eye(2)), arr([1, 4])) == [2 1; 4 5] + @test broadcast(+, arr(eye(2)), arr([1 4])) == [2 4; 1 5] + @test broadcast(+, arr([1 0]), arr([1, 4])) == [2 1; 5 4] + @test broadcast(+, arr([1, 0]), arr([1 4])) == [2 5; 1 4] + @test broadcast(+, arr([1, 0]), arr([1, 4])) == [2, 4] + @test broadcast(+) == 0 + @test broadcast(*) == 1 + + @test arr(eye(2)) .+ arr([1, 4]) == arr([2 1; 4 5]) + @test arr(eye(2)) .+ arr([1 4]) == arr([2 4; 1 5]) + @test arr([1 0]) .+ arr([1, 4]) == arr([2 1; 5 4]) + @test arr([1, 0]) .+ arr([1 4]) == arr([2 5; 1 4]) + @test arr([1, 0]) .+ arr([1, 4]) == arr([2, 4]) + @test arr([1]) .+ arr([]) == arr([]) + + A = arr(eye(2)); @test broadcast!(+, A, A, arr([1, 4])) == arr([2 1; 4 5]) + A = arr(eye(2)); @test broadcast!(+, A, A, arr([1 4])) == arr([2 4; 1 5]) + A = arr([1 0]); @test_fails broadcast!(+, A, A, arr([1, 4])) + A = arr([1 0]); @test broadcast!(+, A, A, arr([1 4])) == arr([2 4]) + + M = arr([11 12; 21 22]) + @test broadcast_getindex(M, eye(Int, 2)+1,arr([1, 2])) == [21 11; 12 22] + + A = arr(zeros(2,2)) + broadcast_setindex!(A, arr([21 11; 12 22]), eye(Int, 2)+1,arr([1, 2])) + @test A == M +end diff --git a/test/runtests.jl b/test/runtests.jl index ed09fbcfa061e..c83c277c7e0a8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ testnames = ["core", "keywordargs", "numbers", "strings", "unicode", "random", "math", "functional", "bigint", "sorting", "statistics", "spawn", "parallel", "priorityqueue", "arpack", "file", "perf", "suitesparse", "version", - "pollfd", "mpfr"] + "pollfd", "mpfr", "broadcast"] # Disabled: "complex"