Skip to content

Commit

Permalink
improve handling of world age transitions
Browse files Browse the repository at this point in the history
  • Loading branch information
vtjnash committed Sep 13, 2016
1 parent bc729aa commit 0490a9e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 53 deletions.
97 changes: 54 additions & 43 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ type InferenceState
# whereas backedges is optimized for iteration
edges::ObjectIdDict # a Dict{InferenceState, Vector{LineNum}}
backedges::Vector{Tuple{InferenceState, Vector{LineNum}}}
li_edges::Vector{Any} # a Set{MethodInstance}
# iteration fixed-point detection
fixedpoint::Bool
inworkq::Bool
Expand Down Expand Up @@ -167,8 +166,8 @@ type InferenceState
end

if cached && !toplevel
min_valid = min_age(linfo)
max_valid = max_age(linfo)
min_valid = min_age(linfo.def)
max_valid = max_age(linfo.def)
else
min_valid = UInt(0)
max_valid = UInt(0)
Expand All @@ -181,7 +180,6 @@ type InferenceState
ssavalue_uses, ssavalue_init,
ObjectIdDict(), # Dict{InferenceState, Vector{LineNum}}(),
Vector{Tuple{InferenceState, Vector{LineNum}}}(),
Vector{Any}(), # Set{MethodInstance}()
false, false, optimize, inlining, cached, false, false)
push!(active, frame)
nactive[] += 1
Expand Down Expand Up @@ -211,26 +209,6 @@ function get_staged(li::MethodInstance)
return src
end

# TODO: track the worlds for which this InferenceState
# is being used, and split it if the WIP requires it
function update_valid_ages!(sv::InferenceState)
if isdefined(sv.linfo, :backedges)
min_valid = sv.min_valid
max_valid = sv.max_valid
for li in sv.li_edges
li = li::MethodInstance
min_valid = max(min_valid, min_age(li))
max_valid = min(max_valid, max_age(li))
end
sv.min_valid = min_valid
sv.max_valid = max_valid
end
end
function update_valid_age!(li::MethodInstance, sv::InferenceState)
sv.min_valid = max(sv.min_valid, min_age(li))
sv.max_valid = min(sv.max_valid, max_age(li))
end


#### current global inference state ####

Expand Down Expand Up @@ -1543,15 +1521,45 @@ end
inlining_enabled() = (JLOptions().can_inline == 1)
coverage_enabled() = (JLOptions().code_coverage != 0)

# TODO: track the worlds for which this InferenceState
# is being used, and split it if the WIP requires it?
function converge_valid_age!(sv::InferenceState)
# push the validity range of sv into its fixedpoint callers
# recursing as needed to cover the graph
for (i, _) in sv.backedges
if i.fixedpoint
updated = false
if i.min_valid < sv.min_valid
i.min_valid = sv.min_valid
updated = true
end
if i.max_valid > sv.max_valid
i.max_valid = sv.max_valid
updated = true
end
if updated
converge_valid_age!(i)
end
end
end
nothing
end
function update_valid_age!(edge::InferenceState, sv::InferenceState)
sv.min_valid = max(edge.min_valid, sv.min_valid)
sv.max_valid = min(edge.max_valid, sv.max_valid)
nothing
end
function update_valid_age!(li::MethodInstance, sv::InferenceState)
sv.min_valid = max(sv.min_valid, min_age(li))
sv.max_valid = min(sv.max_valid, max_age(li))
nothing
end
function add_backedge(li::MethodInstance, sv::InferenceState)
caller = sv.linfo
isdefined(caller, :def) || return # don't add backedges to toplevel exprs
isdefined(li, :backedges) || (li.backedges = []) # lazy-init the backedges array
in(caller, li.backedges) || push!(li.backedges, caller) # add a backedge from callee to caller
update_valid_age!(li, sv)
if li.inInference
in(li, sv.li_edges) || push!(sv.li_edges, li) # add a forward edge from caller to callee
end
nothing
end
function add_mt_backedge(mt::MethodTable, sv::InferenceState)
Expand Down Expand Up @@ -1675,8 +1683,6 @@ function typeinf_edge(method::Method, atypes::ANY, sparams::SimpleVector, needtr
end

ccall(:jl_typeinf_begin, Void, ())
# XXX: the following logic is likely subtly broken if code.code was nothing,
# although it seems unlikely something bad (infinite recursion) will happen as a result
if linfo.inInference
# inference on this signature may be in progress,
# find the corresponding frame in the active list
Expand All @@ -1691,7 +1697,7 @@ function typeinf_edge(method::Method, atypes::ANY, sparams::SimpleVector, needtr
# TODO: this assertion seems iffy
assert(frame !== nothing)
else
# TODO: verify again here that linfo wasn't just inferred
# XXX: the following logic needs to repeat the test for linfo.inferred now that it hold the lock
# inference not started yet, make a new frame for a new lambda
if method.isstaged
try
Expand All @@ -1712,6 +1718,7 @@ function typeinf_edge(method::Method, atypes::ANY, sparams::SimpleVector, needtr
# if we were called from inside inference, the caller will be the InferenceState object
# for which the edge was required
caller = caller::InferenceState
update_valid_age!(frame, caller)
if !caller.inferred
if haskey(caller.edges, frame)
Ws = caller.edges[frame]::Vector{Int}
Expand All @@ -1727,7 +1734,6 @@ function typeinf_edge(method::Method, atypes::ANY, sparams::SimpleVector, needtr
end
end
typeinf_loop(frame)
isa(caller, InferenceState) && add_backedge(linfo, caller::InferenceState)
ccall(:jl_typeinf_end, Void, ())
return (frame.src, widenconst(frame.bestguess), frame.inferred && frame.optimized)
end
Expand Down Expand Up @@ -1804,8 +1810,12 @@ function typeinf_loop(frame)
i.inworkq = true
end
end
for i in fplist
# push valid ages from each node across the graph
converge_valid_age!(i::InferenceState)
end
for i in length(fplist):-1:1
finish(fplist[i]) # this may add incomplete work to active
finish(fplist[i]::InferenceState) # this may add incomplete work to active
end
end
end
Expand Down Expand Up @@ -1972,8 +1982,8 @@ function typeinf_frame(frame)
else # fixedpoint propagation
for (i, _) in frame.edges
i = i::InferenceState
update_valid_ages!(i) # converge age at the same time
if !i.fixedpoint
update_valid_age!(i, frame) # work towards converging age at the same time
if !i.inworkq
push!(workq, i)
i.inworkq = true
Expand All @@ -1994,7 +2004,7 @@ function unmark_fixedpoint(frame::InferenceState)
# based upon (recursively) assuming that frame was stuck
if frame.fixedpoint
frame.fixedpoint = false
for (i,_) in frame.backedges
for (i, _) in frame.backedges
unmark_fixedpoint(i)
end
end
Expand Down Expand Up @@ -2028,8 +2038,9 @@ end
# inference completed on `me`
# update the MethodInstance and notify the edges
function finish(me::InferenceState)
for (i,_) in me.edges
@assert (i::InferenceState).fixedpoint
for (i, _) in me.edges
i = i::InferenceState
@assert i.fixedpoint
end
# below may call back into inference and
# see this InferenceState is in an incomplete state
Expand Down Expand Up @@ -2119,33 +2130,33 @@ function finish(me::InferenceState)
inferred.code = ccall(:jl_compress_ast, Any, (Any, Any), me.linfo.def, inferred.code)
end
end
update_valid_ages!(me)
ccall(:jl_specialization_set_world, Void, (Any, UInt, UInt),
me.linfo, me.min_valid, me.max_valid)
end
# TODO: check that mutating the lambda info is OK first?
ccall(:jl_set_lambda_rettype, Void, (Any, Any, Any, Any), me.linfo, widenconst(me.bestguess), const_api, inferred)
end

me.src.inferred = true
me.linfo.inInference = false
# finalize and record the linfo result
me.optimized = true

# lazy-delete the item from active for several reasons:
# efficiency, correctness, and recursion-safety
nactive[] -= 1
active[findlast(active, me)] = nothing

# update all of the callers by traversing the backedges
for (i,_) in me.backedges
for (i, _) in me.backedges
if !me.fixedpoint || !i.fixedpoint
# wake up each backedge, unless both me and it already reached a fixed-point (cycle resolution stage)
delete!(i.edges, me)
i.inworkq || push!(workq, i)
i.inworkq = true
end
add_backedge(me.linfo, i) # add the real backedge now
end

# finalize and record the linfo result
me.src.inferred = true
me.linfo.inInference = false
me.optimized = true
nothing
end

Expand Down
23 changes: 13 additions & 10 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,12 @@ JL_DLLEXPORT jl_method_instance_t *jl_specializations_get_linfo(jl_method_t *m,
JL_GC_PUSH1(&li);
// TODO: fuse lookup and insert steps
assert(world >= m->min_world && world <= m->max_world);
size_t max_world = (jl_world_counter == world ? m->max_world : world);
li->min_world = world;
li->max_world = max_world;
assert(world <= max_world);
jl_typemap_insert(&m->specializations, (jl_value_t*)m, type, jl_emptysvec, NULL, jl_emptysvec, (jl_value_t*)li, 0, &tfunc_cache,
world, world, NULL);
world, max_world, NULL);
JL_UNLOCK(&m->writelock);
JL_GC_POP();
return li;
Expand All @@ -152,11 +156,7 @@ static int update_valid_world(jl_typemap_entry_t *entry, void *closure)
{
if (entry->func.value == closure) {
jl_method_instance_t *li = (jl_method_instance_t*)closure;
// these asserts are based on how it is used currently from inference.jl
// they aren't really correct in general (well, in general, this method will corrupt the system state)
assert(entry->min_world == entry->max_world);
assert(li->min_world <= entry->min_world);
assert(li->max_world >= entry->max_world);
assert(li->min_world <= entry->min_world); // entry min_world should have been the same as li min_world before jl_specialization_set_world
entry->min_world = li->min_world;
entry->max_world = li->max_world;
}
Expand All @@ -165,12 +165,15 @@ static int update_valid_world(jl_typemap_entry_t *entry, void *closure)
JL_DLLEXPORT void jl_specialization_set_world(jl_method_instance_t *li, size_t min_world, size_t max_world)
{
assert(min_world <= max_world);
assert(li->min_world <= min_world);
assert(li->max_world >= max_world);
// this assert is based on how it is used currently from inference.jl
// it is really not correct in general (well, in general, this method will corrupt the system state)
assert(min_world <= li->min_world); // may widen lower bound, but don't expect to be narrowing it
// could be widening or narrowing upper bound
li->min_world = min_world;
li->max_world = max_world;
if (li->def != NULL)
if (li->def != NULL) {
jl_typemap_visitor(li->def->specializations, update_valid_world, (void*)li);
}
}

JL_DLLEXPORT jl_value_t *jl_specializations_lookup(jl_method_t *m, jl_tupletype_t *type, size_t world)
Expand Down Expand Up @@ -854,7 +857,7 @@ static jl_method_instance_t *cache_method(jl_methtable_t *mt, union jl_typemap_t
}

jl_typemap_insert(cache, parent, origtype, jl_emptysvec, type, guardsigs, (jl_value_t*)newmeth, jl_cachearg_offset(mt), &lambda_cache,
m->min_world, m->max_world, NULL);
newmeth->min_world, newmeth->max_world, NULL);

if (definition->traced && jl_method_tracer && allow_exec)
jl_call_tracer(jl_method_tracer, (jl_value_t*)newmeth);
Expand Down

0 comments on commit 0490a9e

Please sign in to comment.