From 28922e0857a2fcc8ef53ffbfebb76cc011f9000b Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Sun, 11 Sep 2016 07:45:18 -0400 Subject: [PATCH] Inline `invoke` (take 3) --- base/inference.jl | 188 ++++++++++++++++++++++++++++++++++++++-------- src/gf.c | 51 ++++++++++++- 2 files changed, 206 insertions(+), 33 deletions(-) diff --git a/base/inference.jl b/base/inference.jl index 529ec03b2d5ab..e44b60c518a03 100644 --- a/base/inference.jl +++ b/base/inference.jl @@ -632,9 +632,9 @@ function invoke_tfunc(f::ANY, types::ANY, argtype::ANY, sv::InferenceState) if is(entry, nothing) return Any end - meth = entry.func - (ti, env) = ccall(:jl_match_method, Any, (Any, Any, Any), - argtype, meth.sig, meth.tvars)::SimpleVector + meth = (entry::TypeMapEntry).func + (ti, env) = ccall(:jl_match_method, Ref{SimpleVector}, (Any, Any, Any), + argtype, meth.sig, meth.tvars) return typeinf_edge(meth::Method, ti, env, sv)[2] end @@ -2335,10 +2335,25 @@ end #### post-inference optimizations #### -function inline_as_constant(val::ANY, argexprs, linfo::LambdaInfo) +function inline_as_constant(val::ANY, argexprs, linfo::LambdaInfo, + invoke_data::ANY) # check if any arguments aren't effect_free and need to be kept around - stmts = Any[] - for i = 1:length(argexprs) + starti = 1 + stmts = if invoke_data === nothing + [] + else + invoke_data = invoke_data::InvokeData + if invoke_data.texpr === nothing + Any[invoke_data.fexpr] + else + starti = 2 + arg0 = argexprs[1] + Any[invoke_data.fexpr, + effect_free(arg0, linfo, false) ? nothing : arg0, + invoke_data.texpr] + end + end + for i = starti:length(argexprs) arg = argexprs[i] if !effect_free(arg, linfo, false) push!(stmts, arg) @@ -2357,8 +2372,28 @@ function countunionsplit(atypes::Vector{Any}) return nu end +immutable InvokeData + mt::MethodTable + entry::TypeMapEntry + types0 + fexpr + texpr +end + +function get_spec_lambda(atypes::ANY, invoke_data::ANY) + if invoke_data === nothing + return ccall(:jl_get_spec_lambda, Any, (Any,), atypes) + else + invoke_data = invoke_data::InvokeData + # TODO compute intersection and throws an error + atypes <: invoke_data.types0 || return nothing + return ccall(:jl_get_invoke_lambda, Any, (Any, Any, Any), + invoke_data.mt, invoke_data.entry, atypes) + end +end + function invoke_NF(argexprs, etype::ANY, atypes, sv, enclosing, - atype_unlimited::ANY) + atype_unlimited::ANY, invoke_data::ANY) # converts a :call to :invoke nu = countunionsplit(atypes) nu > MAX_UNION_SPLITTING && return NF @@ -2373,6 +2408,7 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, enclosing, ex.typ = etype stmts = [] arg_hoisted = false + arg0_hoisted = false for i = length(atypes):-1:1 ti = atypes[i] if arg_hoisted || isa(ti, Union) @@ -2382,13 +2418,21 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, enclosing, newvar = newvar!(sv, ti) insert!(stmts, 1, :($newvar = $aei)) ex.args[i] = newvar + if i == 1 + arg0_hoisted = true + end end end end + if !(invoke_data === nothing) + invoke_data = invoke_data::InvokeData + insert!(stmts, 1, invoke_data.fexpr) + insert!(stmts, 2 + arg0_hoisted, invoke_data.texpr) + end function splitunion(atypes::Vector{Any}, i::Int) if i == 0 local sig = argtypes_to_type(atypes) - local li = ccall(:jl_get_spec_lambda, Any, (Any,), sig) + local li = get_spec_lambda(sig, invoke_data) li === nothing && return false local stmt = [] push!(stmt, Expr(:(=), linfo_var, li)) @@ -2455,13 +2499,30 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, enclosing, return (ret_var, stmts) end else - local cache_linfo = ccall(:jl_get_spec_lambda, Any, (Any,), atype_unlimited) + local cache_linfo = get_spec_lambda(atype_unlimited, invoke_data) cache_linfo === nothing && return NF unshift!(argexprs, cache_linfo) ex = Expr(:invoke) ex.args = argexprs ex.typ = etype - return ex + if invoke_data === nothing + return ex + else + invoke_data = invoke_data::InvokeData + fexpr = invoke_data.fexpr + texpr = invoke_data.texpr + if texpr === nothing + if fexpr === nothing + return ex + else + return ex, Any[invoke_data.fexpr] + end + end + newvar = newvar!(sv, atypes[1]) + stmts = Any[fexpr, :($newvar = $(argexprs[1])), texpr] + argexprs[1] = newvar + return ex, stmts + end end return NF end @@ -2515,14 +2576,52 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference end end end - if isa(f, IntrinsicFunction) || ft ⊑ IntrinsicFunction || + invoke_data = nothing + if f === Core.invoke && length(atypes) >= 3 + ft = widenconst(atypes[2]) + invoke_tt = widenconst(atypes[3]) + if !isleaftype(ft) || !isleaftype(invoke_tt) || !isType(invoke_tt) + return NF + end + if !(isa(invoke_tt.parameters[1], Type) && + invoke_tt.parameters[1] <: Tuple) + return NF + end + invoke_tt_params = invoke_tt.parameters[1].parameters + invoke_types = Tuple{ft, invoke_tt_params...} + invoke_entry = ccall(:jl_gf_invoke_lookup, Any, (Any,), invoke_types) + invoke_entry === nothing && return NF + fexpr = argexprs[1] + texpr = argexprs[3] + if effect_free(fexpr, enclosing, false) + fexpr = nothing + end + if effect_free(texpr, enclosing, false) + fexpr = nothing + end + invoke_data = InvokeData(ft.name.mt, invoke_entry, + invoke_types, fexpr, texpr) + atype0 = atypes[2] + argexpr0 = argexprs[2] + atypes = atypes[4:end] + argexprs = argexprs[4:end] + unshift!(atypes, atype0) + unshift!(argexprs, argexpr0) + f = isdefined(ft, :instance) ? ft.instance : nothing + elseif isa(f, IntrinsicFunction) || ft ⊑ IntrinsicFunction || isa(f, Builtin) || ft ⊑ Builtin return NF end - local atype_unlimited = argtypes_to_type(atypes) + atype_unlimited = argtypes_to_type(atypes) + if !(invoke_data === nothing) + invoke_data = invoke_data::InvokeData + # TODO emit a type check and proceed for this case + atype_unlimited <: invoke_data.types0 || return NF + end if !sv.inlining - return invoke_NF(argexprs, e.typ, atypes, sv, enclosing, atype_unlimited) + return invoke_NF(argexprs, e.typ, atypes, sv, enclosing, + atype_unlimited, invoke_data) end if length(atype_unlimited.parameters) - 1 > MAX_TUPLETYPE_LEN @@ -2530,32 +2629,46 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference else atype = atype_unlimited end - meth = _methods_by_ftype(atype, 1) - if meth === false || length(meth) != 1 - return invoke_NF(argexprs, e.typ, atypes, sv, enclosing, atype_unlimited) + if invoke_data === nothing + meth = _methods_by_ftype(atype, 1) + if meth === false || length(meth) != 1 + return invoke_NF(argexprs, e.typ, atypes, sv, enclosing, + atype_unlimited, invoke_data) + end + meth = meth[1]::SimpleVector + metharg = meth[1]::Type + methsp = meth[2]::SimpleVector + method = meth[3]::Method + else + invoke_data = invoke_data::InvokeData + method = invoke_data.entry.func + (metharg, methsp) = ccall(:jl_match_method, Ref{SimpleVector}, + (Any, Any, Any), + atype_unlimited, method.sig, method.tvars) + methsp = methsp::SimpleVector end - meth = meth[1]::SimpleVector - metharg = meth[1]::Type - methsp = meth[2] - method = meth[3]::Method # check whether call can be inlined to just a quoted constant value if isa(f, widenconst(ft)) && !method.isstaged && (method.lambda_template.pure || f === return_type) && (isType(e.typ) || isa(e.typ,Const)) if isType(e.typ) if !has_typevars(e.typ.parameters[1]) - return inline_as_constant(e.typ.parameters[1], argexprs, enclosing) + return inline_as_constant(e.typ.parameters[1], argexprs, + enclosing, invoke_data) end else assert(isa(e.typ,Const)) - return inline_as_constant(e.typ.val, argexprs, enclosing) + return inline_as_constant(e.typ.val, argexprs, enclosing, + invoke_data) end end methsig = method.sig if !(atype <: metharg) - return invoke_NF(argexprs, e.typ, atypes, sv, enclosing, atype_unlimited) + return invoke_NF(argexprs, e.typ, atypes, sv, enclosing, + atype_unlimited, invoke_data) end + argexprs0 = argexprs na = method.lambda_template.nargs # check for vararg function isva = false @@ -2582,18 +2695,22 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference (linfo, ty, inferred) = typeinf(method, metharg, methsp, false) if linfo === nothing || !inferred - return invoke_NF(e.args, e.typ, atypes, sv, enclosing, atype_unlimited) + return invoke_NF(argexprs0, e.typ, atypes, sv, enclosing, + atype_unlimited, invoke_data) end if linfo !== nothing && linfo.jlcall_api == 2 # in this case function can be inlined to a constant - return inline_as_constant(linfo.constval, argexprs, enclosing) + return inline_as_constant(linfo.constval, argexprs, enclosing, + invoke_data) elseif linfo !== nothing && !linfo.inlineable - return invoke_NF(e.args, e.typ, atypes, sv, enclosing, atype_unlimited) + return invoke_NF(argexprs0, e.typ, atypes, sv, enclosing, + atype_unlimited, invoke_data) elseif linfo === nothing || linfo.code === nothing (linfo, ty, inferred) = typeinf(method, metharg, methsp, true) end if linfo === nothing || !inferred || !linfo.inlineable || (ast = linfo.code) === nothing - return invoke_NF(e.args, e.typ, atypes, sv, enclosing, atype_unlimited) + return invoke_NF(argexprs0, e.typ, atypes, sv, enclosing, + atype_unlimited, invoke_data) end spvals = Any[] @@ -2607,8 +2724,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference end end - methargs = metharg.parameters - nm = length(methargs) + nm = length(metharg.parameters) if !isa(ast, Array{Any,1}) ast = ccall(:jl_uncompress_ast, Any, (Any,Any), linfo, ast) @@ -2622,14 +2738,19 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference propagate_inbounds = linfo.propagate_inbounds # see if each argument occurs only once in the body expression - stmts = Any[] - prelude_stmts = Any[] + stmts = [] + prelude_stmts = [] stmts_free = true # true = all entries of stmts are effect_free for i=na:-1:1 # stmts_free needs to be calculated in reverse-argument order #args_i = args[i] aei = argexprs[i] aeitype = argtype = widenconst(exprtype(aei, enclosing)) + if i == 1 && !(invoke_data === nothing) + invoke_data = invoke_data::InvokeData + invoke_data.texpr === nothing || unshift!(prelude_stmts, + invoke_data.texpr) + end # ok for argument to occur more than once if the actual argument # is a symbol or constant, or is not affected by previous statements @@ -2663,6 +2784,11 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference end end end + if !(invoke_data === nothing) + invoke_data = invoke_data::InvokeData + invoke_data.fexpr === nothing || unshift!(prelude_stmts, + invoke_data.fexpr) + end # re-number the SSAValues and copy their type-info to the new ast ssavalue_types = linfo.ssavaluetypes diff --git a/src/gf.c b/src/gf.c index b067083609e10..8efa99632dacd 100644 --- a/src/gf.c +++ b/src/gf.c @@ -1967,11 +1967,10 @@ JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_datatype_t *types) jl_value_t *jl_gf_invoke(jl_tupletype_t *types0, jl_value_t **args, size_t nargs) { jl_svec_t *tpenv = jl_emptysvec; - jl_tupletype_t *newsig = NULL; jl_tupletype_t *tt = NULL; jl_tupletype_t *types = NULL; jl_tupletype_t *sig = NULL; - JL_GC_PUSH5(&types, &tpenv, &newsig, &sig, &tt); + JL_GC_PUSH4(&types, &tpenv, &sig, &tt); jl_value_t *gf = args[0]; types = (jl_datatype_t*)jl_argtype_with_function(gf, (jl_tupletype_t*)types0); jl_methtable_t *mt = jl_gf_mtable(gf); @@ -2022,6 +2021,54 @@ jl_value_t *jl_gf_invoke(jl_tupletype_t *types0, jl_value_t **args, size_t nargs return jl_call_method_internal(mfunc, args, nargs); } +JL_DLLEXPORT jl_value_t *jl_get_invoke_lambda(jl_methtable_t *mt, + jl_typemap_entry_t *entry, + jl_tupletype_t *tt) +{ + if (!jl_is_leaf_type((jl_value_t*)tt)) + return jl_nothing; + jl_method_t *method = entry->func.method; + jl_typemap_entry_t *tm = NULL; + if (method->invokes.unknown != NULL) { + tm = jl_typemap_assoc_by_type(method->invokes, tt, NULL, 0, 1, + jl_cachearg_offset(mt)); + if (tm) { + return (jl_value_t*)tm->func.linfo; + } + } + + JL_LOCK(&method->writelock); + if (method->invokes.unknown != NULL) { + tm = jl_typemap_assoc_by_type(method->invokes, tt, NULL, 0, 1, + jl_cachearg_offset(mt)); + if (tm) { + jl_lambda_info_t *mfunc = tm->func.linfo; + JL_UNLOCK(&method->writelock); + return (jl_value_t*)mfunc; + } + } + jl_svec_t *tpenv = jl_emptysvec; + jl_tupletype_t *sig = NULL; + JL_GC_PUSH2(&tpenv, &sig); + if (entry->tvars != jl_emptysvec) { + jl_value_t *ti = + jl_lookup_match((jl_value_t*)tt, (jl_value_t*)entry->sig, &tpenv, entry->tvars); + assert(ti != (jl_value_t*)jl_bottom_type); + (void)ti; + } + sig = join_tsig(tt, entry->sig); + jl_method_t *func = entry->func.method; + + if (func->invokes.unknown == NULL) + func->invokes.unknown = jl_nothing; + + jl_lambda_info_t *mfunc = cache_method(mt, &func->invokes, entry->func.value, + sig, tt, entry, tpenv, 1); + JL_GC_POP(); + JL_UNLOCK(&method->writelock); + return (jl_value_t*)mfunc; +} + static jl_function_t *jl_new_generic_function_with_supertype(jl_sym_t *name, jl_module_t *module, jl_datatype_t *st, int iskw) { // type name is function name prefixed with #