Skip to content

Commit

Permalink
Merge pull request #9098 from JuliaLang/teh/ngenerate
Browse files Browse the repository at this point in the history
Expunging @ngenerate and @nsplat
  • Loading branch information
timholy committed Feb 7, 2015
2 parents 49a1f2e + def80f7 commit 4fd7fef
Show file tree
Hide file tree
Showing 6 changed files with 542 additions and 674 deletions.
64 changes: 37 additions & 27 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,38 +239,48 @@ broadcast!_function(f::Function) = (B, As...) -> broadcast!(f, B, As...)
broadcast_function(f::Function) = (As...) -> broadcast(f, As...)

broadcast_getindex(src::AbstractArray, I::AbstractArray...) = broadcast_getindex!(Array(eltype(src), broadcast_shape(I...)), src, I...)
@ngenerate N typeof(dest) function broadcast_getindex!(dest::AbstractArray, src::AbstractArray, I::NTuple{N, AbstractArray}...)
check_broadcast_shape(size(dest), I...) # unnecessary if this function is never called directly
checkbounds(src, I...)
@nloops N i dest d->(@nexprs N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
@nexprs N k->(@inbounds J_k = @nref N I_k d->j_d_k)
@inbounds (@nref N dest i) = (@nref N src J)
stagedfunction broadcast_getindex!(dest::AbstractArray, src::AbstractArray, I::AbstractArray...)
N = length(I)
Isplat = Expr[:(I[$d]) for d = 1:N]
quote
@nexprs $N d->(I_d = I[d])
check_broadcast_shape(size(dest), $(Isplat...)) # unnecessary if this function is never called directly
checkbounds(src, $(Isplat...))
@nloops $N i dest d->(@nexprs $N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
@nexprs $N k->(@inbounds J_k = @nref $N I_k d->j_d_k)
@inbounds (@nref $N dest i) = (@nref $N src J)
end
dest
end
dest
end

@ngenerate N typeof(A) function broadcast_setindex!(A::AbstractArray, x, I::NTuple{N, AbstractArray}...)
checkbounds(A, I...)
shape = broadcast_shape(I...)
@nextract N shape d->(length(shape) < d ? 1 : shape[d])
if !isa(x, AbstractArray)
@nloops N i d->(1:shape_d) d->(@nexprs N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
@nexprs N k->(@inbounds J_k = @nref N I_k d->j_d_k)
@inbounds (@nref N A J) = x
end
else
X = x
# To call setindex_shape_check, we need to create fake 1-d indexes of the proper size
@nexprs N d->(fakeI_d = 1:shape_d)
Base.setindex_shape_check(X, (@ntuple N fakeI)...)
k = 1
@nloops N i d->(1:shape_d) d->(@nexprs N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
@nexprs N k->(@inbounds J_k = @nref N I_k d->j_d_k)
@inbounds (@nref N A J) = X[k]
k += 1
stagedfunction broadcast_setindex!(A::AbstractArray, x, I::AbstractArray...)
N = length(I)
Isplat = Expr[:(I[$d]) for d = 1:N]
quote
@nexprs $N d->(I_d = I[d])
checkbounds(A, $(Isplat...))
shape = broadcast_shape($(Isplat...))
@nextract $N shape d->(length(shape) < d ? 1 : shape[d])
if !isa(x, AbstractArray)
@nloops $N i d->(1:shape_d) d->(@nexprs $N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
@nexprs $N k->(@inbounds J_k = @nref $N I_k d->j_d_k)
@inbounds (@nref $N A J) = x
end
else
X = x
# To call setindex_shape_check, we need to create fake 1-d indexes of the proper size
@nexprs $N d->(fakeI_d = 1:shape_d)
Base.setindex_shape_check(X, (@ntuple $N fakeI)...)
k = 1
@nloops $N i d->(1:shape_d) d->(@nexprs $N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
@nexprs $N k->(@inbounds J_k = @nref $N I_k d->j_d_k)
@inbounds (@nref $N A J) = X[k]
k += 1
end
end
A
end
A
end

## elementwise operators ##
Expand Down
271 changes: 1 addition & 270 deletions base/cartesian.jl
Original file line number Diff line number Diff line change
@@ -1,275 +1,6 @@
module Cartesian

export @ngenerate, @nsplat, @nloops, @nref, @ncall, @nexprs, @nextract, @nall, @ntuple, @nif, ngenerate

const CARTESIAN_DIMS = 4

### @ngenerate, for auto-generation of separate versions of functions for different dimensionalities
# Examples (deliberately trivial):
# @ngenerate N returntype myndims{T,N}(A::Array{T,N}) = N
# or alternatively
# function gen_body(N::Int)
# quote
# return $N
# end
# end
# eval(ngenerate(:N, returntypeexpr, :(myndims{T,N}(A::Array{T,N})), gen_body))
# The latter allows you to use a single gen_body function for both ngenerate and
# when your function maintains its own method cache (e.g., reduction or broadcasting).
#
# Special syntax for function prototypes:
# @ngenerate N returntype function myfunction(A::AbstractArray, I::NTuple{N, Int}...)
# for N = 3 translates to
# function myfunction(A::AbstractArray, I_1::Int, I_2::Int, I_3::Int)
# and for the generic (cached) case as
# function myfunction(A::AbstractArray, I::Int...)
# @nextract N I I
# with N = length(I). N should _not_ be listed as a parameter of the function unless
# earlier arguments use it that way.
# To avoid ambiguity, it would be preferable to have some specific syntax for this, such as
# myfunction(A::AbstractArray, I::Int...N)
# where N can be an integer or symbol. Currently T...N generates a parser error.
macro ngenerate(itersym, returntypeexpr, funcexpr)
if isa(funcexpr, Expr) && funcexpr.head == :macrocall && funcexpr.args[1] == symbol("@inline")
funcexpr = Base._inline(funcexpr.args[2])
end
isfuncexpr(funcexpr) || throw(ArgumentError("requires a function expression"))
esc(ngenerate(itersym, returntypeexpr, funcexpr.args[1], N->sreplace!(copy(funcexpr.args[2]), itersym, N)))
end

# @nsplat takes an expression like
# @nsplat N 2:3 myfunction(A, I::NTuple{N,Real}...) = getindex(A, I...)
# and generates
# myfunction(A, I_1::Real, I_2::Real) = getindex(A, I_1, I_2)
# myfunction(A, I_1::Real, I_2::Real, I_3::Real) = getindex(A, I_1, I_2, I_3)
# myfunction(A, I::Real...) = getindex(A, I...)
# An @nsplat function _cannot_ have any other Cartesian macros in it.
# If you omit the range, it uses 1:CARTESIAN_DIMS.
macro nsplat(itersym, args...)
local rng
if length(args) == 1
rng = 1:CARTESIAN_DIMS
funcexpr = args[1]
elseif length(args) == 2
rangeexpr = args[1]
funcexpr = args[2]
if !isa(rangeexpr, Expr) || rangeexpr.head != :(:) || length(rangeexpr.args) != 2
throw(ArgumentError("first argument must be a from:to expression"))
end
rng = rangeexpr.args[1]:rangeexpr.args[2]
else
throw(ArgumentError("wrong number of arguments"))
end
if isa(funcexpr, Expr) && funcexpr.head == :macrocall && funcexpr.args[1] == symbol("@inline")
funcexpr = Base._inline(funcexpr.args[2])
end
isfuncexpr(funcexpr) || throw(ArgumentError("second argument must be a function expression"))
prototype = funcexpr.args[1]
body = funcexpr.args[2]
varname, T = get_splatinfo(prototype, itersym)
isempty(varname) && throw(ArgumentError("last argument must be a splat"))
explicit = [Expr(:function, resolvesplat!(copy(prototype), varname, T, N),
resolvesplats!(copy(body), varname, N)) for N in rng]
protosplat = resolvesplat!(copy(prototype), varname, T, 0)
protosplat.args[end] = Expr(:..., protosplat.args[end])
splat = Expr(:function, protosplat, body)
esc(Expr(:block, explicit..., splat))
end

generate1(itersym, prototype, bodyfunc, N::Int, varname, T) =
Expr(:function, spliceint!(sreplace!(resolvesplat!(copy(prototype), varname, T, N), itersym, N)),
resolvesplats!(bodyfunc(N), varname, N))

function ngenerate(itersym, returntypeexpr, prototype, bodyfunc, dims=1:CARTESIAN_DIMS, makecached::Bool = true)
varname, T = get_splatinfo(prototype, itersym)
# Generate versions for specific dimensions
fdim = [generate1(itersym, prototype, bodyfunc, N, varname, T) for N in dims]
if !makecached
return Expr(:block, fdim...)
end
# Generate the generic cache-based version
if isempty(varname)
setitersym, extractvarargs = :(), N -> nothing
else
s = symbol(varname)
setitersym = hasparameter(prototype, itersym) ? (:(@assert $itersym == length($s))) : (:($itersym = length($s)))
extractvarargs = N -> Expr(:block, map(popescape, _nextract(N, s, s).args)...)
end
fsym = funcsym(prototype)
dictname = symbol(fsym,"_cache")
fargs = funcargs(prototype)
if !isempty(varname)
fargs[end] = Expr(:..., fargs[end].args[1])
end
flocal = funcrename(copy(prototype), :_F_)
F = Expr(:function, resolvesplat!(prototype, varname, T), quote
$setitersym
if !haskey($dictname, $itersym)
gen1 = Base.Cartesian.generate1($(symbol(itersym)), $(Expr(:quote, flocal)), $bodyfunc, $itersym, $varname, $T)
$(dictname)[$itersym] = eval(quote
local _F_
$gen1
_F_
end)
end
($(dictname)[$itersym]($(fargs...)))::$returntypeexpr
end)
Expr(:block, fdim..., quote
let $dictname = Dict{Int,Function}()
$F
end
end)
end

isfuncexpr(ex::Expr) =
ex.head == :function || (ex.head == :(=) && typeof(ex.args[1]) == Expr && ex.args[1].head == :call)
isfuncexpr(arg) = false

sreplace!(arg, sym, val) = arg
function sreplace!(ex::Expr, sym, val)
for i = 1:length(ex.args)
ex.args[i] = sreplace!(ex.args[i], sym, val)
end
ex
end
sreplace!(s::Symbol, sym, val) = s == sym ? val : s

# If using the syntax that will need "desplatting",
# myfunction(A::AbstractArray, I::NTuple{N, Int}...)
# return the variable name (as a string) and type
function get_splatinfo(ex::Expr, itersym::Symbol)
if ex.head == :call
a = ex.args[end]
if isa(a, Expr) && a.head == :... && length(a.args) == 1
b = a.args[1]
if isa(b, Expr) && b.head == :(::)
varname = string(b.args[1])
c = b.args[2]
if isa(c, Expr) && c.head == :curly && c.args[1] == :NTuple && c.args[2] == itersym
T = c.args[3]
return varname, T
end
end
end
end
"", Void
end

# Replace splatted with desplatted for a specific number of arguments
function resolvesplat!(prototype, varname, T::Union(Type,Symbol,Expr), N::Int)
if !isempty(varname)
prototype.args[end] = N > 0 ? Expr(:(::), symbol(varname, "_1"), T) :
Expr(:(::), symbol(varname), T)
for i = 2:N
push!(prototype.args, Expr(:(::), symbol(varname, "_", i), T))
end
end
prototype
end

# Return the generic splatting form, e.g.,
# myfunction(A::AbstractArray, I::Int...)
function resolvesplat!(prototype, varname, T::Union(Type,Symbol,Expr))
if !isempty(varname)
svarname = symbol(varname)
prototype.args[end] = Expr(:..., :($svarname::$T))
end
prototype
end

# Desplatting function calls: replace func(a, b, I...) with func(a, b, I_1, I_2, I_3)
resolvesplats!(arg, varname, N) = arg
function resolvesplats!(ex::Expr, varname, N::Int)
if ex.head == :call
for i = 2:length(ex.args)-1
resolvesplats!(ex.args[i], varname, N)
end
a = ex.args[end]
if isa(a, Expr) && a.head == :... && a.args[1] == symbol(varname)
ex.args[end] = symbol(varname, "_1")
for i = 2:N
push!(ex.args, symbol(varname, "_", i))
end
else
resolvesplats!(a, varname, N)
end
else
for i = 1:length(ex.args)
resolvesplats!(ex.args[i], varname, N)
end
end
ex
end

# Remove any function parameters that are integers
function spliceint!(ex::Expr)
if ex.head == :escape
return esc(spliceint!(ex.args[1]))
end
ex.head == :call || throw(ArgumentError("$ex must be a call"))
if isa(ex.args[1], Expr) && ex.args[1].head == :curly
args = ex.args[1].args
for i = length(args):-1:1
if isa(args[i], Int)
deleteat!(args, i)
end
end
end
ex
end

function popescape(ex::Expr)
while ex.head == :escape
ex = ex.args[1]
end
ex
end

# Extract the "function name"
function funcsym(prototype::Expr)
prototype = popescape(prototype)
prototype.head == :call || throw(ArgumentError("$prototype must be a call"))
tmp = prototype.args[1]
if isa(tmp, Expr) && tmp.head == :curly
tmp = tmp.args[1]
end
return tmp
end

function funcrename(prototype::Expr, name::Symbol)
prototype = popescape(prototype)
prototype.head == :call || throw(ArgumentError("$prototype must be a call"))
tmp = prototype.args[1]
if isa(tmp, Expr) && tmp.head == :curly
tmp.args[1] = name
else
prototype.args[1] = name
end
return prototype
end

function hasparameter(prototype::Expr, sym::Symbol)
prototype = popescape(prototype)
prototype.head == :call || throw(ArgumentError("$prototype must be a call"))
tmp = prototype.args[1]
if isa(tmp, Expr) && tmp.head == :curly
for i = 2:length(tmp.args)
if tmp.args[i] == sym
return true
end
end
end
false
end

# Extract the symbols of the function arguments
funcarg(s::Symbol) = s
funcarg(ex::Expr) = ex.args[1]
function funcargs(prototype::Expr)
prototype = popescape(prototype)
prototype.head == :call || throw(ArgumentError("$prototype must be a call"))
map(a->funcarg(a), prototype.args[2:end])
end
export @nloops, @nref, @ncall, @nexprs, @nextract, @nall, @ntuple, @nif

### Cartesian-specific macros

Expand Down
Loading

0 comments on commit 4fd7fef

Please sign in to comment.