Skip to content

Commit

Permalink
merge specializations and tfunc arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
vtjnash committed Jun 10, 2016
1 parent 466da65 commit 1ec5092
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 125 deletions.
117 changes: 68 additions & 49 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ immutable Const
end

type InferenceState
atypes #::Type # type sig
sp::SimpleVector # static parameters
label_counter::Int # index of the current highest label for this function
fedbackvars::Dict{SSAValue, Bool}
Expand Down Expand Up @@ -78,6 +77,7 @@ type InferenceState

function InferenceState(linfo::LambdaInfo, atypes::ANY, sparams::SimpleVector, optimize::Bool)
@assert isa(linfo.code,Array{Any,1})
linfo.inInference = true
nslots = length(linfo.slotnames)
nl = label_counter(linfo.code)+1

Expand Down Expand Up @@ -156,7 +156,7 @@ type InferenceState

inmodule = isdefined(linfo, :def) ? linfo.def.module : current_module() # toplevel thunks are inferred in the current module
frame = new(
atypes, sp, nl, Dict{SSAValue, Bool}(), inmodule, 0, false,
sp, nl, Dict{SSAValue, Bool}(), inmodule, 0, false,
linfo, linfo, la, s, Union{}, W, n,
cur_hand, handler_at, n_handlers,
ssavalue_uses, ssavalue_init,
Expand Down Expand Up @@ -768,7 +768,7 @@ function abstract_call_gf_by_type(f::ANY, argtype::ANY, sv)
limitlength = false
for (callee, _) in sv.edges
callee = callee::InferenceState
if method === callee.linfo.def && ls > length(callee.atypes.parameters)
if method === callee.linfo.def && ls > length(callee.linfo.specTypes.parameters)
limitlength = true
break
end
Expand All @@ -781,16 +781,16 @@ function abstract_call_gf_by_type(f::ANY, argtype::ANY, sv)
infstate = infstate::InferenceState
if isdefined(infstate.linfo, :def) && method === infstate.linfo.def
td = type_depth(sig)
if ls > length(infstate.atypes.parameters)
if ls > length(infstate.linfo.specTypes.parameters)
limitlength = true
end
if td > type_depth(infstate.atypes)
if td > type_depth(infstate.linfo.specTypes)
# impose limit if we recur and the argument types grow beyond MAX_TYPE_DEPTH
if td > MAX_TYPE_DEPTH
sig = limit_type_depth(sig, 0, true, [])
break
else
p1, p2 = sig.parameters, infstate.atypes.parameters
p1, p2 = sig.parameters, infstate.linfo.specTypes.parameters
if length(p2) == ls
limitdepth = false
newsig = Array{Any}(ls)
Expand Down Expand Up @@ -1117,7 +1117,7 @@ function abstract_eval(e::ANY, vtypes::VarTable, sv::InferenceState)
end
elseif isleaftype(t)
t = Type{t}
elseif isleaftype(sv.atypes)
elseif isleaftype(sv.linfo.specTypes)
if isa(t,TypeVar)
t = Type{t.ub}
else
Expand Down Expand Up @@ -1386,7 +1386,8 @@ end

# create a specialized LambdaInfo from a method
function specialize_method(method::Method, types::ANY, sp::SimpleVector)
li = ccall(:jl_get_specialized, Any, (Any, Any, Any), method, types, sp)::LambdaInfo
li = ccall(:jl_get_specialized, Ref{LambdaInfo}, (Any, Any, Any), method, types, sp)
return li
end

# create copies of any field that type-inference might modify
Expand All @@ -1408,38 +1409,54 @@ function unshare_linfo!(li::LambdaInfo)
end

#### entry points for inferring a LambdaInfo given a type signature ####

function typeinf_edge(method::Method, atypes::ANY, sparams::SimpleVector, needtree::Bool, optimize::Bool, cached::Bool, caller)
local code = nothing
local frame = nothing
offs = 0
# check cached t-functions
# aggregate all saved type inference data there
if cached && !is(method.tfunc, nothing)
code = ccall(:jl_tfunc_cache_lookup, Any, (Any, Any, Int8), method, atypes, offs)
if isa(code, InferenceState)
# inference on this signature is in progress
frame = code
if isa(caller, LambdaInfo)
# record the LambdaInfo where this result should be cached when it is finished
@assert frame.destination === frame.linfo || frame.destination === caller
frame.destination = caller
end
elseif isa(code, Type)
# sometimes just a return type is stored here. if a full AST
# is not needed, we can return it.
if !needtree
return (nothing, code, true)
end
elseif isa(code,LambdaInfo)
@assert code.inferred
return (code, code.rettype, true)
else
# otherwise this is an InferenceState from a different bootstrap stage's
# copy of the inference code; ignore it.
# check cached specializations
# for an existing result stored there
if cached
if !is(method.specializations, nothing)
code = ccall(:jl_specializations_lookup, Any, (Any, Any), method, atypes)
if isa(code, Void)
# something completely new
elseif isa(code, LambdaInfo)
# something existing
if code.inferred
return (code, code.rettype, true)
end
else
# sometimes just a return type is stored here. if a full AST
# is not needed, we can return it.
typeassert(code, Type)
if !needtree
return (nothing, code, true)
end
code = nothing
end
end

if isa(code, LambdaInfo) && code.inInference
# inference on this signature may be in progress,
# find the corresponding frame in the active list
for infstate in active
infstate === nothing && continue
infstate = infstate::InferenceState
if code === infstate.linfo
frame = infstate
break
end
end
end
end

if isa(caller, LambdaInfo)
code = caller
end

if frame === nothing
# inference not started yet, make a new frame for a new lambda
# add lam to be inferred and record the edge

if caller === nothing && needtree && in_typeinf_loop
# if the caller needed the ast, but we are already in the typeinf loop
# then just return early -- we can't fulfill this request
Expand Down Expand Up @@ -1468,9 +1485,11 @@ function typeinf_edge(method::Method, atypes::ANY, sparams::SimpleVector, needtr
return (nothing, Union{}, false)
end
end
# add lam to be inferred and record the edge
if isa(caller, LambdaInfo)
linfo = caller

if isa(code, LambdaInfo)
# reuse the existing code object
linfo = code
@assert typeseq(linfo.specTypes, atypes)
elseif method.isstaged
if !isleaftype(atypes)
# don't call staged functions on abstract types.
Expand All @@ -1488,16 +1507,18 @@ function typeinf_edge(method::Method, atypes::ANY, sparams::SimpleVector, needtr
linfo = specialize_method(method, atypes, sparams)
end
# our stack frame inference context
frame = InferenceState(unshare_linfo!(linfo), atypes, sparams, optimize)

frame = InferenceState(unshare_linfo!(linfo::LambdaInfo), atypes, sparams, optimize)
if cached
tfunc_bp = ccall(:jl_tfunc_cache_insert, Ref{TypeMapEntry}, (Any, Any, Any, Int8), method, atypes, frame, offs)
frame.tfunc_bp = tfunc_bp
frame.tfunc_bp = ccall(:jl_specializations_insert, Ref{TypeMapEntry}, (Any, Any, Any), method, atypes, linfo)
end
end
frame = frame::InferenceState

if !isa(caller, Void) && !isa(caller, LambdaInfo)
@assert isa(caller, InferenceState)
# if we were called from inside inference,
# the caller will be the InferenceState object
# for which the edge was required
caller = caller::InferenceState
if haskey(caller.edges, frame)
Ws = caller.edges[frame]::Vector{Int}
if !(caller.currpc in Ws)
Expand Down Expand Up @@ -1869,10 +1890,6 @@ end
# inference completed on `me`
# update the LambdaInfo and notify the edges
function finish(me::InferenceState)
# lazy-delete the item from active for several reasons:
# efficiency, correctness, and recursion-safety
nactive[] -= 1
active[findlast(active, me)] = nothing
for (i,_) in me.edges
@assert (i::InferenceState).fixedpoint
end
Expand Down Expand Up @@ -1941,9 +1958,11 @@ function finish(me::InferenceState)
out.pure = me.linfo.pure
out.inlineable = me.linfo.inlineable
end
if me.tfunc_bp !== nothing
me.tfunc_bp.func = me.linfo
end

# lazy-delete the item from active for several reasons:
# efficiency, correctness, and recursion-safety
nactive[] -= 1
active[findlast(active, me)] = nothing

# update all of the callers by traversing the backedges
for (i,_) in me.backedges
Expand Down
3 changes: 1 addition & 2 deletions src/alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -566,14 +566,13 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(void)
jl_method_t *m =
(jl_method_t*)newobj((jl_value_t*)jl_method_type,
NWORDS(sizeof(jl_method_t)));
m->tfunc.unknown = jl_nothing;
m->specializations.unknown = jl_nothing;
m->sig = NULL;
m->tvars = NULL;
m->ambig = NULL;
m->roots = NULL;
m->module = jl_current_module;
m->lambda_template = NULL;
m->specializations = NULL;
m->name = NULL;
m->file = null_sym;
m->line = 0;
Expand Down
15 changes: 6 additions & 9 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ static void jl_serialize_value_(ios_t *s, jl_value_t *v)
arraylist_push(&reinit_list, (void*)pos);
arraylist_push(&reinit_list, (void*)3);
}
if (jl_is_method(v) && jl_typeof(((jl_method_t*)v)->tfunc.unknown) == (jl_value_t*)jl_typemap_level_type) {
if (jl_is_method(v) && jl_typeof(((jl_method_t*)v)->specializations.unknown) == (jl_value_t*)jl_typemap_level_type) {
arraylist_push(&reinit_list, (void*)pos);
arraylist_push(&reinit_list, (void*)4);
}
Expand Down Expand Up @@ -835,7 +835,7 @@ static void jl_serialize_value_(ios_t *s, jl_value_t *v)
else if (jl_is_method(v)) {
writetag(s, jl_method_type);
jl_method_t *m = (jl_method_t*)v;
union jl_typemap_t *tf = &m->tfunc;
union jl_typemap_t *tf = &m->specializations;
if (tf->unknown && tf->unknown != jl_nothing) {
// go through the t-func cache, replacing ASTs with just return
// types for abstract argument types. these ASTs are generally
Expand All @@ -844,7 +844,6 @@ static void jl_serialize_value_(ios_t *s, jl_value_t *v)
}
jl_serialize_value(s, tf->unknown);
jl_serialize_value(s, (jl_value_t*)m->name);
jl_serialize_value(s, (jl_value_t*)m->specializations);
write_int8(s, m->isstaged);
jl_serialize_value(s, (jl_value_t*)m->file);
write_int32(s, m->line);
Expand Down Expand Up @@ -1445,12 +1444,10 @@ static jl_value_t *jl_deserialize_value_(ios_t *s, jl_value_t *vtag, jl_value_t
NWORDS(sizeof(jl_method_t)));
if (usetable)
arraylist_push(&backref_list, m);
m->tfunc.unknown = jl_deserialize_value(s, (jl_value_t**)&m->tfunc);
jl_gc_wb(m, m->tfunc.unknown);
m->specializations.unknown = jl_deserialize_value(s, (jl_value_t**)&m->specializations);
jl_gc_wb(m, m->specializations.unknown);
m->name = (jl_sym_t*)jl_deserialize_value(s, NULL);
jl_gc_wb(m, m->name);
m->specializations = (jl_array_t*)jl_deserialize_value(s, (jl_value_t**)&m->specializations);
if (m->specializations) jl_gc_wb(m, m->specializations);
m->isstaged = read_int8(s);
m->file = (jl_sym_t*)jl_deserialize_value(s, NULL);
m->line = read_int32(s);
Expand Down Expand Up @@ -1834,9 +1831,9 @@ static void jl_reinit_item(ios_t *f, jl_value_t *v, int how, arraylist_t *tracee
arraylist_push(tracee_list, mt);
break;
}
case 4: { // rehash tfunc
case 4: { // rehash specializations tfunc
jl_method_t *m = (jl_method_t*)v;
jl_typemap_rehash(m->tfunc, 0);
jl_typemap_rehash(m->specializations, 0);
break;
}
default:
Expand Down
Loading

0 comments on commit 1ec5092

Please sign in to comment.