diff --git a/base/Base.jl b/base/Base.jl index 057c512887c6c..71229a1aba48c 100644 --- a/base/Base.jl +++ b/base/Base.jl @@ -222,6 +222,9 @@ using .Libc: getpid, gethostname, time include("env.jl") +# YAKC +include("yakc.jl") + # Concurrency include("linked_list.jl") include("condition.jl") diff --git a/base/boot.jl b/base/boot.jl index e653a82399ba5..f531e3215dcd1 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -785,4 +785,14 @@ Integer(x::Union{Float32, Float64}) = Int(x) # The internal jl_parse which will call into Core._parse if not `nothing`. _parse = nothing +# YAKC Definition +#= +struct YAKC{A <: Tuple, R} + env::Any + ci::CodeInfo + fptr1 + fptr +end +=# + ccall(:jl_set_istopmod, Cvoid, (Any, Bool), Core, true) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index d4f92a3b8176f..79e4163e802fd 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -89,7 +89,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), edges = Any[] nonbot = 0 # the index of the only non-Bottom inference result if > 0 seen = 0 # number of signatures actually inferred - istoplevel = sv.linfo.def isa Module + istoplevel = sv.linfo !== nothing && sv.linfo.def isa Module multiple_matches = napplicable > 1 if f !== nothing && napplicable == 1 && is_method_pure(applicable[1]::MethodMatch) @@ -337,7 +337,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp sv_method2 isa Method || (sv_method2 = nothing) # Union{Method, Nothing} while !(infstate === nothing) infstate = infstate::InferenceState - if method === infstate.linfo.def + if infstate.linfo !== nothing && method === infstate.linfo.def if infstate.linfo.specTypes == sig # avoid widening when detecting self-recursion # TODO: merge call cycle and return right away @@ -773,6 +773,24 @@ function argtype_tail(argtypes::Vector{Any}, i::Int) return argtypes[i:n] end +function _yakc_tfunc(@nospecialize(arg), @nospecialize(lb), @nospecialize(ub), + @nospecialize(env), @nospecialize(ci), linfo::MethodInstance) + argt, argt_exact = instanceof_tfunc(arg) + lbt, lb_exact = instanceof_tfunc(lb) + if !lb_exact + lbt = Union{} + end + + ubt, ub_exact = instanceof_tfunc(ub) + + t = Core.YAKC{argt_exact ? argt : <:argt} + t = t{(lbt == ubt && ub_exact) ? ubt : T} where lbt<:T<:ubt + + isa(ci, Const) || return t + + PartialYAKC(t, env, linfo, ci.val) +end + function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, sv::InferenceState, max_methods::Int) la = length(argtypes) @@ -790,6 +808,10 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U ty = typeintersect(ty, cnd.elsetype) end return tmerge(tx, ty) + elseif f === Core._yakc + la == 6 || return Union{} + return _yakc_tfunc(argtypes[2], argtypes[3], argtypes[4], + argtypes[5], argtypes[6], sv.linfo) end rt = builtin_tfunction(interp, f, argtypes[2:end], sv) if f === getfield && isa(fargs, Vector{Any}) && la == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] ⊑ Tuple @@ -1003,6 +1025,27 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), return abstract_call_gf_by_type(interp, f, argtypes, atype, sv, max_methods) end +function abstract_call_yakc(interp::AbstractInterpreter, yakc::PartialYAKC, argtypes::Vector{Any}, sv::InferenceState) + if isa(yakc.ci, CodeInfo) + nargtypes = argtypes[2:end] + pushfirst!(nargtypes, yakc.env) + result = InferenceResult(Core.YAKC, nargtypes) + state = InferenceState(result, copy(yakc.ci), false, interp) + typeinf_local(interp, state) + finish(state, interp) + yakc.ci = result.src + return CallMeta(result.result, false) + elseif isa(yakc.ci, OptimizationState) + return CallMeta(yakc.ci.src.rettype, nothing) + else + nargtypes = argtypes[2:end] + pushfirst!(nargtypes, Core.YAKC) + sig = argtypes_to_type(nargtypes) + rt, edge = abstract_call_method(interp, yakc.ci::Method, sig, Core.svec(), false, sv) + return CallMeta(rt, edge) + end +end + # call where the function is any lattice element function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS) @@ -1014,6 +1057,8 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{ f = ft.parameters[1] elseif isa(ft, DataType) && isdefined(ft, :instance) f = ft.instance + elseif isa(ft, PartialYAKC) + return abstract_call_yakc(interp, ft, argtypes, sv) else # non-constant function, but the number of arguments is known # and the ft is not a Builtin or IntrinsicFunction @@ -1326,14 +1371,14 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) elseif isa(stmt, ReturnNode) pc´ = n + 1 rt = widenconditional(abstract_eval_value(interp, stmt.val, s[pc], frame)) - if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct) + if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct) && !isa(rt, PartialYAKC) # only propagate information we know we can store # and is valid inter-procedurally rt = widenconst(rt) end if tchanged(rt, frame.bestguess) # new (wider) return type for frame - frame.bestguess = tmerge(frame.bestguess, rt) + frame.bestguess = frame.bestguess === NOT_FOUND ? rt : tmerge(frame.bestguess, rt) for (caller, caller_pc) in frame.cycle_backedges # notify backedges of updated type information typeassert(caller.stmt_types[caller_pc], VarTable) # we must have visited this statement before diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index 2d5fce04c0454..915db7b0b4894 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -5,7 +5,7 @@ const LineNum = Int mutable struct InferenceState params::InferenceParams result::InferenceResult # remember where to put the result - linfo::MethodInstance + linfo::Union{MethodInstance, Nothing} sptypes::Vector{Any} # types of static parameter slottypes::Vector{Any} mod::Module @@ -37,6 +37,8 @@ mutable struct InferenceState callers_in_cycle::Vector{InferenceState} parent::Union{Nothing, InferenceState} + has_yakcs::Bool + # TODO: move these to InferenceResult / Params? cached::Bool limited::Bool @@ -57,9 +59,23 @@ mutable struct InferenceState cached::Bool, interp::AbstractInterpreter) linfo = result.linfo code = src.code::Array{Any,1} - toplevel = !isa(linfo.def, Method) - sp = sptypes_from_meth_instance(linfo::MethodInstance) + if !isa(linfo, Nothing) + toplevel = !isa(linfo.def, Method) + sp = sptypes_from_meth_instance(linfo::MethodInstance) + if !toplevel + meth = linfo.def + inmodule = meth.module + else + inmodule = linfo.def::Module + end + else + linfo = nothing + toplevel = true + inmodule = Core + sp = Any[] + end + code = src.code::Array{Any,1} nssavalues = src.ssavaluetypes::Int src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ] @@ -93,13 +109,6 @@ mutable struct InferenceState W = BitSet() push!(W, 1) #initial pc to visit - if !toplevel - meth = linfo.def - inmodule = meth.module - else - inmodule = linfo.def::Module - end - valid_worlds = WorldRange(src.min_world, src.max_world == typemax(UInt) ? get_world_counter() : src.max_world) frame = new( @@ -112,7 +121,7 @@ mutable struct InferenceState ssavalue_uses, throw_blocks, Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges Vector{InferenceState}(), # callers_in_cycle - #=parent=#nothing, + #=parent=#nothing, #= has_yakcs =# false, cached, false, false, false, CachedMethodTable(method_table(interp)), interp) @@ -242,6 +251,7 @@ end # temporarily accumulate our edges to later add as backedges in the callee function add_backedge!(li::MethodInstance, caller::InferenceState) + caller.linfo !== nothing || return # don't add backends to yakcs isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs if caller.stmt_edges[caller.currpc] === nothing caller.stmt_edges[caller.currpc] = [] diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index d53b8193e639a..a89af8c607e45 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -34,7 +34,7 @@ struct InliningState{S <: Union{EdgeTracker, Nothing}, T <: Union{InferenceCache end mutable struct OptimizationState - linfo::MethodInstance + linfo::Union{MethodInstance, Nothing} src::CodeInfo stmt_info::Vector{Any} mod::Module @@ -311,6 +311,13 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any} ftyp = argextype(farg, src, sptypes, slottypes) end end + # Give calls to YAKCs zero cost. The plan is for these to be a single + # indirect call so have very little cost. On the other hand, there + # is enormous benefit to inlining these into a function where we can + # see the definition of the YAKC. Perhaps this should even be negative + if widenconst(ftyp) <: Core.YAKC + return 0 + end f = singleton_type(ftyp) if isa(f, IntrinsicFunction) iidx = Int(reinterpret(Int32, f::IntrinsicFunction)) + 1 diff --git a/base/compiler/ssair/driver.jl b/base/compiler/ssair/driver.jl index 465102e82e155..391467edcd539 100644 --- a/base/compiler/ssair/driver.jl +++ b/base/compiler/ssair/driver.jl @@ -33,7 +33,24 @@ function normalize(@nospecialize(stmt), meta::Vector{Any}) return stmt end -function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, nargs::Int, sv::OptimizationState) +function add_yakc_argtypes!(argtypes, t) + dt = unwrap_unionall(t) + dt1 = unwrap_unionall(dt.parameters[1]) + if isa(dt1, TypeVar) || isa(dt1.parameters[1], TypeVar) + push!(argtypes, Any) + else + TT = dt1.parameters[1] + if isa(TT, Union) + TT = tuplemerge(TT.a, TT.b) + end + for p in TT.parameters + push!(argtypes, rewrap_unionall(p, t)) + end + end +end + + +function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, nargs::Int, sv::OptimizationState, slottypes=sv.slottypes, stmtinfo=sv.stmt_info) # Go through and add an unreachable node after every # Union{} call. Then reindex labels. idx = 1 @@ -41,7 +58,8 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg changemap = fill(0, length(code)) labelmap = coverage ? fill(0, length(code)) : changemap prevloc = zero(eltype(ci.codelocs)) - stmtinfo = sv.stmt_info + stmtinfo = copy(stmtinfo) + yakcs = IRCode[] while idx <= length(code) codeloc = ci.codelocs[idx] if coverage && codeloc != prevloc && codeloc != 0 @@ -57,7 +75,22 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg idx += 1 prevloc = codeloc end - if code[idx] isa Expr && ci.ssavaluetypes[idx] === Union{} + stmt = code[idx] + if isexpr(stmt, :(=)) + stmt = stmt.args[2] + end + ssat = ci.ssavaluetypes[idx] + if isa(ssat, PartialYAKC) && isexpr(stmt, :call) + ft = argextype(stmt.args[1], ci, sv.sptypes) + # Pre-convert any YAKC objects + if isa(ft, Const) && ft.val === Core._yakc && isa(ssat.ci, OptimizationState) + yakc_ir = make_ir(ssat.ci.src, 0, ssat.ci) + push!(yakcs, yakc_ir) + stmt.head = :new_yakc + push!(stmt.args, length(yakcs)) + end + end + if stmt isa Expr && ssat === Union{} if !(idx < length(code) && isa(code[idx + 1], ReturnNode) && !isdefined((code[idx + 1]::ReturnNode), :val)) # insert unreachable in the same basic block after the current instruction (splitting it) insert!(code, idx + 1, ReturnNode()) @@ -105,7 +138,7 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg cfg = compute_basic_blocks(code) types = Any[] stmts = InstructionStream(code, types, stmtinfo, ci.codelocs, flags) - ir = IRCode(stmts, cfg, collect(LineInfoNode, ci.linetable), sv.slottypes, meta, sv.sptypes) + ir = IRCode(stmts, cfg, collect(LineInfoNode, ci.linetable), slottypes, meta, sv.sptypes, yakcs) return ir end @@ -117,24 +150,37 @@ function slot2reg(ir::IRCode, ci::CodeInfo, nargs::Int, sv::OptimizationState) return ir end -function run_passes(ci::CodeInfo, nargs::Int, sv::OptimizationState) - preserve_coverage = coverage_enabled(sv.mod) - ir = convert_to_ircode(ci, copy_exprargs(ci.code), preserve_coverage, nargs, sv) +function compact_all!(ir::IRCode) + length(ir.stmts) == 0 && return ir + for i in 1:length(ir.yakcs) + ir.yakcs[i] = compact_all!(ir.yakcs[i]) + end + compact!(ir) +end + +function make_ir(ci::CodeInfo, nargs::Int, sv::OptimizationState) + ir = convert_to_ircode(ci, copy_exprargs(ci.code), coverage_enabled(sv.mod), nargs, sv) ir = slot2reg(ir, ci, nargs, sv) + ir +end + +function run_passes(ci::CodeInfo, nargs::Int, sv::OptimizationState) + ir = make_ir(ci, nargs, sv) #@Base.show ("after_construct", ir) # TODO: Domsorting can produce an updated domtree - no need to recompute here @timeit "compact 1" ir = compact!(ir) @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds) #@timeit "verify 2" verify_ir(ir) - ir = compact!(ir) + ir = compact_all!(ir) #@Base.show ("before_sroa", ir) @timeit "SROA" ir = getfield_elim_pass!(ir) + ir = yakc_optim_pass!(ir) #@Base.show ir.new_nodes #@Base.show ("after_sroa", ir) ir = adce_pass!(ir) #@Base.show ("after_adce", ir) @timeit "type lift" ir = type_lift_pass!(ir) - @timeit "compact 3" ir = compact!(ir) + @timeit "compact 3" ir = compact_all!(ir) #@Base.show ir if JLOptions().debug_level == 2 @timeit "verify 3" (verify_ir(ir); verify_linetable(ir.linetable)) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 4d23cc586262c..48abb993aed40 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -40,7 +40,7 @@ end struct InliningTodo # The MethodInstance to be inlined - mi::MethodInstance + mi::Union{MethodInstance, Nothing} spec::Union{ResolvedInliningSpec, DelayedInliningSpec} end @@ -65,6 +65,9 @@ end function ssa_inlining_pass!(ir::IRCode, linetable::Vector{LineInfoNode}, state::InliningState, propagate_inbounds::Bool) # Go through the function, performing simple ininlingin (e.g. replacing call by constants # and analyzing legality of inlining). + for (idx, ir′) in enumerate(ir.yakcs) + ir.yakcs[idx] = ssa_inlining_pass!(ir′, ir′.linetable, state, propagate_inbounds) + end @timeit "analysis" todo = assemble_inline_todo!(ir, state) isempty(todo) && return ir # Do the actual inlining for every call we identified @@ -309,11 +312,17 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector push!(linetable, LineInfoNode(entry.module, entry.method, entry.file, entry.line, (entry.inlined_at > 0 ? entry.inlined_at + linetable_offset : inlined_at))) end - nargs_def = item.mi.def.nargs - isva = nargs_def > 0 && item.mi.def.isva - if isva - vararg = mk_tuplecall!(compact, argexprs[nargs_def:end], compact.result[idx][:line]) - argexprs = Any[argexprs[1:(nargs_def - 1)]..., vararg] + sparam_vals = Core.svec() + sig = Tuple + if item.mi !== nothing + nargs_def = item.mi.def.nargs + isva = nargs_def > 0 && item.mi.def.isva + if isva + vararg = mk_tuplecall!(compact, argexprs[nargs_def:end], compact.result[idx][:line]) + argexprs = Any[argexprs[1:(nargs_def - 1)]..., vararg] + end + sig = item.mi.def.sig + sparam_vals = item.mi.sparam_vals end flag = compact.result[idx][:flag] boundscheck_idx = boundscheck @@ -335,7 +344,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector # face of rename_arguments! mutating in place - should figure out # something better eventually. inline_compact[idx′] = nothing - stmt′ = ssa_substitute!(idx′, stmt′, argexprs, item.mi.def.sig, item.mi.sparam_vals, linetable_offset, boundscheck_idx, compact) + stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck_idx, compact) if isa(stmt′, ReturnNode) isa(stmt′.val, SSAValue) && (compact.used_ssas[stmt′.val.id] += 1) return_value = SSAValue(idx′) @@ -345,6 +354,8 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector compact_exprtype(compact, stmt′.val) : compact_exprtype(inline_compact, stmt′.val) break + elseif isexpr(stmt′, :new_yakc) + stmt′ = Expr(:new_yakc, stmt′.args[1], stmt′.args[2], stmt′.args[3], stmt′.args[4] + length(compact.ir.yakcs)) end inline_compact[idx′] = stmt′ end @@ -362,7 +373,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx) for ((_, idx′), stmt′) in inline_compact inline_compact[idx′] = nothing - stmt′ = ssa_substitute!(idx′, stmt′, argexprs, item.mi.def.sig, item.mi.sparam_vals, linetable_offset, boundscheck_idx, compact) + stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck_idx, compact) if isa(stmt′, ReturnNode) if isdefined(stmt′, :val) val = stmt′.val @@ -382,7 +393,6 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector push!(pn.values, val) stmt′ = GotoNode(post_bb_id) end - end elseif isa(stmt′, GotoNode) stmt′ = GotoNode(stmt′.label + bb_offset) @@ -392,6 +402,8 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector stmt′ = GotoIfNot(stmt′.cond, stmt′.dest + bb_offset) elseif isa(stmt′, PhiNode) stmt′ = PhiNode(Int32[edge+bb_offset for edge in stmt′.edges], stmt′.values) + elseif isexpr(stmt′, :new_yakc) + stmt′ = Expr(:new_yakc, stmt′.args[1], stmt′.args[2], stmt′.args[3], stmt′.args[4] + length(compact.ir.yakcs)) end inline_compact[idx′] = stmt′ end @@ -410,6 +422,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector return_value = insert_node_here!(compact, pn, compact_exprtype(compact, SSAValue(idx)), compact.result[idx][:line]) end end + append!(compact.ir.yakcs, spec.ir.yakcs) return_value end @@ -533,6 +546,12 @@ function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vect for ((old_idx, idx), stmt) in compact if old_idx == inline_idx argexprs = copy(stmt.args) + if isa(item, InliningTodo) && item.mi === nothing + # For yakc, the `self` argument is the object passed as + # the environment + @assert isa(argexprs[1], SSAValue) + argexprs[1] = compact[argexprs[1]].args[5] + end refinish = false if compact.result_idx == first(compact.result_bbs[compact.active_result_bb].stmts) compact.active_result_bb -= 1 @@ -994,6 +1013,19 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, invoke_data::Invok return nothing end +function narrow_yakc!(stmt, calltype) + if isa(calltype, PartialYAKC) + # Narrow yakc type + if isa(calltype.ci, OptimizationState) + stmt.args[3] = stmt.args[4] = widenconst(calltype.ci.src.rettype) + elseif isa(calltype.ci, Method) + m = calltype.ci + stmt.args[6] = m + stmt.args[3] = stmt.args[4] = widenconst(m.unspecialized.cache.rettype) + end + end +end + # Handles all analysis and inlining of intrinsics and builtins. In particular, # this method does not access the method table or otherwise process generic # functions. @@ -1003,6 +1035,9 @@ function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sta if stmt.head === :splatnew inline_splatnew!(ir, idx) return nothing + elseif stmt.head === :new_yakc + narrow_yakc!(stmt, ir.stmts[idx][:type]) + return nothing end stmt.head === :call || return nothing @@ -1022,6 +1057,12 @@ function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sta return nothing end + # Handle _yakc (if not already handled above) + if sig.f === Core._yakc + narrow_yakc!(stmt, calltype) + return + end + # Handle invoke invoke_data = nothing if sig.f === Core.invoke && length(sig.atypes) >= 3 @@ -1033,6 +1074,16 @@ function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sta return nothing end + # Inlining YAKC that don't have any specializations + if sig.ft ⊑ Core.YAKC + callee = stmt.args[1] + if isa(callee, SSAValue) && isexpr(ir.stmts[callee.id][:inst], :new_yakc) && length(ir.stmts[callee.id][:inst].args) == 7 + ir′ = ir.yakcs[ir.stmts[callee.id][:inst].args[end]::Int]::IRCode + push!(todo, idx=>InliningTodo(nothing, ResolvedInliningSpec(ir′, linear_inline_eligible(ir′)))) + return nothing + end + end + sig = with_atype(sig) # In :invoke, make sure that the arguments we're passing are a subtype of the @@ -1165,6 +1216,10 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) ir.stmts[idx][:inst] = quoted(calltype.val) continue end + # Refuse to inline YAKCs we can't see otherwise, to preserve the + # possibility of functions higher in the call stack seeing this + # and performing the inlining. + continue end # Ok, now figure out what method to call diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index fb987b3bc3aa0..23d9786ce60b8 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -260,12 +260,16 @@ struct IRCode cfg::CFG new_nodes::NewNodeStream meta::Vector{Any} + # For easier reference, we store all yakcs. In general passes will want to + # just recurse into these. Passes that are yakc aware may do more fancy + # optimizations + yakcs::Vector{IRCode} - function IRCode(stmts::InstructionStream, cfg::CFG, linetable::Vector{LineInfoNode}, argtypes::Vector{Any}, meta::Vector{Any}, sptypes::Vector{Any}) - return new(stmts, argtypes, sptypes, linetable, cfg, NewNodeStream(), meta) + function IRCode(stmts::InstructionStream, cfg::CFG, linetable::Vector{LineInfoNode}, argtypes::Vector{Any}, meta::Vector{Any}, sptypes::Vector{Any}, yakcs::Vector{IRCode}) + return new(stmts, argtypes, sptypes, linetable, cfg, NewNodeStream(), meta, yakcs) end function IRCode(ir::IRCode, stmts::InstructionStream, cfg::CFG, new_nodes::NewNodeStream) - return new(stmts, ir.argtypes, ir.sptypes, ir.linetable, cfg, new_nodes, ir.meta) + return new(stmts, ir.argtypes, ir.sptypes, ir.linetable, cfg, new_nodes, ir.meta, ir.yakcs) end global copy copy(ir::IRCode) = new(copy(ir.stmts), copy(ir.argtypes), copy(ir.sptypes), @@ -379,7 +383,7 @@ function getindex(x::UseRef) end function is_relevant_expr(e::Expr) - return e.head in (:call, :invoke, :new, :splatnew, :(=), :(&), + return e.head in (:call, :invoke, :new, :new_yakc, :splatnew, :(=), :(&), :gc_preserve_begin, :gc_preserve_end, :foreigncall, :isdefined, :copyast, :undefcheck, :throw_undef_if_not, diff --git a/base/compiler/ssair/legacy.jl b/base/compiler/ssair/legacy.jl index 1fa847734359b..23ca5654036e3 100644 --- a/base/compiler/ssair/legacy.jl +++ b/base/compiler/ssair/legacy.jl @@ -13,6 +13,7 @@ end function inflate_ir(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any}) code = copy_exprargs(ci.code) # TODO: this is a huge hot-spot cfg = compute_basic_blocks(code) + yakcs = IRCode[] for i = 1:length(code) stmt = code[i] # Translate statement edges to bb_edges @@ -25,6 +26,20 @@ function inflate_ir(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any}) elseif isa(stmt, Expr) && stmt.head === :enter stmt.args[1] = block_for_inst(cfg, stmt.args[1]) code[i] = stmt + elseif isa(stmt, Expr) && stmt.head == :new + # Pre-convert any YAKC objects + if length(stmt.args) == 3 && isa(stmt.args[3], CodeInfo) + t = widenconst(argextype(stmt.args[1], ci, sptypes)) + if t <: Type{<:Core.YAKC} + argtypes′ = Any[argextype(stmt.args[2], ci, sptypes)] + add_yakc_argtypes!(argtypes′, t) + yakc_ir = inflate_ir(stmt.args[3], Any[], argtypes′) + push!(yakcs, yakc_ir) + stmt.head = :new_yakc + push!(stmt.args, length(yakcs)) + end + end + code[i] = stmt else code[i] = stmt end @@ -33,7 +48,7 @@ function inflate_ir(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any}) nstmts = length(code) ssavaluetypes = ci.ssavaluetypes isa Vector{Any} ? copy(ci.ssavaluetypes) : Any[ Any for i = 1:(ci.ssavaluetypes::Int) ] stmts = InstructionStream(code, ssavaluetypes, Any[nothing for i = 1:nstmts], copy(ci.codelocs), copy(ci.ssaflags)) - ir = IRCode(stmts, cfg, collect(LineInfoNode, ci.linetable), argtypes, Any[], sptypes) + ir = IRCode(stmts, cfg, collect(LineInfoNode, ci.linetable), argtypes, Any[], sptypes, yakcs) return ir end @@ -62,6 +77,13 @@ function replace_code_newstyle!(ci::CodeInfo, ir::IRCode, nargs::Int) stmt = PhiNode(Int32[last(ir.cfg.blocks[edge].stmts) for edge in stmt.edges], stmt.values) elseif isa(stmt, Expr) && stmt.head === :enter stmt.args[1] = first(ir.cfg.blocks[stmt.args[1]::Int].stmts) + elseif isa(stmt, Expr) && stmt.head == :new_yakc + ci′ = copy(stmt.args[end-1]) + ir′ = ir.yakcs[stmt.args[end]] + replace_code_newstyle!(ci′, ir′, length(ir′.argtypes)-1) + pop!(stmt.args) + stmt.args[end] = ci′ + stmt.head = :call end ci.code[i] = stmt end diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 4af13d81b76d0..ead54073185a7 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -1194,3 +1194,41 @@ function cfg_simplify!(ir::IRCode) compact.active_result_bb = length(bb_starts) return finish(compact) end + +function analyze_env_uses(yakc_ir) + uses = BitSet() + for idx in 1:length(yakc_ir.stmts) + stmt = yakc_ir.stmts[idx][:inst] + isexpr(stmt, :call) || continue + if is_known_call(stmt, getfield, yakc_ir, Any[]) + if stmt.args[2] == Argument(1) + push!(uses, stmt.args[3]) + end + end + end + uses +end + +function yakc_optim_pass!(ir::IRCode) + if isempty(ir.yakcs) + return ir + end + + # For any yakcs being co-optimized, optimize the capture environment + uses = BitSet[] + for ir′ in ir.yakcs + push!(uses, analyze_env_uses(ir′)) + end + + compact = IncrementalCompact(ir) + for ((_, idx), stmt) in compact + isexpr(stmt, :new_yakc) || continue + this_yakc_uses = uses[stmt.args[end]] + if isempty(this_yakc_uses) + stmt.args[5] = nothing + end + end + + # TODO: We could hoise code here + return finish(compact) +end diff --git a/base/compiler/ssair/queries.jl b/base/compiler/ssair/queries.jl index 6a6ac89c91e7c..dec1765883b2d 100644 --- a/base/compiler/ssair/queries.jl +++ b/base/compiler/ssair/queries.jl @@ -42,14 +42,15 @@ function stmt_effect_free(@nospecialize(stmt), @nospecialize(rt), src, sptypes:: isexact || return false isconcretedispatch(typ) || return false typ = typ::DataType - fieldcount(typ) >= length(ea) - 1 || return false - for fld_idx in 1:(length(ea) - 1) + nargs = length(ea) - (head === :new ? 1 : 2) + fieldcount(typ) >= nargs || return false + for fld_idx in 1:nargs eT = argextype(ea[fld_idx + 1], src, sptypes) fT = fieldtype(typ, fld_idx) eT ⊑ fT || return false end return true - elseif head === :isdefined || head === :the_exception || head === :copyast || head === :inbounds || head === :boundscheck + elseif head === :isdefined || head === :the_exception || head === :copyast || head === :inbounds || head === :boundscheck || head === :new_yakc return true else # e.g. :loopinfo diff --git a/base/compiler/ssair/verify.jl b/base/compiler/ssair/verify.jl index 40cc8731ce477..9343bd91a05b5 100644 --- a/base/compiler/ssair/verify.jl +++ b/base/compiler/ssair/verify.jl @@ -14,13 +14,13 @@ end function check_op(ir::IRCode, domtree::DomTree, @nospecialize(op), use_bb::Int, use_idx::Int, print::Bool) if isa(op, SSAValue) if op.id > length(ir.stmts) - def_bb = block_for_inst(ir.cfg, ir.new_nodes[op.id - length(ir.stmts)].pos) + def_bb = block_for_inst(ir.cfg, ir.new_nodes.info[op.id - length(ir.stmts)].pos) else def_bb = block_for_inst(ir.cfg, op.id) end if (def_bb == use_bb) if op.id > length(ir.stmts) - @assert ir.new_nodes[op.id - length(ir.stmts)].pos <= use_idx + @assert ir.new_nodes.info[op.id - length(ir.stmts)].pos <= use_idx else if op.id >= use_idx @verify_error "Def ($(op.id)) does not dominate use ($(use_idx)) in same BB" diff --git a/base/compiler/stmtinfo.jl b/base/compiler/stmtinfo.jl index 214fd89a17078..b510160944ff2 100644 --- a/base/compiler/stmtinfo.jl +++ b/base/compiler/stmtinfo.jl @@ -79,3 +79,13 @@ This info is illegal on any statement that is not an `_apply_iterate` call. struct UnionSplitApplyCallInfo infos::Vector{ApplyCallInfo} end + +""" + struct YAKCCallInfo + +The call was to a YAKC of known provenance. A MethodInstance was created to +hold the optimized code for this YAKC. +""" +struct YAKCCallInfo + mi::MethodInstance +end diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 2cd89d0442fdb..31b1ecba4ff4e 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -96,6 +96,9 @@ function CodeInstance(result::InferenceResult, @nospecialize(inferred_result::An elseif isa(result.result, PartialStruct) rettype_const = (result.result::PartialStruct).fields const_flags = 0x2 + elseif isa(result.result, PartialYAKC) + rettype_const = result.result + const_flags = 0x2 else rettype_const = nothing const_flags = 0x00 @@ -173,6 +176,48 @@ function cache_result!(interp::AbstractInterpreter, result::InferenceResult, val nothing end +function finish_yakc!(yakc::PartialYAKC, mod::Module, interp::AbstractInterpreter) + # Infer with the most general types possible, so that optimization has + # access to the result. + + argt = unwrap_unionall(yakc.t).parameters[1] + + argtypes = Any[argt.parameters...] + pushfirst!(argtypes, yakc.env) + if isdispatchtuple(argt) + # If we don't need to track specializations, just infer this here + # right now. + result = InferenceResult(Core.YAKC, argtypes) + state = InferenceState(result, copy(yakc.ci), false, interp) + typeinf_local(interp, state) + finish(state, interp) + yakc.ci = result.src + else + # Otherwise infer via the method instance cache. + m = ccall(:jl_mk_yakc_method, Any, (Any,), mod)::Method + m.nargs = Int32(length(argtypes)) + m.source = yakc.ci + argtypes = Any[argt.parameters...] + pushfirst!(argtypes, Core.YAKC) + m.unspecialized = specialize_method(m, argtypes_to_type(argtypes), Core.svec()) + + mi = m.unspecialized + + lock_mi_inference(interp, mi) + result = InferenceResult(mi) + frame = InferenceState(result, #=cached=#true, interp) + if frame === nothing + # can't get the source for this, so we know nothing + unlock_mi_inference(interp, mi) + return + end + typeinf(interp, frame) + unlock_mi_inference(interp, mi) + + yakc.ci = m + end +end + # inference completed on `me` # update the MethodInstance function finish(me::InferenceState, interp::AbstractInterpreter) @@ -186,9 +231,15 @@ function finish(me::InferenceState, interp::AbstractInterpreter) else # annotate fulltree with type information type_annotate!(me) + if isa(me.bestguess, PartialYAKC) + mod = isa(me.linfo.def, Module) ? me.linfo.def : me.linfo.def.module + finish_yakc!(me.bestguess, mod, interp) + end me.result.src = OptimizationState(me, OptimizationParams(interp), interp) end - me.result.result = me.bestguess + if me.result !== nothing + me.result.result = me.bestguess + end nothing end @@ -508,6 +559,8 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize if isdefined(code, :rettype_const) if isa(code.rettype_const, Vector{Any}) && !(Vector{Any} <: code.rettype) return PartialStruct(code.rettype, code.rettype_const), mi + elseif isa(code.rettype_const, PartialYAKC) && code.rettype <: Core.YAKC + return code.rettype_const, mi else return Const(code.rettype_const), mi end @@ -552,7 +605,7 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize end function widenconst_bestguess(bestguess) - !isa(bestguess, Const) && !isa(bestguess, PartialStruct) && !isa(bestguess, Type) && return widenconst(bestguess) + !isa(bestguess, Const) && !isa(bestguess, PartialStruct) && !isa(bestguess, PartialYAKC) && !isa(bestguess, Type) && return widenconst(bestguess) return bestguess end diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 17a444e840b77..0d4f637aeafe5 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -54,6 +54,15 @@ struct PartialTypeVar PartialTypeVar(tv::TypeVar, lb_certain::Bool, ub_certain::Bool) = new(tv, lb_certain, ub_certain) end +mutable struct PartialYAKC + t::Type + env::Any + parent::MethodInstance + ci::Any + # TODO: Where do we cache these results? +end +widenconst(py::PartialYAKC) = py.t + # Wraps a type and represents that the value may also be undef at this point. # (only used in optimize, not abstractinterpret) struct MaybeUndef @@ -162,6 +171,9 @@ function ⊑(@nospecialize(a), @nospecialize(b)) end return false end + if isa(a, PartialYAKC) + return widenconst(a) <: widenconst(b) + end if isa(a, Const) if isa(b, Const) return a.val === b.val diff --git a/base/compiler/types.jl b/base/compiler/types.jl index 45a5bfcb12169..b531d5ac04d59 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -23,7 +23,7 @@ abstract type AbstractInterpreter; end A type that represents the result of running type inference on a chunk of code. """ mutable struct InferenceResult - linfo::MethodInstance + linfo::Union{Nothing, MethodInstance} argtypes::Vector{Any} overridden_by_const::BitVector result # ::Type, or InferenceState if WIP @@ -32,6 +32,9 @@ mutable struct InferenceResult argtypes, overridden_by_const = matching_cache_argtypes(linfo, given_argtypes) return new(linfo, argtypes, overridden_by_const, Any, nothing) end + function InferenceResult(::Type{Core.YAKC}, argtypes::Vector{Any}) + new(nothing, argtypes, BitVector(), Any, nothing) + end end diff --git a/base/yakc.jl b/base/yakc.jl new file mode 100644 index 0000000000000..fd301b61a5553 --- /dev/null +++ b/base/yakc.jl @@ -0,0 +1,6 @@ +@noinline function (y::Core.YAKC{A, R})(args...) where {A,R} + typeassert(args, A) + ccall(y.fptr1, Any, (Any, Ptr{Any}, Int), y, Any[args...], length(args))::R +end + +# YAKC macro goes here diff --git a/src/Makefile b/src/Makefile index 578677b3e1b9b..467e1d3e55bb4 100644 --- a/src/Makefile +++ b/src/Makefile @@ -45,7 +45,7 @@ RUNTIME_SRCS := \ simplevector runtime_intrinsics precompile \ threading partr stackwalk gc gc-debug gc-pages gc-stacks method \ jlapi signal-handling safepoint timing subtype \ - crc32c APInt-C processor ircode + crc32c APInt-C processor ircode yakc SRCS := jloptions runtime_ccall rtutils LLVMLINK := diff --git a/src/builtin_proto.h b/src/builtin_proto.h index 8021c404bd5e7..c75364738d662 100644 --- a/src/builtin_proto.h +++ b/src/builtin_proto.h @@ -35,6 +35,7 @@ DECLARE_BUILTIN(apply_type); DECLARE_BUILTIN(applicable); DECLARE_BUILTIN(invoke); DECLARE_BUILTIN(_expr); DECLARE_BUILTIN(typeassert); DECLARE_BUILTIN(ifelse); DECLARE_BUILTIN(_typevar); DECLARE_BUILTIN(_typebody); +DECLARE_BUILTIN(_yakc); JL_CALLABLE(jl_f_invoke_kwsorter); JL_CALLABLE(jl_f__structtype); diff --git a/src/builtins.c b/src/builtins.c index 6d5f3f2779a12..b431cf2f2c076 100644 --- a/src/builtins.c +++ b/src/builtins.c @@ -1395,6 +1395,19 @@ JL_CALLABLE(jl_f__equiv_typedef) return equiv_type(args[0], args[1]) ? jl_true : jl_false; } +// yakc +JL_CALLABLE(jl_f__yakc) +{ + JL_NARGS(_yakc, 5, 5) + JL_TYPECHK(_yakc, type, args[0]); + JL_TYPECHK(_yakc, type, args[1]); + JL_TYPECHK(_yakc, type, args[2]); + if (!jl_is_method(args[4]) && !jl_is_code_info(args[4])) { + jl_error("Invalid YAKC source"); + } +return jl_new_yakc(args[0], args[1], args[2], args[3], args[4]); +} + // IntrinsicFunctions --------------------------------------------------------- static void (*runtime_fp[num_intrinsics])(void); @@ -1554,6 +1567,7 @@ void jl_init_primitives(void) JL_GC_DISABLED jl_builtin__apply_iterate = add_builtin_func("_apply_iterate", jl_f__apply_iterate); jl_builtin__expr = add_builtin_func("_expr", jl_f__expr); jl_builtin_svec = add_builtin_func("svec", jl_f_svec); + jl_builtin__yakc = add_builtin_func("_yakc", jl_f__yakc); add_builtin_func("_apply_pure", jl_f__apply_pure); add_builtin_func("_apply_latest", jl_f__apply_latest); add_builtin_func("_apply_in_world", jl_f__apply_in_world); @@ -1604,6 +1618,7 @@ void jl_init_primitives(void) JL_GC_DISABLED add_builtin("Ptr", (jl_value_t*)jl_pointer_type); add_builtin("LLVMPtr", (jl_value_t*)jl_llvmpointer_type); add_builtin("Task", (jl_value_t*)jl_task_type); + add_builtin("YAKC", (jl_value_t*)jl_yakc_type); add_builtin("AbstractArray", (jl_value_t*)jl_abstractarray_type); add_builtin("DenseArray", (jl_value_t*)jl_densearray_type); diff --git a/src/datatype.c b/src/datatype.c index 269076ef2819b..fbbac7767bede 100644 --- a/src/datatype.c +++ b/src/datatype.c @@ -927,8 +927,21 @@ static void init_struct_tail(jl_datatype_t *type, jl_value_t *jv, size_t na) JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args, uint32_t na) { jl_ptls_t ptls = jl_get_ptls_states(); - if (!jl_is_datatype(type) || type->layout == NULL) + if (!jl_is_datatype(type) || type->layout == NULL) { + // As a special case we're allowed to have unionalls over yakc, + // where each typevar is replaced by its upper bound. + if (jl_is_unionall(type)) { + jl_datatype_t *dt = jl_unwrap_unionall(type); + if (jl_is_yakc_type(dt)) { + while (jl_is_unionall(type)) { + jl_tvar_t *tv = ((jl_unionall_t*)type)->var; + type = jl_instantiate_unionall(type, tv->ub); + } + return jl_new_structv(type, args, na); + } + } jl_type_error("new", (jl_value_t*)jl_datatype_type, (jl_value_t*)type); + } if (type->ninitialized > na || na > jl_datatype_nfields(type)) jl_error("invalid struct allocation"); for (size_t i = 0; i < na; i++) { diff --git a/src/dump.c b/src/dump.c index 4425ada1ad268..3e6ee8c1700b4 100644 --- a/src/dump.c +++ b/src/dump.c @@ -2554,7 +2554,6 @@ void jl_init_serializer(void) jl_box_int64(12), jl_box_int64(13), jl_box_int64(14), jl_box_int64(15), jl_box_int64(16), jl_box_int64(17), jl_box_int64(18), jl_box_int64(19), jl_box_int64(20), - jl_box_int64(21), jl_bool_type, jl_linenumbernode_type, jl_pinode_type, jl_upsilonnode_type, jl_type_type, jl_bottom_type, jl_ref_type, @@ -2570,6 +2569,7 @@ void jl_init_serializer(void) jl_namedtuple_type, jl_array_int32_type, jl_typedslot_type, jl_uint32_type, jl_uint64_type, jl_type_type_mt, jl_nonfunction_mt, + jl_yakc_type, ptls->root_task, @@ -2638,6 +2638,7 @@ void jl_init_serializer(void) arraylist_push(&builtin_typenames, jl_tuple_typename); arraylist_push(&builtin_typenames, jl_vararg_typename); arraylist_push(&builtin_typenames, jl_namedtuple_typename); + arraylist_push(&builtin_typenames, jl_yakc_typename); } #ifdef __cplusplus diff --git a/src/gf.c b/src/gf.c index 322222798e8de..7c993e3bea55f 100644 --- a/src/gf.c +++ b/src/gf.c @@ -205,7 +205,6 @@ JL_DLLEXPORT jl_value_t *jl_methtable_lookup(jl_methtable_t *mt, jl_value_t *typ // ----- MethodInstance specialization instantiation ----- // -JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t*); JL_DLLEXPORT jl_code_instance_t* jl_new_codeinst( jl_method_instance_t *mi, jl_value_t *rettype, jl_value_t *inferred_const, jl_value_t *inferred, @@ -2381,8 +2380,6 @@ JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup_worlds(jl_value_t *types, size_t wo return (jl_value_t*)matc->method; } -static jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t *gf, jl_value_t **args, size_t nargs); - // invoke() // this does method dispatch with a set of types to match other than the // types of the actual arguments. this means it sometimes does NOT call the @@ -2412,7 +2409,7 @@ jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t *gf, jl_value_t **args, return jl_gf_invoke_by_method(method, gf, args, nargs); } -static jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t *gf, jl_value_t **args, size_t nargs) +jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t *gf, jl_value_t **args, size_t nargs) { jl_method_instance_t *mfunc = NULL; jl_typemap_entry_t *tm = NULL; diff --git a/src/interpreter.c b/src/interpreter.c index ba97321a922ee..f178b04c5b814 100644 --- a/src/interpreter.c +++ b/src/interpreter.c @@ -651,6 +651,39 @@ jl_value_t *NOINLINE jl_fptr_interpret_call(jl_value_t *f, jl_value_t **args, ui return r; } +jl_value_t *jl_interpret_yakc(jl_yakc_t *yakc, jl_value_t **args, size_t nargs) +{ + jl_code_info_t *source = NULL; + jl_value_t *code = yakc->code; + if (jl_is_method(code)) { + source = (jl_code_info_t*)yakc->method->source; + } + else { + source = yakc->source; + } + jl_array_t *stmts = source->code; + assert(jl_typeis(stmts, jl_array_any_type)); + interpreter_state *s; + unsigned nroots = jl_source_nslots(source) + jl_source_nssavalues(source) + 2; + JL_GC_PUSHFRAME(s, nroots); + jl_value_t **locals = (jl_value_t**)&s[1] + 3; + locals[0] = (jl_value_t*)yakc; + locals[1] = (jl_value_t*)stmts; + locals[2] = (jl_value_t*)yakc->env; + s->locals = locals + 2; + s->src = source; + s->module = NULL; + s->sparam_vals = NULL; + s->preevaluation = 0; + s->continue_at = 0; + s->mi = NULL; + for (int i = 0; i < nargs; ++i) + s->locals[1 + i] = args[i]; + jl_value_t *r = eval_body(stmts, s, 0, 0); + JL_GC_POP(); + return r; +} + jl_value_t *NOINLINE jl_interpret_toplevel_thunk(jl_module_t *m, jl_code_info_t *src) { interpreter_state *s; diff --git a/src/jltypes.c b/src/jltypes.c index dd36776f1bac1..35fae8c7a7370 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -52,6 +52,7 @@ jl_datatype_t *jl_unionall_type; jl_datatype_t *jl_datatype_type; jl_datatype_t *jl_function_type; jl_datatype_t *jl_builtin_type; +jl_unionall_t *jl_yakc_type; jl_datatype_t *jl_typeofbottom_type; jl_value_t *jl_bottom_type; @@ -84,6 +85,7 @@ JL_DLLEXPORT jl_value_t *jl_false; jl_unionall_t *jl_array_type; jl_typename_t *jl_array_typename; +jl_typename_t *jl_yakc_typename; jl_value_t *jl_array_uint8_type; jl_value_t *jl_array_any_type; jl_value_t *jl_array_symbol_type; @@ -2392,7 +2394,6 @@ void jl_init_types(void) JL_GC_DISABLED jl_perm_symsvec(4, "spec_types", "sparams", "method", "fully_covers"), jl_svec(4, jl_type_type, jl_simplevector_type, jl_method_type, jl_bool_type), 0, 0, 4); - // all Kinds share the Type method table (not the nonfunction one) jl_unionall_type->name->mt = jl_uniontype_type->name->mt = jl_datatype_type->name->mt = jl_type_type_mt; @@ -2467,8 +2468,17 @@ void jl_init_types(void) JL_GC_DISABLED jl_value_t *listt = jl_new_struct(jl_uniontype_type, jl_task_type, jl_nothing_type); jl_svecset(jl_task_type->types, 0, listt); - // complete builtin type metadata jl_value_t *pointer_void = jl_apply_type1((jl_value_t*)jl_pointer_type, (jl_value_t*)jl_nothing_type); + + tv = jl_svec2(tvar("A"), tvar("R")); + jl_yakc_type = (jl_unionall_t*)jl_new_datatype(jl_symbol("YAKC"), core, jl_any_type, tv, + jl_perm_symsvec(4, "env", "ci", "fptr1", "fptr"), + jl_svec(4, jl_any_type, jl_code_info_type, pointer_void, pointer_void), 0, 0, 4)->name->wrapper; + jl_yakc_typename = ((jl_datatype_t*)jl_unwrap_unionall((jl_value_t*)jl_yakc_type))->name; + jl_compute_field_offsets((jl_datatype_t*)jl_unwrap_unionall((jl_value_t*)jl_yakc_type)); + + + // complete builtin type metadata jl_voidpointer_type = (jl_datatype_t*)pointer_void; jl_uint8pointer_type = (jl_datatype_t*)jl_apply_type1((jl_value_t*)jl_pointer_type, (jl_value_t*)jl_uint8_type); jl_svecset(jl_datatype_type->types, 6, jl_voidpointer_type); diff --git a/src/julia.h b/src/julia.h index eacc76eb0e0f2..de7d722fb268b 100644 --- a/src/julia.h +++ b/src/julia.h @@ -351,6 +351,19 @@ struct _jl_method_instance_t { uint8_t inInference; // flags to tell if inference is running on this object }; +// YACK - Yet another kind of closure. +typedef struct jl_yakc_t { + JL_DATA_TYPE + jl_value_t *env; + union { + jl_value_t *code; + jl_code_info_t *source; + jl_method_t *method; + }; + jl_fptr_args_t fptr1; + void *fptr; +} jl_yakc_t; + // This type represents an executable operation typedef struct _jl_code_instance_t { JL_DATA_TYPE @@ -624,6 +637,8 @@ extern JL_DLLEXPORT jl_unionall_t *jl_vararg_type JL_GLOBALLY_ROOTED; extern JL_DLLEXPORT jl_typename_t *jl_vararg_typename JL_GLOBALLY_ROOTED; extern JL_DLLEXPORT jl_datatype_t *jl_function_type JL_GLOBALLY_ROOTED; extern JL_DLLEXPORT jl_datatype_t *jl_builtin_type JL_GLOBALLY_ROOTED; +extern JL_DLLEXPORT jl_unionall_t *jl_yakc_type JL_GLOBALLY_ROOTED; +extern JL_DLLEXPORT jl_typename_t *jl_yakc_typename JL_GLOBALLY_ROOTED; extern JL_DLLEXPORT jl_value_t *jl_bottom_type JL_GLOBALLY_ROOTED; extern JL_DLLEXPORT jl_datatype_t *jl_method_instance_type JL_GLOBALLY_ROOTED; @@ -1203,6 +1218,19 @@ STATIC_INLINE int jl_is_array(void *v) JL_NOTSAFEPOINT return jl_is_array_type(t); } + +STATIC_INLINE int jl_is_yakc_type(void *t) JL_NOTSAFEPOINT +{ + return (jl_is_datatype(t) && + ((jl_datatype_t*)(t))->name == jl_yakc_typename); +} + +STATIC_INLINE int jl_is_yakc(void *v) JL_NOTSAFEPOINT +{ + jl_value_t *t = jl_typeof(v); + return jl_is_yakc_type(t); +} + STATIC_INLINE int jl_is_cpointer_type(jl_value_t *t) JL_NOTSAFEPOINT { return (jl_is_datatype(t) && diff --git a/src/julia_internal.h b/src/julia_internal.h index 79079504c9bba..46dc0324946e6 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -484,6 +484,7 @@ jl_array_t *jl_get_loaded_modules(void); jl_value_t *jl_toplevel_eval_flex(jl_module_t *m, jl_value_t *e, int fast, int expanded); jl_value_t *jl_eval_global_var(jl_module_t *m JL_PROPAGATES_ROOT, jl_sym_t *e); +jl_value_t *jl_interpret_yakc(jl_yakc_t *yakc, jl_value_t **args, size_t nargs); jl_value_t *jl_interpret_toplevel_thunk(jl_module_t *m, jl_code_info_t *src); jl_value_t *jl_interpret_toplevel_expr_in(jl_module_t *m, jl_value_t *e, jl_code_info_t *src, @@ -493,6 +494,8 @@ jl_value_t *jl_call_scm_on_ast(const char *funcname, jl_value_t *expr, jl_module void jl_linenumber_to_lineinfo(jl_code_info_t *ci, jl_module_t *mod, jl_value_t *name); jl_method_instance_t *jl_method_lookup(jl_value_t **args, size_t nargs, size_t world); + +jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t *gf, jl_value_t **args, size_t nargs); jl_value_t *jl_gf_invoke(jl_value_t *types, jl_value_t *f, jl_value_t **args, size_t nargs); JL_DLLEXPORT jl_value_t *jl_matching_methods(jl_tupletype_t *types, int lim, int include_ambiguous, size_t world, size_t *min_valid, size_t *max_valid, int *ambig); @@ -515,6 +518,8 @@ extern jl_array_t *jl_module_init_order JL_GLOBALLY_ROOTED; extern htable_t jl_current_modules JL_GLOBALLY_ROOTED; JL_DLLEXPORT void jl_compile_extern_c(void *llvmmod, void *params, void *sysimg, jl_value_t *declrt, jl_value_t *sigt); +jl_yakc_t *jl_new_yakc(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub, jl_value_t *env, jl_value_t *source); + // Each tuple can exist in one of 4 Vararg states: // NONE: no vararg Tuple{Int,Float32} // INT: vararg with integer length Tuple{Int,Vararg{Float32,2}} @@ -708,6 +713,8 @@ void jl_get_function_id(void *native_code, jl_code_instance_t *ncode, JL_DLLEXPORT jl_array_t *jl_idtable_rehash(jl_array_t *a, size_t newsz); jl_value_t **jl_table_peek_bp(jl_array_t *a, jl_value_t *key) JL_NOTSAFEPOINT; +JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t*); + JL_DLLEXPORT jl_methtable_t *jl_new_method_table(jl_sym_t *name, jl_module_t *module); jl_method_instance_t *jl_get_specialization1(jl_tupletype_t *types, size_t world, size_t *min_valid, size_t *max_valid, int mt_cache); jl_method_instance_t *jl_get_specialized(jl_method_t *m, jl_value_t *types, jl_svec_t *sp); diff --git a/src/staticdata.c b/src/staticdata.c index 2b70975c7d7ab..e0d1a20dd5b9f 100644 --- a/src/staticdata.c +++ b/src/staticdata.c @@ -46,6 +46,7 @@ static void *const _tags[] = { &jl_vararg_type, &jl_abstractarray_type, &jl_densearray_type, &jl_nothing_type, &jl_function_type, &jl_typeofbottom_type, &jl_unionall_type, &jl_typename_type, &jl_builtin_type, &jl_code_info_type, + &jl_yakc_type, &jl_task_type, &jl_uniontype_type, &jl_abstractstring_type, &jl_array_any_type, &jl_intrinsic_type, &jl_abstractslot_type, &jl_methtable_type, &jl_typemap_level_type, &jl_typemap_entry_type, @@ -60,7 +61,7 @@ static void *const _tags[] = { // special typenames &jl_tuple_typename, &jl_pointer_typename, &jl_llvmpointer_typename, &jl_array_typename, &jl_type_typename, &jl_vararg_typename, &jl_namedtuple_typename, - &jl_vecelement_typename, + &jl_vecelement_typename, &jl_yakc_typename, // special exceptions &jl_errorexception_type, &jl_argumenterror_type, &jl_typeerror_type, &jl_methoderror_type, &jl_loaderror_type, &jl_initerror_type, @@ -85,6 +86,7 @@ static void *const _tags[] = { &jl_builtin_const_arrayref, &jl_builtin_arrayset, &jl_builtin_arraysize, &jl_builtin_apply_type, &jl_builtin_applicable, &jl_builtin_invoke, &jl_builtin__expr, &jl_builtin_ifelse, &jl_builtin__typebody, + &jl_builtin__yakc, NULL }; static jl_value_t **const*const tags = (jl_value_t**const*const)_tags; @@ -125,7 +127,7 @@ static const jl_fptr_args_t id_to_fptrs[] = { &jl_f_arrayref, &jl_f_const_arrayref, &jl_f_arrayset, &jl_f_arraysize, &jl_f_apply_type, &jl_f_applicable, &jl_f_invoke, &jl_f_sizeof, &jl_f__expr, &jl_f__typevar, &jl_f_ifelse, &jl_f__structtype, &jl_f__abstracttype, &jl_f__primitivetype, - &jl_f__typebody, &jl_f__setsuper, &jl_f__equiv_typedef, + &jl_f__typebody, &jl_f__setsuper, &jl_f__equiv_typedef, &jl_f__yakc, NULL }; typedef struct { diff --git a/src/yakc.c b/src/yakc.c new file mode 100644 index 0000000000000..2d7fde2ccecb9 --- /dev/null +++ b/src/yakc.c @@ -0,0 +1,42 @@ +#include "julia.h" +#include "julia_internal.h" + +JL_DLLEXPORT jl_value_t *jl_invoke_yakc(jl_yakc_t *yakc, jl_value_t **args, size_t nargs) +{ + // TODO: Compiler support + jl_tupletype_t *argt = jl_tparam0(jl_typeof(yakc)); + if (nargs != jl_nparams(argt)) + jl_error("Incorrect argument count for YAKC"); + for (int i = 0; i < nargs; ++i) + jl_typeassert(args[i], jl_field_type(argt, i)); + jl_value_t *ret; + if (jl_is_method(yakc->source)) { + ret = jl_gf_invoke_by_method((jl_method_t*)yakc->source, yakc, args, nargs + 1); + } else { + ret = jl_interpret_yakc(yakc, args, nargs); + } + jl_typeassert(ret, jl_tparam1(jl_typeof(yakc))); + return ret; +} + +jl_yakc_t *jl_new_yakc(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub, jl_value_t *env, jl_value_t *source) +{ + jl_ptls_t ptls = jl_get_ptls_states(); + jl_value_t *yakc_t = jl_apply_type2((jl_value_t*)jl_yakc_type, argt, rt_ub); + jl_yakc_t *yakc = jl_gc_alloc(ptls, sizeof(jl_yakc_t), yakc_t); + yakc->env = env; + yakc->source = source; + yakc->fptr1 = jl_invoke_yakc; + yakc->fptr = NULL; + return yakc; +} + + +JL_DLLEXPORT jl_method_t* jl_mk_yakc_method(jl_module_t *def_mod) +{ + jl_method_t *m = jl_new_method_uninit(def_mod); + m->name = jl_symbol("YAKC"); + m->sig = (jl_value_t*)jl_anytuple_type; + m->slot_syms = jl_an_empty_string; + return m; +} diff --git a/test/choosetests.jl b/test/choosetests.jl index 9ca97a543123c..3bb900cc3124e 100644 --- a/test/choosetests.jl +++ b/test/choosetests.jl @@ -54,7 +54,8 @@ function choosetests(choices = []) "checked", "bitset", "floatfuncs", "precompile", "boundscheck", "error", "ambiguous", "cartesian", "osutils", "channels", "iostream", "secretbuffer", "specificity", - "reinterpretarray", "syntax", "logging", "missing", "asyncmap", "atexit" + "reinterpretarray", "syntax", "logging", "missing", "asyncmap", "atexit", + "yakc" ] tests = [] diff --git a/test/compiler/ssair.jl b/test/compiler/ssair.jl index ae8e86c2c9a5e..8ed910c7fa67f 100644 --- a/test/compiler/ssair.jl +++ b/test/compiler/ssair.jl @@ -121,7 +121,7 @@ let cfg = CFG(BasicBlock[ make_bb([2, 3] , [] ), ], Int[]) insts = Compiler.InstructionStream([], [], Any[], Int32[], UInt8[]) - code = Compiler.IRCode(insts, cfg, LineInfoNode[], [], [], []) + code = Compiler.IRCode(insts, cfg, LineInfoNode[], [], [], [], Compiler.IRCode[]) compact = Compiler.IncrementalCompact(code, true) @test length(compact.result_bbs) == 4 && 0 in compact.result_bbs[3].preds end diff --git a/test/yakc.jl b/test/yakc.jl new file mode 100644 index 0000000000000..c6c29b8e6daf2 --- /dev/null +++ b/test/yakc.jl @@ -0,0 +1,91 @@ +using Test + +const_int() = 1 + +let ci = @code_lowered const_int() + @eval function yakc_trivial() + $(Expr(:call, Core._yakc, Tuple{}, Any, Any, nothing, ci)) + end +end +@test yakc_trivial()() == 1 + +let ci = @code_lowered const_int() + @eval function yakc_simple_inf() + $(Expr(:call, Core._yakc, Tuple{}, Union{}, Any, nothing, ci)) + end +end +@test isa(yakc_simple_inf(), Core.YAKC{Tuple{}, Int}) + +struct YakcClos2Int + a::Int + b::Int +end +(a::YakcClos2Int)() = getfield(a, 1) + getfield(a, 2) +let ci = @code_lowered YakcClos2Int(1, 2)(); + @eval function yakc_trivial_clos() + $(Expr(:call, Core._yakc, Tuple{}, Int64, Int64, (1, 2), ci)) + end +end +@test yakc_trivial_clos()() == 3 + +let ci = @code_lowered YakcClos2Int(1, 2)(); + @eval function yakc_self_call_clos() + $(Expr(:call, Core._yakc, Tuple{}, Int64, Int64, (1, 2), ci))() + end +end +@test yakc_self_call_clos() == 3 +let opt = @code_typed yakc_self_call_clos() + @test length(opt[1].code) == 1 + @test isa(opt[1].code[1], Core.ReturnNode) +end + +struct YakcClos1Any + a +end +(a::YakcClos1Any)() = getfield(a, 1) +let ci = @code_lowered YakcClos1Any(1)() + @eval function yakc_pass_clos(x) + $(Expr(:call, Core._yakc, Tuple{}, Any, Any, :((x,)), ci)) + end +end +@test yakc_pass_clos(1)() == 1 +@test yakc_pass_clos("a")() == "a" + +let ci = @code_lowered YakcClos1Any(1)() + @eval function yakc_infer_pass_clos(x) + $(Expr(:call, Core._yakc, Tuple{}, Union{}, Any, :((x,)), ci)) + end +end +@test isa(yakc_infer_pass_clos(1), Core.YAKC{Tuple{}, typeof(1)}) +@test isa(yakc_infer_pass_clos("a"), Core.YAKC{Tuple{}, typeof("a")}) +@test yakc_infer_pass_clos(1)() == 1 +@test yakc_infer_pass_clos("a")() == "a" + +let ci = @code_lowered identity(1) + @eval function yakc_infer_pass_id() + $(Expr(:call, Core._yakc, Tuple{Any}, Any, Any, nothing, ci)) + end +end +function complicated_identity(x) + yakc_infer_pass_id()(x) +end +@test @inferred(complicated_identity(1)) == 1 +@test @inferred(complicated_identity("a")) == "a" + +struct YakcOpt + A +end + +(A::YakcOpt)() = ndims(getfield(A, 1)) + +let ci = @code_lowered YakcOpt([1 2])() + @eval function yakc_opt_ndims(A) + $(Expr(:call, Core._yakc, Tuple{}, Union{}, Any, :((A,)), ci)) + end +end +let A = [1 2] + let yakc = yakc_opt_ndims(A) + @test sizeof(yakc.env) == 0 + @test yakc() == 2 + end +end