Skip to content

Commit

Permalink
implement integration for exception type inference
Browse files Browse the repository at this point in the history
Integration for JuliaLang/julia#51754.
Statement-wise and call-wise information is available only after
`v"1.11.0-DEV.1127"`.
  • Loading branch information
aviatesk committed Dec 21, 2023
1 parent 82f678d commit 6ea929f
Show file tree
Hide file tree
Showing 12 changed files with 251 additions and 139 deletions.
117 changes: 77 additions & 40 deletions src/Cthulhu.jl

Large diffs are not rendered by default.

63 changes: 49 additions & 14 deletions src/callsite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@ struct MICallInfo <: CallInfo
mi::MethodInstance
rt
effects::Effects
function MICallInfo(mi::MethodInstance, @nospecialize(rt), effects)
exct
function MICallInfo(mi::MethodInstance, @nospecialize(rt), effects, @nospecialize(exct=nothing))
if isa(rt, LimitedAccuracy)
return LimitedCallInfo(new(mi, ignorelimited(rt), effects))
return LimitedCallInfo(new(mi, ignorelimited(rt), effects, exct))
else
return new(mi, rt, effects)
return new(mi, rt, effects, exct)
end
end
end
get_mi(ci::MICallInfo) = ci.mi
get_rt(ci::CallInfo) = ci.rt
get_rt(ci::MICallInfo) = ci.rt
get_effects(ci::MICallInfo) = ci.effects
get_exct(ci::MICallInfo) = ci.exct

abstract type WrappedCallInfo <: CallInfo end

Expand All @@ -27,6 +29,7 @@ ignorewrappers(ci::WrappedCallInfo) = ignorewrappers(get_wrapped(ci))
get_mi(ci::WrappedCallInfo) = get_mi(ignorewrappers(ci))
get_rt(ci::WrappedCallInfo) = get_rt(ignorewrappers(ci))
get_effects(ci::WrappedCallInfo) = get_effects(ignorewrappers(ci))
get_exct(ci::WrappedCallInfo) = get_exct(ignorewrappers(ci))

# only appears when inspecting pre-optimization states
struct LimitedCallInfo <: WrappedCallInfo
Expand All @@ -38,9 +41,12 @@ struct RTCallInfo <: CallInfo
f
argtyps
rt
exct
end
get_rt(ci::RTCallInfo) = ci.rt
get_mi(ci::RTCallInfo) = nothing
get_effects(ci::RTCallInfo) = Effects()
get_exct(ci::RTCallInfo) = ci.exct

# uncached callsite, we can't recurse into this call
struct UncachedCallInfo <: WrappedCallInfo
Expand All @@ -56,6 +62,7 @@ end
get_mi(::PureCallInfo) = nothing
get_rt(pci::PureCallInfo) = pci.rt
get_effects(::PureCallInfo) = EFFECTS_TOTAL
get_exct(::PureCallInfo) = Union{}

# Failed
struct FailedCallInfo <: CallInfo
Expand All @@ -64,7 +71,8 @@ struct FailedCallInfo <: CallInfo
end
get_mi(ci::FailedCallInfo) = fail(ci)
get_rt(ci::FailedCallInfo) = fail(ci)
get_effects(ci::FailedCallInfo) = Effects()
get_effects(ci::FailedCallInfo) = fail(ci)
get_exct(ci::FailedCallInfo) = fail(ci)
function fail(ci::FailedCallInfo)
@warn "MethodInstance extraction failed." ci.sig ci.rt
return nothing
Expand All @@ -77,7 +85,8 @@ struct GeneratedCallInfo <: CallInfo
end
get_mi(genci::GeneratedCallInfo) = fail(genci)
get_rt(genci::GeneratedCallInfo) = fail(genci)
get_effects(genci::GeneratedCallInfo) = Effects()
get_effects(genci::GeneratedCallInfo) = fail(genci)
get_exct(genci::GeneratedCallInfo) = fail(genci)
function fail(genci::GeneratedCallInfo)
@warn "Can't extract MethodInstance from call to generated functions." genci.sig genci.rt
return nothing
Expand All @@ -86,18 +95,24 @@ end
struct MultiCallInfo <: CallInfo
sig
rt
exct
callinfos::Vector{CallInfo}
MultiCallInfo(@nospecialize(sig), @nospecialize(rt), callinfos::Vector{CallInfo},
@nospecialize(exct=nothing)) =
new(sig, rt, exct, callinfos)
end
# actual code-error
get_mi(ci::MultiCallInfo) = error("Can't extract MethodInstance from multiple call informations")
get_rt(ci::MultiCallInfo) = ci.rt
get_effects(mci::MultiCallInfo) = mapreduce(get_effects, CC.merge_effects, mci.callinfos)
get_exct(ci::MultiCallInfo) = ci.exct

struct TaskCallInfo <: CallInfo
ci::CallInfo
end
get_mi(tci::TaskCallInfo) = get_mi(tci.ci)
get_rt(tci::TaskCallInfo) = get_rt(tci.ci)
get_effects(tci::TaskCallInfo) = get_effects(tci.ci)
get_exct(tci::TaskCallInfo) = get_exct(tci.ci)

struct InvokeCallInfo <: CallInfo
ci::CallInfo
Expand All @@ -106,6 +121,7 @@ end
get_mi(ici::InvokeCallInfo) = get_mi(ici.ci)
get_rt(ici::InvokeCallInfo) = get_rt(ici.ci)
get_effects(ici::InvokeCallInfo) = get_effects(ici.ci)
get_exct(ici::InvokeCallInfo) = get_exct(ici.ci)

# OpaqueClosure CallInfo
struct OCCallInfo <: CallInfo
Expand All @@ -115,6 +131,7 @@ end
get_mi(occi::OCCallInfo) = get_mi(occi.ci)
get_rt(occi::OCCallInfo) = get_rt(occi.ci)
get_effects(occi::OCCallInfo) = get_effects(occi.ci)
get_exct(occi::OCCallInfo) = get_exct(occi.ci)

# Special handling for ReturnTypeCall
struct ReturnTypeCallInfo <: CallInfo
Expand All @@ -123,6 +140,7 @@ end
get_mi((; vmi)::ReturnTypeCallInfo) = isa(vmi, FailedCallInfo) ? nothing : get_mi(vmi)
get_rt((; vmi)::ReturnTypeCallInfo) = Type{isa(vmi, FailedCallInfo) ? Union{} : widenconst(get_rt(vmi))}
get_effects(::ReturnTypeCallInfo) = EFFECTS_TOTAL
get_exct(::ReturnTypeCallInfo) = Union{} # FIXME

struct ConstPropCallInfo <: CallInfo
mi::CallInfo
Expand All @@ -131,6 +149,7 @@ end
get_mi(cpci::ConstPropCallInfo) = cpci.result.linfo
get_rt(cpci::ConstPropCallInfo) = get_rt(cpci.mi)
get_effects(cpci::ConstPropCallInfo) = get_effects(cpci.result)
get_exct(cpci::ConstPropCallInfo) = get_exct(cpci.mi)

struct ConcreteCallInfo <: CallInfo
mi::CallInfo
Expand All @@ -139,6 +158,7 @@ end
get_mi(ceci::ConcreteCallInfo) = get_mi(ceci.mi)
get_rt(ceci::ConcreteCallInfo) = get_rt(ceci.mi)
get_effects(ceci::ConcreteCallInfo) = get_effects(ceci.mi)
get_exct(cici::ConcreteCallInfo) = get_exct(ceci.mi)

struct SemiConcreteCallInfo <: CallInfo
mi::CallInfo
Expand All @@ -147,6 +167,7 @@ end
get_mi(scci::SemiConcreteCallInfo) = get_mi(scci.mi)
get_rt(scci::SemiConcreteCallInfo) = get_rt(scci.mi)
get_effects(scci::SemiConcreteCallInfo) = get_effects(scci.mi)
get_exct(scci::SemiConcreteCallInfo) = get_exct(scci.mi)

# CUDA callsite
struct CuCallInfo <: CallInfo
Expand Down Expand Up @@ -187,22 +208,22 @@ function headstring(@nospecialize(T))
end
end

function __show_limited(limiter, name, tt, @nospecialize(rt), effects)
function __show_limited(limiter, name, tt, @nospecialize(rt), effects, @nospecialize(exct=nothing))
vastring(@nospecialize(T)) = (isvarargtype(T) ? headstring(T)*"..." : string(T)::String)

# If effects are explicitly turned on, make sure to print them, even
# if there otherwise isn't space for them, since the effects are the
# most important piece of information if turned on.
with_effects = get(limiter, :with_effects, false)::Bool
exception_type = get(limiter, :exception_type, false)::Bool && exct !== nothing

if with_effects
limiter.width += textwidth(repr(effects)) + 1
end

with_effects && (limiter.width += textwidth(repr(effects)) + 1)
exception_type && (limiter.width += textwidth(string(exct)) + 1)
if !has_space(limiter, name)
print(limiter, '')
@goto print_effects
end

print(limiter, string(name))
pstrings = String[vastring(T) for T in tt]
headstrings = String[
Expand Down Expand Up @@ -234,15 +255,28 @@ function __show_limited(limiter, name, tt, @nospecialize(rt), effects)
print(limiter, "::…")
end

@label print_effects
@label print_effects
if with_effects
# Print effects unlimited
print(limiter.io, " ", effects)
end
if exception_type
print(limiter.io, ' ', ExctWrapper(exct))
end

return nothing
end

struct ExctWrapper
exct
ExctWrapper(@nospecialize exct) = new(exct)
end

function Base.show(io::IO, (;exct)::ExctWrapper)
color = exct === Union{} ? :green : :yellow
printstyled(io, "(↑::", exct, ")"; color)
end

function show_callinfo(limiter, mici::MICallInfo)
mi = mici.mi
tt = (Base.unwrap_unionall(mi.specTypes)::DataType).parameters[2:end]
Expand All @@ -252,7 +286,8 @@ function show_callinfo(limiter, mici::MICallInfo)
name = mi.def.name
end
rt = get_rt(mici)
__show_limited(limiter, name, tt, rt, get_effects(mici))
exct = get_exct(mici)
__show_limited(limiter, name, tt, rt, get_effects(mici), exct)
end

function show_callinfo(limiter, ci::Union{MultiCallInfo, FailedCallInfo, GeneratedCallInfo})
Expand Down
29 changes: 21 additions & 8 deletions src/codeview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ function cthulhu_warntype(io::IO, debuginfo::AnyDebugInfo,
if inline_cost
isa(mi, MethodInstance) || error("Need a MethodInstance to show inlining costs. Call `cthulhu_typed` directly instead.")
end
cthulhu_typed(io, debuginfo, src, rt, effects, mi; iswarn=true, optimize, hide_type_stable, inline_cost, interp)
cthulhu_typed(io, debuginfo, src, rt, nothing, effects, mi; iswarn=true, optimize, hide_type_stable, inline_cost, interp)
return nothing
end

Expand All @@ -121,9 +121,12 @@ end
cthulhu_typed(io::IO, debuginfo::DebugInfo, args...; kwargs...) =
cthulhu_typed(io, Symbol(debuginfo), args...; kwargs...)
function cthulhu_typed(io::IO, debuginfo::Symbol,
src::Union{CodeInfo,IRCode}, @nospecialize(rt), effects::Effects, mi::Union{Nothing,MethodInstance};
src::Union{CodeInfo,IRCode}, @nospecialize(rt), @nospecialize(exct),
effects::Effects, mi::Union{Nothing,MethodInstance};
iswarn::Bool=false, hide_type_stable::Bool=false, optimize::Bool=true,
pc2remarks::Union{Nothing,PC2Remarks}=nothing, pc2effects::Union{Nothing,PC2Effects}=nothing,
pc2remarks::Union{Nothing,PC2Remarks}=nothing,
pc2effects::Union{Nothing,PC2Effects}=nothing,
pc2excts::Union{Nothing,PC2Excts}=nothing,
inline_cost::Bool=false, type_annotations::Bool=true, annotate_source::Bool=false,
inlay_types_vscode::Bool=false, diagnostics_vscode::Bool=false, jump_always::Bool=false,
interp::AbstractInterpreter=CthulhuInterpreter())
Expand Down Expand Up @@ -248,18 +251,28 @@ function cthulhu_typed(io::IO, debuginfo::Symbol,
end
end
# postprinter configuration
__postprinter = if type_annotations
___postprinter = if type_annotations
iswarn ? InteractiveUtils.warntype_type_printer : IRShow.default_expr_type_printer
else
Returns(nothing)
end
_postprinter = if isa(src, CodeInfo) && !isnothing(pc2effects)
__postprinter = if isa(src, CodeInfo) && !isnothing(pc2effects)
function (io::IO; idx::Int, @nospecialize(kws...))
__postprinter(io; idx, kws...)
___postprinter(io; idx, kws...)
local effects = get(pc2effects, idx, nothing)
effects === nothing && return
print(io, ' ', effects)
end
else
___postprinter
end
_postprinter = if isa(src, CodeInfo) && !isnothing(pc2excts)
function (io::IO; idx::Int, @nospecialize(kws...))
__postprinter(io; idx, kws...)
local exct = get(pc2excts, idx, nothing)
exct === nothing && return
print(io, ' ', ExctWrapper(exct))
end
else
__postprinter
end
Expand Down Expand Up @@ -293,7 +306,7 @@ function cthulhu_typed(io::IO, debuginfo::Symbol,
cfg = src isa IRCode ? src.cfg : Core.Compiler.compute_basic_blocks(src.code)
max_bb_idx_size = length(string(length(cfg.blocks)))
str = irshow_config.line_info_preprinter(lambda_io, " "^(max_bb_idx_size + 2), -1)
callsite = Callsite(0, MICallInfo(mi, rettype, effects), :invoke)
callsite = Callsite(0, MICallInfo(mi, rettype, effects, exct), :invoke)
println(lambda_io, "", ""^(max_bb_idx_size), str, " ", callsite)
end

Expand Down Expand Up @@ -444,7 +457,7 @@ function Base.show(
return
end
println(io, "Cthulhu.Bookmark (world: ", world, ")")
cthulhu_typed(io, debuginfo, CI, rt, effects, b.mi; iswarn, optimize, hide_type_stable, b.interp)
cthulhu_typed(io, debuginfo, CI, rt, nothing, effects, b.mi; iswarn, optimize, hide_type_stable, b.interp)
end

function InteractiveUtils.code_typed(b::Bookmark; optimize::Bool=true)
Expand Down
13 changes: 7 additions & 6 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,14 @@ missing `$AbstractCursor` API:
""")
navigate(curs::CthulhuCursor, callsite::Callsite) = CthulhuCursor(get_mi(callsite))

get_remarks(::AbstractInterpreter, ::Union{MethodInstance,InferenceResult}) = nothing
get_remarks(interp::CthulhuInterpreter, key::Union{MethodInstance,InferenceResult}) = get(interp.remarks, key, nothing)
get_remarks(::AbstractInterpreter, ::SemiConcreteCallInfo) = PC2Remarks()
get_remarks(::AbstractInterpreter, ::InferenceKey) = nothing
get_remarks(interp::CthulhuInterpreter, key::InferenceKey) = get(interp.remarks, key, nothing)

get_effects(::AbstractInterpreter, ::Union{MethodInstance,InferenceResult}) = nothing
get_effects(interp::CthulhuInterpreter, key::Union{MethodInstance,InferenceResult}) = get(interp.effects, key, nothing)
get_effects(::AbstractInterpreter, ::SemiConcreteCallInfo) = PC2Effects()
get_effects(::AbstractInterpreter, ::InferenceKey) = nothing
get_effects(interp::CthulhuInterpreter, key::InferenceKey) = get(interp.effects, key, nothing)

get_excts(::AbstractInterpreter, ::InferenceKey) = nothing
get_excts(interp::CthulhuInterpreter, key::InferenceKey) = get(interp.exception_types, key, nothing)

# This method is optional, but should be implemented if there is
# a sensible default cursor for a MethodInstance
Expand Down
38 changes: 29 additions & 9 deletions src/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ struct InferredSource
stmt_info::Vector{CCCallInfo}
effects::Effects
rt::Any
InferredSource(src::CodeInfo, stmt_info::Vector{CCCallInfo}, effects, @nospecialize(rt)) =
new(src, stmt_info, effects, rt)
exct::Any
InferredSource(src::CodeInfo, stmt_info::Vector{CCCallInfo}, effects, @nospecialize(rt),
@nospecialize(exct)) =
new(src, stmt_info, effects, rt, exct)
end

struct OptimizedSource
Expand All @@ -20,26 +22,31 @@ struct OptimizedSource
effects::Effects
end

const InferenceKey = Union{MethodInstance,InferenceResult}
const InferenceDict{T} = Dict{InferenceKey, T}
const PC2Remarks = Vector{Pair{Int, String}}
const PC2Effects = Dict{Int, Effects}
const PC2Excts = Dict{Int, Any}

struct CthulhuInterpreter <: AbstractInterpreter
native::AbstractInterpreter

unopt::Dict{Union{MethodInstance,InferenceResult}, InferredSource}
unopt::InferenceDict{InferredSource}
opt::Dict{MethodInstance, CodeInstance}

remarks::Dict{Union{MethodInstance,InferenceResult}, PC2Remarks}
effects::Dict{Union{MethodInstance,InferenceResult}, PC2Effects}
remarks::InferenceDict{PC2Remarks}
effects::InferenceDict{PC2Effects}
exception_types::InferenceDict{PC2Excts}
end

function CthulhuInterpreter(interp::AbstractInterpreter=NativeInterpreter())
return CthulhuInterpreter(
interp,
Dict{Union{MethodInstance,InferenceResult}, InferredSource}(),
InferenceDict{InferredSource}(),
Dict{MethodInstance, CodeInstance}(),
Dict{Union{MethodInstance,InferenceResult}, PC2Remarks}(),
Dict{Union{MethodInstance,InferenceResult}, PC2Effects}())
InferenceDict{PC2Remarks}(),
InferenceDict{PC2Effects}(),
InferenceDict{PC2Excts}())
end

import .CC: InferenceParams, OptimizationParams, get_world_counter,
Expand Down Expand Up @@ -138,11 +145,13 @@ function InferredSource(state::InferenceState)
slottypes === nothing ? nothing : copy(slottypes)
end
end
exct = @static VERSION v"1.11.0-DEV.207" ? state.result.exc_result : nothing
return InferredSource(
unoptsrc,
copy(state.stmt_info),
isdefined(CC, :Effects) ? state.ipo_effects : nothing,
state.result.result)
state.result.result,
exct)
end

function CC.finish(state::InferenceState, interp::CthulhuInterpreter)
Expand Down Expand Up @@ -236,3 +245,14 @@ function CC.finish!(interp::CthulhuInterpreter, caller::InferenceResult)
caller.src = create_cthulhu_source(caller.src, caller.ipo_effects)
end
end

@static if VERSION v"1.11.0-DEV.1127"
function CC.update_exc_bestguess!(interp::CthulhuInterpreter, @nospecialize(exct),
frame::InferenceState)
key = CC.any(frame.result.overridden_by_const) ? frame.result : frame.linfo
pc2excts = get!(PC2Excts, interp.exception_types, key)
pc2excts[frame.currpc] = CC.tmerge(CC.typeinf_lattice(interp), exct, get(pc2excts, frame.currpc, Union{}))
return @invoke CC.update_exc_bestguess!(interp::AbstractInterpreter, exct::Any,
frame::InferenceState)
end
end
Loading

0 comments on commit 6ea929f

Please sign in to comment.