Skip to content

Commit

Permalink
add entry point to construct an OpaqueClosure from pre-optimized IRCo…
Browse files Browse the repository at this point in the history
…de (JuliaLang#44197)

* add entry point to construct an OpaqueClosure from pre-optimized IRCode

* update `jl_new_codeinst` signature

* fixes to OpaqueClosure argument count handling and MethodError display

* more test coverage

Co-authored-by: Shuhei Kadowaki <[email protected]>
  • Loading branch information
JeffBezanson and aviatesk authored Apr 28, 2022
1 parent 9320fba commit 2049baa
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 24 deletions.
6 changes: 5 additions & 1 deletion base/errorshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,11 @@ function show_method_candidates(io::IO, ex::MethodError, @nospecialize kwargs=()
buf = IOBuffer()
iob0 = iob = IOContext(buf, io)
tv = Any[]
sig0 = method.sig
if func isa Core.OpaqueClosure
sig0 = signature_type(func, typeof(func).parameters[1])
else
sig0 = method.sig
end
while isa(sig0, UnionAll)
push!(tv, sig0.var)
iob = IOContext(iob, :unionall_env => sig0.var)
Expand Down
3 changes: 3 additions & 0 deletions base/methodshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ end

# NOTE: second argument is deprecated and is no longer used
function kwarg_decl(m::Method, kwtype = nothing)
if m.sig === Tuple # OpaqueClosure
return Symbol[]
end
mt = get_methodtable(m)
if isdefined(mt, :kwsorter)
kwtype = typeof(mt.kwsorter)
Expand Down
21 changes: 12 additions & 9 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ extern jl_value_t *jl_builtin_getfield;
extern jl_value_t *jl_builtin_tuple;

jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name,
jl_value_t *nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva);
int nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva);

static void check_c_types(const char *where, jl_value_t *rt, jl_value_t *at)
{
Expand Down Expand Up @@ -51,11 +51,14 @@ static jl_value_t *resolve_globals(jl_value_t *expr, jl_module_t *module, jl_sve
return jl_module_globalref(module, (jl_sym_t*)expr);
}
else if (jl_is_returnnode(expr)) {
jl_value_t *val = resolve_globals(jl_returnnode_value(expr), module, sparam_vals, binding_effects, eager_resolve);
if (val != jl_returnnode_value(expr)) {
JL_GC_PUSH1(&val);
expr = jl_new_struct(jl_returnnode_type, val);
JL_GC_POP();
jl_value_t *retval = jl_returnnode_value(expr);
if (retval) {
jl_value_t *val = resolve_globals(retval, module, sparam_vals, binding_effects, eager_resolve);
if (val != retval) {
JL_GC_PUSH1(&val);
expr = jl_new_struct(jl_returnnode_type, val);
JL_GC_POP();
}
}
return expr;
}
Expand Down Expand Up @@ -102,7 +105,7 @@ static jl_value_t *resolve_globals(jl_value_t *expr, jl_module_t *module, jl_sve
if (!jl_is_code_info(ci)) {
jl_error("opaque_closure_method: lambda should be a CodeInfo");
}
jl_method_t *m = jl_make_opaque_closure_method(module, name, nargs, functionloc, (jl_code_info_t*)ci, isva);
jl_method_t *m = jl_make_opaque_closure_method(module, name, jl_unbox_long(nargs), functionloc, (jl_code_info_t*)ci, isva);
return (jl_value_t*)m;
}
if (e->head == jl_cfunction_sym) {
Expand Down Expand Up @@ -782,7 +785,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module)
// method definition ----------------------------------------------------------

jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name,
jl_value_t *nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva)
int nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva)
{
jl_method_t *m = jl_new_method_uninit(module);
JL_GC_PUSH1(&m);
Expand All @@ -796,7 +799,7 @@ jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name
assert(jl_is_symbol(name));
m->name = (jl_sym_t*)name;
}
m->nargs = jl_unbox_long(nargs) + 1;
m->nargs = nargs + 1;
assert(jl_is_linenode(functionloc));
jl_value_t *file = jl_linenode_file(functionloc);
m->file = jl_is_symbol(file) ? (jl_sym_t*)file : jl_empty_sym;
Expand Down
82 changes: 68 additions & 14 deletions src/opaque_closure.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,23 @@ JL_DLLEXPORT int jl_is_valid_oc_argtype(jl_tupletype_t *argt, jl_method_t *sourc
return 1;
}

jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
jl_value_t *source_, jl_value_t **env, size_t nenv)
static jl_value_t *prepend_type(jl_value_t *t0, jl_tupletype_t *t)
{
jl_svec_t *sig_args = NULL;
JL_GC_PUSH1(&sig_args);
size_t nsig = 1 + jl_svec_len(t->parameters);
sig_args = jl_alloc_svec_uninit(nsig);
jl_svecset(sig_args, 0, t0);
for (size_t i = 0; i < nsig-1; ++i) {
jl_svecset(sig_args, 1+i, jl_tparam(t, i));
}
jl_value_t *sigtype = (jl_value_t*)jl_apply_tuple_type_v(jl_svec_data(sig_args), nsig);
JL_GC_POP();
return sigtype;
}

static jl_opaque_closure_t *new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
jl_value_t *source_, jl_value_t *captures)
{
if (!jl_is_tuple_type((jl_value_t*)argt)) {
jl_error("OpaqueClosure argument tuple must be a tuple type");
Expand All @@ -40,26 +55,19 @@ jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_
}
if (jl_nparams(argt) + 1 - jl_is_va_tuple(argt) < source->nargs - source->isva)
jl_error("Argument type tuple has too few required arguments for method");
jl_task_t *ct = jl_current_task;
jl_value_t *sigtype = NULL;
JL_GC_PUSH1(&sigtype);
sigtype = prepend_type(jl_typeof(captures), argt);

jl_value_t *oc_type JL_ALWAYS_LEAFTYPE;
oc_type = jl_apply_type2((jl_value_t*)jl_opaque_closure_type, (jl_value_t*)argt, rt_ub);
JL_GC_PROMISE_ROOTED(oc_type);
jl_value_t *captures = NULL, *sigtype = NULL;
jl_svec_t *sig_args = NULL;
JL_GC_PUSH3(&captures, &sigtype, &sig_args);
captures = jl_f_tuple(NULL, env, nenv);

size_t nsig = 1 + jl_svec_len(argt->parameters);
sig_args = jl_alloc_svec_uninit(nsig);
jl_svecset(sig_args, 0, jl_typeof(captures));
for (size_t i = 0; i < nsig-1; ++i) {
jl_svecset(sig_args, 1+i, jl_tparam(argt, i));
}
sigtype = (jl_value_t*)jl_apply_tuple_type_v(jl_svec_data(sig_args), nsig);
jl_method_instance_t *mi = jl_specializations_get_linfo(source, sigtype, jl_emptysvec);
size_t world = jl_atomic_load_acquire(&jl_world_counter);
jl_code_instance_t *ci = jl_compile_method_internal(mi, world);

jl_task_t *ct = jl_current_task;
jl_opaque_closure_t *oc = (jl_opaque_closure_t*)jl_gc_alloc(ct->ptls, sizeof(jl_opaque_closure_t), oc_type);
JL_GC_POP();
oc->source = source;
Expand All @@ -82,6 +90,52 @@ jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_
return oc;
}

jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
jl_value_t *source_, jl_value_t **env, size_t nenv)
{
jl_value_t *captures = jl_f_tuple(NULL, env, nenv);
JL_GC_PUSH1(&captures);
jl_opaque_closure_t *oc = new_opaque_closure(argt, rt_lb, rt_ub, source_, captures);
JL_GC_POP();
return oc;
}

jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name,
int nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva);

JL_DLLEXPORT jl_code_instance_t* jl_new_codeinst(
jl_method_instance_t *mi, jl_value_t *rettype,
jl_value_t *inferred_const, jl_value_t *inferred,
int32_t const_flags, size_t min_world, size_t max_world,
uint32_t ipo_effects, uint32_t effects, jl_value_t *argescapes,
uint8_t relocatability);

JL_DLLEXPORT void jl_mi_cache_insert(jl_method_instance_t *mi JL_ROOTING_ARGUMENT,
jl_code_instance_t *ci JL_ROOTED_ARGUMENT JL_MAYBE_UNROOTED);

JL_DLLEXPORT jl_opaque_closure_t *jl_new_opaque_closure_from_code_info(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
jl_module_t *mod, jl_code_info_t *ci, int lineno, jl_value_t *file, int nargs, int isva, jl_value_t *env)
{
if (!ci->inferred)
jl_error("CodeInfo must already be inferred");
jl_value_t *root = NULL, *sigtype = NULL;
jl_code_instance_t *inst = NULL;
JL_GC_PUSH3(&root, &sigtype, &inst);
root = jl_box_long(lineno);
root = jl_new_struct(jl_linenumbernode_type, root, file);
root = (jl_value_t*)jl_make_opaque_closure_method(mod, jl_nothing, nargs, root, ci, isva);

sigtype = prepend_type(jl_typeof(env), argt);
jl_method_instance_t *mi = jl_specializations_get_linfo((jl_method_t*)root, sigtype, jl_emptysvec);
inst = jl_new_codeinst(mi, rt_ub, NULL, (jl_value_t*)ci,
0, ((jl_method_t*)root)->primary_world, -1, 0, 0, jl_nothing, 0);
jl_mi_cache_insert(mi, inst);

jl_opaque_closure_t *oc = new_opaque_closure(argt, rt_lb, rt_ub, root, env);
JL_GC_POP();
return oc;
}

JL_CALLABLE(jl_new_opaque_closure_jlcall)
{
if (nargs < 4)
Expand Down
46 changes: 46 additions & 0 deletions test/opaque_closure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,49 @@ end
let oc = @opaque a->sin(a)
@test length(code_typed(oc, (Int,))) == 1
end

# constructing an opaque closure from IRCode
using Core.Compiler: IRCode
using Core: CodeInfo

function OC(ir::IRCode, nargs::Int, isva::Bool, env...)
if (isva && nargs > length(ir.argtypes)) || (!isva && nargs != length(ir.argtypes)-1)
throw(ArgumentError("invalid argument count"))
end
src = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
src.slotflags = UInt8[]
src.slotnames = fill(:none, nargs+1)
Core.Compiler.replace_code_newstyle!(src, ir, nargs+1)
Core.Compiler.widen_all_consts!(src)
src.inferred = true
# NOTE: we need ir.argtypes[1] == typeof(env)

ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
Tuple{ir.argtypes[2:end]...}, Union{}, Any, @__MODULE__, src, 0, nothing, nargs, isva, env)
end

function OC(src::CodeInfo, env...)
M = src.parent.def
sig = Base.tuple_type_tail(src.parent.specTypes)

ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
sig, Union{}, Any, @__MODULE__, src, 0, nothing, M.nargs - 1, M.isva, env)
end

let ci = code_typed(+, (Int, Int))[1][1]
ir = Core.Compiler.inflate_ir(ci)
@test OC(ir, 2, false)(40, 2) == 42
@test OC(ci)(40, 2) == 42
end

let ci = code_typed((x, y...)->(x, y), (Int, Int))[1][1]
ir = Core.Compiler.inflate_ir(ci)
@test OC(ir, 2, true)(40, 2) === (40, (2,))
@test OC(ci)(40, 2) === (40, (2,))
end

let ci = code_typed((x, y...)->(x, y), (Int, Int))[1][1]
ir = Core.Compiler.inflate_ir(ci)
@test_throws MethodError OC(ir, 2, true)(1, 2, 3)
@test_throws MethodError OC(ci)(1, 2, 3)
end

0 comments on commit 2049baa

Please sign in to comment.