Skip to content

Commit

Permalink
better varargs-based approach to interception
Browse files Browse the repository at this point in the history
  • Loading branch information
jrevels committed Jan 18, 2018
1 parent 229179b commit ecbe4d9
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 212 deletions.
5 changes: 2 additions & 3 deletions src/Cassette.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ const MAX_ARGS = 20

struct Unused end

include("utilities/metaprogramming.jl")
include("utilities/anonymous.jl")
include("utilities/misc.jl")
include("utilities.jl")

include("contextual/contexts.jl")
include("contextual/anonymous.jl")
include("contextual/metadata.jl")

include("overdub/reflection.jl")
Expand Down
File renamed without changes.
196 changes: 84 additions & 112 deletions src/overdub/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,136 +77,108 @@ end

Base.show(io::IO, o::Overdub{P}) where {P} = print("Overdub{$(P.name.name)}($(typeof(context(o)).name), $(func(o)))")

##################
# default passes #
##################
####################
# Overdub{Execute} #
####################

@inline (o::Overdub{Execute})(args...) = (hook(o, args...); execute(o, args...))

######################
# Overdub{Intercept} #
######################

# Replace all calls with `Overdub{Execute}` calls.
# Note that this approach emits code in which LHS SSAValues are not
# monotonically increasing. This currently isn't a problem, but in
# the future, valid IR might require monotonically increasing LHS
# SSAValues, in which case we'll have to add an extra SSA-remapping
# pass to this function.
function overdub_calls!(method_body::CodeInfo)
function overdub_pass!(method_body::CodeInfo)
# set up new SSAValues
self_ssa = SSAValue(method_body.ssavaluetypes)
method_body.ssavaluetypes += 1
ctx_ssa = SSAValue(method_body.ssavaluetypes)
greatest_ssa_value = method_body.ssavaluetypes
self = SSAValue(greatest_ssa_value)
new_code = Any[nothing, :($self = $(GlobalRef(Cassette, :func))($(SlotNumber(1))))]
label_map = Dict{Int,Int}()
# Replace calls with overdubbed calls, and replace
# SlotNumber(1) references with the underlying function.
# Also, fix LabelNodes and record the changes in a map that
# we'll use in a future pass to
for i in 2:length(method_body.code)

# set up replacement code
new_code = copy_prelude_code(method_body.code)
prelude_end = length(new_code)
push!(new_code, :($self_ssa = $(GlobalRef(Cassette, :func))($(SlotNumber(1)))))
push!(new_code, :($ctx_ssa = $(GlobalRef(Cassette, :context))($(SlotNumber(1)))))
in_overdub_region = false

# fill in replacement code
for i in (prelude_end + 1):length(method_body.code)
stmnt = method_body.code[i]
replace_match!(s -> self, s -> isa(s, SlotNumber) && s.id == 1, stmnt)
replace_match!(is_call, stmnt) do call
greatest_ssa_value += 1
new_ssa_value = SSAValue(greatest_ssa_value)
new_ssa_stmnt = Expr(:(=), new_ssa_value, Expr(:call, GlobalRef(Cassette, :intercept), SlotNumber(1), call.args[1]))
push!(new_code, new_ssa_stmnt)
call.args[1] = new_ssa_value
return call
end
push!(new_code, stmnt)
if isa(stmnt, LabelNode) && stmnt.label != length(new_code)
new_code[end] = LabelNode(length(new_code))
label_map[stmnt.label] = length(new_code)
end
end
# label positions might be messed up now due to
# the added SSAValues, so we have to fix them using
# the label map we built up during the earlier pass
for i in 2:length(new_code)
stmnt = new_code[i]
if isa(stmnt, GotoNode)
new_code[i] = GotoNode(get(label_map, stmnt.label, stmnt.label))
elseif isa(stmnt, Expr) && stmnt.head == :gotoifnot
stmnt.args[2] = get(label_map, stmnt.args[2], stmnt.args[2])
if stmnt === BEGIN_OVERDUB_REGION
in_overdub_region = true
else
# replace `SlotNumber(1)` references with `self_ssa`
replace_match!(s -> self_ssa, s -> isa(s, SlotNumber) && s.id == 1, stmnt)
if in_overdub_region
# replace calls with overdubbed calls
replace_match!(s -> is_call(s), stmnt) do call
greatest_ssa_value += 1
new_ssa_value = SSAValue(greatest_ssa_value)
new_ssa_stmnt = Expr(:(=), new_ssa_value, Expr(:call, GlobalRef(Cassette, :intercept), SlotNumber(1), call.args[1]))
push!(new_code, new_ssa_stmnt)
call.args[1] = new_ssa_value
return call
end
end
push!(new_code, stmnt)
end
end
method_body.code = new_code
method_body.ssavaluetypes = greatest_ssa_value + 1
return method_body
end

# replace all `new` expressions with calls to `Cassette._newbox`
function overdub_new!(method_body::CodeInfo)
code = method_body.code
ctx_ssa = SSAValue(method_body.ssavaluetypes)
insert!(code, 2, :($ctx_ssa = $(GlobalRef(Cassette, :context))($(SlotNumber(1)))))
method_body.ssavaluetypes += 1
replace_match!(x -> isa(x, Expr) && x.head === :new, code) do x
# replace all `new` expressions with calls to `Cassette._newbox`
replace_match!(x -> isa(x, Expr) && x.head === :new, new_code) do x
return Expr(:call, GlobalRef(Cassette, :_newbox), ctx_ssa, x.args...)
end
for i in eachindex(code)
stmnt = code[i]
if isa(stmnt, GotoNode)
code[i] = GotoNode(stmnt.label + 1)
elseif isa(stmnt, LabelNode)
code[i] = LabelNode(stmnt.label + 1)
elseif isa(stmnt, Expr) && stmnt.head == :gotoifnot
stmnt.args[2] += 1
end
end

method_body.code = fix_labels_and_gotos!(new_code)
method_body.ssavaluetypes = greatest_ssa_value + 1
return method_body
end

############################
# Overdub Call Definitions #
############################

# Overdub{Execute} #
#------------------#

@inline (o::Overdub{Execute})(args...) = (hook(o, args...); execute(o, args...))

# Overdub{Intercept} #
#--------------------#

for N in 0:MAX_ARGS
arg_names = [Symbol("_CASSETTE_$i") for i in 2:(N+1)]
arg_types = [:(unbox(C, $T)) for T in arg_names]
stub_expr = Expr(:new,
Core.GeneratedFunctionStub,
:_overdub_generator,
Any[:f, arg_names...],
Any[:F, :C, :M, :world, :debug, :pass],
@__LINE__,
QuoteNode(Symbol(@__FILE__)),
true)
@eval begin
function _overdub_generator(::Type{F}, ::Type{C}, ::Type{M}, world, debug, pass, f, $(arg_names...)) where {F,C,M}
ftype = unbox(C, F)
atypes = ($(arg_types...),)
signature = Tuple{ftype,atypes...}
try
method_body = lookup_method_body(signature, $arg_names, world, debug)
if isa(method_body, CodeInfo)
if !(pass <: Unused)
method_body = pass(signature, method_body)
end
method_body = overdub_new!(overdub_calls!(method_body))
method_body.inlineable = true
method_body.signature_for_inference_heuristics = Core.svec(ftype, atypes, world)
else
arg_names = $arg_names
method_body = quote
$(Expr(:meta, :inline))
$Cassette.execute(Val(true), f, $(arg_names...))
end
end
debug && Core.println("RETURNING OVERDUBBED METHOD BODY: ", method_body)
return method_body
catch err
errmsg = "ERROR COMPILING $signature OVERDUBBED WITH CONTEXT $C: " * sprint(showerror, err)
Core.println(errmsg) # in case the returned body doesn't get reached
return quote
error($errmsg)
end
function _overdub_generator(::Type{F}, ::Type{C}, ::Type{M}, world, debug, pass, f, args) where {F,C,M}
ftype = unbox(C, F)
atypes = Tuple(unbox(C, T) for T in args)
signature = Tuple{ftype,atypes...}
try
method_body = lookup_method_body(signature, world, debug)
if isa(method_body, CodeInfo)
if !(pass <: Unused)
method_body = pass(signature, method_body)
end
method_body = overdub_pass!(method_body)
method_body.inlineable = true
method_body.signature_for_inference_heuristics = Core.svec(ftype, atypes, world)
debug && Core.println("RETURNING OVERDUBBED CODEINFO: ", sprint(show, method_body))
else
method_body = quote
$(Expr(:meta, :inline))
$Cassette.execute(Val(true), f, $(OVERDUB_ARGS_SYMBOL)...)
end
debug && Core.println("NO CODEINFO FOUND; EXECUTING AS PRIMITIVE")
end
function (f::Overdub{Intercept,F,Settings{C,M,world,debug,pass}})($(arg_names...)) where {F,C,M,world,debug,pass}
$(Expr(:meta, :generated, stub_expr))
return method_body
catch err
errmsg = "ERROR COMPILING $signature IN CONTEXT $C: " * sprint(showerror, err)
Core.println(errmsg) # in case the returned body doesn't get reached
return quote
error($errmsg)
end
end
end

@eval function (f::Overdub{Intercept,F,Settings{C,M,world,debug,pass}})($(OVERDUB_ARGS_SYMBOL)...) where {F,C,M,world,debug,pass}
$(Expr(:meta,
:generated,
Expr(:new,
Core.GeneratedFunctionStub,
:_overdub_generator,
Any[:f, OVERDUB_ARGS_SYMBOL],
Any[:F, :C, :M, :world, :debug, :pass],
@__LINE__,
QuoteNode(Symbol(@__FILE__)),
true)))
end
78 changes: 40 additions & 38 deletions src/overdub/reflection.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
const OVERDUB_ARGS_SYMBOL = gensym("cassette_overdub_arguments")

const BEGIN_OVERDUB_REGION = gensym("cassette_begin_overdub_region")

# Return the CodeInfo method body for signature `S` and `world`,
# if it exists in the method table. Otherwise, return `nothing`.
function lookup_method_body(::Type{S}, arg_names::Vector,
function lookup_method_body(::Type{S},
world::UInt = typemax(UInt),
debug::Bool = false) where {S<:Tuple}
if debug
Expand All @@ -10,16 +14,7 @@ function lookup_method_body(::Type{S}, arg_names::Vector,
Core.println("\tWORLD: ", world)
end
S.parameters[1].name.module === Core.Compiler && return nothing
results = _lookup_method_body(S, arg_names, world)
results === nothing && return nothing
method, code_info = results
debug && Core.println("LOOKED UP METHOD: ", method)
debug && Core.println("LOOKED UP CODEINFO: ", code_info)
return code_info
end

function _lookup_method_body(::Type{S}, arg_names::Vector,
world::UInt = typemax(UInt)) where {S<:Tuple}
# retrieve initial Method + CodeInfo
_methods = Base._methods_by_ftype(S, -1, world)
length(_methods) == 1 || return nothing
Expand All @@ -29,38 +24,45 @@ function _lookup_method_body(::Type{S}, arg_names::Vector,
static_params = Any[raw_static_params...]
code_info = Core.Compiler.retrieve_code_info(method_instance)
isa(code_info, CodeInfo) || return nothing
code_info = Core.Compiler.copy_code_info(code_info)
debug && Core.println("FOUND METHOD: ", sprint(show, method))
debug && Core.println("FOUND CODEINFO: ", sprint(show, code_info))

# substitute static parameters/varargs
# substitute static parameters
body = Expr(:block)
body.args = code_info.code
Core.Compiler.substitute!(body, 0, Any[], method_signature, static_params, 0, :propagate)

# construct new slotnames/slotflags, offset non-self slotnumbers
code_info.slotnames = Any[code_info.slotnames[1], OVERDUB_ARGS_SYMBOL, code_info.slotnames[2:end]...]
code_info.slotflags = Any[code_info.slotflags[1], 0x00, code_info.slotflags[2:end]...]
replace_match!(s -> SlotNumber(s.id + 1), s -> isa(s, SlotNumber) && s.id > 1, code_info.code)

# construct new `code_info.code` in which original arguments are properly destructured
new_code = copy_prelude_code(code_info.code)
prelude_end = length(new_code)
n_actual_args = fieldcount(S) - 1
n_method_args = Int64(method.nargs) - 1
for i in 1:n_method_args
slot = i + 2
actual_argument = Expr(:call, GlobalRef(Core, :getfield), SlotNumber(2), i)
push!(new_code, :($(SlotNumber(slot)) = $actual_argument))
code_info.slotflags[slot] |= 0x01 << 0x01 # make sure the "assigned" bitflag is set
end
if method.isva
nargs = Int64(method.nargs)
new_nargs = length(arg_names) + 1
if new_nargs < nargs # then the final varargs argument is an empty tuple
offset = 0
new_slots = Any[SlotNumber(i) for i in 1:(nargs - 1)]
push!(new_slots, Expr(:call, GlobalRef(Core, :tuple)))
Base.Core.Compiler.substitute!(body, nargs, new_slots, method_signature, static_params, offset, :propagate)
else
new_slotnames = code_info.slotnames[1:(nargs - 1)]
new_slotflags = code_info.slotflags[1:(nargs - 1)]
for i in nargs:new_nargs
push!(new_slotnames, arg_names[i - 1])
push!(new_slotflags, 0x00)
end
append!(new_slotnames, code_info.slotnames[(nargs + 1):end])
append!(new_slotflags, code_info.slotflags[(nargs + 1):end])
offset = new_nargs - nargs
vararg_tuple = Expr(:call, GlobalRef(Core, :tuple), [SlotNumber(i) for i in nargs:new_nargs]...)
new_slots = Any[SlotNumber(i) for i in 1:(nargs - 1)]
push!(new_slots, vararg_tuple)
Base.Core.Compiler.substitute!(body, new_nargs, new_slots, method_signature, static_params, offset, :propagate)
code_info.slotnames = new_slotnames
code_info.slotflags = new_slotflags
isempty(new_code) || pop!(new_code) # remove the slot reassignment that we're replacing
final_arguments = Expr(:call, GlobalRef(Core, :tuple))
for i in n_method_args:n_actual_args
ssaval = SSAValue(code_info.ssavaluetypes)
actual_argument = Expr(:call, GlobalRef(Core, :getfield), SlotNumber(2), i)
push!(new_code, :($ssaval = $actual_argument))
push!(final_arguments.args, ssaval)
code_info.ssavaluetypes += 1
end
else
Base.Core.Compiler.substitute!(body, 0, Any[], method_signature, static_params, 0, :propagate)
push!(new_code, :($(SlotNumber(n_method_args + 2)) = $final_arguments))
end

return method, code_info
push!(new_code, BEGIN_OVERDUB_REGION)
append!(new_code, code_info.code[(prelude_end + 1):end])
code_info.code = fix_labels_and_gotos!(new_code)
return code_info
end
Loading

0 comments on commit ecbe4d9

Please sign in to comment.