Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster cache lookup in broadcast! via nested Dicts and get! macro #6107

Merged
merged 3 commits into from
Mar 12, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 25 additions & 66 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Broadcast

using ..Cartesian
import Base.promote_eltype
import Base.@get!
import Base.num_bit_chunks, Base.@_msk_end, Base.getindex_unchecked
import Base.(.+), Base.(.-), Base.(.*), Base.(./), Base.(.\)
import Base.(.==), Base.(.<), Base.(.!=), Base.(.<=)
Expand Down Expand Up @@ -203,73 +204,31 @@ function gen_broadcast_function_tobitarray(genbody::Function, nd::Int, narrays::
end
end

let broadcast_cache = Dict()
global broadcast!
function broadcast!(f::Function, B, As::Union(Array,BitArray)...)
nd = ndims(B)
narrays = length(As)
key = (f, nd, narrays)
if !haskey(broadcast_cache, key)
func = gen_broadcast_function(gen_broadcast_body_iter, nd, narrays, f)
broadcast_cache[key] = func
else
func = broadcast_cache[key]
end
func(B, As...)
B
end
end # let broadcast_cache

let broadcast_cache = Dict()
global broadcast!
function broadcast!(f::Function, B::BitArray, As::Union(Array,BitArray)...)
nd = ndims(B)
narrays = length(As)
key = (f, nd, narrays)
if !haskey(broadcast_cache, key)
func = gen_broadcast_function_tobitarray(gen_broadcast_body_iter_tobitarray, nd, narrays, f)
broadcast_cache[key] = func
else
func = broadcast_cache[key]
end
func(B, As...)
B
end
end # let broadcast_cache

let broadcast_cache = Dict()
global broadcast!
function broadcast!(f::Function, B, As...)
nd = ndims(B)
narrays = length(As)
key = (f, nd, narrays)
if !haskey(broadcast_cache, key)
func = gen_broadcast_function(gen_broadcast_body_cartesian, nd, narrays, f)
broadcast_cache[key] = func
else
func = broadcast_cache[key]
end
func(B, As...)
B
end
end # let broadcast_cache

let broadcast_cache = Dict()
global broadcast!
function broadcast!(f::Function, B::BitArray, As...)
nd = ndims(B)
narrays = length(As)
key = (f, nd, narrays)
if !haskey(broadcast_cache, key)
func = gen_broadcast_function_tobitarray(gen_broadcast_body_cartesian_tobitarray, nd, narrays, f)
broadcast_cache[key] = func
else
func = broadcast_cache[key]
end
func(B, As...)
B
for (Bsig, Asig, gbf, gbb) in
((BitArray , Union(Array,BitArray) ,
:gen_broadcast_function_tobitarray, :gen_broadcast_body_iter_tobitarray ),
(Any , Union(Array,BitArray) ,
:gen_broadcast_function , :gen_broadcast_body_iter ),
(BitArray , Any ,
:gen_broadcast_function_tobitarray, :gen_broadcast_body_cartesian_tobitarray),
(Any , Any ,
:gen_broadcast_function , :gen_broadcast_body_cartesian ))

@eval let cache = Dict{Function,Dict{Int,Dict{Int,Function}}}()
global broadcast!
function broadcast!(f::Function, B::$Bsig, As::$Asig...)
nd = ndims(B)
narrays = length(As)

cache_f = @get! cache f Dict{Int,Dict{Int,Function}}()
cache_f_na = @get! cache_f narrays Dict{Int,Function}()
func = @get! cache_f_na nd $gbf($gbb, nd, narrays, f)

func(B, As...)
B
end
end # let broadcast_cache
end
end # let broadcast_cache


broadcast(f::Function, As...) = broadcast!(f, Array(promote_eltype(As...), broadcast_shape(As...)), As...)
Expand Down
20 changes: 20 additions & 0 deletions base/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,26 @@ function get!{K,V}(default::Function, h::Dict{K,V}, key0)
return v
end

# NOTE: this macro is specific to Dict, not Associative, and should
# therefore not be exported as-is: it's for internal use only.
macro get!(h, key0, default)
quote
K, V = eltype($(esc(h)))
key = convert(K, $(esc(key0)))
isequal(key, $(esc(key0))) || error($(esc(key0)), " is not a valid key for type ", K)
idx = ht_keyindex2($(esc(h)), key)
if idx < 0
idx = -idx
v = convert(V, $(esc(default)))
_setindex!($(esc(h)), v, key, idx)
else
@inbounds v = $(esc(h)).vals[idx]
end
v
end
end


function getindex{K,V}(h::Dict{K,V}, key)
index = ht_keyindex(h, key)
return (index<0) ? throw(KeyError(key)) : h.vals[index]::V
Expand Down