Skip to content

Commit

Permalink
Merge pull request JuliaLang#94 from JuliaStats/sjk/broadcast
Browse files Browse the repository at this point in the history
Implement broadcasting for DataArrays/PooledDataArrays
  • Loading branch information
simonster committed Jun 16, 2014
2 parents 791040b + 661e219 commit 3baae6b
Show file tree
Hide file tree
Showing 7 changed files with 458 additions and 14 deletions.
4 changes: 3 additions & 1 deletion benchmark/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ macro perf(fn, replications, idx...)
df = compare([()->$fn for i=$idx], $replications)
gc_enable()
gc()
df["Function"] = TEST_NAMES[$idx]
df[:Function] = TEST_NAMES[$idx]
df[:Relative] = df[:Average]./df[1, :Average]
println(df)
end
end
Expand All @@ -61,6 +62,7 @@ const Bool2 = make_test_types(make_bools, 1000)
@perf isequal(Float1[i], Float2[i]) 10000
@perf .==(Float1[i], Float2[i]) 100
@perf +(Float1[i], Float2[i]) 100
@perf .+(Float1[i], Float2[i]) 100
@perf .*(Float1[i], Float2[i]) 100
@perf ./(Float1[i], Float2[i]) 50
@perf *(Float1[i], Float2[i]) 10 div(length(Float1), 2)+1:length(Float1)
Expand Down
1 change: 1 addition & 0 deletions src/DataArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ module DataArrays
include("datamatrix.jl")
include("linalg.jl")
include("operators.jl")
include("broadcast.jl")
include("extras.jl")
include("grouping.jl")
include("statistics.jl")
Expand Down
321 changes: 321 additions & 0 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
using DataArrays, Base.Cartesian, Base.@get!
using Base.Broadcast: bitcache_chunks, bitcache_size, dumpbitcache,
promote_eltype, broadcast_shape, eltype_plus, type_minus, type_div,
type_pow

# Check that all arguments are broadcast compatible with shape
# Differs from Base in that we check for exact matches
function check_broadcast_shape(shape::Dims, As::Union(AbstractArray,Number)...)
samesize = true
for A in As
if ndims(A) > length(shape)
error("cannot broadcast array to have fewer dimensions")
end
for k in 1:length(shape)
n, nA = shape[k], size(A, k)
samesize &= (n == nA)
if n != nA != 1
error("array could not be broadcast to match destination")
end
end
end
samesize
end

# Set na or data portion of DataArray
_unsafe_dasetindex!(data, na_chunks, val::NAtype, idx::Int) =
na_chunks[Base.@_div64(int(idx)-1)+1] |= (uint64(1) << Base.@_mod64(int(idx)-1))
_unsafe_dasetindex!(data, na_chunks, val, idx::Int) = setindex!(data, val, idx)

# Get ref for value for a PooledDataArray, adding to the pool if
# necessary
_unsafe_pdaref!(Bpool, Brefdict::Dict, val::NAtype) = 0
function _unsafe_pdaref!{K,V}(Bpool, Brefdict::Dict{K,V}, val)
@get! Brefdict val begin
push!(Bpool, val)
convert(V, length(Bpool))
end
end

# Generate a branch for each possible combination of NA/not NA. This
# gives good performance at the cost of 2^narrays branches.
function gen_na_conds(f, nd, arrtype, outtype, daidx=find([arrtype...] .!= AbstractArray), pos=1, isna=())
if pos > length(daidx)
args = Any[symbol("v_$(k)") for k = 1:length(arrtype)]
for i = 1:length(daidx)
if isna[i]
args[daidx[i]] = NA
end
end

# Needs to be gensymmed so that the compiler won't box it
val = gensym("val")
quote
$val = $(Expr(:call, f, args...))
$(if outtype == DataArray
:(@inbounds _unsafe_dasetindex!(Bdata, Bc, $val, ind))
elseif outtype == PooledDataArray
:(@inbounds (@nref $nd Brefs i) = _unsafe_pdaref!(Bpool, Brefdict, $val))
end)
end
else
k = daidx[pos]
quote
if $(symbol("isna_$(k)"))
$(gen_na_conds(f, nd, arrtype, outtype, daidx, pos+1, tuple(isna..., true)))
else
$(if arrtype[k] == DataArray
:(@inbounds $(symbol("v_$(k)")) = $(symbol("data_$(k)"))[$(symbol("state_$(k)_0"))])
else
:(@inbounds $(symbol("v_$(k)")) = $(symbol("pool_$(k)"))[$(symbol("r_$(k)"))])
end)
$(gen_na_conds(f, nd, arrtype, outtype, daidx, pos+1, tuple(isna..., false)))
end
end
end
end

# Broadcast implementation for DataArrays
#
# TODO: Fall back on faster implementation for same-sized inputs when
# it is safe to do so.
function gen_broadcast_dataarray(nd::Int, arrtype::(DataType...), outtype, f::Function)
F = Expr(:quote, f)
narrays = length(arrtype)
As = [symbol("A_$(i)") for i = 1:narrays]
dataarrays = find([arrtype...] .== DataArray)
abstractdataarrays = find([arrtype...] .!= AbstractArray)
have_fastpath = outtype == DataArray && all(x->!(x <: PooledDataArray), arrtype)

@eval begin
local _F_
function _F_(B::$(outtype), $(As...))
@assert ndims(B) == $nd

# Set up input DataArray/PooledDataArrays
$(Expr(:block, [
arrtype[k] == DataArray ? quote
$(symbol("na_$(k)")) = $(symbol("A_$(k)")).na.chunks
$(symbol("data_$(k)")) = $(symbol("A_$(k)")).data
$(symbol("state_$(k)_0")) = $(symbol("state_$(k)_$(nd)")) = 1
@nexprs $nd d->($(symbol("skip_$(k)_d")) = size($(symbol("data_$(k)")), d) == 1)
end : arrtype[k] == PooledDataArray ? quote
$(symbol("refs_$(k)")) = $(symbol("A_$(k)")).refs
$(symbol("pool_$(k)")) = $(symbol("A_$(k)")).pool
end : nothing
for k = 1:narrays]...))

# Set up output DataArray/PooledDataArray
$(if outtype == DataArray
quote
Bc = B.na.chunks
fill!(Bc, 0)
Bdata = B.data
ind = 1
end
elseif outtype == PooledDataArray
quote
Bpool = B.pool = similar(B.pool, 0)
Brefs = B.refs
Brefdict = Dict{eltype(Bpool),eltype(Brefs)}()
end
end)

@nloops($nd, i, $(outtype == DataArray ? (:Bdata) : (:Brefs)),
# pre
d->($(Expr(:block, [
arrtype[k] == DataArray ? quote
$(symbol("state_$(k)_")){d-1} = $(symbol("state_$(k)_d"));
$(symbol("j_$(k)_d")) = $(symbol("skip_$(k)_d")) ? 1 : i_d
end : quote
$(symbol("j_$(k)_d")) = size($(symbol("A_$(k)")), d) == 1 ? 1 : i_d
end
for k = 1:narrays]...))),

# post
d->($(Expr(:block, [quote
$(symbol("skip_$(k)_d")) || ($(symbol("state_$(k)_d")) = $(symbol("state_$(k)_0")))
end for k in dataarrays]...))),

# body
begin
# Advance iterators for DataArray and determine NA status
$(Expr(:block, [
arrtype[k] == DataArray ? quote
@inbounds $(symbol("isna_$(k)")) = Base.unsafe_bitgetindex($(symbol("na_$(k)")), $(symbol("state_$(k)_0")))
end : arrtype[k] == PooledDataArray ? quote
@inbounds $(symbol("r_$(k)")) = @nref $nd $(symbol("refs_$(k)")) d->$(symbol("j_$(k)_d"))
$(symbol("isna_$(k)")) = $(symbol("r_$(k)")) == 0
end : nothing
for k = 1:narrays]...))

# Extract values for ordinary AbstractArrays
$(Expr(:block, [
:(@inbounds $(symbol("v_$(k)")) = @nref $nd $(symbol("A_$(k)")) d->$(symbol("j_$(k)_d")))
for k = find([arrtype...] .== AbstractArray)]...))

# Compute and store return value
$(gen_na_conds(F, nd, arrtype, outtype))

# Increment state
$(Expr(:block, [:($(symbol("state_$(k)_0")) += 1) for k in dataarrays]...))
$(if outtype == DataArray
:(ind += 1)
end)
end)
end
_F_
end
end

datype(A_1::PooledDataArray, As...) = tuple(PooledDataArray, datype(As...)...)
datype(A_1::DataArray, As...) = tuple(DataArray, datype(As...)...)
datype(A_1, As...) = tuple(AbstractArray, datype(As...)...)
datype() = ()

datype_int(A_1::PooledDataArray, As...) = (uint64(2) | (datype_int(As...) << 2))
datype_int(A_1::DataArray, As...) = (uint64(1) | (datype_int(As...) << 2))
datype_int(A_1, As...) = (datype_int(As...) << 2)
datype_int() = uint64(0)

for bsig in (DataArray, PooledDataArray), asig in (Union(Array,BitArray,Number), Any)
@eval let cache = Dict{Function,Dict{Uint64,Dict{Int,Function}}}()
function Base.map!(f::Base.Callable, B::$bsig, As::$asig...)
nd = ndims(B)
length(As) <= 8 || error("too many arguments")
samesize = check_broadcast_shape(size(B), As...)
samesize || error("dimensions must match")
arrtype = datype_int(As...)

cache_f = @get! cache f Dict{Uint64,Dict{Int,Function}}()
cache_f_na = @get! cache_f arrtype Dict{Int,Function}()
func = @get! cache_f_na nd gen_broadcast_dataarray(nd, datype(As...), $bsig, f)

func(B, As...)
B
end
# ambiguity
Base.map!(f::Base.Callable, B::$bsig, r::Range) =
invoke(Base.map!, (Base.Callable, $bsig, $asig), f, B, r)
function Base.broadcast!(f::Function, B::$bsig, As::$asig...)
nd = ndims(B)
length(As) <= 8 || error("too many arguments")
samesize = check_broadcast_shape(size(B), As...)
arrtype = datype_int(As...)

cache_f = @get! cache f Dict{Uint64,Dict{Int,Function}}()
cache_f_na = @get! cache_f arrtype Dict{Int,Function}()
func = @get! cache_f_na nd gen_broadcast_dataarray(nd, datype(As...), $bsig, f)

# println(code_typed(func, typeof(tuple(B, As...))))
func(B, As...)
B
end
end
end

databroadcast(f::Function, As...) = broadcast!(f, DataArray(promote_eltype(As...), broadcast_shape(As...)), As...)
pdabroadcast(f::Function, As...) = broadcast!(f, PooledDataArray(promote_eltype(As...), broadcast_shape(As...)), As...)

function exreplace!(ex::Expr, search, rep)
for i = 1:length(ex.args)
if ex.args[i] == search
splice!(ex.args, i, rep)
break
else
exreplace!(ex.args[i], search, rep)
end
end
ex
end
exreplace!(ex, search, rep) = ex

macro da_broadcast_vararg(func)
if (func.head != :function && func.head != :(=)) ||
func.args[1].head != :call || !isa(func.args[1].args[end], Expr) ||
func.args[1].args[end].head != :...
error("@da_broadcast_vararg may only be applied to vararg functions")
end

va = func.args[1].args[end]
defs = {}
for n = 1:4, aa = 0:n-1
def = deepcopy(func)
rep = Any[symbol("A_$(i)") for i = 1:n]
push!(rep, va)
exreplace!(def.args[2], va, rep)
rep = cell(n+1)
for i = 1:aa
rep[i] = Expr(:(::), symbol("A_$i"), AbstractArray)
end
for i = aa+1:n
rep[i] = Expr(:(::), symbol("A_$i"), Union(DataArray, PooledDataArray))
end
rep[end] = Expr(:..., Expr(:(::), va.args[1], AbstractArray))
exreplace!(def.args[1], va, rep)
push!(defs, def)
end
esc(Expr(:block, defs...))
end

macro da_broadcast_binary(func)
if (func.head != :function && func.head != :(=)) ||
func.args[1].head != :call || length(func.args[1].args) != 3
error("@da_broadcast_binary may only be applied to two-argument functions")
end
(f, A, B) = func.args[1].args
body = func.args[2]
quote
$f($A::Union(DataArray, PooledDataArray), $B::Union(DataArray, PooledDataArray)) = $(body)
$f($A::Union(DataArray, PooledDataArray), $B::AbstractArray) = $(body)
$f($A::AbstractArray, $B::Union(DataArray, PooledDataArray)) = $(body)
end
end

# Broadcasting DataArrays returns a DataArray
@da_broadcast_vararg Base.broadcast(f::Function, As...) = databroadcast(f, As...)

# Definitions for operators,
Base.(:(.*))(A::BitArray, B::Union(DataArray{Bool}, PooledDataArray{Bool})) = databroadcast(*, A, B)
Base.(:(.*))(A::Union(DataArray{Bool}, PooledDataArray{Bool}), B::BitArray) = databroadcast(*, A, B)
@da_broadcast_vararg Base.(:(.*))(As...) = databroadcast(*, As...)
@da_broadcast_binary Base.(:(.%))(A, B) = databroadcast(%, A, B)
@da_broadcast_vararg Base.(:(.+))(As...) = broadcast!(+, DataArray(eltype_plus(As...), broadcast_shape(As...)), As...)
@da_broadcast_binary Base.(:(.-))(A, B) = broadcast!(-, DataArray(type_minus(eltype(A), eltype(B)), broadcast_shape(A,B)), A, B)
@da_broadcast_binary Base.(:(./))(A, B) = broadcast!(/, DataArray(type_div(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B)
@da_broadcast_binary Base.(:(.\))(A, B) = broadcast!(\, DataArray(type_div(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B)
Base.(:(.^))(A::Union(DataArray{Bool}, PooledDataArray{Bool}), B::Union(DataArray{Bool}, PooledDataArray{Bool})) = databroadcast(>=, A, B)
Base.(:(.^))(A::BitArray, B::Union(DataArray{Bool}, PooledDataArray{Bool})) = databroadcast(>=, A, B)
Base.(:(.^))(A::AbstractArray{Bool}, B::Union(DataArray{Bool}, PooledDataArray{Bool})) = databroadcast(>=, A, B)
Base.(:(.^))(A::Union(DataArray{Bool}, PooledDataArray{Bool}), B::BitArray) = databroadcast(>=, A, B)
Base.(:(.^))(A::Union(DataArray{Bool}, PooledDataArray{Bool}), B::AbstractArray{Bool}) = databroadcast(>=, A, B)
@da_broadcast_binary Base.(:(.^))(A, B) = broadcast!(^, DataArray(type_pow(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B)

# XXX is a PDA the right return type for these?
Base.broadcast(f::Function, As::PooledDataArray...) = pdabroadcast(f, As...)
Base.(:(.*))(As::PooledDataArray...) = pdabroadcast(*, As...)
Base.(:(.%))(A::PooledDataArray, B::PooledDataArray) = pdabroadcast(%, A, B)
Base.(:(.+))(As::PooledDataArray...) = broadcast!(+, PooledDataArray(eltype_plus(As...), broadcast_shape(As...)), As...)
Base.(:(.-))(A::PooledDataArray, B::PooledDataArray) =
broadcast!(-, PooledDataArray(type_minus(eltype(A), eltype(B)), broadcast_shape(A,B)), A, B)
Base.(:(./))(A::PooledDataArray, B::PooledDataArray) =
broadcast!(/, PooledDataArray(type_div(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B)
Base.(:(.\))(A::PooledDataArray, B::PooledDataArray) =
broadcast!(\, PooledDataArray(type_div(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B)
Base.(:(.^))(A::PooledDataArray{Bool}, B::PooledDataArray{Bool}) = databroadcast(>=, A, B)
Base.(:(.^))(A::PooledDataArray, B::PooledDataArray) =
broadcast!(^, PooledDataArray(type_pow(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B)

for (sf, vf) in zip(scalar_comparison_operators, array_comparison_operators)
@eval begin
# ambiguity
$(vf)(A::Union(PooledDataArray{Bool},DataArray{Bool}), B::Union(PooledDataArray{Bool},DataArray{Bool})) =
broadcast!($sf, DataArray(Bool, broadcast_shape(A, B)), A, B)
$(vf)(A::Union(PooledDataArray{Bool},DataArray{Bool}), B::AbstractArray{Bool}) =
broadcast!($sf, DataArray(Bool, broadcast_shape(A, B)), A, B)
$(vf)(A::AbstractArray{Bool}, B::Union(PooledDataArray{Bool},DataArray{Bool})) =
broadcast!($sf, DataArray(Bool, broadcast_shape(A, B)), A, B)

@da_broadcast_binary $(vf)(A, B) = broadcast!($sf, DataArray(Bool, broadcast_shape(A, B)), A, B)
end
end
12 changes: 1 addition & 11 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,6 @@ for (sf,vf) in zip(scalar_comparison_operators, array_comparison_operators)
@swappable ($(sf))(::NAtype, b) = NA

@dataarray_binary_scalar $(vf) $(sf) Bool true
@dataarray_binary_array $(vf) $(sf) Bool
end
end

Expand Down Expand Up @@ -767,11 +766,7 @@ Base.(:^)(::NAtype, ::Integer) = NA
Base.(:^)(::NAtype, ::Number) = NA

for (vf, sf) in ((:(Base.(:+)), :(Base.(:+))),
(:(Base.(:.+)), :(Base.(:+))),
(:(Base.(:-)), :(Base.(:-))),
(:(Base.(:.-)), :(Base.(:-))),
(:(Base.(:.*)), :(Base.(:*))),
(:(Base.(:.^)), :(Base.(:^))))
(:(Base.(:-)), :(Base.(:-))))
@eval begin
# Necessary to avoid ambiguity warnings
@swappable ($vf)(A::BitArray, B::AbstractDataArray{Bool}) = ($vf)(bitunpack(A), B)
Expand All @@ -781,9 +776,6 @@ for (vf, sf) in ((:(Base.(:+)), :(Base.(:+))),
end
end

@swappable Base.(:./)(A::BitArray, B::AbstractDataArray{Bool}) = ./(bitunpack(A), B)
@swappable Base.(:./)(A::BitArray, B::DataArray{Bool}) = ./(bitunpack(A), B)

# / and ./ are defined separately since they promote to floating point
for f in ((:(Base.(:/)), :(Base.(:./))))
@eval begin
Expand All @@ -800,8 +792,6 @@ Base.(:/){T,N}(b::AbstractArray{T,N}, ::NAtype) =
DataArray(Array(T, size(b)), trues(size(b)))
@dataarray_binary_scalar Base.(:./) Base.(:/) eltype(a) <: FloatingPoint || typeof(b) <: FloatingPoint ?
promote_type(eltype(a), typeof(b)) : Float64 true
@dataarray_binary_array Base.(:./) Base.(:/) eltype(a) <: FloatingPoint || eltype(b) <: FloatingPoint ?
promote_type(eltype(a), eltype(b)) : Float64

for f in biscalar_operators
@eval begin
Expand Down
Loading

0 comments on commit 3baae6b

Please sign in to comment.