Skip to content

Commit

Permalink
Merge pull request #232 from timholy/teh/methodtrigs
Browse files Browse the repository at this point in the history
Organize inference triggers by Method
  • Loading branch information
timholy authored Jan 24, 2021
2 parents 48ee784 + fd3fcdb commit 15f689e
Show file tree
Hide file tree
Showing 2 changed files with 251 additions and 41 deletions.
226 changes: 187 additions & 39 deletions src/parcel_snoopi_deep.jl
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,14 @@ Return the MethodInstance `mi` of the caller in the selected stackframe in `itri
"""
callerinstance(itrig::InferenceTrigger) = itrig.callerframes[end].linfo

function callerinstances(itrigs::AbstractVector{InferenceTrigger})
callers = Set{MethodInstance}()
for itrig in itrigs
!isempty(itrig.callerframes) && push!(callers, callerinstance(itrig))
end
return callers
end

function callermodule(itrig::InferenceTrigger)
if !isempty(itrig.callerframes)
m = callerinstance(itrig).def
Expand Down Expand Up @@ -853,6 +861,26 @@ function hasparameter(@nospecialize(typ), @nospecialize(ft), exact::Bool)
return false
end

"""
ncallees, ncallers = diversity(itrigs::AbstractVector{InferenceTrigger})
Count the number of distinct MethodInstances among the callees and callers, respectively, among the triggers in `itrigs`.
"""
function diversity(itrigs)
# Analyze caller => callee argument type diversity
callees, callers, ncextra = Set{MethodInstance}(), Set{MethodInstance}(), 0
for itrig in itrigs
push!(callees, MethodInstance(itrig.node))
caller = itrig.callerframes[end].linfo
if isa(caller, MethodInstance)
push!(callers, caller)
else
ncextra += 1
end
end
return length(callees), length(callers) + ncextra
end

# Integrations
AbstractTrees.children(tinf::InferenceTimingNode) = tinf.children

Expand Down Expand Up @@ -936,6 +964,52 @@ InteractiveUtils.edit(node::TriggerNode) = edit(node.itrig)
Base.stacktrace(node::TriggerNode) = stacktrace(node.itrig)
Cthulhu.ascend(node::TriggerNode) = ascend(node.itrig)

### tagged trigger lists
# good for organizing a collection of related triggers

struct TaggedTriggers{TT}
tag::TT
itrigs::Vector{InferenceTrigger}
end

const MethodTriggers = TaggedTriggers{Method}

"""
mtrigs = accumulate_by_source(Method, itrigs::AbstractVector{InferenceTrigger})
Consolidate inference triggers via their caller method. `mtrigs` is a vector of `Method=>list`
pairs, where `list` is a list of `InferenceTrigger`s.
"""
function accumulate_by_source(::Type{Method}, itrigs::AbstractVector{InferenceTrigger})
cs = Dict{Method,Vector{InferenceTrigger}}()
for itrig in itrigs
isempty(itrig.callerframes) && continue
mi = callerinstance(itrig)
m = mi.def
if isa(m, Method)
list = get!(Vector{InferenceTrigger}, cs, m)
push!(list, itrig)
end
end
return sort!([MethodTriggers(m, list) for (m, list) in cs]; by=methtrig->length(methtrig.itrigs))
end

function Base.show(io::IO, methtrigs::MethodTriggers)
ncallees, ncallers = diversity(methtrigs.itrigs)
print(io, methtrigs.tag, " (", ncallees, " callees from ", ncallers, " callers)")
end

function parcel(mtrigs::AbstractVector{MethodTriggers})
bymod = Dict{Module,Vector{MethodTriggers}}()
for mtrig in mtrigs
m = mtrig.tag
modlist = get!(valtype(bymod), bymod, m.module)
push!(modlist, mtrig)
end
sort!(collect(bymod); by=pr->length(pr.second))
end

InteractiveUtils.edit(mtrigs::MethodTriggers) = edit(mtrigs.tag)

### inference trigger locations
# useful for analyzing patterns at the level of Methods rather than MethodInstances
Expand All @@ -954,37 +1028,16 @@ end
Base.show(io::IO, loc::Location) = print(io, loc.func, " at ", loc.file, ':', loc.line)
InteractiveUtils.edit(loc::Location) = edit(string(loc.file), loc.line)

struct LocationTrigger
loc::Location
itrigs::Vector{InferenceTrigger}
end

"""
ncallees, ncallers = diversity(loctrigs::LocationTriggers)
const LocationTriggers = TaggedTriggers{Location}

Count the number of distinct MethodInstances among the callees and callers, respectively, at a particular code location.
"""
function diversity(loctrigs::LocationTrigger)
# Analyze caller => callee argument type diversity
callees, callers, ncextra = Set{MethodInstance}(), Set{MethodInstance}(), 0
for itrig in loctrigs.itrigs
push!(callees, MethodInstance(itrig.node))
caller = itrig.callerframes[end].linfo
if isa(caller, MethodInstance)
push!(callers, caller)
else
ncextra += 1
end
end
return length(callees), length(callers) + ncextra
end
diversity(loctrigs::LocationTriggers) = diversity(loctrigs.itrigs)

function Base.show(io::IO, loctrigs::LocationTrigger)
function Base.show(io::IO, loctrigs::LocationTriggers)
ncallees, ncallers = diversity(loctrigs)
print(io, loctrigs.loc, " (", ncallees, " callees from ", ncallers, " callers)")
print(io, loctrigs.tag, " (", ncallees, " callees from ", ncallers, " callers)")
end

InteractiveUtils.edit(loctrig::LocationTrigger) = edit(loctrig.loc)
InteractiveUtils.edit(loctrig::LocationTriggers) = edit(loctrig.tag)

"""
loctrigs = accumulate_by_source(itrigs::AbstractVector{InferenceTrigger})
Expand All @@ -1002,7 +1055,7 @@ julia> itrigs = inference_triggers(SnoopCompile.itrigs_demo())
Inference triggered to call MethodInstance for double(::Float64) from calldouble1 (/pathto/SnoopCompile/src/parcel_snoopi_deep.jl:762) inlined into MethodInstance for calldouble2(::Vector{Vector{Any}}) (/pathto/SnoopCompile/src/parcel_snoopi_deep.jl:763)
julia> accumulate_by_source(itrigs)
1-element Vector{SnoopCompile.LocationTrigger}:
1-element Vector{SnoopCompile.LocationTriggers}:
calldouble1 at /pathto/SnoopCompile/src/parcel_snoopi_deep.jl:762 (2 callees from 1 callers)
```
"""
Expand All @@ -1013,7 +1066,7 @@ function accumulate_by_source(itrigs::AbstractVector{InferenceTrigger}; bycallee
itrigs_loc = get!(Vector{InferenceTrigger}, cs, lockey)
push!(itrigs_loc, itrig)
end
loctrigs = [LocationTrigger(lockey isa Location ? lockey : lockey[1], itrigs_loc) for (lockey, itrigs_loc) in cs]
loctrigs = [LocationTriggers(lockey isa Location ? lockey : lockey[1], itrigs_loc) for (lockey, itrigs_loc) in cs]
return sort!(loctrigs; by=loctrig->length(loctrig.itrigs))
end

Expand All @@ -1027,7 +1080,7 @@ function location_key(itrig::InferenceTrigger)
return loc, ft
end

filtermod(mod::Module, loctrigs::AbstractVector{LocationTrigger}) = filter(loctrigs) do loctrig
filtermod(mod::Module, loctrigs::AbstractVector{LocationTriggers}) = filter(loctrigs) do loctrig
any(==(mod) callermodule, loctrig.itrigs)
end

Expand Down Expand Up @@ -1094,8 +1147,7 @@ function show_suggest(io::IO, categories, rtcallee, sf)
showvahint = showannotate = false
handled = false
if HasCoreBox categories
printstyled(io, "has Core.Box"; color=:red)
print(io, " (fix this before tackling other problems, see https://timholy.github.io/SnoopCompile.jl/stable/snoopr/#Fixing-Core.Box)")
coreboxmsg(io)
return nothing
end
if categories == [FromTestDirect]
Expand Down Expand Up @@ -1221,6 +1273,11 @@ function show_suggest(io::IO, categories, rtcallee, sf)
# end
end

function coreboxmsg(io::IO)
printstyled(io, "has Core.Box"; color=:red)
print(io, " (fix this before tackling other problems, see https://timholy.github.io/SnoopCompile.jl/stable/snoopr/#Fixing-Core.Box)")
end

"""
isignorable(s::Suggested)
Expand Down Expand Up @@ -1311,13 +1368,8 @@ function suggest(itrig::InferenceTrigger)
maybec = false
for (ct::CodeInfo, _) in cts
# Check for Core.Box
for typlist in (ct.slottypes, ct.ssavaluetypes)
for typ in typlist
if hascorebox(typ)
push!(s.categories, HasCoreBox)
break
end
end
if hascorebox(ct)
push!(s.categories, HasCoreBox)
end
ltidxs = linetable_match(ct.linetable, itrig.callerframes[1])
stmtidxs = findall((ltidxs), ct.codelocs)
Expand Down Expand Up @@ -1449,6 +1501,16 @@ function getcalleef(@nospecialize(callee), ct)
end

function hascorebox(@nospecialize(typ))
if isa(typ, CodeInfo)
ct = typ
for typlist in (ct.slottypes, ct.ssavaluetypes)
for typ in typlist
if hascorebox(typ)
return true
end
end
end
end
typ = unwrapconst(typ)
isa(typ, Type) || return false
typ === Core.Box && return true
Expand All @@ -1463,6 +1525,92 @@ function hascorebox(@nospecialize(typ))
return false
end

function Base.summary(io::IO, mtrigs::MethodTriggers)
callers = callerinstances(mtrigs.itrigs)
m = mtrigs.tag
println(io, m, " had ", length(callers), " specializations")
hascb = false
for mi in callers
tt = Base.unwrap_unionall(mi.specTypes)::DataType
mlist = Base._methods_by_ftype(tt, -1, typemax(UInt))
if length(mlist) < 10
cts = Base.code_typed_by_type(tt; debuginfo=:source)
for (ct::CodeInfo, _) in cts
if hascorebox(ct)
hascb = true
print(io, mi, " ")
coreboxmsg(io)
println(io)
break
end
end
else
@warn "not checking $mi for Core.Box, too many methods"
end
hascb && break
end
loctrigs = accumulate_by_source(mtrigs.itrigs)
sort!(loctrigs; by=loctrig->loctrig.tag.line)
println(io, "Triggering calls:")
for loctrig in loctrigs
itrig = loctrig.itrigs[1]
ft = (Base.unwrap_unionall(MethodInstance(itrig.node).specTypes)::DataType).parameters[1]
loc = loctrig.tag
if loc.func == m.name
print(io, "Line ", loctrig.tag.line)
else
print(io, "Inlined ", loc)
end
println(io, ": calling ", ft2f(ft), " (", length(loctrig.itrigs), " instances)")
end
end
Base.summary(mtrigs::MethodTriggers) = summary(stdout, mtrigs)

struct ClosureF
ft
end
function Base.show(io::IO, cf::ClosureF)
lnns = [LineNumberNode(Int(m.line), m.file) for m in Base.MethodList(cf.ft.name.mt)]
print(io, "closure ", cf.ft, " at ")
if length(lnns) == 1
print(io, lnns[1])
else
sort!(lnns; by=lnn->(lnn.file, lnn.line))
# avoid the repr with #= =#
print(io, '[')
for (i, lnn) in enumerate(lnns)
print(io, lnn.file, ':', lnn.line)
i < length(lnns) && print(io, ", ")
end
print(io, ']')
end
end

function ft2f(@nospecialize(ft))
if isa(ft, DataType)
return ft <: Type ? #= Type{T} =# ft.parameters[1] :
isdefined(ft, :instance) ? #= Function =# ft.instance : #= closure =# ClosureF(ft)
end
error("unhandled: ", ft)
end

function Base.summary(io::IO, loctrig::LocationTriggers)
ncallees, ncallers = diversity(loctrig)
if ncallees > ncallers
callees = unique([Method(itrig.node) for itrig in loctrig.itrigs])
println(io, ncallees, " callees from ", ncallers, " callers, consider despecializing the callee(s):")
show(io, MIME("text/plain"), callees)
println(io, "\nor improving inferrability of the callers")
else
cats_callee_sfs = unique(first, [(suggest(itrig).categories, MethodInstance(itrig.node), itrig.callerframes) for itrig in loctrig.itrigs])
println(io, ncallees, " callees from ", ncallers, " callers, consider improving inference in the caller(s). Recommendations:")
for (catg, callee, sfs) in cats_callee_sfs
show_suggest(io, catg, callee, isempty(sfs) ? "<none>" : sfs[end])
end
end
end
Base.summary(loctrig::LocationTriggers) = summary(stdout, loctrig)

const SuggestNode = AbstractTrees.AnnotationNode{Union{Nothing,Suggested}}
SuggestNode(s::Union{Nothing,Suggested}) = SuggestNode(s, SuggestNode[])

Expand Down Expand Up @@ -1642,7 +1790,7 @@ function max_end_time(node::InferenceTimingNode, recursive::Bool=false, tmax=-on
end

for IO in (IOContext{Base.TTY}, IOContext{IOBuffer}, IOBuffer)
for T = (InferenceTimingNode, InferenceTrigger, Precompiles, MethodLoc, Location, LocationTrigger)
for T = (InferenceTimingNode, InferenceTrigger, Precompiles, MethodLoc, MethodTriggers, Location, LocationTriggers)
@assert precompile(show, (IO, T))
end
end
Expand Down
Loading

0 comments on commit 15f689e

Please sign in to comment.