diff --git a/base/inference.jl b/base/inference.jl index 3390e178a43a2..59651d25c0165 100644 --- a/base/inference.jl +++ b/base/inference.jl @@ -72,7 +72,6 @@ type InferenceState # whereas backedges is optimized for iteration edges::ObjectIdDict # a Dict{InferenceState, Vector{LineNum}} backedges::Vector{Tuple{InferenceState, Vector{LineNum}}} - li_edges::Vector{Any} # a Set{MethodInstance} # iteration fixed-point detection fixedpoint::Bool inworkq::Bool @@ -167,8 +166,8 @@ type InferenceState end if cached && !toplevel - min_valid = min_age(linfo) - max_valid = max_age(linfo) + min_valid = min_age(linfo.def) + max_valid = max_age(linfo.def) else min_valid = UInt(0) max_valid = UInt(0) @@ -181,7 +180,6 @@ type InferenceState ssavalue_uses, ssavalue_init, ObjectIdDict(), # Dict{InferenceState, Vector{LineNum}}(), Vector{Tuple{InferenceState, Vector{LineNum}}}(), - Vector{Any}(), # Set{MethodInstance}() false, false, optimize, inlining, cached, false, false) push!(active, frame) nactive[] += 1 @@ -211,26 +209,6 @@ function get_staged(li::MethodInstance) return src end -# TODO: track the worlds for which this InferenceState -# is being used, and split it if the WIP requires it -function update_valid_ages!(sv::InferenceState) - if isdefined(sv.linfo, :backedges) - min_valid = sv.min_valid - max_valid = sv.max_valid - for li in sv.li_edges - li = li::MethodInstance - min_valid = max(min_valid, min_age(li)) - max_valid = min(max_valid, max_age(li)) - end - sv.min_valid = min_valid - sv.max_valid = max_valid - end -end -function update_valid_age!(li::MethodInstance, sv::InferenceState) - sv.min_valid = max(sv.min_valid, min_age(li)) - sv.max_valid = min(sv.max_valid, max_age(li)) -end - #### current global inference state #### @@ -1543,15 +1521,45 @@ end inlining_enabled() = (JLOptions().can_inline == 1) coverage_enabled() = (JLOptions().code_coverage != 0) +# TODO: track the worlds for which this InferenceState +# is being used, and split it if the WIP requires it? +function converge_valid_age!(sv::InferenceState) + # push the validity range of sv into its fixedpoint callers + # recursing as needed to cover the graph + for (i, _) in sv.backedges + if i.fixedpoint + updated = false + if i.min_valid < sv.min_valid + i.min_valid = sv.min_valid + updated = true + end + if i.max_valid > sv.max_valid + i.max_valid = sv.max_valid + updated = true + end + if updated + converge_valid_age!(i) + end + end + end + nothing +end +function update_valid_age!(edge::InferenceState, sv::InferenceState) + sv.min_valid = max(edge.min_valid, sv.min_valid) + sv.max_valid = min(edge.max_valid, sv.max_valid) + nothing +end +function update_valid_age!(li::MethodInstance, sv::InferenceState) + sv.min_valid = max(sv.min_valid, min_age(li)) + sv.max_valid = min(sv.max_valid, max_age(li)) + nothing +end function add_backedge(li::MethodInstance, sv::InferenceState) caller = sv.linfo isdefined(caller, :def) || return # don't add backedges to toplevel exprs isdefined(li, :backedges) || (li.backedges = []) # lazy-init the backedges array in(caller, li.backedges) || push!(li.backedges, caller) # add a backedge from callee to caller update_valid_age!(li, sv) - if li.inInference - in(li, sv.li_edges) || push!(sv.li_edges, li) # add a forward edge from caller to callee - end nothing end function add_mt_backedge(mt::MethodTable, sv::InferenceState) @@ -1675,8 +1683,6 @@ function typeinf_edge(method::Method, atypes::ANY, sparams::SimpleVector, needtr end ccall(:jl_typeinf_begin, Void, ()) - # XXX: the following logic is likely subtly broken if code.code was nothing, - # although it seems unlikely something bad (infinite recursion) will happen as a result if linfo.inInference # inference on this signature may be in progress, # find the corresponding frame in the active list @@ -1691,7 +1697,7 @@ function typeinf_edge(method::Method, atypes::ANY, sparams::SimpleVector, needtr # TODO: this assertion seems iffy assert(frame !== nothing) else - # TODO: verify again here that linfo wasn't just inferred + # XXX: the following logic needs to repeat the test for linfo.inferred now that it hold the lock # inference not started yet, make a new frame for a new lambda if method.isstaged try @@ -1712,6 +1718,7 @@ function typeinf_edge(method::Method, atypes::ANY, sparams::SimpleVector, needtr # if we were called from inside inference, the caller will be the InferenceState object # for which the edge was required caller = caller::InferenceState + update_valid_age!(frame, caller) if !caller.inferred if haskey(caller.edges, frame) Ws = caller.edges[frame]::Vector{Int} @@ -1727,7 +1734,6 @@ function typeinf_edge(method::Method, atypes::ANY, sparams::SimpleVector, needtr end end typeinf_loop(frame) - isa(caller, InferenceState) && add_backedge(linfo, caller::InferenceState) ccall(:jl_typeinf_end, Void, ()) return (frame.src, widenconst(frame.bestguess), frame.inferred && frame.optimized) end @@ -1804,8 +1810,12 @@ function typeinf_loop(frame) i.inworkq = true end end + for i in fplist + # push valid ages from each node across the graph + converge_valid_age!(i::InferenceState) + end for i in length(fplist):-1:1 - finish(fplist[i]) # this may add incomplete work to active + finish(fplist[i]::InferenceState) # this may add incomplete work to active end end end @@ -1972,8 +1982,8 @@ function typeinf_frame(frame) else # fixedpoint propagation for (i, _) in frame.edges i = i::InferenceState - update_valid_ages!(i) # converge age at the same time if !i.fixedpoint + update_valid_age!(i, frame) # work towards converging age at the same time if !i.inworkq push!(workq, i) i.inworkq = true @@ -1994,7 +2004,7 @@ function unmark_fixedpoint(frame::InferenceState) # based upon (recursively) assuming that frame was stuck if frame.fixedpoint frame.fixedpoint = false - for (i,_) in frame.backedges + for (i, _) in frame.backedges unmark_fixedpoint(i) end end @@ -2028,8 +2038,9 @@ end # inference completed on `me` # update the MethodInstance and notify the edges function finish(me::InferenceState) - for (i,_) in me.edges - @assert (i::InferenceState).fixedpoint + for (i, _) in me.edges + i = i::InferenceState + @assert i.fixedpoint end # below may call back into inference and # see this InferenceState is in an incomplete state @@ -2119,7 +2130,6 @@ function finish(me::InferenceState) inferred.code = ccall(:jl_compress_ast, Any, (Any, Any), me.linfo.def, inferred.code) end end - update_valid_ages!(me) ccall(:jl_specialization_set_world, Void, (Any, UInt, UInt), me.linfo, me.min_valid, me.max_valid) end @@ -2127,25 +2137,26 @@ function finish(me::InferenceState) ccall(:jl_set_lambda_rettype, Void, (Any, Any, Any, Any), me.linfo, widenconst(me.bestguess), const_api, inferred) end - me.src.inferred = true - me.linfo.inInference = false - # finalize and record the linfo result - me.optimized = true - # lazy-delete the item from active for several reasons: # efficiency, correctness, and recursion-safety nactive[] -= 1 active[findlast(active, me)] = nothing # update all of the callers by traversing the backedges - for (i,_) in me.backedges + for (i, _) in me.backedges if !me.fixedpoint || !i.fixedpoint # wake up each backedge, unless both me and it already reached a fixed-point (cycle resolution stage) delete!(i.edges, me) i.inworkq || push!(workq, i) i.inworkq = true end + add_backedge(me.linfo, i) # add the real backedge now end + + # finalize and record the linfo result + me.src.inferred = true + me.linfo.inInference = false + me.optimized = true nothing end diff --git a/src/gf.c b/src/gf.c index 312866fe41741..20519cc416d7e 100644 --- a/src/gf.c +++ b/src/gf.c @@ -141,8 +141,12 @@ JL_DLLEXPORT jl_method_instance_t *jl_specializations_get_linfo(jl_method_t *m, JL_GC_PUSH1(&li); // TODO: fuse lookup and insert steps assert(world >= m->min_world && world <= m->max_world); + size_t max_world = (jl_world_counter == world ? m->max_world : world); + li->min_world = world; + li->max_world = max_world; + assert(world <= max_world); jl_typemap_insert(&m->specializations, (jl_value_t*)m, type, jl_emptysvec, NULL, jl_emptysvec, (jl_value_t*)li, 0, &tfunc_cache, - world, world, NULL); + world, max_world, NULL); JL_UNLOCK(&m->writelock); JL_GC_POP(); return li; @@ -152,11 +156,7 @@ static int update_valid_world(jl_typemap_entry_t *entry, void *closure) { if (entry->func.value == closure) { jl_method_instance_t *li = (jl_method_instance_t*)closure; - // these asserts are based on how it is used currently from inference.jl - // they aren't really correct in general (well, in general, this method will corrupt the system state) - assert(entry->min_world == entry->max_world); - assert(li->min_world <= entry->min_world); - assert(li->max_world >= entry->max_world); + assert(li->min_world <= entry->min_world); // entry min_world should have been the same as li min_world before jl_specialization_set_world entry->min_world = li->min_world; entry->max_world = li->max_world; } @@ -165,12 +165,15 @@ static int update_valid_world(jl_typemap_entry_t *entry, void *closure) JL_DLLEXPORT void jl_specialization_set_world(jl_method_instance_t *li, size_t min_world, size_t max_world) { assert(min_world <= max_world); - assert(li->min_world <= min_world); - assert(li->max_world >= max_world); + // this assert is based on how it is used currently from inference.jl + // it is really not correct in general (well, in general, this method will corrupt the system state) + assert(min_world <= li->min_world); // may widen lower bound, but don't expect to be narrowing it + // could be widening or narrowing upper bound li->min_world = min_world; li->max_world = max_world; - if (li->def != NULL) + if (li->def != NULL) { jl_typemap_visitor(li->def->specializations, update_valid_world, (void*)li); + } } JL_DLLEXPORT jl_value_t *jl_specializations_lookup(jl_method_t *m, jl_tupletype_t *type, size_t world) @@ -854,7 +857,7 @@ static jl_method_instance_t *cache_method(jl_methtable_t *mt, union jl_typemap_t } jl_typemap_insert(cache, parent, origtype, jl_emptysvec, type, guardsigs, (jl_value_t*)newmeth, jl_cachearg_offset(mt), &lambda_cache, - m->min_world, m->max_world, NULL); + newmeth->min_world, newmeth->max_world, NULL); if (definition->traced && jl_method_tracer && allow_exec) jl_call_tracer(jl_method_tracer, (jl_value_t*)newmeth);