Skip to content

Commit

Permalink
Merge pull request JuliaLang#45103 from JuliaLang/kf/jb/ircode2oc
Browse files Browse the repository at this point in the history
Quality of life improvements for IR2OC branch
  • Loading branch information
Keno authored Apr 29, 2022
2 parents 2049baa + 4206af5 commit 19a3ddd
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 57 deletions.
7 changes: 5 additions & 2 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,13 @@ struct InliningState{S <: Union{EdgeTracker, Nothing}, MICache, I<:AbstractInter
interp::I
end

is_source_inferred(@nospecialize(src::Union{CodeInfo, Vector{UInt8}})) =
ccall(:jl_ir_flag_inferred, Bool, (Any,), src)

function inlining_policy(interp::AbstractInterpreter, @nospecialize(src), stmt_flag::UInt8,
mi::MethodInstance, argtypes::Vector{Any})
if isa(src, CodeInfo) || isa(src, Vector{UInt8})
src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), src)
src_inferred = is_source_inferred(src)
src_inlineable = is_stmt_inline(stmt_flag) || ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
return src_inferred && src_inlineable ? src : nothing
elseif src === nothing && is_stmt_inline(stmt_flag)
Expand All @@ -73,7 +76,7 @@ function inlining_policy(interp::AbstractInterpreter, @nospecialize(src), stmt_f
inf_result === nothing && return nothing
src = inf_result.src
if isa(src, CodeInfo)
src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), src)
src_inferred = is_source_inferred(src)
return src_inferred ? src : nothing
else
return nothing
Expand Down
43 changes: 43 additions & 0 deletions base/opaque_closure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,46 @@ end
macro opaque(ty, ex)
esc(Expr(:opaque_closure, ty, ex))
end

# OpaqueClosure construction from pre-inferred CodeInfo/IRCode
using Core.Compiler: IRCode
using Core: CodeInfo

function compute_ir_rettype(ir::IRCode)
rt = Union{}
for i = 1:length(ir.stmts)
stmt = ir.stmts[i][:inst]
if isa(stmt, Core.Compiler.ReturnNode) && isdefined(stmt, :val)
rt = Core.Compiler.tmerge(Core.Compiler.argextype(stmt.val, ir), rt)
end
end
return Core.Compiler.widenconst(rt)
end

function Core.OpaqueClosure(ir::IRCode, env...;
nargs::Int = length(ir.argtypes)-1,
isva::Bool = false,
rt = compute_ir_rettype(ir))
if (isva && nargs > length(ir.argtypes)) || (!isva && nargs != length(ir.argtypes)-1)
throw(ArgumentError("invalid argument count"))
end
src = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
src.slotflags = UInt8[]
src.slotnames = fill(:none, nargs+1)
src.slottypes = copy(ir.argtypes)
Core.Compiler.replace_code_newstyle!(src, ir, nargs+1)
Core.Compiler.widen_all_consts!(src)
src.inferred = true
# NOTE: we need ir.argtypes[1] == typeof(env)

ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
Tuple{ir.argtypes[2:end]...}, Union{}, rt, @__MODULE__, src, 0, nothing, nargs, isva, env)
end

function Core.OpaqueClosure(src::CodeInfo, env...)
M = src.parent.def
sig = Base.tuple_type_tail(src.parent.specTypes)

ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
sig, Union{}, src.rettype, @__MODULE__, src, 0, nothing, M.nargs - 1, M.isva, env)
end
40 changes: 23 additions & 17 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1002,24 +1002,30 @@ void *jl_get_llvmf_defn_impl(jl_method_instance_t *mi, size_t world, char getwra
jl_value_t *jlrettype = (jl_value_t*)jl_any_type;
jl_code_info_t *src = NULL;
JL_GC_PUSH2(&src, &jlrettype);
jl_value_t *ci = jl_rettype_inferred(mi, world, world);
if (ci != jl_nothing) {
jl_code_instance_t *codeinst = (jl_code_instance_t*)ci;
src = (jl_code_info_t*)codeinst->inferred;
if ((jl_value_t*)src != jl_nothing && !jl_is_code_info(src) && jl_is_method(mi->def.method))
src = jl_uncompress_ir(mi->def.method, codeinst, (jl_array_t*)src);
jlrettype = codeinst->rettype;
}
if (!src || (jl_value_t*)src == jl_nothing) {
src = jl_type_infer(mi, world, 0);
if (src)
jlrettype = src->rettype;
else if (jl_is_method(mi->def.method)) {
src = mi->def.method->generator ? jl_code_for_staged(mi) : (jl_code_info_t*)mi->def.method->source;
if (src && !jl_is_code_info(src) && jl_is_method(mi->def.method))
src = jl_uncompress_ir(mi->def.method, NULL, (jl_array_t*)src);
if (jl_is_method(mi->def.method) && mi->def.method->source != NULL && jl_ir_flag_inferred((jl_array_t*)mi->def.method->source)) {
src = (jl_code_info_t*)mi->def.method->source;
if (src && !jl_is_code_info(src))
src = jl_uncompress_ir(mi->def.method, NULL, (jl_array_t*)src);
} else {
jl_value_t *ci = jl_rettype_inferred(mi, world, world);
if (ci != jl_nothing) {
jl_code_instance_t *codeinst = (jl_code_instance_t*)ci;
src = (jl_code_info_t*)codeinst->inferred;
if ((jl_value_t*)src != jl_nothing && !jl_is_code_info(src) && jl_is_method(mi->def.method))
src = jl_uncompress_ir(mi->def.method, codeinst, (jl_array_t*)src);
jlrettype = codeinst->rettype;
}
if (!src || (jl_value_t*)src == jl_nothing) {
src = jl_type_infer(mi, world, 0);
if (src)
jlrettype = src->rettype;
else if (jl_is_method(mi->def.method)) {
src = mi->def.method->generator ? jl_code_for_staged(mi) : (jl_code_info_t*)mi->def.method->source;
if (src && !jl_is_code_info(src) && jl_is_method(mi->def.method))
src = jl_uncompress_ir(mi->def.method, NULL, (jl_array_t*)src);
}
// TODO: use mi->uninferred
}
// TODO: use mi->uninferred
}

// emit this function into a new llvm module
Expand Down
7 changes: 7 additions & 0 deletions src/ircode.c
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,11 @@ JL_DLLEXPORT jl_array_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code)
jl_encode_value_(&s, jl_get_nth_field((jl_value_t*)code, i), copy);
}

// For opaque closure, also save the slottypes. We technically only need the first slot type,
// but this is simpler for now. We may want to refactor where this gets stored in the future.
if (m->is_for_opaque_closure)
jl_encode_value_(&s, code->slottypes, 1);

if (m->generator)
// can't optimize generated functions
jl_encode_value_(&s, (jl_value_t*)jl_compress_argnames(code->slotnames), 1);
Expand Down Expand Up @@ -834,6 +839,8 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t
jl_value_t **fld = (jl_value_t**)((char*)jl_data_ptr(code) + jl_field_offset(jl_code_info_type, i));
*fld = jl_decode_value(&s);
}
if (m->is_for_opaque_closure)
code->slottypes = jl_decode_value(&s);

jl_value_t *slotnames = jl_decode_value(&s);
if (!jl_is_string(slotnames))
Expand Down
34 changes: 29 additions & 5 deletions stdlib/InteractiveUtils/src/codeview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ code_warntype(@nospecialize(f), @nospecialize(t=Base.default_tt(f)); kwargs...)

import Base.CodegenParams

const GENERIC_SIG_WARNING = "; WARNING: This code may not match what actually runs.\n"
const OC_MISMATCH_WARNING =
"""
; WARNING: The pre-inferred opaque closure is not callable with the given arguments
; and will error on dispatch with this signature.
"""

# Printing code representations in IR and assembly
function _dump_function(@nospecialize(f), @nospecialize(t), native::Bool, wrapper::Bool,
strip_ir_metadata::Bool, dump_module::Bool, syntax::Symbol,
Expand All @@ -153,10 +160,28 @@ function _dump_function(@nospecialize(f), @nospecialize(t), native::Bool, wrappe
if isa(f, Core.Builtin)
throw(ArgumentError("argument is not a generic function"))
end
warning = ""
# get the MethodInstance for the method match
world = Base.get_world_counter()
match = Base._which(signature_type(f, t), world)
linfo = Core.Compiler.specialize_method(match)
if !isa(f, Core.OpaqueClosure)
world = Base.get_world_counter()
match = Base._which(signature_type(f, t), world)
linfo = Core.Compiler.specialize_method(match)
# TODO: use jl_is_cacheable_sig instead of isdispatchtuple
isdispatchtuple(linfo.specTypes) || (warning = GENERIC_SIG_WARNING)
else
world = UInt64(f.world)
if Core.Compiler.is_source_inferred(f.source.source)
# OC was constructed from inferred source. There's only one
# specialization and we can't infer anything more precise either.
world = f.source.primary_world
linfo = f.source.specializations[1]
Core.Compiler.hasintersect(typeof(f).parameters[1], t) || (warning = OC_MISMATCH_WARNING)
else
linfo = Core.Compiler.specialize_method(f.source, Tuple{typeof(f.captures), t.parameters...}, Core.svec())
actual = isdispatchtuple(linfo.specTypes)
isdispatchtuple(linfo.specTypes) || (warning = GENERIC_SIG_WARNING)
end
end
# get the code for it
if debuginfo === :default
debuginfo = :source
Expand All @@ -175,8 +200,7 @@ function _dump_function(@nospecialize(f), @nospecialize(t), native::Bool, wrappe
else
str = _dump_function_linfo_llvm(linfo, world, wrapper, strip_ir_metadata, dump_module, optimize, debuginfo, params)
end
# TODO: use jl_is_cacheable_sig instead of isdispatchtuple
isdispatchtuple(linfo.specTypes) || (str = "; WARNING: This code may not match what actually runs.\n" * str)
str = warning * str
return str
end

Expand Down
12 changes: 12 additions & 0 deletions stdlib/InteractiveUtils/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -670,3 +670,15 @@ let # `default_tt` should work with any function with one method
sin(a)
end); true)
end

@testset "code_llvm on opaque_closure" begin
let ci = code_typed(+, (Int, Int))[1][1]
ir = Core.Compiler.inflate_ir(ci, Any[], Any[Tuple{}, Int, Int])
oc = Core.OpaqueClosure(ir)
@test (code_llvm(devnull, oc, Tuple{Int, Int}); true)
let io = IOBuffer()
code_llvm(io, oc, Tuple{})
@test occursin(InteractiveUtils.OC_MISMATCH_WARNING, String(take!(io)))
end
end
end
45 changes: 12 additions & 33 deletions test/opaque_closure.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test
using InteractiveUtils
using Core: OpaqueClosure

const_int() = 1

Expand Down Expand Up @@ -241,47 +242,25 @@ let oc = @opaque a->sin(a)
end

# constructing an opaque closure from IRCode
using Core.Compiler: IRCode
using Core: CodeInfo

function OC(ir::IRCode, nargs::Int, isva::Bool, env...)
if (isva && nargs > length(ir.argtypes)) || (!isva && nargs != length(ir.argtypes)-1)
throw(ArgumentError("invalid argument count"))
end
src = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
src.slotflags = UInt8[]
src.slotnames = fill(:none, nargs+1)
Core.Compiler.replace_code_newstyle!(src, ir, nargs+1)
Core.Compiler.widen_all_consts!(src)
src.inferred = true
# NOTE: we need ir.argtypes[1] == typeof(env)

ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
Tuple{ir.argtypes[2:end]...}, Union{}, Any, @__MODULE__, src, 0, nothing, nargs, isva, env)
end

function OC(src::CodeInfo, env...)
M = src.parent.def
sig = Base.tuple_type_tail(src.parent.specTypes)

ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
sig, Union{}, Any, @__MODULE__, src, 0, nothing, M.nargs - 1, M.isva, env)
end

let ci = code_typed(+, (Int, Int))[1][1]
ir = Core.Compiler.inflate_ir(ci)
@test OC(ir, 2, false)(40, 2) == 42
@test OC(ci)(40, 2) == 42
@test OpaqueClosure(ir; nargs=2, isva=false)(40, 2) == 42
@test OpaqueClosure(ci)(40, 2) == 42

ir = Core.Compiler.inflate_ir(ci, Any[], Any[Tuple{}, Int, Int])
@test OpaqueClosure(ir; nargs=2, isva=false)(40, 2) == 42
@test isa(OpaqueClosure(ir; nargs=2, isva=false), Core.OpaqueClosure{Tuple{Int, Int}, Int})
@test_throws TypeError OpaqueClosure(ir; nargs=2, isva=false)(40.0, 2)
end

let ci = code_typed((x, y...)->(x, y), (Int, Int))[1][1]
ir = Core.Compiler.inflate_ir(ci)
@test OC(ir, 2, true)(40, 2) === (40, (2,))
@test OC(ci)(40, 2) === (40, (2,))
@test OpaqueClosure(ir; nargs=2, isva=true)(40, 2) === (40, (2,))
@test OpaqueClosure(ci)(40, 2) === (40, (2,))
end

let ci = code_typed((x, y...)->(x, y), (Int, Int))[1][1]
ir = Core.Compiler.inflate_ir(ci)
@test_throws MethodError OC(ir, 2, true)(1, 2, 3)
@test_throws MethodError OC(ci)(1, 2, 3)
@test_throws MethodError OpaqueClosure(ir; nargs=2, isva=true)(1, 2, 3)
@test_throws MethodError OpaqueClosure(ci)(1, 2, 3)
end

0 comments on commit 19a3ddd

Please sign in to comment.