Skip to content

Commit

Permalink
AbstractInterpreter: enable partial pure/concrete eval for external…
Browse files Browse the repository at this point in the history
… `AbstractInterpreter` with overlayed method table

Built on top of #44511, and solves <JuliaGPU/GPUCompiler.jl#309>.
This commit allows external `AbstractInterpreter` to use pure/concrete
evals even if it uses an overlayed method table. More specifically, such
`AbstractInterpreter` can use pure/concrete evals as far as any matching
methods in question doesn't come from the overlayed method table:
```julia
@test Base.return_types((), MTOverlayInterp()) do
    isbitstype(Int) ? nothing : missing
end == Any[Nothing]
Base.@assume_effects :terminates_globally function issue41694(x)
    res = 1
    1 < x < 20 || throw("bad")
    while x > 1
        res *= x
        x -= 1
    end
    return res
end
@test Base.return_types((), MTOverlayInterp()) do
    issue41694(3) == 6 ? nothing : missing
end == Any[Nothing]
```
  • Loading branch information
aviatesk committed Mar 14, 2022
1 parent a98f719 commit 86bcb86
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 56 deletions.
46 changes: 30 additions & 16 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
any_const_result = false
const_results = Union{InferenceResult,Nothing,ConstResult}[]
multiple_matches = napplicable > 1
if matches.overlayed
# currently we don't have a good way to execute the overlayed method definition,
# so we should give up pure/concrete eval when any of the matched methods is overlayed
f = nothing
end

val = pure_eval_call(interp, f, applicable, arginfo, sv)
val !== nothing && return CallMeta(val, MethodResultPure(info)) # TODO: add some sort of edge(s)
Expand Down Expand Up @@ -102,7 +107,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
this_arginfo = ArgInfo(fargs, this_argtypes)
const_call_result = abstract_call_method_with_const_args(interp, result, f, this_arginfo, match, sv)
const_call_result = abstract_call_method_with_const_args(interp, result,
f, this_arginfo, match, sv)
effects = result.edge_effects
const_result = nothing
if const_call_result !== nothing
Expand Down Expand Up @@ -144,7 +150,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# this is in preparation for inlining, or improving the return result
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
this_arginfo = ArgInfo(fargs, this_argtypes)
const_call_result = abstract_call_method_with_const_args(interp, result, f, this_arginfo, match, sv)
const_call_result = abstract_call_method_with_const_args(interp, result,
f, this_arginfo, match, sv)
effects = result.edge_effects
const_result = nothing
if const_call_result !== nothing
Expand Down Expand Up @@ -228,6 +235,7 @@ struct MethodMatches
valid_worlds::WorldRange
mt::Core.MethodTable
fullmatch::Bool
overlayed::Bool
end
any_ambig(info::MethodMatchInfo) = info.results.ambig
any_ambig(m::MethodMatches) = any_ambig(m.info)
Expand All @@ -239,6 +247,7 @@ struct UnionSplitMethodMatches
valid_worlds::WorldRange
mts::Vector{Core.MethodTable}
fullmatches::Vector{Bool}
overlayed::Bool
end
any_ambig(m::UnionSplitMethodMatches) = _any(any_ambig, m.info.matches)

Expand All @@ -253,16 +262,19 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
valid_worlds = WorldRange()
mts = Core.MethodTable[]
fullmatches = Bool[]
overlayed = false
for i in 1:length(split_argtypes)
arg_n = split_argtypes[i]::Vector{Any}
sig_n = argtypes_to_type(arg_n)
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
mt === nothing && return FailedMethodMatch("Could not identify method table for call")
mt = mt::Core.MethodTable
matches = findall(sig_n, method_table; limit = max_methods)
if matches === missing
result = findall(sig_n, method_table; limit = max_methods)
if result === missing
return FailedMethodMatch("For one of the union split cases, too many methods matched")
end
matches, overlayedᵢ = result
overlayed |= overlayedᵢ
push!(infos, MethodMatchInfo(matches))
for m in matches
push!(applicable, m)
Expand All @@ -288,25 +300,28 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
UnionSplitInfo(infos),
valid_worlds,
mts,
fullmatches)
fullmatches,
overlayed)
else
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
if mt === nothing
return FailedMethodMatch("Could not identify method table for call")
end
mt = mt::Core.MethodTable
matches = findall(atype, method_table; limit = max_methods)
if matches === missing
result = findall(atype, method_table; limit = max_methods)
if result === missing
# this means too many methods matched
# (assume this will always be true, so we don't compute / update valid age in this case)
return FailedMethodMatch("Too many methods matched")
end
matches, overlayed = result
fullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
return MethodMatches(matches.matches,
MethodMatchInfo(matches),
matches.valid_worlds,
mt,
fullmatch)
fullmatch,
overlayed)
end
end

Expand Down Expand Up @@ -659,8 +674,7 @@ end

function pure_eval_eligible(interp::AbstractInterpreter,
@nospecialize(f), applicable::Vector{Any}, arginfo::ArgInfo, sv::InferenceState)
return !isoverlayed(method_table(interp)) &&
f !== nothing &&
return f !== nothing &&
length(applicable) == 1 &&
is_method_pure(applicable[1]::MethodMatch) &&
is_all_const_arg(arginfo)
Expand Down Expand Up @@ -696,8 +710,7 @@ end

function concrete_eval_eligible(interp::AbstractInterpreter,
@nospecialize(f), result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState)
return !isoverlayed(method_table(interp)) &&
f !== nothing &&
return f !== nothing &&
result.edge !== nothing &&
is_total_or_error(result.edge_effects) &&
is_all_const_arg(arginfo)
Expand Down Expand Up @@ -1496,7 +1509,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)::Type
nargtype = Tuple{ft, nargtype.parameters...}
argtype = Tuple{ft, argtype.parameters...}
match, valid_worlds = findsup(types, method_table(interp))
match, valid_worlds, overlayed = findsup(types, method_table(interp))
match === nothing && return CallMeta(Any, false)
update_valid_age!(sv, valid_worlds)
method = match.method
Expand All @@ -1514,7 +1527,8 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
# t, a = ti.parameters[i], argtypes′[i]
# argtypes′[i] = t ⊑ a ? t : a
# end
const_call_result = abstract_call_method_with_const_args(interp, result, singleton_type(ft′), arginfo, match, sv)
const_call_result = abstract_call_method_with_const_args(interp, result,
overlayed ? nothing : singleton_type(ft′), arginfo, match, sv)
const_result = nothing
if const_call_result !== nothing
if const_call_result.rt rt
Expand Down Expand Up @@ -1662,8 +1676,8 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::Part
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
const_result = nothing
if !result.edgecycle
const_call_result = abstract_call_method_with_const_args(interp, result, nothing,
arginfo, match, sv)
const_call_result = abstract_call_method_with_const_args(interp, result,
nothing, arginfo, match, sv)
if const_call_result !== nothing
if const_call_result.rt rt
(; rt, const_result) = const_call_result
Expand Down
64 changes: 35 additions & 29 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,18 @@ end
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch

"""
findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) -> MethodLookupResult or missing
findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) ->
(matches::MethodLookupResult, overlayed::Bool) or missing
Find all methods in the given method table `view` that are applicable to the
given signature `sig`. If no applicable methods are found, an empty result is
returned. If the number of applicable methods exceeded the specified limit,
`missing` is returned.
Find all methods in the given method table `view` that are applicable to the given signature `sig`.
If no applicable methods are found, an empty result is returned.
If the number of applicable methods exceeded the specified limit, `missing` is returned.
`overlayed` indicates if any matching method is defined in an overlayed method table.
"""
function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=Int(typemax(Int32)))
return _findall(sig, nothing, table.world, limit)
result = _findall(sig, nothing, table.world, limit)
result === missing && return missing
return result, false
end

function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=Int(typemax(Int32)))
Expand All @@ -57,7 +60,7 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
nr = length(result)
if nr 1 && result[nr].fully_covers
# no need to fall back to the internal method table
return result
return result, true
end
# fall back to the internal method table
fallback_result = _findall(sig, nothing, table.world, limit)
Expand All @@ -68,7 +71,7 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
WorldRange(
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
result.ambig | fallback_result.ambig)
result.ambig | fallback_result.ambig), !isempty(result)
end

function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int)
Expand All @@ -83,31 +86,38 @@ function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable},
end

"""
findsup(sig::Type, view::MethodTableView) -> Tuple{MethodMatch, WorldRange} or nothing
Find the (unique) method `m` such that `sig <: m.sig`, while being more
specific than any other method with the same property. In other words, find
the method which is the least upper bound (supremum) under the specificity/subtype
relation of the queried `signature`. If `sig` is concrete, this is equivalent to
asking for the method that will be called given arguments whose types match the
given signature. This query is also used to implement `invoke`.
Such a method `m` need not exist. It is possible that no method is an
upper bound of `sig`, or it is possible that among the upper bounds, there
is no least element. In both cases `nothing` is returned.
findsup(sig::Type, view::MethodTableView) ->
(match::MethodMatch, valid_worlds::WorldRange, overlayed::Bool) or nothing
Find the (unique) method such that `sig <: match.method.sig`, while being more
specific than any other method with the same property. In other words, find the method
which is the least upper bound (supremum) under the specificity/subtype relation of
the queried `sig`nature. If `sig` is concrete, this is equivalent to asking for the method
that will be called given arguments whose types match the given signature.
Note that this query is also used to implement `invoke`.
Such a matching method `match` doesn't necessarily exist.
It is possible that no method is an upper bound of `sig`, or
it is possible that among the upper bounds, there is no least element.
In both cases `nothing` is returned.
`overlayed` indicates if the matching method is defined in an overlayed method table.
"""
function findsup(@nospecialize(sig::Type), table::InternalMethodTable)
return _findsup(sig, nothing, table.world)
return (_findsup(sig, nothing, table.world)..., false)
end

function findsup(@nospecialize(sig::Type), table::OverlayMethodTable)
match, valid_worlds = _findsup(sig, table.mt, table.world)
match !== nothing && return match, valid_worlds
match !== nothing && return match, valid_worlds, true
# fall back to the internal method table
fallback_match, fallback_valid_worlds = _findsup(sig, nothing, table.world)
return fallback_match, WorldRange(
max(valid_worlds.min_world, fallback_valid_worlds.min_world),
min(valid_worlds.max_world, fallback_valid_worlds.max_world))
return (
fallback_match,
WorldRange(
max(valid_worlds.min_world, fallback_valid_worlds.min_world),
min(valid_worlds.max_world, fallback_valid_worlds.max_world)),
false)
end

function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt)
Expand All @@ -118,7 +128,3 @@ function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable},
valid_worlds = WorldRange(min_valid[], max_valid[])
return match, valid_worlds
end

isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface")
isoverlayed(::InternalMethodTable) = false
isoverlayed(::OverlayMethodTable) = true
40 changes: 29 additions & 11 deletions test/compiler/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,43 @@ import Base.Experimental: @MethodTable, @overlay
@MethodTable(OverlayedMT)
CC.method_table(interp::MTOverlayInterp) = CC.OverlayMethodTable(CC.get_world_counter(interp), OverlayedMT)

@overlay OverlayedMT sin(x::Float64) = 1
@test Base.return_types((Int,), MTOverlayInterp()) do x
sin(x)
end == Any[Int]
@test Base.return_types((Any,), MTOverlayInterp()) do x
Base.@invoke sin(x::Float64)
end == Any[Int]
strangesin(x) = sin(x)
@overlay OverlayedMT strangesin(x::Float64) = iszero(x) ? nothing : cos(x)
@test Base.return_types((Float64,); MTOverlayInterp()) do x
strangesin(x)
end |> only === Union{Float64,Nothing}
@test Base.return_types((Any,); MTOverlayInterp()) do x
Base.@invoke strangesin(x::Float64)
end |> only === Union{Float64,Nothing}

# fallback to the internal method table
@test Base.return_types((Int,), MTOverlayInterp()) do x
cos(x)
end == Any[Float64]
@test Base.return_types((Any,), MTOverlayInterp()) do x
end |> only === Float64
@test Base.return_types((Any,); MTOverlayInterp()) do x
Base.@invoke cos(x::Float64)
end == Any[Float64]
end |> only === Float64

# not fully covered overlay method match
overlay_match(::Any) = nothing
@overlay OverlayedMT overlay_match(::Int) = missing
@test Base.return_types((Any,), MTOverlayInterp()) do x
overlay_match(x)
end == Any[Union{Nothing,Missing}]
end |> only === Union{Nothing,Missing}

# partial pure/concrete evaluation
@test Base.return_types((), MTOverlayInterp()) do
isbitstype(Int) ? nothing : missing
end |> only === Nothing
Base.@assume_effects :terminates_globally function issue41694(x)
res = 1
1 < x < 20 || throw("bad")
while x > 1
res *= x
x -= 1
end
return res
end
@test Base.return_types((), MTOverlayInterp()) do
issue41694(3) == 6 ? nothing : missing
end |> only === Nothing

0 comments on commit 86bcb86

Please sign in to comment.