diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 000bb1849edea..0a1badff94baf 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -852,13 +852,17 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse typ = typ::DataType # Partition defuses by field fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)] + all_forwarded = true for use in defuse.uses stmt = ir[SSAValue(use)] # == `getfield` call # We may have discovered above that this use is dead # after the getfield elim of immutables. In that case, # it would have been deleted. That's fine, just ignore # the use in that case. - stmt === nothing && continue + if stmt === nothing + all_forwarded = false + continue + end field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ) field === nothing && @goto skip push!(fielddefuse[field].uses, use) @@ -928,7 +932,12 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse end end preserve_uses === nothing && continue - push!(intermediaries, newidx) + if all_forwarded + # this means all ccall preserves have been replaced with forwarded loads + # so we can potentially eliminate the allocation, otherwise we must preserve + # the whole allocation. + push!(intermediaries, newidx) + end # Insert the new preserves for (use, new_preserves) in preserve_uses ir[SSAValue(use)] = form_new_preserves(ir[SSAValue(use)]::Expr, intermediaries, new_preserves) @@ -938,7 +947,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse end end -function form_new_preserves(origex::Expr, preserved::Vector{Int}, new_preserves::Vector{Any}) +function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any}) newex = Expr(:foreigncall) nccallargs = length(origex.args[3]::SimpleVector) for i in 1:(6+nccallargs-1) @@ -946,7 +955,8 @@ function form_new_preserves(origex::Expr, preserved::Vector{Int}, new_preserves: end for i in (6+nccallargs):length(origex.args) x = origex.args[i] - if isa(x, SSAValue) && x.id in preserved + # don't need to preserve intermediaries + if isa(x, SSAValue) && x.id in intermediates continue end push!(newex.args, x) diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index dbffa41edc7ae..05fb890f91848 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -672,3 +672,35 @@ let return Core.Compiler.widenconst(ft) !== typeof(typeassert) end end + +let + # Test for https://github.com/JuliaLang/julia/issues/43402 + # Ensure that structs required not used outside of the ccall, + # still get listed in the ccall_preserves + + src = @eval Module() begin + @inline function effectful() + s1 = Ref{Csize_t}() + s2 = Ref{Csize_t}() + ccall(:some_ccall, Cvoid, + (Ref{Csize_t},Ref{Csize_t}), + s1, s2) + return s1[], s2[] + end + + code_typed() do + s1, s2 = effectful() + return s1 + end |> only |> first + end + + refs = map(Core.SSAValue, findall(x->x isa Expr && x.head == :new, src.code)) + some_ccall = findfirst(x -> x isa Expr && x.head == :foreigncall && x.args[1] == :(:some_ccall), src.code) + @assert some_ccall !== nothing + stmt = src.code[some_ccall] + nccallargs = length(stmt.args[3]::Core.SimpleVector) + preserves = stmt.args[6+nccallargs:end] + @test length(refs) == 2 + @test length(preserves) == 2 + @test all(alloc -> alloc in preserves, refs) +end