diff --git a/base/inference.jl b/base/inference.jl index ef77009eb3b89..e7d9eae93e11f 100644 --- a/base/inference.jl +++ b/base/inference.jl @@ -681,8 +681,8 @@ function invoke_tfunc(f::ANY, types::ANY, argtype::ANY, sv::InferenceState) return Any end meth = entry.func - (ti, env) = ccall(:jl_match_method, Any, (Any, Any, Any), - argtype, meth.sig, meth.tvars)::SimpleVector + (ti, env) = ccall(:jl_match_method, Ref{SimpleVector}, (Any, Any, Any), + argtype, meth.sig, meth.tvars) return typeinf_edge(meth::Method, ti, env, sv) end @@ -2363,14 +2363,34 @@ end #### post-inference optimizations #### -function inline_as_constant(val::ANY, argexprs, sv::InferenceState) +immutable InvokeData + mt::MethodTable + entry::TypeMapEntry + types0 + fexpr + texpr +end + +function inline_as_constant(val::ANY, argexprs, sv::InferenceState, + invoke_data::ANY) + if invoke_data === nothing + invoke_fexpr = nothing + invoke_texpr = nothing + else + invoke_data = invoke_data::InvokeData + invoke_fexpr = invoke_data.fexpr + invoke_texpr = invoke_data.texpr + end # check if any arguments aren't effect_free and need to be kept around - stmts = Any[] + stmts = invoke_fexpr === nothing ? [] : Any[invoke_fexpr] for i = 1:length(argexprs) arg = argexprs[i] if !effect_free(arg, sv.src, sv.mod, false) push!(stmts, arg) end + if i == 1 && !(invoke_texpr === nothing) + push!(stmts, invoke_texpr) + end end return (QuoteNode(val), stmts) end @@ -2385,10 +2405,31 @@ function countunionsplit(atypes::Vector{Any}) return nu end -function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY) +function get_spec_lambda(atypes::ANY, invoke_data::ANY) + if invoke_data === nothing + return ccall(:jl_get_spec_lambda, Any, (Any,), atypes) + else + invoke_data = invoke_data::InvokeData + # TODO compute intersection and throws an error + atypes <: invoke_data.types0 || return nothing + return ccall(:jl_get_invoke_lambda, Any, (Any, Any, Any), + invoke_data.mt, invoke_data.entry, atypes) + end +end + +function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY, + invoke_data::ANY) # converts a :call to :invoke nu = countunionsplit(atypes) nu > MAX_UNION_SPLITTING && return NF + if invoke_data === nothing + invoke_fexpr = nothing + invoke_texpr = nothing + else + invoke_data = invoke_data::InvokeData + invoke_fexpr = invoke_data.fexpr + invoke_texpr = invoke_data.texpr + end if nu > 1 spec_hit = nothing @@ -2400,7 +2441,12 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY) ex.typ = etype stmts = [] arg_hoisted = false + arg0_hoisted = false for i = length(atypes):-1:1 + if i == 1 && !(invoke_texpr === nothing) + unshift!(stmts, invoke_texpr) + arg_hoisted = true + end ti = atypes[i] if arg_hoisted || isa(ti, Union) aei = ex.args[i] @@ -2409,13 +2455,17 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY) newvar = newvar!(sv, ti) unshift!(stmts, :($newvar = $aei)) ex.args[i] = newvar + if i == 1 + arg0_hoisted = true + end end end end + invoke_fexpr === nothing || unshift!(stmts, invoke_fexpr) function splitunion(atypes::Vector{Any}, i::Int) if i == 0 local sig = argtypes_to_type(atypes) - local li = ccall(:jl_get_spec_lambda, Any, (Any,), sig) + local li = get_spec_lambda(sig, invoke_data) li === nothing && return false local stmt = [] push!(stmt, Expr(:(=), linfo_var, li)) @@ -2483,13 +2533,24 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, atype_unlimited::ANY) return (ret_var, stmts) end else - local cache_linfo = ccall(:jl_get_spec_lambda, Any, (Any,), atype_unlimited) + local cache_linfo = get_spec_lambda(atype_unlimited, invoke_data) cache_linfo === nothing && return NF unshift!(argexprs, cache_linfo) ex = Expr(:invoke) ex.args = argexprs ex.typ = etype - return ex + if invoke_texpr === nothing + if invoke_fexpr === nothing + return ex + else + return ex, Any[invoke_fexpr] + end + end + newvar = newvar!(sv, atypes[1]) + stmts = Any[invoke_fexpr, :($newvar = $(argexprs[1])), + invoke_texpr] + argexprs[1] = newvar + return ex, stmts end return NF end @@ -2543,14 +2604,54 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference end end end - if isa(f, IntrinsicFunction) || ft ⊑ IntrinsicFunction || + invoke_data = nothing + invoke_fexpr = nothing + invoke_texpr = nothing + if f === Core.invoke && length(atypes) >= 3 + ft = widenconst(atypes[2]) + invoke_tt = widenconst(atypes[3]) + if !isleaftype(ft) || !isleaftype(invoke_tt) || !isType(invoke_tt) + return NF + end + if !(isa(invoke_tt.parameters[1], Type) && + invoke_tt.parameters[1] <: Tuple) + return NF + end + invoke_tt_params = invoke_tt.parameters[1].parameters + invoke_types = Tuple{ft, invoke_tt_params...} + invoke_entry = ccall(:jl_gf_invoke_lookup, Any, (Any,), invoke_types) + invoke_entry === nothing && return NF + invoke_fexpr = argexprs[1] + invoke_texpr = argexprs[3] + if effect_free(invoke_fexpr, sv.src, sv.mod, false) + invoke_fexpr = nothing + end + if effect_free(invoke_texpr, sv.src, sv.mod, false) + invoke_fexpr = nothing + end + invoke_data = InvokeData(ft.name.mt, invoke_entry, + invoke_types, invoke_fexpr, invoke_texpr) + atype0 = atypes[2] + argexpr0 = argexprs[2] + atypes = atypes[4:end] + argexprs = argexprs[4:end] + unshift!(atypes, atype0) + unshift!(argexprs, argexpr0) + f = isdefined(ft, :instance) ? ft.instance : nothing + elseif isa(f, IntrinsicFunction) || ft ⊑ IntrinsicFunction || isa(f, Builtin) || ft ⊑ Builtin return NF end - local atype_unlimited = argtypes_to_type(atypes) + atype_unlimited = argtypes_to_type(atypes) + if !(invoke_data === nothing) + invoke_data = invoke_data::InvokeData + # TODO emit a type check and proceed for this case + atype_unlimited <: invoke_data.types0 || return NF + end if !sv.inlining - return invoke_NF(argexprs, e.typ, atypes, sv, atype_unlimited) + return invoke_NF(argexprs, e.typ, atypes, sv, atype_unlimited, + invoke_data) end if length(atype_unlimited.parameters) - 1 > MAX_TUPLETYPE_LEN @@ -2558,26 +2659,39 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference else atype = atype_unlimited end - meth = _methods_by_ftype(atype, 1) - if meth === false || length(meth) != 1 - return invoke_NF(argexprs, e.typ, atypes, sv, atype_unlimited) + if invoke_data === nothing + meth = _methods_by_ftype(atype, 1) + if meth === false || length(meth) != 1 + return invoke_NF(argexprs, e.typ, atypes, sv, + atype_unlimited, invoke_data) + end + meth = meth[1]::SimpleVector + metharg = meth[1]::Type + methsp = meth[2]::SimpleVector + method = meth[3]::Method + else + invoke_data = invoke_data::InvokeData + method = invoke_data.entry.func + (metharg, methsp) = ccall(:jl_match_method, Ref{SimpleVector}, + (Any, Any, Any), + atype_unlimited, method.sig, method.tvars) + methsp = methsp::SimpleVector end - meth = meth[1]::SimpleVector - metharg = meth[1]::Type - methsp = meth[2] - method = meth[3]::Method # check whether call can be inlined to just a quoted constant value if isa(f, widenconst(ft)) && !method.isstaged && (method.source.pure || f === return_type) if isconstType(e.typ,false) - return inline_as_constant(e.typ.parameters[1], argexprs, sv) + return inline_as_constant(e.typ.parameters[1], argexprs, sv, + invoke_data) elseif isa(e.typ,Const) - return inline_as_constant(e.typ.val, argexprs, sv) + return inline_as_constant(e.typ.val, argexprs, sv, + invoke_data) end end methsig = method.sig if !(atype <: metharg) - return invoke_NF(argexprs, e.typ, atypes, sv, atype_unlimited) + return invoke_NF(argexprs, e.typ, atypes, sv, atype_unlimited, + invoke_data) end argexprs0 = argexprs @@ -2653,11 +2767,12 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference if isa(linfo, MethodInstance) && linfo.jlcall_api == 2 # in this case function can be inlined to a constant - return inline_as_constant(linfo.inferred, argexprs, sv) + return inline_as_constant(linfo.inferred, argexprs, sv, invoke_data) end if !isa(src, CodeInfo) || !src.inferred || !src.inlineable - return invoke_NF(argexprs0, e.typ, atypes, sv, atype_unlimited) + return invoke_NF(argexprs0, e.typ, atypes, sv, atype_unlimited, + invoke_data) end ast = src.code rettype = linfo.rettype @@ -2673,8 +2788,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference end end - methargs = metharg.parameters - nm = length(methargs) + nm = length(metharg.parameters) if !isa(ast, Array{Any,1}) ast = ccall(:jl_uncompress_ast, Any, (Any, Any), method, ast) @@ -2688,14 +2802,17 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference propagate_inbounds = src.propagate_inbounds # see if each argument occurs only once in the body expression - stmts = Any[] - prelude_stmts = Any[] + stmts = [] + prelude_stmts = [] stmts_free = true # true = all entries of stmts are effect_free for i=na:-1:1 # stmts_free needs to be calculated in reverse-argument order #args_i = args[i] aei = argexprs[i] aeitype = argtype = widenconst(exprtype(aei, sv.src, sv.mod)) + if i == 1 && !(invoke_texpr === nothing) + unshift!(prelude_stmts, invoke_texpr) + end # ok for argument to occur more than once if the actual argument # is a symbol or constant, or is not affected by previous statements @@ -2729,6 +2846,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference end end end + invoke_fexpr === nothing || unshift!(prelude_stmts, invoke_fexpr) # re-number the SSAValues and copy their type-info to the new ast ssavalue_types = src.ssavaluetypes diff --git a/src/gf.c b/src/gf.c index e8bbec81d8de1..684048a46210c 100644 --- a/src/gf.c +++ b/src/gf.c @@ -1875,11 +1875,10 @@ JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_datatype_t *types) jl_value_t *jl_gf_invoke(jl_tupletype_t *types0, jl_value_t **args, size_t nargs) { jl_svec_t *tpenv = jl_emptysvec; - jl_tupletype_t *newsig = NULL; jl_tupletype_t *tt = NULL; jl_tupletype_t *types = NULL; jl_tupletype_t *sig = NULL; - JL_GC_PUSH5(&types, &tpenv, &newsig, &sig, &tt); + JL_GC_PUSH4(&types, &tpenv, &sig, &tt); jl_value_t *gf = args[0]; types = (jl_datatype_t*)jl_argtype_with_function(gf, (jl_tupletype_t*)types0); jl_methtable_t *mt = jl_gf_mtable(gf); @@ -1930,6 +1929,87 @@ jl_value_t *jl_gf_invoke(jl_tupletype_t *types0, jl_value_t **args, size_t nargs return jl_call_method_internal(mfunc, args, nargs); } +typedef struct _tupletype_stack_t { + struct _tupletype_stack_t *parent; + jl_tupletype_t *tt; +} tupletype_stack_t; + +static int tupletype_on_stack(jl_tupletype_t *tt, tupletype_stack_t *stack) +{ + while (stack) { + if (tt == stack->tt) + return 1; + stack = stack->parent; + } + return 0; +} + +static int tupletype_has_datatype(jl_tupletype_t *tt, tupletype_stack_t *stack) +{ + for (int i = 0; i < jl_nparams(tt); i++) { + jl_value_t *ti = jl_tparam(tt, i); + if (ti == (jl_value_t*)jl_datatype_type) + return 1; + if (jl_is_tuple_type(ti)) { + jl_tupletype_t *tt1 = (jl_tupletype_t*)ti; + if (!tupletype_on_stack(tt1, stack) && + tupletype_has_datatype(tt1, stack)) { + return 1; + } + } + } + return 0; +} + +JL_DLLEXPORT jl_value_t *jl_get_invoke_lambda(jl_methtable_t *mt, + jl_typemap_entry_t *entry, + jl_tupletype_t *tt) +{ + if (!jl_is_leaf_type((jl_value_t*)tt) || tupletype_has_datatype(tt, NULL)) + return jl_nothing; + + jl_method_t *method = entry->func.method; + jl_typemap_entry_t *tm = NULL; + if (method->invokes.unknown != NULL) { + tm = jl_typemap_assoc_by_type(method->invokes, tt, NULL, 0, 1, + jl_cachearg_offset(mt)); + if (tm) { + return (jl_value_t*)tm->func.linfo; + } + } + + JL_LOCK(&method->writelock); + if (method->invokes.unknown != NULL) { + tm = jl_typemap_assoc_by_type(method->invokes, tt, NULL, 0, 1, + jl_cachearg_offset(mt)); + if (tm) { + jl_method_instance_t *mfunc = tm->func.linfo; + JL_UNLOCK(&method->writelock); + return (jl_value_t*)mfunc; + } + } + jl_svec_t *tpenv = jl_emptysvec; + jl_tupletype_t *sig = NULL; + JL_GC_PUSH2(&tpenv, &sig); + if (entry->tvars != jl_emptysvec) { + jl_value_t *ti = + jl_lookup_match((jl_value_t*)tt, (jl_value_t*)entry->sig, &tpenv, entry->tvars); + assert(ti != (jl_value_t*)jl_bottom_type); + (void)ti; + } + sig = join_tsig(tt, entry->sig); + jl_method_t *func = entry->func.method; + + if (func->invokes.unknown == NULL) + func->invokes.unknown = jl_nothing; + + jl_method_instance_t *mfunc = cache_method(mt, &func->invokes, entry->func.value, + sig, tt, entry, tpenv, 1); + JL_GC_POP(); + JL_UNLOCK(&method->writelock); + return (jl_value_t*)mfunc; +} + static jl_function_t *jl_new_generic_function_with_supertype(jl_sym_t *name, jl_module_t *module, jl_datatype_t *st, int iskw) { // type name is function name prefixed with #