Skip to content

Commit

Permalink
Add backedges for Cassette (#32237)
Browse files Browse the repository at this point in the history
Adds an extra field in `CodeInfo` that allows users (Cassette, et al.)
to specify dependencies (as forward edges to `MethodInstance`s) that should
be turned into backedges once the CodeInfo is passed over by inference.

The test includes a minimal implementation of the Cassette mechansim to
exercise this code path.
  • Loading branch information
Keno authored Jun 10, 2019
1 parent aaaa6a1 commit 1a859d6
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 8 deletions.
5 changes: 3 additions & 2 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ mutable struct InferenceState
inmodule = linfo.def::Module
end

min_valid = UInt(1)
max_valid = get_world_counter()
min_valid = src.min_world
max_valid = src.max_world == typemax(UInt) ?
get_world_counter() : src.max_world
frame = new(
params, result, linfo,
sp, slottypes, inmodule, 0,
Expand Down
9 changes: 9 additions & 0 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,15 @@ function store_backedges(frame::InferenceState)
end
end
end
edges = frame.src.edges
if edges !== nothing
edges = edges::Vector{MethodInstance}
for edge in edges
@assert isa(edge, MethodInstance)
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any), edge, caller)
end
frame.src.edges = nothing
end
end
end

Expand Down
8 changes: 5 additions & 3 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2070,7 +2070,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_code_info_type =
jl_new_datatype(jl_symbol("CodeInfo"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(17,
jl_perm_symsvec(18,
"code",
"codelocs",
"ssavaluetypes",
Expand All @@ -2082,13 +2082,14 @@ void jl_init_types(void) JL_GC_DISABLED
"slottypes",
"rettype",
"parent",
"edges",
"min_world",
"max_world",
"inferred",
"inlineable",
"propagate_inbounds",
"pure"),
jl_svec(17,
jl_svec(18,
jl_array_any_type,
jl_any_type,
jl_any_type,
Expand All @@ -2100,13 +2101,14 @@ void jl_init_types(void) JL_GC_DISABLED
jl_any_type,
jl_any_type,
jl_any_type,
jl_any_type,
jl_ulong_type,
jl_ulong_type,
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type),
0, 1, 17);
0, 1, 18);

jl_method_type =
jl_new_datatype(jl_symbol("Method"), core,
Expand Down
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ typedef struct _jl_code_info_t {
jl_value_t *slottypes; // inferred types of slots
jl_value_t *rettype;
jl_method_instance_t *parent; // context (optionally, if available, otherwise nothing)
jl_value_t *edges; // forward edges to method instances that must be invalidated
size_t min_world;
size_t max_world;
// various boolean properties:
Expand Down
1 change: 1 addition & 0 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
src->inlineable = 0;
src->propagate_inbounds = 0;
src->pure = 0;
src->edges = jl_nothing;
return src;
}

Expand Down
11 changes: 9 additions & 2 deletions stdlib/Serialization/src/Serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1015,8 +1015,15 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo})
ci.slottypes = deserialize(s)
ci.rettype = deserialize(s)
ci.parent = deserialize(s)
ci.min_world = reinterpret(UInt, deserialize(s))
ci.max_world = reinterpret(UInt, deserialize(s))
world_or_edges = deserialize(s)
pre_13 = isa(world_or_edges, Integer)
if pre_13
ci.min_world = world_or_edges
else
ci.edges = world_or_edges
ci.min_world = reinterpret(UInt, deserialize(s))
ci.max_world = reinterpret(UInt, deserialize(s))
end
end
ci.inferred = deserialize(s)
ci.inlineable = deserialize(s)
Expand Down
2 changes: 1 addition & 1 deletion test/choosetests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ function choosetests(choices = [])
end

compilertests = ["compiler/inference", "compiler/validation", "compiler/ssair", "compiler/irpasses",
"compiler/codegen", "compiler/inline"]
"compiler/codegen", "compiler/inline", "compiler/contextual"]

if "compiler" in skip_tests
filter!(x -> (x != "compiler" && !(x in compilertests)), tests)
Expand Down
110 changes: 110 additions & 0 deletions test/compiler/contextual.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
module MiniCassette
# A minimal demonstration of the cassette mechanism. Doesn't support all the
# fancy features, but sufficient to exercise this code path in the compiler.

using Core.Compiler: method_instances, retrieve_code_info, CodeInfo,
MethodInstance, SSAValue, GotoNode, Slot, SlotNumber, quoted,
signature_type
using Base: _methods_by_ftype
using Base.Meta: isexpr
using Test

export Ctx, overdub

struct Ctx; end

# A no-op cassette-like transform
function transform_expr(expr, map_slot_number, map_ssa_value, sparams)
transform(expr) = transform_expr(expr, map_slot_number, map_ssa_value, sparams)
if isexpr(expr, :call)
return Expr(:call, overdub, SlotNumber(2), map(transform, expr.args)...)
elseif isexpr(expr, :gotoifnot)
return Expr(:gotoifnot, transform(expr.args[1]), map_ssa_value(SSAValue(expr.args[2])).id)
elseif isexpr(expr, :static_parameter)
return quoted(sparams[expr.args[1]])
elseif isa(expr, Expr)
return Expr(expr.head, map(transform, expr.args)...)
elseif isa(expr, GotoNode)
return GotoNode(map_ssa_value(SSAValue(expr.label)).id)
elseif isa(expr, Slot)
return map_slot_number(expr.id)
elseif isa(expr, SSAValue)
return map_ssa_value(expr)
else
return expr
end
end

function transform!(ci, nargs, sparams)
code = ci.code
ci.slotnames = Symbol[Symbol("#self#"), :ctx, :f, :args, ci.slotnames[nargs+1:end]...]
ci.slotflags = UInt8[(0x00 for i = 1:4)..., ci.slotflags[nargs+1:end]...]
# Insert one SSAValue for every argument statement
prepend!(code, [Expr(:call, getfield, SlotNumber(4), i) for i = 1:nargs])
prepend!(ci.codelocs, [0 for i = 1:nargs])
ci.ssavaluetypes += nargs
function map_slot_number(slot)
if slot == 1
# self in the original function is now `f`
return SlotNumber(3)
elseif 2 <= slot <= nargs + 1
# Arguments get inserted as ssa values at the top of the function
return SSAValue(slot - 1)
else
# The first non-argument slot will be 5
return SlotNumber(slot - (nargs + 1) + 4)
end
end
map_ssa_value(ssa::SSAValue) = SSAValue(ssa.id + nargs)
for i = (nargs+1:length(code))
code[i] = transform_expr(code[i], map_slot_number, map_ssa_value, sparams)
end
end

function overdub_generator(self, c, f, args)
if f <: Core.Builtin || !isdefined(f, :instance)
return :(return f(args...))
end

tt = Tuple{f, args...}
mthds = _methods_by_ftype(tt, -1, typemax(UInt))
@assert length(mthds) == 1
mtypes, msp, m = mthds[1]
mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp)
# Unsupported in this mini-cassette
@assert !mi.def.isva
code_info = retrieve_code_info(mi)
@assert isa(code_info, CodeInfo)
code_info = copy(code_info)
if isdefined(code_info, :edges)
code_info.edges = MethodInstance[mi]
end
transform!(code_info, length(args), msp)
code_info
end

@eval function overdub(c::Ctx, f, args...)
$(Expr(:meta, :generated_only))
$(Expr(:meta,
:generated,
Expr(:new,
Core.GeneratedFunctionStub,
:overdub_generator,
Any[:overdub, :ctx, :f, :args],
Any[],
@__LINE__,
QuoteNode(Symbol(@__FILE__)),
true)))
end
end

using .MiniCassette

# Test #265 for Cassette
f() = 1
@test overdub(Ctx(), f) === 1
f() = 2
@test overdub(Ctx(), f) === 2

# Test that MiniCassette is at least somewhat capable by overdubbing gcd
@test overdub(Ctx(), gcd, 10, 20) === gcd(10, 20)

0 comments on commit 1a859d6

Please sign in to comment.