From 1a859d6c9e87294b0113937a549a3df58bffde36 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Mon, 10 Jun 2019 15:10:23 -0400 Subject: [PATCH] Add backedges for Cassette (#32237) Adds an extra field in `CodeInfo` that allows users (Cassette, et al.) to specify dependencies (as forward edges to `MethodInstance`s) that should be turned into backedges once the CodeInfo is passed over by inference. The test includes a minimal implementation of the Cassette mechansim to exercise this code path. --- base/compiler/inferencestate.jl | 5 +- base/compiler/typeinfer.jl | 9 ++ src/jltypes.c | 8 +- src/julia.h | 1 + src/method.c | 1 + stdlib/Serialization/src/Serialization.jl | 11 ++- test/choosetests.jl | 2 +- test/compiler/contextual.jl | 110 ++++++++++++++++++++++ 8 files changed, 139 insertions(+), 8 deletions(-) create mode 100644 test/compiler/contextual.jl diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index 2d6d2373e2361..4caf6b61ec810 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -87,8 +87,9 @@ mutable struct InferenceState inmodule = linfo.def::Module end - min_valid = UInt(1) - max_valid = get_world_counter() + min_valid = src.min_world + max_valid = src.max_world == typemax(UInt) ? + get_world_counter() : src.max_world frame = new( params, result, linfo, sp, slottypes, inmodule, 0, diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 93851e5ba37bc..0d7ee43cf618c 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -191,6 +191,15 @@ function store_backedges(frame::InferenceState) end end end + edges = frame.src.edges + if edges !== nothing + edges = edges::Vector{MethodInstance} + for edge in edges + @assert isa(edge, MethodInstance) + ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any), edge, caller) + end + frame.src.edges = nothing + end end end diff --git a/src/jltypes.c b/src/jltypes.c index a1f9d7459825d..422fcc7af2f37 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -2070,7 +2070,7 @@ void jl_init_types(void) JL_GC_DISABLED jl_code_info_type = jl_new_datatype(jl_symbol("CodeInfo"), core, jl_any_type, jl_emptysvec, - jl_perm_symsvec(17, + jl_perm_symsvec(18, "code", "codelocs", "ssavaluetypes", @@ -2082,13 +2082,14 @@ void jl_init_types(void) JL_GC_DISABLED "slottypes", "rettype", "parent", + "edges", "min_world", "max_world", "inferred", "inlineable", "propagate_inbounds", "pure"), - jl_svec(17, + jl_svec(18, jl_array_any_type, jl_any_type, jl_any_type, @@ -2100,13 +2101,14 @@ void jl_init_types(void) JL_GC_DISABLED jl_any_type, jl_any_type, jl_any_type, + jl_any_type, jl_ulong_type, jl_ulong_type, jl_bool_type, jl_bool_type, jl_bool_type, jl_bool_type), - 0, 1, 17); + 0, 1, 18); jl_method_type = jl_new_datatype(jl_symbol("Method"), core, diff --git a/src/julia.h b/src/julia.h index e0c3028a08656..cb4a0dd3e877a 100644 --- a/src/julia.h +++ b/src/julia.h @@ -254,6 +254,7 @@ typedef struct _jl_code_info_t { jl_value_t *slottypes; // inferred types of slots jl_value_t *rettype; jl_method_instance_t *parent; // context (optionally, if available, otherwise nothing) + jl_value_t *edges; // forward edges to method instances that must be invalidated size_t min_world; size_t max_world; // various boolean properties: diff --git a/src/method.c b/src/method.c index 42021a9a73a7e..7f9189be0d152 100644 --- a/src/method.c +++ b/src/method.c @@ -320,6 +320,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void) src->inlineable = 0; src->propagate_inbounds = 0; src->pure = 0; + src->edges = jl_nothing; return src; } diff --git a/stdlib/Serialization/src/Serialization.jl b/stdlib/Serialization/src/Serialization.jl index 8336874404d1f..05d54acc49bcb 100644 --- a/stdlib/Serialization/src/Serialization.jl +++ b/stdlib/Serialization/src/Serialization.jl @@ -1015,8 +1015,15 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo}) ci.slottypes = deserialize(s) ci.rettype = deserialize(s) ci.parent = deserialize(s) - ci.min_world = reinterpret(UInt, deserialize(s)) - ci.max_world = reinterpret(UInt, deserialize(s)) + world_or_edges = deserialize(s) + pre_13 = isa(world_or_edges, Integer) + if pre_13 + ci.min_world = world_or_edges + else + ci.edges = world_or_edges + ci.min_world = reinterpret(UInt, deserialize(s)) + ci.max_world = reinterpret(UInt, deserialize(s)) + end end ci.inferred = deserialize(s) ci.inlineable = deserialize(s) diff --git a/test/choosetests.jl b/test/choosetests.jl index 7e641426d340d..50cd3d39311e6 100644 --- a/test/choosetests.jl +++ b/test/choosetests.jl @@ -108,7 +108,7 @@ function choosetests(choices = []) end compilertests = ["compiler/inference", "compiler/validation", "compiler/ssair", "compiler/irpasses", - "compiler/codegen", "compiler/inline"] + "compiler/codegen", "compiler/inline", "compiler/contextual"] if "compiler" in skip_tests filter!(x -> (x != "compiler" && !(x in compilertests)), tests) diff --git a/test/compiler/contextual.jl b/test/compiler/contextual.jl new file mode 100644 index 0000000000000..bbbdb1eb041fb --- /dev/null +++ b/test/compiler/contextual.jl @@ -0,0 +1,110 @@ +module MiniCassette + # A minimal demonstration of the cassette mechanism. Doesn't support all the + # fancy features, but sufficient to exercise this code path in the compiler. + + using Core.Compiler: method_instances, retrieve_code_info, CodeInfo, + MethodInstance, SSAValue, GotoNode, Slot, SlotNumber, quoted, + signature_type + using Base: _methods_by_ftype + using Base.Meta: isexpr + using Test + + export Ctx, overdub + + struct Ctx; end + + # A no-op cassette-like transform + function transform_expr(expr, map_slot_number, map_ssa_value, sparams) + transform(expr) = transform_expr(expr, map_slot_number, map_ssa_value, sparams) + if isexpr(expr, :call) + return Expr(:call, overdub, SlotNumber(2), map(transform, expr.args)...) + elseif isexpr(expr, :gotoifnot) + return Expr(:gotoifnot, transform(expr.args[1]), map_ssa_value(SSAValue(expr.args[2])).id) + elseif isexpr(expr, :static_parameter) + return quoted(sparams[expr.args[1]]) + elseif isa(expr, Expr) + return Expr(expr.head, map(transform, expr.args)...) + elseif isa(expr, GotoNode) + return GotoNode(map_ssa_value(SSAValue(expr.label)).id) + elseif isa(expr, Slot) + return map_slot_number(expr.id) + elseif isa(expr, SSAValue) + return map_ssa_value(expr) + else + return expr + end + end + + function transform!(ci, nargs, sparams) + code = ci.code + ci.slotnames = Symbol[Symbol("#self#"), :ctx, :f, :args, ci.slotnames[nargs+1:end]...] + ci.slotflags = UInt8[(0x00 for i = 1:4)..., ci.slotflags[nargs+1:end]...] + # Insert one SSAValue for every argument statement + prepend!(code, [Expr(:call, getfield, SlotNumber(4), i) for i = 1:nargs]) + prepend!(ci.codelocs, [0 for i = 1:nargs]) + ci.ssavaluetypes += nargs + function map_slot_number(slot) + if slot == 1 + # self in the original function is now `f` + return SlotNumber(3) + elseif 2 <= slot <= nargs + 1 + # Arguments get inserted as ssa values at the top of the function + return SSAValue(slot - 1) + else + # The first non-argument slot will be 5 + return SlotNumber(slot - (nargs + 1) + 4) + end + end + map_ssa_value(ssa::SSAValue) = SSAValue(ssa.id + nargs) + for i = (nargs+1:length(code)) + code[i] = transform_expr(code[i], map_slot_number, map_ssa_value, sparams) + end + end + + function overdub_generator(self, c, f, args) + if f <: Core.Builtin || !isdefined(f, :instance) + return :(return f(args...)) + end + + tt = Tuple{f, args...} + mthds = _methods_by_ftype(tt, -1, typemax(UInt)) + @assert length(mthds) == 1 + mtypes, msp, m = mthds[1] + mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp) + # Unsupported in this mini-cassette + @assert !mi.def.isva + code_info = retrieve_code_info(mi) + @assert isa(code_info, CodeInfo) + code_info = copy(code_info) + if isdefined(code_info, :edges) + code_info.edges = MethodInstance[mi] + end + transform!(code_info, length(args), msp) + code_info + end + + @eval function overdub(c::Ctx, f, args...) + $(Expr(:meta, :generated_only)) + $(Expr(:meta, + :generated, + Expr(:new, + Core.GeneratedFunctionStub, + :overdub_generator, + Any[:overdub, :ctx, :f, :args], + Any[], + @__LINE__, + QuoteNode(Symbol(@__FILE__)), + true))) + end +end + +using .MiniCassette + +# Test #265 for Cassette +f() = 1 +@test overdub(Ctx(), f) === 1 +f() = 2 +@test overdub(Ctx(), f) === 2 + +# Test that MiniCassette is at least somewhat capable by overdubbing gcd +@test overdub(Ctx(), gcd, 10, 20) === gcd(10, 20)