Skip to content

Commit

Permalink
gf: support more dispatch on abstract types (#31916)
Browse files Browse the repository at this point in the history
This removes the restriction on defining dispatch over user-defined abstract types.

The "cannot add methods to an abstract type" error is now only
applicable to a couple types (`Any`, `Function`, and functions),
and instead now gives a "not implemented yet" message.

fixes #14919 for 99% of cases
  • Loading branch information
vtjnash authored May 15, 2019
1 parent 5c5f5c2 commit 99d2406
Show file tree
Hide file tree
Showing 20 changed files with 340 additions and 202 deletions.
2 changes: 1 addition & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ call_result_unused(frame::InferenceState, pc::LineNum=frame.currpc) =
function abstract_call_gf_by_type(@nospecialize(f), argtypes::Vector{Any}, @nospecialize(atype), sv::InferenceState,
max_methods = sv.params.MAX_METHODS)
atype_params = unwrap_unionall(atype).parameters
ft = unwrap_unionall(atype_params[1]) # TODO: ccall jl_first_argument_datatype here
ft = unwrap_unionall(atype_params[1]) # TODO: ccall jl_method_table_for here
isa(ft, DataType) || return Any # the function being called is unknown. can't properly handle this backedge right now
ftname = ft.name
isdefined(ftname, :mt) || return Any # not callable. should be Bottom, but can't track this backedge right now
Expand Down
32 changes: 17 additions & 15 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

struct InvokeData
mt::Core.MethodTable
entry::Core.TypeMapEntry
types0
min_valid::UInt
max_valid::UInt
end

struct Signature
Expand Down Expand Up @@ -581,9 +582,9 @@ function spec_lambda(@nospecialize(atype), sv::OptimizationState, @nospecialize(
else
invoke_data = invoke_data::InvokeData
atype <: invoke_data.types0 || return nothing
mi = ccall(:jl_get_invoke_lambda, Any, (Any, Any, Any, UInt),
invoke_data.mt, invoke_data.entry, atype, sv.params.world)
#XXX: compute min/max_valid
mi = ccall(:jl_get_invoke_lambda, Any, (Any, Any), invoke_data.entry, atype)
min_valid[1] = invoke_data.min_valid
max_valid[1] = invoke_data.max_valid
end
mi !== nothing && add_backedge!(mi::MethodInstance, sv)
update_valid_age!(min_valid[1], max_valid[1], sv)
Expand Down Expand Up @@ -924,6 +925,7 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, invoke_data::Invok
result = analyze_method!(idx, sig, metharg, methsp, method, stmt, sv, true, invoke_data,
calltype)
handle_single_case!(ir, stmt, idx, result, true, todo, sv)
update_valid_age!(invoke_data.min_valid, invoke_data.max_valid, sv)
return nothing
end

Expand Down Expand Up @@ -1109,26 +1111,26 @@ end

function compute_invoke_data(@nospecialize(atypes), params::Params)
ft = widenconst(atypes[2])
invoke_tt = widenconst(atypes[3])
mt = argument_mt(ft)
if mt === nothing || !isType(invoke_tt) || has_free_typevars(invoke_tt) ||
has_free_typevars(ft) || (ft <: Builtin)
if !isdispatchelem(ft) || has_free_typevars(ft) || (ft <: Builtin)
# TODO: this can be rather aggressive at preventing inlining of closures
# XXX: this is wrong for `ft <: Type`, since we are failing to check that
# the result doesn't have subtypes, or to do an intersection lookup
# but we need to check that `ft` can't have a subtype at runtime before using the supertype lookup below
return nothing
end
if !(isa(invoke_tt.parameters[1], Type) &&
invoke_tt.parameters[1] <: Tuple)
invoke_tt = widenconst(atypes[3])
if !isType(invoke_tt) || has_free_typevars(invoke_tt)
return nothing
end
invoke_tt = invoke_tt.parameters[1]
if !(isa(unwrap_unionall(invoke_tt), DataType) && invoke_tt <: Tuple)
return nothing
end
invoke_types = rewrap_unionall(Tuple{ft, unwrap_unionall(invoke_tt).parameters...}, invoke_tt)
min_valid = UInt[typemin(UInt)]
max_valid = UInt[typemax(UInt)]
invoke_entry = ccall(:jl_gf_invoke_lookup, Any, (Any, UInt),
invoke_types, params.world)
invoke_types, params.world) # XXX: min_valid, max_valid
invoke_entry === nothing && return nothing
#XXX: update_valid_age!(min_valid[1], max_valid[1], sv)
invoke_data = InvokeData(mt, invoke_entry, invoke_types)
invoke_data = InvokeData(invoke_entry, invoke_types, min_valid[1], max_valid[1])
atype0 = atypes[2]
atypes = atypes[4:end]
pushfirst!(atypes, atype0)
Expand Down
7 changes: 3 additions & 4 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1128,11 +1128,10 @@ end
add_tfunc(apply_type, 1, INT_INF, apply_type_tfunc, 10)

function invoke_tfunc(@nospecialize(ft), @nospecialize(types), @nospecialize(argtype), sv::InferenceState)
argument_mt(ft) === nothing && return Any
argtype = typeintersect(types, argtype)
if argtype === Bottom
return Bottom
end
argtype === Bottom && return Bottom
argtype isa DataType || return Any # other cases are not implemented below
isdispatchelem(ft) || return Any # check that we might not have a subtype of `ft` at runtime, before doing supertype lookup below
types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)
argtype = Tuple{ft, argtype.parameters...}
entry = ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), types, sv.params.world)
Expand Down
8 changes: 1 addition & 7 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -638,11 +638,6 @@ function fieldindex(T::DataType, name::Symbol, err::Bool=true)
end

argument_datatype(@nospecialize t) = ccall(:jl_argument_datatype, Any, (Any,), t)
function argument_mt(@nospecialize t)
dt = argument_datatype(t)
(dt === nothing || !isdefined(dt.name, :mt)) && return nothing
dt.name.mt
end

"""
fieldcount(t::Type)
Expand Down Expand Up @@ -1295,8 +1290,7 @@ function delete_method(m::Method)
end

function get_methodtable(m::Method)
ft = ccall(:jl_first_argument_datatype, Any, (Any,), m.sig)
(ft::DataType).name.mt
return ccall(:jl_method_table_for, Any, (Any,), m.sig)::Core.MethodTable
end

"""
Expand Down
2 changes: 0 additions & 2 deletions src/common_symbols1.inc
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,3 @@ jl_symbol("haskey"),
jl_symbol("setproperty!"),
jl_symbol("promote"),
jl_symbol("undef"),
jl_symbol("Vector"),
jl_symbol("parent"),
4 changes: 2 additions & 2 deletions src/common_symbols2.inc
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
jl_symbol("Vector"),
jl_symbol("parent"),
jl_symbol("_promote"),
jl_symbol("Ref"),
jl_symbol("push!"),
Expand Down Expand Up @@ -250,5 +252,3 @@ jl_symbol("checked_add"),
jl_symbol("mod"),
jl_symbol("unsafe_write"),
jl_symbol("libuv.jl"),
jl_symbol("Matrix"),
jl_symbol("a"),
39 changes: 21 additions & 18 deletions src/datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ JL_DLLEXPORT jl_methtable_t *jl_new_method_table(jl_sym_t *name, jl_module_t *mo
mt->backedges = NULL;
JL_MUTEX_INIT(&mt->writelock);
mt->offs = 1;
mt->frozen = 0;
return mt;
}

Expand Down Expand Up @@ -452,11 +453,8 @@ JL_DLLEXPORT jl_datatype_t *jl_new_datatype(
jl_typename_t *tn = NULL;
JL_GC_PUSH2(&t, &tn);

if (t == NULL)
t = jl_new_uninitialized_datatype();
else
tn = t->name;
// init before possibly calling jl_new_typename_in
// init enough before possibly calling jl_new_typename_in
t = jl_new_uninitialized_datatype();
t->super = super;
if (super != NULL) jl_gc_wb(t, t->super);
t->parameters = parameters;
Expand All @@ -471,23 +469,28 @@ JL_DLLEXPORT jl_datatype_t *jl_new_datatype(
t->ditype = NULL;
t->size = 0;

if (tn == NULL) {
t->name = NULL;
if (jl_is_typename(name)) {
tn = (jl_typename_t*)name;
t->name = NULL;
if (jl_is_typename(name)) {
// This code-path is used by the Serialization module to by-pass normal expectations
tn = (jl_typename_t*)name;
}
else {
tn = jl_new_typename_in((jl_sym_t*)name, module);
if (super == jl_function_type || super == jl_builtin_type || jl_symbol_name(name)[0] == '#') {
// Callable objects (including compiler-generated closures) get independent method tables
// as an optimization
tn->mt = jl_new_method_table(name, module);
jl_gc_wb(tn, tn->mt);
if (jl_svec_len(parameters) > 0)
tn->mt->offs = 0;
}
else {
tn = jl_new_typename_in((jl_sym_t*)name, module);
if (!abstract) {
tn->mt = jl_new_method_table(name, module);
jl_gc_wb(tn, tn->mt);
if (jl_svec_len(parameters) > 0)
tn->mt->offs = 0;
}
// Everything else, gets to use the unified table
tn->mt = jl_nonfunction_mt;
}
t->name = tn;
jl_gc_wb(t, t->name);
}
t->name = tn;
jl_gc_wb(t, t->name);
t->name->names = fnames;
jl_gc_wb(t->name, t->name->names);

Expand Down
Loading

0 comments on commit 99d2406

Please sign in to comment.