diff --git a/base/codevalidation.jl b/base/codevalidation.jl index b13b24a6fb54d..c5f7bef3049a1 100644 --- a/base/codevalidation.jl +++ b/base/codevalidation.jl @@ -10,7 +10,6 @@ const VALID_EXPR_HEADS = ObjectIdDict( :(=) => 2:2, :method => 1:4, :const => 1:1, - :null => 0:0, # TODO from @vtjnash: remove this + any :null handling code in Base :new => 1:typemax(Int), :return => 1:1, :the_exception => 0:0, @@ -25,7 +24,8 @@ const VALID_EXPR_HEADS = ObjectIdDict( :isdefined => 1:1, :simdloop => 0:0, :gc_preserve_begin => 0:typemax(Int), - :gc_preserve_end => 0:typemax(Int) + :gc_preserve_end => 0:typemax(Int), + :thunk => 1:1 ) # @enum isn't defined yet, otherwise I'd use it for this @@ -33,6 +33,7 @@ const INVALID_EXPR_HEAD = "invalid expression head" const INVALID_EXPR_NARGS = "invalid number of expression args" const INVALID_LVALUE = "invalid LHS value" const INVALID_RVALUE = "invalid RHS value" +const INVALID_RETURN = "invalid argument to :return" const INVALID_CALL_ARG = "invalid :call argument" const EMPTY_SLOTNAMES = "slotnames field is empty" const SLOTFLAGS_MISMATCH = "length(slotnames) != length(slotflags)" @@ -41,6 +42,7 @@ const SLOTTYPES_MISMATCH_UNINFERRED = "uninferred CodeInfo slottypes field is no const SSAVALUETYPES_MISMATCH = "not all SSAValues in AST have a type in ssavaluetypes" const SSAVALUETYPES_MISMATCH_UNINFERRED = "uninferred CodeInfo ssavaluetypes field does not equal the number of present SSAValues" const NON_TOP_LEVEL_METHOD = "encountered `Expr` head `:method` in non-top-level code (i.e. `nargs` > 0)" +const NON_TOP_LEVEL_GLOBAL = "encountered `Expr` head `:global` in non-top-level code (i.e. `nargs` > 0)" const SIGNATURE_NARGS_MISMATCH = "method signature does not match number of method arguments" const SLOTNAMES_NARGS_MISMATCH = "CodeInfo for method contains fewer slotnames than the number of method arguments" @@ -57,18 +59,38 @@ InvalidCodeError(kind::AbstractString) = InvalidCodeError(kind, nothing) Validate `c`, logging any violation by pushing an `InvalidCodeError` into `errors`. """ function validate_code!(errors::Vector{>:InvalidCodeError}, c::CodeInfo, is_top_level::Bool = false) + function validate_val!(@nospecialize(x)) + if isa(x, Expr) + if x.head == :call || x.head == :invoke + for arg in x.args + if !is_valid_argument(arg) + push!(errors, InvalidCodeError(INVALID_CALL_ARG, arg)) + else + validate_val!(arg) + end + end + end + elseif isa(x, SSAValue) + id = x.id + 1 # ensures that id > 0 for use with IntSet + !in(id, ssavals) && push!(ssavals, id) + end + end + ssavals = IntSet() lhs_slotnums = IntSet() - walkast(c.code) do x + for x in c.code if isa(x, Expr) - !is_top_level && x.head == :method && push!(errors, InvalidCodeError(NON_TOP_LEVEL_METHOD)) + if !is_top_level + x.head === :method && push!(errors, InvalidCodeError(NON_TOP_LEVEL_METHOD)) + x.head === :global && push!(errors, InvalidCodeError(NON_TOP_LEVEL_GLOBAL)) + end narg_bounds = get(VALID_EXPR_HEADS, x.head, -1:-1) nargs = length(x.args) if narg_bounds == -1:-1 push!(errors, InvalidCodeError(INVALID_EXPR_HEAD, (x.head, x))) elseif !in(nargs, narg_bounds) push!(errors, InvalidCodeError(INVALID_EXPR_NARGS, (x.head, nargs, x))) - elseif x.head == :(=) + elseif x.head === :(=) lhs, rhs = x.args if !is_valid_lvalue(lhs) push!(errors, InvalidCodeError(INVALID_LVALUE, lhs)) @@ -76,25 +98,39 @@ function validate_code!(errors::Vector{>:InvalidCodeError}, c::CodeInfo, is_top_ n = lhs.id push!(lhs_slotnums, n) end - if !is_valid_rvalue(rhs) + if !is_valid_rvalue(lhs, rhs) push!(errors, InvalidCodeError(INVALID_RVALUE, rhs)) end - elseif x.head == :call || x.head == :invoke - for arg in x.args - if !is_valid_rvalue(arg) - push!(errors, InvalidCodeError(INVALID_CALL_ARG, arg)) - end + validate_val!(lhs) + validate_val!(rhs) + elseif x.head === :gotoifnot + if !is_valid_argument(x.args[1]) + push!(errors, InvalidCodeError(INVALID_CALL_ARG, x.args[1])) + end + validate_val!(x.args[1]) + elseif x.head === :return + if !is_valid_return(x.args[1]) + push!(errors, InvalidCodeError(INVALID_RETURN, x.args[1])) end + validate_val!(x.args[1]) + else + validate_val!(x) end - elseif isa(x, SSAValue) - id = x.id + 1 # ensures that id > 0 for use with IntSet - !in(id, ssavals) && push!(ssavals, id) + elseif isa(x, NewvarNode) + elseif isa(x, LabelNode) + elseif isa(x, GotoNode) + elseif x === nothing + elseif isa(x, SlotNumber) + elseif isa(x, GlobalRef) + elseif isa(x, LineNumberNode) + else + push!(errors, InvalidCodeError("invalid statement", x)) end end nslotnames = length(c.slotnames) nslotflags = length(c.slotflags) nssavals = length(ssavals) - nslotnames == 0 && push!(errors, InvalidCodeError(EMPTY_SLOTNAMES)) + !is_top_level && nslotnames == 0 && push!(errors, InvalidCodeError(EMPTY_SLOTNAMES)) nslotnames != nslotflags && push!(errors, InvalidCodeError(SLOTFLAGS_MISMATCH, (nslotnames, nslotflags))) if c.inferred nslottypes = length(c.slottypes) @@ -119,32 +155,50 @@ the `CodeInfo` instance associated with `mi`. """ function validate_code!(errors::Vector{>:InvalidCodeError}, mi::Core.MethodInstance, c::Union{Void,CodeInfo} = Core.Inference.retrieve_code_info(mi)) - m = mi.def::Method - n_sig_params = length(Core.Inference.unwrap_unionall(m.sig).parameters) - if (m.isva ? (n_sig_params < (m.nargs - 1)) : (n_sig_params != m.nargs)) - push!(errors, InvalidCodeError(SIGNATURE_NARGS_MISMATCH, (m.isva, n_sig_params, m.nargs))) + is_top_level = mi.def isa Module + if is_top_level + mnargs = 0 + else + m = mi.def::Method + mnargs = m.nargs + n_sig_params = length(Core.Inference.unwrap_unionall(m.sig).parameters) + if (m.isva ? (n_sig_params < (mnargs - 1)) : (n_sig_params != mnargs)) + push!(errors, InvalidCodeError(SIGNATURE_NARGS_MISMATCH, (m.isva, n_sig_params, mnargs))) + end end if isa(c, CodeInfo) - m.nargs > length(c.slotnames) && push!(errors, InvalidCodeError(SLOTNAMES_NARGS_MISMATCH)) - validate_code!(errors, c, m.nargs == 0) + mnargs > length(c.slotnames) && push!(errors, InvalidCodeError(SLOTNAMES_NARGS_MISMATCH)) + validate_code!(errors, c, is_top_level) end return errors end validate_code(args...) = validate_code!(Vector{InvalidCodeError}(), args...) -function walkast(f, stmts::Array) - for stmt in stmts - f(stmt) - isa(stmt, Expr) && walkast(f, stmt.args) +is_valid_lvalue(x) = isa(x, Slot) || isa(x, SSAValue) || isa(x, GlobalRef) + +function is_valid_argument(x) + if isa(x, Slot) || isa(x, SSAValue) || isa(x, GlobalRef) || isa(x, QuoteNode) || + (isa(x,Expr) && (x.head in (:static_parameter, :boundscheck, :copyast))) || + isa(x, Number) || isa(x, AbstractString) || isa(x, Char) || isa(x, Tuple) || + isa(x, Type) || isa(x, Core.Box) || isa(x, Module) || x === nothing + return true end + # TODO: consider being stricter about what needs to be wrapped with QuoteNode + return !(isa(x,Expr) || isa(x,Symbol) || isa(x,GotoNode) || isa(x,LabelNode) || + isa(x,LineNumberNode) || isa(x,NewvarNode)) end -is_valid_lvalue(x) = isa(x, SlotNumber) || isa(x, SSAValue) || isa(x, GlobalRef) - -function is_valid_rvalue(x) - isa(x, Expr) && return !in(x.head, (:gotoifnot, :line, :const, :meta)) - return !isa(x, GotoNode) && !isa(x, LabelNode) && !isa(x, LineNumberNode) +function is_valid_rvalue(lhs, x) + is_valid_argument(x) && return true + if isa(x, Expr) && x.head in (:new, :the_exception, :isdefined, :call, :invoke, :foreigncall, :gc_preserve_begin) + return true + # TODO: disallow `globalref = call` when .typ field is removed + #return isa(lhs, SSAValue) || isa(lhs, Slot) + end + return false end +is_valid_return(x) = is_valid_rvalue(nothing, x) || (isa(x,Expr) && x.head in (:new, :lambda)) + is_flag_set(byte::UInt8, flag::UInt8) = (byte & flag) == flag diff --git a/base/inference.jl b/base/inference.jl index c8d60eac7c86e..6f3034c8ba6af 100644 --- a/base/inference.jl +++ b/base/inference.jl @@ -288,21 +288,25 @@ mutable struct InferenceState end end -function InferenceState(linfo::MethodInstance, - optimize::Bool, cached::Bool, params::InferenceParams) - # prepare an InferenceState object for inferring lambda - src = retrieve_code_info(linfo) - src === nothing && return nothing +function _validate(linfo::MethodInstance, src::CodeInfo, kind::String) if JLOptions().debug_level == 2 # this is a debug build of julia, so let's validate linfo errors = validate_code(linfo, src) if !isempty(errors) for e in errors - println(STDERR, "WARNING: Encountered invalid lowered code for method ", + println(STDERR, "WARNING: Encountered invalid ", kind, " code for method ", linfo.def, ": ", e) end end end +end + +function InferenceState(linfo::MethodInstance, + optimize::Bool, cached::Bool, params::InferenceParams) + # prepare an InferenceState object for inferring lambda + src = retrieve_code_info(linfo) + src === nothing && return nothing + _validate(linfo, src, "lowered") return InferenceState(linfo, src, optimize, cached, params) end @@ -2446,8 +2450,6 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState) e = e::Expr if e.head === :call t = abstract_eval_call(e, vtypes, sv) - elseif e.head === :null - t = Void elseif e.head === :new t = instanceof_tfunc(abstract_eval(e.args[1], vtypes, sv))[1] for i = 2:length(e.args) @@ -3087,6 +3089,15 @@ end #### entry points for inferring a MethodInstance given a type signature #### +function is_self_quoting(@nospecialize(x)) + return isa(x,Number) || isa(x,AbstractString) || isa(x,Tuple) || isa(x,Type) || isa(x,Char) || x === nothing || + isa(x,Builtin) || isa(x,IntrinsicFunction) +end + +function quoted(@nospecialize(x)) + return is_self_quoting(x) ? x : QuoteNode(x) +end + # compute an inferred AST and return type function typeinf_code(method::Method, @nospecialize(atypes), sparams::SimpleVector, optimize::Bool, cached::Bool, params::InferenceParams) @@ -3107,7 +3118,7 @@ function typeinf_code(linfo::MethodInstance, optimize::Bool, cached::Bool, if linfo.jlcall_api == 2 method = linfo.def::Method tree = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ()) - tree.code = Any[ Expr(:return, QuoteNode(linfo.inferred_const)) ] + tree.code = Any[ Expr(:return, quoted(linfo.inferred_const)) ] tree.slotnames = Any[ compiler_temp_sym for i = 1:method.nargs ] tree.slotflags = UInt8[ 0 for i = 1:method.nargs ] tree.slottypes = nothing @@ -3175,6 +3186,7 @@ function typeinf_ext(linfo::MethodInstance, world::UInt) ccall(:jl_typeinf_end, Void, ()) @assert frame.inferred # TODO: deal with this better @assert frame.linfo === linfo + linfo.rettype = widenconst(frame.bestguess) return svec(linfo, frame.src, linfo.rettype) end end @@ -3549,6 +3561,7 @@ function optimize(me::InferenceState) me.src.inlineable = isinlineable(def, me.src, me.mod, me.params, bonus) end me.src.inferred = true + _validate(me.linfo, me.src, "optimized") nothing end @@ -3789,7 +3802,7 @@ function _widen_all_consts!(e::Expr, untypedload::Vector{Bool}, slottypes::Vecto elseif isa(x, TypedSlot) vt = widenconst(x.typ) if !(vt === x.typ) - if slottypes[x.id] <: vt + if slottypes[x.id] ⊑ vt x = SlotNumber(x.id) untypedload[x.id] = true else @@ -3871,9 +3884,7 @@ function substitute!( e = e::Expr head = e.head if head === :static_parameter - sp = spvals[e.args[1]] - is_self_quoting(sp) && return sp - return QuoteNode(sp) + return quoted(spvals[e.args[1]]) elseif head === :foreigncall @assert !isa(spsig, UnionAll) || !isempty(spvals) for i = 1:length(e.args) @@ -4007,7 +4018,7 @@ function effect_free(@nospecialize(e), src::CodeInfo, mod::Module, allow_volatil end if head === :static_parameter # if we aren't certain about the type, it might be an UndefVarError at runtime - return isa(e.typ, DataType) && isleaftype(e.typ) + return (isa(e.typ, DataType) && isleaftype(e.typ)) || isa(e.typ, Const) end if e.typ === Bottom return false @@ -4065,6 +4076,8 @@ function effect_free(@nospecialize(e), src::CodeInfo, mod::Module, allow_volatil # fall-through elseif head === :return # fall-through + elseif head === :isdefined + return allow_volatile elseif head === :the_exception return allow_volatile else @@ -4112,14 +4125,7 @@ function inline_as_constant(@nospecialize(val), argexprs, sv::InferenceState, @n push!(stmts, invoke_texpr) end end - if !is_self_quoting(val) - val = QuoteNode(val) - end - return (val, stmts) -end - -function is_self_quoting(@nospecialize(x)) - return isa(x,Number) || isa(x,AbstractString) || isa(x,Tuple) || isa(x,Type) + return (quoted(val), stmts) end function countunionsplit(atypes) @@ -4221,9 +4227,11 @@ function invoke_NF(argexprs, @nospecialize(etype), atypes::Vector{Any}, sv::Infe local match = splitunion(atypes, i - 1) if match !== false after = genlabel(sv) + isa_var = newvar!(sv, Bool) isa_ty = Expr(:call, GlobalRef(Core, :isa), aei, ty) isa_ty.typ = Bool - unshift!(match, Expr(:gotoifnot, isa_ty, after.label)) + unshift!(match, Expr(:gotoifnot, isa_var, after.label)) + unshift!(match, Expr(:(=), isa_var, isa_ty)) append!(stmts, match) push!(stmts, after) else @@ -4381,11 +4389,12 @@ function inlineable(@nospecialize(f), @nospecialize(ft), e::Expr, atypes::Vector if isa(e.typ, Const) return inline_as_constant(e.typ.val, argexprs, sv, nothing) end - not_is = Expr(:call, GlobalRef(Core.Intrinsics, :not_int), - Expr(:call, GlobalRef(Core, :(===)), argexprs[2], argexprs[3])) + is_var = newvar!(sv, Bool) + stmts = Any[ Expr(:(=), is_var, Expr(:call, GlobalRef(Core, :(===)), argexprs[2], argexprs[3])) ] + stmts[1].args[2].typ = Bool + not_is = Expr(:call, GlobalRef(Core.Intrinsics, :not_int), is_var) not_is.typ = Bool - not_is.args[2].typ = Bool - return (not_is, ()) + return (not_is, stmts) elseif length(atypes) == 3 && istopfunction(topmod, f, :(>:)) # special-case inliner for issupertype # that works, even though inference generally avoids inferring the `>:` Method @@ -4602,6 +4611,12 @@ function inlineable(@nospecialize(f), @nospecialize(ft), e::Expr, atypes::Vector stmts_free = true # true = all entries of stmts are effect_free argexprs = copy(argexprs) + if isva + # move constructed vararg tuple to an ssavalue + varargvar = newvar!(sv, atypes[na]) + push!(prelude_stmts, Expr(:(=), varargvar, argexprs[na])) + argexprs[na] = varargvar + end for i = na:-1:1 # stmts_free needs to be calculated in reverse-argument order #args_i = args[i] aei = argexprs[i] @@ -4979,7 +4994,7 @@ function inlining_pass!(sv::InferenceState, propagate_inbounds::Bool) else boundscheck = :on end - eargs[i] = inlining_pass(ei, sv, stmtbuf, 1, boundscheck) + eargs[i] = inline_expr(ei, sv, stmtbuf, boundscheck) if !isempty(stmtbuf) splice!(eargs, i:(i - 1), stmtbuf) i += length(stmtbuf) @@ -4993,121 +5008,35 @@ end const corenumtype = Union{Int32, Int64, Float32, Float64} -# return inlined replacement for `e`, inserting new needed statements -# at index `ins` in `stmts`. -function inlining_pass(e::Expr, sv::InferenceState, stmts::Vector{Any}, ins, boundscheck::Symbol) - if e.head === :meta - # ignore meta nodes - return e - end - if e.head === :method - # avoid running the inlining pass on function definitions - return e - end - # inliners for special expressions - if e.head === :boundscheck - return e - end - if e.head === :isdefined +function inline_expr(e::Expr, sv::InferenceState, stmts::Vector{Any}, boundscheck::Symbol) + if e.head === :call + return inline_call(e, sv, stmts, boundscheck) + elseif e.head === :isdefined isa(e.typ, Const) && return e.typ.val - return e + elseif e.head === :(=) && isa(e.args[2], Expr) + e.args[2] = inline_expr(e.args[2], sv, stmts, boundscheck) + elseif e.head === :return && isa(e.args[1], Expr) + e.args[1] = inline_expr(e.args[1], sv, stmts, boundscheck) end + return e +end + +function finddef(v::SSAValue, stmts::Vector{Any}) + for s in stmts + if isa(s,Expr) && s.head === :(=) && s.args[1] === v + return s + end + end + return nothing +end +# return inlined replacement for call `e`, inserting new needed statements in `stmts`. +function inline_call(e::Expr, sv::InferenceState, stmts::Vector{Any}, boundscheck::Symbol) eargs = e.args if length(eargs) < 1 return e end arg1 = eargs[1] - isccall = false - i0 = 1 - # don't inline first (global) arguments of ccall, as this needs to be evaluated - # by the interpreter and inlining might put in something it can't handle, - # like another ccall (or try to move the variables out into the function) - if e.head === :foreigncall - # 5 is rewritten to 1 below to handle the callee. - i0 = 5 - isccall = true - elseif is_known_call(e, Core.Intrinsics.llvmcall, sv.src, sv.mod) - i0 = 5 - end - has_stmts = false # needed to preserve order-of-execution - prev_stmts_length = length(stmts) - for _i = length(eargs):-1:i0 - if isccall && _i == 5 - i = 1 - isccallee = true - else - i = _i - isccallee = false - end - ei = eargs[i] - if isa(ei,Expr) - ei = ei::Expr - if ei.head === :& - argloc = ei.args - i = 1 - ei = argloc[1] - if !isa(ei,Expr) - continue - end - ei = ei::Expr - else - argloc = eargs - end - sl0 = length(stmts) - res = inlining_pass(ei, sv, stmts, ins, boundscheck) - ns = length(stmts) - sl0 # number of new statements just added - if isccallee - restype = exprtype(res, sv.src, sv.mod) - if isa(restype, Const) - argloc[i] = restype.val - if !effect_free(res, sv.src, sv.mod, false) - insert!(stmts, ins+ns, res) - end - # Assume this is the last argument to process - break - end - end - if has_stmts && !effect_free(res, sv.src, sv.mod, false) - restype = exprtype(res, sv.src, sv.mod) - vnew = newvar!(sv, restype) - argloc[i] = vnew - insert!(stmts, ins+ns, Expr(:(=), vnew, res)) - else - argloc[i] = res - end - if !has_stmts && ns > 0 && !(_i == i0) - for s = ins:ins+ns-1 - stmt = stmts[s] - if !effect_free(stmt, sv.src, sv.mod, true) - has_stmts = true; break - end - end - end - end - end - if isccall - le = length(eargs) - nccallargs = eargs[5]::Int - ccallargs = ObjectIdDict() - for i in 6:(5 + nccallargs) - ccallargs[eargs[i]] = nothing - end - i = 6 + nccallargs - while i <= le - rootarg = eargs[i] - if haskey(ccallargs, rootarg) - deleteat!(eargs, i) - le -= 1 - elseif i < le - ccallargs[rootarg] = nothing - end - i += 1 - end - end - if e.head !== :call - return e - end ft = exprtype(arg1, sv.src, sv.mod) if isa(ft, Const) @@ -5122,8 +5051,10 @@ function inlining_pass(e::Expr, sv::InferenceState, stmts::Vector{Any}, ins, bou end end - ins += (length(stmts) - prev_stmts_length) + ins = 1 + # TODO: determine whether this is really necessary +#= if sv.params.inlining if isdefined(Main, :Base) && ((isdefined(Main.Base, :^) && f === Main.Base.:^) || @@ -5158,7 +5089,7 @@ function inlining_pass(e::Expr, sv::InferenceState, stmts::Vector{Any}, ins, bou end end end - +=# for ninline = 1:100 ata = Vector{Any}(length(e.args)) ata[1] = ft @@ -5196,18 +5127,15 @@ function inlining_pass(e::Expr, sv::InferenceState, stmts::Vector{Any}, ins, bou aarg = e.args[i] argt = exprtype(aarg, sv.src, sv.mod) t = widenconst(argt) - if isa(aarg, Expr) && (is_known_call(aarg, tuple, sv.src, sv.mod) || is_known_call(aarg, svec, sv.src, sv.mod)) - # apply(f, tuple(x, y, ...)) => f(x, y, ...) - newargs[i - 2] = aarg.args[2:end] - elseif isa(argt, Const) && (isa(argt.val, Tuple) || isa(argt.val, SimpleVector)) && + if isa(argt, Const) && (isa(argt.val, Tuple) || isa(argt.val, SimpleVector)) && effect_free(aarg, sv.src, sv.mod, true) val = argt.val - newargs[i - 2] = Any[ QuoteNode(val[i]) for i in 1:(length(val)::Int) ] # avoid making a tuple Generator here! + newargs[i - 2] = Any[ quoted(val[i]) for i in 1:(length(val)::Int) ] # avoid making a tuple Generator here! elseif isa(aarg, Tuple) || (isa(aarg, QuoteNode) && (isa(aarg.value, Tuple) || isa(aarg.value, SimpleVector))) if isa(aarg, QuoteNode) aarg = aarg.value end - newargs[i - 2] = Any[ QuoteNode(aarg[i]) for i in 1:(length(aarg)::Int) ] # avoid making a tuple Generator here! + newargs[i - 2] = Any[ quoted(aarg[i]) for i in 1:(length(aarg)::Int) ] # avoid making a tuple Generator here! elseif isa(t, DataType) && t.name === Tuple.name && !isvatuple(t) && length(t.parameters) <= sv.params.MAX_TUPLE_SPLAT for k = (effect_free_upto + 1):(i - 3) @@ -5235,7 +5163,28 @@ function inlining_pass(e::Expr, sv::InferenceState, stmts::Vector{Any}, ins, bou else tp = t.parameters end - newargs[i - 2] = Any[ mk_getfield(tmpv, j, tp[j]) for j in 1:(length(tp)::Int) ] + ntp = length(tp)::Int + if isa(aarg,SSAValue) && any(p->(p === DataType || p === UnionAll || p === Union), tp) + # replace element type from Tuple{DataType} with more specific type if possible + def = finddef(aarg, sv.src.code) + if def !== nothing + defex = def.args[2] + if isa(defex, Expr) && is_known_call(defex, tuple, sv.src, sv.mod) + tp = collect(Any, tp) + for j = 1:ntp + specific_type = exprtype(defex.args[j+1], sv.src, sv.mod) + if iskindtype(tp[j]) && specific_type ⊑ tp[j] + tp[j] = specific_type + end + end + end + end + end + fldvars = Any[ newvar!(sv, tp[j]) for j in 1:ntp ] + for j = 1:ntp + push!(newstmts, Expr(:(=), fldvars[j], mk_getfield(tmpv, j, tp[j]))) + end + newargs[i - 2] = fldvars else # not all args expandable return e @@ -5376,10 +5325,10 @@ function get_replacement(table::ObjectIdDict, var::Union{SlotNumber, SSAValue}, if isa(init, Expr) && init.head === :static_parameter # if we aren't certain about the type, it might be an UndefVarError at runtime (!effect_free) # so we need to preserve the original point of assignment - if isa(init.typ, DataType) && isleaftype(init.typ) + if (isa(init.typ, DataType) && isleaftype(init.typ)) || isa(init.typ, Const) return init end - elseif isa(init, corenumtype) || init === () || init === nothing + elseif isa(init, Number) || init === () || init === nothing || isa(init, Type) || isa(init, Char) || isa(init, IntrinsicFunction) || isa(init, Builtin) return init elseif isa(init, Slot) && is_argument(nargs, init::Slot) # the transformation is not ideal if the assignment @@ -5413,6 +5362,10 @@ function get_replacement(table::ObjectIdDict, var::Union{SlotNumber, SSAValue}, end return rep end + elseif isa(init, GlobalRef) + if isdefined(init.mod, init.name) && isconst(init.mod, init.name) + return init + end end return var end @@ -5554,6 +5507,9 @@ function void_use_elim_pass!(sv::InferenceState) if h === :return || h === :(=) || h === :gotoifnot || is_meta_expr_head(h) return true end + if h === :isdefined + return false + end return !effect_free(ex, sv.src, sv.mod, false) elseif (isa(ex, GotoNode) || isa(ex, LineNumberNode) || isa(ex, NewvarNode) || isa(ex, Symbol) || isa(ex, LabelNode)) @@ -5717,11 +5673,7 @@ function _getfield_elim_pass!(e::Expr, sv::InferenceState) end end if isdefined(e1, j) - e1j = getfield(e1, j) - if !is_self_quoting(e1j) - e1j = QuoteNode(e1j) - end - return e1j + return quoted(getfield(e1, j)) end end end diff --git a/base/precompile.jl b/base/precompile.jl index 86bcc216f32c8..440a4c46a7c38 100644 --- a/base/precompile.jl +++ b/base/precompile.jl @@ -866,7 +866,6 @@ precompile(Tuple{typeof(Core.Inference._widen_all_consts!), Expr, Array{Bool, 1} precompile(Tuple{typeof(Core.Inference._delete!), Core.Inference.IntSet, Int64}) precompile(Tuple{typeof(Core.Inference.promote_type), Type{Float16}, Type{Int64}}) precompile(Tuple{typeof(Core.Inference.mk_tuplecall), Array{Any, 1}, Core.Inference.InferenceState}) -precompile(Tuple{typeof(Core.Inference.inlining_pass), Expr, Core.Inference.InferenceState, Array{Any, 1}, Int64}) precompile(Tuple{typeof(Core.Inference.annotate_slot_load!), Expr, Array{Any, 1}, Core.Inference.InferenceState, Array{Bool, 1}}) precompile(Tuple{typeof(Core.Inference.record_slot_assign!), Core.Inference.InferenceState}) precompile(Tuple{typeof(Core.Inference.type_annotate!), Core.Inference.InferenceState}) diff --git a/base/repl/REPLCompletions.jl b/base/repl/REPLCompletions.jl index 1e380b7150149..d7b9a6c96b25e 100644 --- a/base/repl/REPLCompletions.jl +++ b/base/repl/REPLCompletions.jl @@ -316,8 +316,7 @@ function get_type_call(expr::Expr) end # Returns the return type. example: get_type(:(Base.strip("", ' ')), Main) returns (String, true) -function get_type(sym::Expr, fn::Module) - sym = expand(fn, sym) +function try_get_type(sym::Expr, fn::Module) val, found = get_value(sym, fn) found && return Base.typesof(val).parameters[1], found if sym.head === :call @@ -330,9 +329,28 @@ function get_type(sym::Expr, fn::Module) return found ? Base.typesof(val).parameters[1] : Any, found end return get_type_call(sym) + elseif sym.head === :thunk + thk = sym.args[1] + rt = ccall(:jl_infer_thunk, Any, (Any, Any), thk::CodeInfo, fn) + rt !== Any && return (rt, true) + elseif sym.head === :ref + # some simple cases of `expand` + return try_get_type(Expr(:call, GlobalRef(Base, :getindex), sym.args...), fn) + elseif sym.head === :. + return try_get_type(Expr(:call, GlobalRef(Core, :getfield), sym.args...), fn) end return (Any, false) end + +try_get_type(other, fn::Module) = get_type(other, fn) + +function get_type(sym::Expr, fn::Module) + # try to analyze nests of calls. if this fails, try using the expanded form. + val, found = try_get_type(sym, fn) + found && return val, found + return try_get_type(expand(fn, sym), fn) +end + function get_type(sym, fn::Module) val, found = get_value(sym, fn) return found ? Base.typesof(val).parameters[1] : Any, found diff --git a/src/ast.scm b/src/ast.scm index 18d389cc46037..0ded4b57ca57d 100644 --- a/src/ast.scm +++ b/src/ast.scm @@ -234,7 +234,9 @@ (ssavalue? e))) (define (simple-atom? x) - (or (number? x) (string? x) (char? x) (eq? x 'true) (eq? x 'false))) + (or (number? x) (string? x) (char? x) (eq? x 'true) (eq? x 'false) + (and (pair? x) (memq (car x) '(ssavalue null))) + (eq? (typeof x) 'julia_value))) ;; identify some expressions that are safe to repeat (define (effect-free? e) diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 7853a213c01d8..62b53ff8b084b 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -2891,7 +2891,7 @@ f(x) = yt(x) (capt (and vi (vinfo:asgn vi) (vinfo:capt vi)))) (if (and (not closed) (not capt) (equal? vt '(core Any))) `(= ,var ,rhs0) - (let* ((rhs1 (if (or (ssavalue? rhs0) (simple-atom? rhs0) + (let* ((rhs1 (if (or (simple-atom? rhs0) (equal? rhs0 '(the_exception))) rhs0 (make-ssavalue))) @@ -3370,7 +3370,7 @@ f(x) = yt(x) ;; pass 5: convert to linear IR ;; with this enabled, all nested calls are assigned to numbered locations -(define *very-linear-mode* #f) +(define *very-linear-mode* #t) (define (linearize e) (cond ((or (not (pair? e)) (quoted? e)) e) @@ -3386,6 +3386,22 @@ f(x) = yt(x) (error msg) (io.write *stderr* msg))) +(define (valid-ir-argument? e) + (or (simple-atom? e) (symbol? e) + (and (pair? e) + (memq (car e) '(quote inert top core globalref outerref + slot static_parameter boundscheck copyast))))) + +(define (valid-ir-rvalue? lhs e) + (or (ssavalue? lhs) + (valid-ir-argument? e) + (and (symbol? lhs) (pair? e) + (memq (car e) '(new the_exception isdefined call invoke foreigncall gc_preserve_begin))))) + +(define (valid-ir-return? e) + ;; returning lambda directly is needed for @generated + (or (valid-ir-rvalue? 'return e) (and (pair? e) (memq (car e) '(new lambda))))) + ;; this pass behaves like an interpreter on the given code. ;; to perform stateful operations, it calls `emit` to record that something ;; needs to be done. in value position, it returns an expression computing @@ -3436,18 +3452,21 @@ f(x) = yt(x) (if rett (convert-for-type-decl y rett) y)) + (define (actually-return x) + (let ((tmp (if (valid-ir-return? x) #f (make-ssavalue)))) + (if tmp (emit `(= ,tmp ,x))) + (emit `(return ,(or tmp x))))) (if (> handler-level 0) - (let ((tmp (cond ((or (simple-atom? x) (ssavalue? x) (equal? x '(null))) - #f) + (let ((tmp (cond ((simple-atom? x) #f) (finally-handler (new-mutable-var)) - (else (make-ssavalue))))) + (else (make-ssavalue))))) (if tmp (emit `(= ,tmp ,x))) (if finally-handler - (leave-finally-block `(return ,(converted (or tmp x)))) + (leave-finally-block `(return ,(or tmp x))) (begin (emit `(leave ,handler-level)) - (emit `(return ,(converted (or tmp x)))))) + (actually-return (converted (or tmp x))))) (or tmp x)) - (emit `(return ,(converted x))))) + (actually-return (converted x)))) (define (emit-break labl) (let ((lvl (caddr labl))) (if (and finally-handler (> (cadddr finally-handler) lvl)) @@ -3469,7 +3488,7 @@ f(x) = yt(x) (and (pair? x) (eq? (car x) 'block)))) e)) (cdr lst)))) - (simple? (every (lambda (x) (or (simple-atom? x) (symbol? x) (ssavalue? x) + (simple? (every (lambda (x) (or (simple-atom? x) (symbol? x) (and (pair? x) (memq (car x) '(quote inert top core globalref outerref copyast))))) lst))) @@ -3481,10 +3500,12 @@ f(x) = yt(x) (aval (compile arg break-labels #t #f linearize))) (loop (cdr lst) (cons (if (and temps? linearize (not simple?) - (not (simple-atom? arg)) (not (ssavalue? arg)) - (not (simple-atom? aval)) (not (ssavalue? aval)) + (not (simple-atom? arg)) + (not (simple-atom? aval)) (not (and (pair? arg) (memq (car arg) '(& quote inert top core globalref outerref copyast)))) + (not (and (symbol? aval) ;; function args are immutable and always assigned + (memq aval (lam:args lam)))) (not (and (symbol? arg) (or (null? (cdr lst)) (null? vals))))) @@ -3496,11 +3517,18 @@ f(x) = yt(x) (define (compile-cond ex break-labels) (let ((cnd (compile ex break-labels #t #f))) (if (and *very-linear-mode* - (not (or (simple-atom? cnd) (ssavalue? cnd) (symbol? cnd)))) + (not (valid-ir-argument? cnd))) (let ((tmp (make-ssavalue))) (emit `(= ,tmp ,cnd)) tmp) cnd))) + (define (emit-assignment lhs rhs) + (if (valid-ir-rvalue? lhs rhs) + (emit `(= ,lhs ,rhs)) + (let ((rr (make-ssavalue))) + (emit `(= ,rr ,rhs)) + (emit `(= ,lhs ,rr)))) + `(null)) ;; the interpreter loop. `break-labels` keeps track of the labels to jump to ;; for all currently closing break-blocks. ;; `value` means we are in a context where a value is required; a meaningful @@ -3555,6 +3583,10 @@ f(x) = yt(x) (equal? (cadr e) '(outerref cglobal)))) (list* (cadr e) (caddr e) (compile-args (cdddr e) break-labels linearize-args))) + ((and (length> e 2) + (or (eq? (cadr e) 'llvmcall) + (equal? (cadr e) '(outerref llvmcall)))) + (cdr e)) (else (compile-args (cdr e) break-labels linearize-args)))) (callex (cons (car e) args))) @@ -3577,7 +3609,7 @@ f(x) = yt(x) (emit `(= ,lhs ,rr)) (if tail (emit-return rr)) rr) - (emit `(= ,lhs ,rhs))))) + (emit-assignment lhs rhs)))) ((block body) (let* ((last-fname filename) (fnm (first-non-meta e)) @@ -3632,14 +3664,14 @@ f(x) = yt(x) (val (if (and value (not tail)) (new-mutable-var) #f))) (emit test) (let ((v1 (compile (caddr e) break-labels value tail))) - (if val (emit `(= ,val ,v1))) + (if val (emit-assignment val v1)) (if (and (not tail) (or (length> e 3) val)) (emit end-jump)) (set-car! (cddr test) (make&mark-label)) (let ((v2 (if (length> e 3) (compile (cadddr e) break-labels value tail) '(null)))) - (if val (emit `(= ,val ,v2))) + (if val (emit-assignment val v2)) (if (not tail) (set-car! (cdr end-jump) (make&mark-label)) (if (length= e 3) @@ -3707,7 +3739,7 @@ f(x) = yt(x) (let* ((v1 (compile (cadr e) break-labels value #f)) (val (if (and value (not tail)) (new-mutable-var) #f))) - (if val (emit `(= ,val ,v1))) + (if val (emit-assignment val v1)) (if tail (begin (emit-return v1) (if (not finally) (set! endl #f))) @@ -3722,7 +3754,7 @@ f(x) = yt(x) 'ccall 1 ,finally-exception) #f)) (let ((v2 (compile (caddr e) break-labels value tail))) - (if val (emit `(= ,val ,v2))))) + (if val (emit-assignment val v2)))) (if endl (mark-label endl)) (if finally (begin (set! finally-handler last-finally-handler) @@ -3734,7 +3766,9 @@ f(x) = yt(x) #f (make-label)))) (if skip - (emit `(gotoifnot (call (core ===) ,finally ,(caar actions)) ,skip))) + (let ((tmp (make-ssavalue))) + (emit `(= ,tmp (call (core ===) ,finally ,(caar actions)))) + (emit `(gotoifnot ,tmp ,skip)))) (let ((ac (cdar actions))) (cond ((eq? (car ac) 'return) (emit-return (cadr ac))) ((eq? (car ac) 'break) (emit-break (cadr ac))) @@ -3747,7 +3781,9 @@ f(x) = yt(x) ((method) (if (length> e 2) (begin (emit `(method ,(or (cadr e) 'false) - ,(compile (caddr e) break-labels #t #f) + ,(with-bindings + ((*very-linear-mode* #f)) + (compile (caddr e) break-labels #t #f)) ,(linearize (cadddr e)) ,(if (car (cddddr e)) 'true 'false))) (if value (compile '(null) break-labels value tail))) diff --git a/src/method.c b/src/method.c index 40c027254743d..73fcf477066cb 100644 --- a/src/method.c +++ b/src/method.c @@ -16,6 +16,7 @@ extern "C" { #endif extern jl_value_t *jl_builtin_getfield; +extern jl_value_t *jl_builtin_tuple; jl_value_t *jl_resolve_globals(jl_value_t *expr, jl_module_t *module, jl_svec_t *sparam_vals) { if (jl_is_symbol(expr)) { @@ -38,42 +39,60 @@ jl_value_t *jl_resolve_globals(jl_value_t *expr, jl_module_t *module, jl_svec_t // ignore these } else { - if (e->head == call_sym && jl_expr_nargs(e) == 3 && - jl_is_quotenode(jl_exprarg(e, 2)) && module != NULL) { - // replace getfield(module_expr, :sym) with GlobalRef - jl_value_t *s = jl_fieldref(jl_exprarg(e, 2), 0); + size_t nargs = jl_expr_nargs(e); + if (e->head == call_sym && nargs > 0) { jl_value_t *fe = jl_exprarg(e, 0); - if (jl_is_symbol(s) && jl_is_globalref(fe)) { + if (jl_is_globalref(fe) && jl_binding_resolved_p(jl_globalref_mod(fe), jl_globalref_name(fe))) { + // look at some known called functions jl_binding_t *b = jl_get_binding(jl_globalref_mod(fe), jl_globalref_name(fe)); - jl_value_t *f = NULL; - if (b && b->constp) { - f = b->value; - } - if (f == jl_builtin_getfield) { - jl_value_t *me = jl_exprarg(e, 1); - jl_module_t *me_mod = NULL; - jl_sym_t *me_sym = NULL; - if (jl_is_globalref(me)) { - me_mod = jl_globalref_mod(me); - me_sym = jl_globalref_name(me); + jl_value_t *f = b && b->constp ? b->value : NULL; + if (f == jl_builtin_getfield && nargs == 3 && + jl_is_quotenode(jl_exprarg(e, 2)) && module != NULL) { + // replace getfield(module_expr, :sym) with GlobalRef + jl_value_t *s = jl_fieldref(jl_exprarg(e, 2), 0); + if (jl_is_symbol(s)) { + jl_value_t *me = jl_exprarg(e, 1); + jl_module_t *me_mod = NULL; + jl_sym_t *me_sym = NULL; + if (jl_is_globalref(me)) { + me_mod = jl_globalref_mod(me); + me_sym = jl_globalref_name(me); + } + else if (jl_is_symbol(me) && jl_binding_resolved_p(module, (jl_sym_t*)me)) { + me_mod = module; + me_sym = (jl_sym_t*)me; + } + if (me_mod && me_sym) { + jl_binding_t *b = jl_get_binding(me_mod, me_sym); + if (b && b->constp) { + jl_value_t *m = b->value; + if (m && jl_is_module(m)) { + return jl_module_globalref((jl_module_t*)m, (jl_sym_t*)s); + } + } + } } - else if (jl_is_symbol(me) && jl_binding_resolved_p(module, (jl_sym_t*)me)) { - me_mod = module; - me_sym = (jl_sym_t*)me; + } + else if (f == jl_builtin_tuple) { + size_t j; + for (j = 1; j < nargs; j++) { + if (!jl_is_quotenode(jl_exprarg(e,j))) + break; } - if (me_mod && me_sym) { - jl_binding_t *b = jl_get_binding(me_mod, me_sym); - if (b && b->constp) { - jl_value_t *m = b->value; - if (m && jl_is_module(m)) { - return jl_module_globalref((jl_module_t*)m, (jl_sym_t*)s); - } + if (j == nargs) { + jl_value_t *val = NULL; + JL_TRY { + val = jl_interpret_toplevel_expr_in(module, (jl_value_t*)e, NULL, sparam_vals); + } + JL_CATCH { } + if (val) + return val; } } } } - size_t i = 0, nargs = jl_array_len(e->args); + size_t i = 0; if (e->head == foreigncall_sym) { JL_NARGSV(ccall method definition, 5); // (fptr, rt, at, cc, narg) jl_value_t *rt = jl_exprarg(e, 1); diff --git a/src/toplevel.c b/src/toplevel.c index 7da07e3a1f55d..cd267d9fb2cf4 100644 --- a/src/toplevel.c +++ b/src/toplevel.c @@ -658,6 +658,15 @@ jl_value_t *jl_toplevel_eval_flex(jl_module_t *m, jl_value_t *e, int fast, int e return result; } +JL_DLLEXPORT jl_value_t *jl_infer_thunk(jl_code_info_t *thk, jl_module_t *m) +{ + jl_method_instance_t *li = jl_new_thunk(thk, m); + JL_GC_PUSH1(&li); + jl_type_infer(&li, jl_get_ptls_states()->world_age, 0); + JL_GC_POP(); + return li->rettype; +} + JL_DLLEXPORT jl_value_t *jl_toplevel_eval(jl_module_t *m, jl_value_t *v) { return jl_toplevel_eval_flex(m, v, 1, 0); diff --git a/test/codevalidation.jl b/test/codevalidation.jl index 511585f5477c2..cb4d947e92d86 100644 --- a/test/codevalidation.jl +++ b/test/codevalidation.jl @@ -22,7 +22,7 @@ c0 = Core.Inference.retrieve_code_info(mi) @testset "INVALID_EXPR_HEAD" begin c = Core.Inference.copy_code_info(c0) - insert!(c.code, 4, Expr(:(=), SlotNumber(2), Expr(:invalid, 1))) + insert!(c.code, 4, Expr(:invalid, 1)) errors = Core.Inference.validate_code(c) @test length(errors) == 1 @test errors[1].kind === Core.Inference.INVALID_EXPR_HEAD @@ -47,25 +47,21 @@ end push!(c.code, Expr(:(=), SlotNumber(2), Expr(h))) end errors = Core.Inference.validate_code(c) - @test length(errors) == 10 + @test length(errors) == 7 @test count(e.kind === Core.Inference.INVALID_RVALUE for e in errors) == 7 - @test count(e.kind === Core.Inference.INVALID_EXPR_NARGS for e in errors) == 2 - @test count(e.kind === Core.Inference.INVALID_EXPR_HEAD for e in errors) == 1 end -@testset "INVALID_CALL_ARG/INVALID_EXPR_NARGS" begin +@testset "INVALID_CALL_ARG" begin c = Core.Inference.copy_code_info(c0) - insert!(c.code, 2, Expr(:(=), SlotNumber(2), Expr(:call, :+, SlotNumber(2), GotoNode(1)))) - insert!(c.code, 4, Expr(:call, :-, Expr(:call, :sin, LabelNode(2)), 3)) + insert!(c.code, 2, Expr(:(=), SlotNumber(2), Expr(:call, GlobalRef(Base,:+), SlotNumber(2), GotoNode(1)))) + insert!(c.code, 4, Expr(:call, GlobalRef(Base,:-), Expr(:call, GlobalRef(Base,:sin), LabelNode(2)), 3)) insert!(c.code, 10, Expr(:call, LineNumberNode(2))) for h in (:gotoifnot, :line, :const, :meta) - push!(c.code, Expr(:call, :f, Expr(h))) + push!(c.code, Expr(:call, GlobalRef(@__MODULE__,:f), Expr(h))) end errors = Core.Inference.validate_code(c) - @test length(errors) == 10 + @test length(errors) == 7 @test count(e.kind === Core.Inference.INVALID_CALL_ARG for e in errors) == 7 - @test count(e.kind === Core.Inference.INVALID_EXPR_NARGS for e in errors) == 2 - @test count(e.kind === Core.Inference.INVALID_EXPR_HEAD for e in errors) == 1 end @testset "EMPTY_SLOTNAMES" begin diff --git a/test/inference.jl b/test/inference.jl index cc1f51411450e..2fa7e6fd22e14 100644 --- a/test/inference.jl +++ b/test/inference.jl @@ -896,7 +896,7 @@ let f, m m.source.ssavaluetypes = 1 m.source.code = Any[ Expr(:(=), SSAValue(0), Expr(:call, GlobalRef(Core, :svec), 1, 2, 3)), - Expr(:return, Expr(:call, Core._apply, :+, SSAValue(0))) + Expr(:return, Expr(:call, Core._apply, GlobalRef(Base, :+), SSAValue(0))) ] @test @inferred(f()) == 6 end diff --git a/test/replcompletions.jl b/test/replcompletions.jl index 0f19d2384f241..505887f8d3135 100644 --- a/test/replcompletions.jl +++ b/test/replcompletions.jl @@ -377,14 +377,11 @@ let s = "CompletionFoo.test4(\"e\",r\" \"," @test s[r] == "CompletionFoo.test4" end -# (As discussed in #19829, the Base.REPLCompletions.get_type function isn't -# powerful enough to analyze general dot calls because it can't handle -# anonymous-function evaluation.) let s = "CompletionFoo.test5(push!(Base.split(\"\",' '),\"\",\"\").==\"\"," c, r, res = test_complete(s) @test !res - @test_broken length(c) == 1 - @test_broken c[1] == string(first(methods(Main.CompletionFoo.test5, Tuple{BitArray{1}}))) + @test length(c) == 1 + @test c[1] == string(first(methods(Main.CompletionFoo.test5, Tuple{BitArray{1}}))) end let s = "CompletionFoo.test4(CompletionFoo.test_y_array[1]()[1], CompletionFoo.test_y_array[1]()[2], "