Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SROA miscompile in large functions #46819

Merged
merged 2 commits into from
Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 6 additions & 24 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1338,27 +1338,6 @@ function info_effects(@nospecialize(result), match::MethodMatch, state::Inlining
end
end

function compute_joint_effects(info::Union{ConstCallInfo, Vector{MethodMatchInfo}}, state::InliningState)
if isa(info, ConstCallInfo)
(; call, results) = info
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
else
results = nothing
infos = info
end
local all_result_count = 0
local joint_effects::Effects = EFFECTS_TOTAL
for i in 1:length(infos)
meth = infos[i].results
for (j, match) in enumerate(meth)
all_result_count += 1
result = results === nothing ? nothing : results[all_result_count]
joint_effects = merge_effects(joint_effects, info_effects(result, match, state))
end
end
return joint_effects
end

function compute_inlining_cases(info::Union{ConstCallInfo, Vector{MethodMatchInfo}},
flag::UInt8, sig::Signature, state::InliningState)
argtypes = sig.argtypes
Expand All @@ -1376,6 +1355,8 @@ function compute_inlining_cases(info::Union{ConstCallInfo, Vector{MethodMatchInf
local only_method = nothing
local meth::MethodLookupResult
local all_result_count = 0
local joint_effects::Effects = EFFECTS_TOTAL
local nothrow::Bool = true
for i in 1:length(infos)
meth = infos[i].results
if meth.ambig
Expand All @@ -1400,6 +1381,8 @@ function compute_inlining_cases(info::Union{ConstCallInfo, Vector{MethodMatchInf
for (j, match) in enumerate(meth)
all_result_count += 1
result = results === nothing ? nothing : results[all_result_count]
joint_effects = merge_effects(joint_effects, info_effects(result, match, state))
nothrow &= match.fully_covers
any_fully_covered |= match.fully_covers
if !validate_sparams(match.sparams)
if !match.fully_covers
Expand All @@ -1418,6 +1401,8 @@ function compute_inlining_cases(info::Union{ConstCallInfo, Vector{MethodMatchInf
end
end

joint_effects = Effects(joint_effects; nothrow)

if handled_all_cases && revisit_idx !== nothing
# we handled everything except one match with unmatched sparams,
# so try to handle it by bypassing validate_sparams
Expand Down Expand Up @@ -1447,9 +1432,6 @@ function compute_inlining_cases(info::Union{ConstCallInfo, Vector{MethodMatchInf
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
end

# TODO fuse `compute_joint_effects` into the loop above, which currently causes compilation error
joint_effects = Effects(compute_joint_effects(info, state); nothrow=handled_all_cases)

return cases, (handled_all_cases & any_fully_covered), joint_effects
end

Expand Down
14 changes: 11 additions & 3 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,18 @@ function cfg_delete_edge!(cfg::CFG, from::Int, to::Int)
nothing
end

function bb_ordering()
lt=(<=)
by=x->first(x.stmts)
ord(lt, by, nothing, Forward)
end
Comment on lines +31 to +35
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about making this constant?

Suggested change
function bb_ordering()
lt=(<=)
by=x->first(x.stmts)
ord(lt, by, nothing, Forward)
end
const BB_ORDERING = ord((<=), x->first(x.stmts), nothing, Forward)


function block_for_inst(index::Vector{Int}, inst::Int)
return searchsortedfirst(index, inst, lt=(<=))
end

function block_for_inst(index::Vector{BasicBlock}, inst::Int)
return searchsortedfirst(index, BasicBlock(StmtRange(inst, inst)), by=x->first(x.stmts), lt=(<=))-1
return searchsortedfirst(index, BasicBlock(StmtRange(inst, inst)), bb_ordering())-1
end

block_for_inst(cfg::CFG, inst::Int) = block_for_inst(cfg.index, inst)
Expand Down Expand Up @@ -674,7 +680,8 @@ end
function block_for_inst(compact::IncrementalCompact, idx::SSAValue)
id = idx.id
if id < compact.result_idx # if ssa within result
return block_for_inst(compact.result_bbs, id)
return searchsortedfirst(compact.result_bbs, BasicBlock(StmtRange(id, id)),
1, compact.active_result_bb, bb_ordering())-1
else
return block_for_inst(compact.ir.cfg, id)
end
Expand All @@ -683,7 +690,8 @@ end
function block_for_inst(compact::IncrementalCompact, idx::OldSSAValue)
id = idx.id
if id < compact.idx # if ssa within result
return block_for_inst(compact.result_bbs, compact.ssa_rename[id])
id = compact.ssa_rename[id]
return block_for_inst(compact, SSAValue(id))
else
return block_for_inst(compact.ir.cfg, id)
end
Expand Down