Skip to content

Commit

Permalink
gf: support more dispatch on abstract types
Browse files Browse the repository at this point in the history
This removes the restriction on defining dispatch over user-defined abstract types.

The "cannot add methods to an abstract type" error is now only
applicable to a couple types (`Any`, `Function`, and functions),
and instead now gives a "not implemented yet" message.

fixes #14919 for 99% of cases
  • Loading branch information
vtjnash committed May 3, 2019
1 parent 1707e13 commit 84e00cb
Show file tree
Hide file tree
Showing 18 changed files with 311 additions and 184 deletions.
2 changes: 1 addition & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ call_result_unused(frame::InferenceState, pc::LineNum=frame.currpc) =
function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nospecialize(atype), sv::InferenceState,
max_methods = sv.params.MAX_METHODS)
atype_params = unwrap_unionall(atype).parameters
ft = unwrap_unionall(atype_params[1]) # TODO: ccall jl_first_argument_datatype here
ft = unwrap_unionall(atype_params[1]) # TODO: ccall jl_method_table_for here
isa(ft, DataType) || return Any # the function being called is unknown. can't properly handle this backedge right now
ftname = ft.name
isdefined(ftname, :mt) || return Any # not callable. should be Bottom, but can't track this backedge right now
Expand Down
32 changes: 17 additions & 15 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

struct InvokeData
mt::Core.MethodTable
entry::Core.TypeMapEntry
types0
min_valid::UInt
max_valid::UInt
end

struct Signature
Expand Down Expand Up @@ -581,9 +582,9 @@ function spec_lambda(@nospecialize(atype), sv::OptimizationState, @nospecialize(
else
invoke_data = invoke_data::InvokeData
atype <: invoke_data.types0 || return nothing
mi = ccall(:jl_get_invoke_lambda, Any, (Any, Any, Any, UInt),
invoke_data.mt, invoke_data.entry, atype, sv.params.world)
#XXX: compute min/max_valid
mi = ccall(:jl_get_invoke_lambda, Any, (Any, Any), invoke_data.entry, atype)
min_valid[1] = invoke_data.min_valid
max_valid[1] = invoke_data.max_valid
end
mi !== nothing && add_backedge!(mi::MethodInstance, sv)
update_valid_age!(min_valid[1], max_valid[1], sv)
Expand Down Expand Up @@ -924,6 +925,7 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, invoke_data::Invok
result = analyze_method!(idx, sig, metharg, methsp, method, stmt, sv, true, invoke_data,
calltype)
handle_single_case!(ir, stmt, idx, result, true, todo, sv)
update_valid_age!(invoke_data.min_valid, invoke_data.max_valid, sv)
return nothing
end

Expand Down Expand Up @@ -1109,26 +1111,26 @@ end

function compute_invoke_data(@nospecialize(atypes), params::Params)
ft = widenconst(atypes[2])
invoke_tt = widenconst(atypes[3])
mt = argument_mt(ft)
if mt === nothing || !isType(invoke_tt) || has_free_typevars(invoke_tt) ||
has_free_typevars(ft) || (ft <: Builtin)
if !isdispatchelem(ft) || has_free_typevars(ft) || (ft <: Builtin)
# TODO: this can be rather aggressive at preventing inlining of closures
# XXX: this is wrong for `ft <: Type`, since we are failing to check that
# the result doesn't have subtypes, or to do an intersection lookup
# but we need to check that `ft` can't have a subtype at runtime before using the supertype lookup below
return nothing
end
if !(isa(invoke_tt.parameters[1], Type) &&
invoke_tt.parameters[1] <: Tuple)
invoke_tt = widenconst(atypes[3])
if !isType(invoke_tt) || has_free_typevars(invoke_tt)
return nothing
end
invoke_tt = invoke_tt.parameters[1]
if !(isa(unwrap_unionall(invoke_tt), DataType) && invoke_tt <: Tuple)
return nothing
end
invoke_types = rewrap_unionall(Tuple{ft, unwrap_unionall(invoke_tt).parameters...}, invoke_tt)
min_valid = UInt[typemin(UInt)]
max_valid = UInt[typemax(UInt)]
invoke_entry = ccall(:jl_gf_invoke_lookup, Any, (Any, UInt),
invoke_types, params.world)
invoke_types, params.world) # XXX: min_valid, max_valid
invoke_entry === nothing && return nothing
#XXX: update_valid_age!(min_valid[1], max_valid[1], sv)
invoke_data = InvokeData(mt, invoke_entry, invoke_types)
invoke_data = InvokeData(invoke_entry, invoke_types, min_valid[1], max_valid[1])
atype0 = atypes[2]
atypes = atypes[4:end]
pushfirst!(atypes, atype0)
Expand Down
7 changes: 3 additions & 4 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1123,11 +1123,10 @@ end
add_tfunc(apply_type, 1, INT_INF, apply_type_tfunc, 10)

function invoke_tfunc(@nospecialize(ft), @nospecialize(types), @nospecialize(argtype), sv::InferenceState)
argument_mt(ft) === nothing && return Any
argtype = typeintersect(types, argtype)
if argtype === Bottom
return Bottom
end
argtype === Bottom && return Bottom
argtype isa DataType || return Any # other cases are not implemented below
isdispatchelem(ft) || return Any # check that we might not have a subtype of `ft` at runtime, before doing supertype lookup below
types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)
argtype = Tuple{ft, argtype.parameters...}
entry = ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), types, sv.params.world)
Expand Down
8 changes: 1 addition & 7 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -634,11 +634,6 @@ function fieldindex(T::DataType, name::Symbol, err::Bool=true)
end

argument_datatype(@nospecialize t) = ccall(:jl_argument_datatype, Any, (Any,), t)
function argument_mt(@nospecialize t)
dt = argument_datatype(t)
(dt === nothing || !isdefined(dt.name, :mt)) && return nothing
dt.name.mt
end

"""
fieldcount(t::Type)
Expand Down Expand Up @@ -1288,8 +1283,7 @@ function delete_method(m::Method)
end

function get_methodtable(m::Method)
ft = ccall(:jl_first_argument_datatype, Any, (Any,), m.sig)
(ft::DataType).name.mt
return ccall(:jl_method_table_for, Any, (Any,), m.sig)::Core.MethodTable
end

"""
Expand Down
39 changes: 21 additions & 18 deletions src/datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ JL_DLLEXPORT jl_methtable_t *jl_new_method_table(jl_sym_t *name, jl_module_t *mo
mt->backedges = NULL;
JL_MUTEX_INIT(&mt->writelock);
mt->offs = 1;
mt->frozen = 0;
return mt;
}

Expand Down Expand Up @@ -452,11 +453,8 @@ JL_DLLEXPORT jl_datatype_t *jl_new_datatype(
jl_typename_t *tn = NULL;
JL_GC_PUSH2(&t, &tn);

if (t == NULL)
t = jl_new_uninitialized_datatype();
else
tn = t->name;
// init before possibly calling jl_new_typename_in
// init enough before possibly calling jl_new_typename_in
t = jl_new_uninitialized_datatype();
t->super = super;
if (super != NULL) jl_gc_wb(t, t->super);
t->parameters = parameters;
Expand All @@ -471,23 +469,28 @@ JL_DLLEXPORT jl_datatype_t *jl_new_datatype(
t->ditype = NULL;
t->size = 0;

if (tn == NULL) {
t->name = NULL;
if (jl_is_typename(name)) {
tn = (jl_typename_t*)name;
t->name = NULL;
if (jl_is_typename(name)) {
// This code-path is used by the Serialization module to by-pass normal expectations
tn = (jl_typename_t*)name;
}
else {
tn = jl_new_typename_in((jl_sym_t*)name, module);
if (super == jl_function_type || super == jl_builtin_type || jl_symbol_name(name)[0] == '#') {
// Callable objects (including compiler-generated closures) get independent method tables
// as an optimization
tn->mt = jl_new_method_table(name, module);
jl_gc_wb(tn, tn->mt);
if (jl_svec_len(parameters) > 0)
tn->mt->offs = 0;
}
else {
tn = jl_new_typename_in((jl_sym_t*)name, module);
if (!abstract) {
tn->mt = jl_new_method_table(name, module);
jl_gc_wb(tn, tn->mt);
if (jl_svec_len(parameters) > 0)
tn->mt->offs = 0;
}
// Everything else, gets to use the unified table
tn->mt = jl_nonfunction_mt;
}
t->name = tn;
jl_gc_wb(t, t->name);
}
t->name = tn;
jl_gc_wb(t, t->name);
t->name->names = fnames;
jl_gc_wb(t->name, t->name->names);

Expand Down
70 changes: 45 additions & 25 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -313,16 +313,33 @@ static void jl_serialize_datatype(jl_serializer_state *s, jl_datatype_t *dt) JL_
tag = 10;
}

if (strncmp(jl_symbol_name(dt->name->name), "#kw#", 4) == 0) {
/* XXX: yuck, but the auto-generated kw types from the serializer isn't a real type, so we *must* be very careful */
assert(tag == 0 || tag == 5 || tag == 6 || tag == 10);
if (tag == 6) {
if (strncmp(jl_symbol_name(dt->name->name), "#kw#", 4) == 0 && !internal && tag != 0) {
/* XXX: yuck, this is horrible, but the auto-generated kw types from the serializer isn't a real type, so we *must* be very careful */
assert(tag == 6); // other struct types should never exist
tag = 9;
if (jl_type_type_mt->kwsorter != NULL && dt == (jl_datatype_t*)jl_typeof(jl_type_type_mt->kwsorter)) {
dt = jl_datatype_type; // any representative member with this MethodTable
}
else if (jl_nonfunction_mt->kwsorter != NULL && dt == (jl_datatype_t*)jl_typeof(jl_nonfunction_mt->kwsorter)) {
dt = jl_symbol_type; // any representative member with this MethodTable
}
else {
// search for the representative member of this MethodTable
jl_methtable_t *mt = dt->name->mt;
jl_datatype_t *primarydt = (jl_datatype_t*)jl_unwrap_unionall(jl_get_global(mt->module, mt->name));
size_t l = strlen(jl_symbol_name(mt->name));
char *prefixed;
prefixed = (char*)malloc(l + 2);
prefixed[0] = '#';
strcpy(&prefixed[1], jl_symbol_name(mt->name));
jl_sym_t *tname = jl_symbol(prefixed);
free(prefixed);
jl_value_t *primarydt = jl_get_global(mt->module, tname);
if (!primarydt)
primarydt = jl_get_global(mt->module, mt->name);
primarydt = jl_unwrap_unionall(primarydt);
assert(jl_is_datatype(primarydt));
assert(jl_typeof(primarydt->name->mt->kwsorter) == (jl_value_t*)dt);
dt = primarydt;
tag = 9;
assert(primarydt == (jl_value_t*)jl_any_type || jl_typeof(((jl_datatype_t*)primarydt)->name->mt->kwsorter) == (jl_value_t*)dt);
dt = (jl_datatype_t*)primarydt;
}
}

Expand Down Expand Up @@ -781,9 +798,9 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
write_uint8(s->s, internal);
if (!internal)
return;
jl_datatype_t *gf = jl_first_argument_datatype((jl_value_t*)m->sig);
assert(jl_is_datatype(gf) && gf->name->mt);
external_mt = !module_in_worklist(gf->name->mt->module);
jl_methtable_t *mt = jl_method_table_for((jl_value_t*)m->sig);
assert((jl_value_t*)mt != jl_nothing);
external_mt = !module_in_worklist(mt->module);
jl_serialize_value(s, m->specializations);
jl_serialize_value(s, (jl_value_t*)m->name);
jl_serialize_value(s, (jl_value_t*)m->file);
Expand Down Expand Up @@ -1111,9 +1128,9 @@ static int jl_collect_methcache_from_mod(jl_typemap_entry_t *ml, void *closure)
return 1;
}

static void jl_collect_methtable_from_mod(jl_array_t *s, jl_typename_t *tn) JL_GC_DISABLED
static void jl_collect_methtable_from_mod(jl_array_t *s, jl_methtable_t *mt) JL_GC_DISABLED
{
jl_typemap_visitor(tn->mt->defs, jl_collect_methcache_from_mod, (void*)s);
jl_typemap_visitor(mt->defs, jl_collect_methcache_from_mod, (void*)s);
}

static void jl_collect_lambdas_from_mod(jl_array_t *s, jl_module_t *m) JL_GC_DISABLED
Expand All @@ -1133,8 +1150,8 @@ static void jl_collect_lambdas_from_mod(jl_array_t *s, jl_module_t *m) JL_GC_DIS
jl_methtable_t *mt = tn->mt;
if (mt != NULL &&
(jl_value_t*)mt != jl_nothing &&
(mt != jl_type_type_mt || tn == jl_type_typename)) {
jl_collect_methtable_from_mod(s, tn);
(mt != jl_type_type_mt && mt != jl_nonfunction_mt)) {
jl_collect_methtable_from_mod(s, mt);
jl_collect_missing_backedges_to_mod(mt);
}
}
Expand Down Expand Up @@ -2171,9 +2188,9 @@ static void jl_insert_methods(jl_array_t *list)
jl_method_t *meth = (jl_method_t*)jl_array_ptr_ref(list, i);
jl_tupletype_t *simpletype = (jl_tupletype_t*)jl_array_ptr_ref(list, i + 1);
assert(jl_is_method(meth));
jl_datatype_t *gf = jl_first_argument_datatype((jl_value_t*)meth->sig);
assert(jl_is_datatype(gf) && gf->name->mt);
jl_method_table_insert(gf->name->mt, meth, simpletype);
jl_methtable_t *mt = jl_method_table_for((jl_value_t*)meth->sig);
assert((jl_value_t*)mt != jl_nothing);
jl_method_table_insert(mt, meth, simpletype);
}
}

Expand Down Expand Up @@ -2240,9 +2257,8 @@ static void jl_insert_backedges(jl_array_t *list, arraylist_t *dependent_worlds)
jl_method_instance_add_backedge((jl_method_instance_t*)callee, caller);
}
else {
jl_datatype_t *ftype = jl_first_argument_datatype(callee);
jl_methtable_t *mt = ftype->name->mt;
assert(jl_is_datatype(ftype) && mt);
jl_methtable_t *mt = jl_method_table_for(callee);
assert((jl_value_t*)mt != jl_nothing);
jl_method_table_add_backedge(mt, callee, (jl_value_t*)caller);
}
}
Expand Down Expand Up @@ -2767,6 +2783,10 @@ JL_DLLEXPORT int jl_save_incremental(const char *fname, jl_array_t *worklist)
assert(jl_is_module(m));
jl_collect_lambdas_from_mod(lambdas, m);
}
jl_collect_methtable_from_mod(lambdas, jl_type_type_mt);
jl_collect_missing_backedges_to_mod(jl_type_type_mt);
jl_collect_methtable_from_mod(lambdas, jl_nonfunction_mt);
jl_collect_missing_backedges_to_mod(jl_nonfunction_mt);

jl_collect_backedges(edges);

Expand Down Expand Up @@ -3041,8 +3061,8 @@ static jl_method_t *jl_lookup_method_worldset(jl_methtable_t *mt, jl_datatype_t
static jl_method_t *jl_recache_method(jl_method_t *m, size_t start, arraylist_t *dependent_worlds)
{
jl_datatype_t *sig = (jl_datatype_t*)m->sig;
jl_datatype_t *ftype = jl_first_argument_datatype((jl_value_t*)sig);
jl_methtable_t *mt = ftype->name->mt;
jl_methtable_t *mt = jl_method_table_for((jl_value_t*)m->sig);
assert((jl_value_t*)mt != jl_nothing);
jl_set_typeof(m, (void*)(intptr_t)0x30); // invalidate the old value to help catch errors
jl_method_t *_new = jl_lookup_method_worldset(mt, sig, dependent_worlds);
jl_update_backref_list((jl_value_t*)m, (jl_value_t*)_new, start);
Expand All @@ -3053,8 +3073,8 @@ static jl_method_instance_t *jl_recache_method_instance(jl_method_instance_t *mi
{
jl_datatype_t *sig = (jl_datatype_t*)mi->def.value;
assert(jl_is_datatype(sig) || jl_is_unionall(sig));
jl_datatype_t *ftype = jl_first_argument_datatype((jl_value_t*)sig);
jl_methtable_t *mt = ftype->name->mt;
jl_methtable_t *mt = jl_method_table_for((jl_value_t*)sig);
assert((jl_value_t*)mt != jl_nothing);
jl_method_t *m = jl_lookup_method_worldset(mt, sig, dependent_worlds);
jl_datatype_t *argtypes = (jl_datatype_t*)mi->specTypes;
jl_set_typeof(mi, (void*)(intptr_t)0x40); // invalidate the old value to help catch errors
Expand Down
Loading

0 comments on commit 84e00cb

Please sign in to comment.