Skip to content

Commit

Permalink
Adjust getsrc to return MethodInstance as well
Browse files Browse the repository at this point in the history
I'm not 100% sure about the lookup here, but this at least seems to get
the nightly tests back into the state they were in before the upstream
CodeInstance change.
  • Loading branch information
topolarity committed Feb 21, 2024
1 parent b4d0505 commit 9940f02
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 29 deletions.
19 changes: 13 additions & 6 deletions TypedSyntax/src/node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ const no_default_value = NoDefaultValue()
# where `mappings[i]` corresponds to the list of nodes matching `(src::CodeInfo).code[i]`.
function tsn_and_mappings(@nospecialize(f), @nospecialize(t); kwargs...)
m = which(f, t)
src, rt = getsrc(f, t)
tsn_and_mappings(m, src, rt; kwargs...)
src, rt, mi = getsrc(f, t)
tsn_and_mappings(mi, src, rt; kwargs...)
end

function tsn_and_mappings(mi::MethodInstance, src::CodeInfo, @nospecialize(rt); warn::Bool=true, strip_macros::Bool=false, kwargs...)
Expand Down Expand Up @@ -65,8 +65,8 @@ function TypedSyntaxNode(mi::MethodInstance; kwargs...)
tsn_and_mappings(mi, src, rt; kwargs...)[1]
end

TypedSyntaxNode(rootnode::SyntaxNode, src::CodeInfo, Δline::Integer=0) =
TypedSyntaxNode(rootnode, src, map_ssas_to_source(src, rootnode, Δline)...)
TypedSyntaxNode(rootnode::SyntaxNode, src::CodeInfo, mi::MethodInstance, Δline::Integer=0) =
TypedSyntaxNode(rootnode, src, map_ssas_to_source(src, mi, rootnode, Δline)...)

function TypedSyntaxNode(rootnode::SyntaxNode, src::CodeInfo, mappings, symtyps)
# There may be ambiguous assignments back to the source; preserve just the unambiguous ones
Expand Down Expand Up @@ -307,15 +307,22 @@ end

function getsrc(@nospecialize(f), @nospecialize(t))
srcrts = code_typed(f, t; debuginfo=:source, optimize=false)
return only(srcrts)
src, rt = only(srcrts)
if hasfield(typeof(src), :parent)
mi = src.parent
else
mi = Base.method_instance(f, t)
end
return src, rt, mi
end

function getsrc(mi::MethodInstance)
cis = Base.code_typed_by_type(mi.specTypes; debuginfo=:source, optimize=false)
isempty(cis) && error("no applicable type-inferred code found for ", mi)
length(cis) == 1 || error("got $(length(cis)) possible type-inferred results for ", mi,
", you may need a more specialized signature")
return cis[1]::Pair{CodeInfo}
src::CodeInfo, rt = cis[1]
return src, rt
end

function is_function_def(node) # this is not `Base.is_function_def`
Expand Down
45 changes: 22 additions & 23 deletions TypedSyntax/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ include("test_module.jl")
"""
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN1.jl")
TSN.eval(Expr(rootnode))
src, _ = getsrc(TSN.f, (Float32, Int, Float64))
tsn = TypedSyntaxNode(rootnode, src)
src, _, mi = getsrc(TSN.f, (Float32, Int, Float64))
tsn = TypedSyntaxNode(rootnode, src, mi)
sig, body = children(tsn)
@test children(sig)[2].typ === Float32
@test children(sig)[3].typ === Int
Expand All @@ -33,8 +33,8 @@ include("test_module.jl")
"""
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN2.jl")
TSN.eval(Expr(rootnode))
src, _ = getsrc(TSN.g, (Int16, Int16, Int32))
tsn = TypedSyntaxNode(rootnode, src)
src, _, mi = getsrc(TSN.g, (Int16, Int16, Int32))
tsn = TypedSyntaxNode(rootnode, src, mi)
sig, body = children(tsn)
@test length(children(sig)) == 4
@test children(body)[2].typ === Int32
Expand All @@ -46,8 +46,8 @@ include("test_module.jl")
st = "math(x) = x + sin(x + π / 4)"
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN2.jl")
TSN.eval(Expr(rootnode))
src, _ = getsrc(TSN.math, (Int,))
tsn = TypedSyntaxNode(rootnode, src)
src, _, mi = getsrc(TSN.math, (Int,))
tsn = TypedSyntaxNode(rootnode, src, mi)
sig, body = children(tsn)
@test has_name_typ(child(body, 1), :x, Int)
@test has_name_typ(child(body, 3, 2, 1), :x, Int)
Expand All @@ -70,8 +70,8 @@ include("test_module.jl")
st = "math2(x) = sin(x) + sin(x)"
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN2.jl")
TSN.eval(Expr(rootnode))
src, _ = getsrc(TSN.math2, (Int,))
tsn = TypedSyntaxNode(rootnode, src)
src, _, mi = getsrc(TSN.math2, (Int,))
tsn = TypedSyntaxNode(rootnode, src, mi)
sig, body = children(tsn)
@test body.typ === Float64
@test_broken child(body, 1).typ === Float64
Expand All @@ -91,8 +91,8 @@ include("test_module.jl")
)
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN3.jl")
TSN.eval(Expr(rootnode))
src, _ = getsrc(TSN.firstfirst, (Vector{Vector{Real}},))
tsn = TypedSyntaxNode(rootnode, src)
src, _, mi = getsrc(TSN.firstfirst, (Vector{Vector{Real}},))
tsn = TypedSyntaxNode(rootnode, src, mi)
sig, body = children(tsn)
@test child(body, idxsinner...).typ === nothing
@test child(body, idxsouter...).typ === Vector{Real}
Expand Down Expand Up @@ -150,8 +150,8 @@ include("test_module.jl")
"""
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN4.jl")
TSN.eval(Expr(rootnode))
src, rt = getsrc(TSN.setlist!, (Vector{Vector{Float32}}, Vector{Vector{UInt8}}, Int, Int))
tsn = TypedSyntaxNode(rootnode, src)
src, rt, mi = getsrc(TSN.setlist!, (Vector{Vector{Float32}}, Vector{Vector{UInt8}}, Int, Int))
tsn = TypedSyntaxNode(rootnode, src, mi)
sig, body = children(tsn)
nodelist = child(body, 1, 2, 1, 1) # `listget`
@test sourcetext(nodelist) == "listget" && nodelist.typ === Vector{Vector{UInt8}}
Expand All @@ -175,8 +175,8 @@ include("test_module.jl")
"""
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN5.jl")
TSN.eval(Expr(rootnode))
src, rt = getsrc(TSN.callfindmin, (Vector{Float64},))
tsn = TypedSyntaxNode(rootnode, src)
src, rt, mi = getsrc(TSN.callfindmin, (Vector{Float64},))
tsn = TypedSyntaxNode(rootnode, src, mi)
sig, body = children(tsn)
t = child(body, 1, 1)
@test kind(t) == K"tuple"
Expand Down Expand Up @@ -280,18 +280,18 @@ include("test_module.jl")
"""
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN6.jl")
TSN.eval(Expr(rootnode))
src, rt = getsrc(TSN.avoidzero, (Int,))
src, rt, mi = getsrc(TSN.avoidzero, (Int,))
# src looks like this:
# %1 = Main.TSN.:(var"#avoidzero#6")(true, #self#, x)::Float64
# return %1
# Consequently there is nothing to match, but at least we shouldn't error
tsn = TypedSyntaxNode(rootnode, src)
tsn = TypedSyntaxNode(rootnode, src, mi)
@test isa(tsn, TypedSyntaxNode)
@test rt === Float64
# Try the kwbodyfunc
m = which(TSN.avoidzero, (Int,))
src, rt = getsrc(Base.bodyfunction(m), (Bool, typeof(TSN.avoidzero), Int,))
tsn = TypedSyntaxNode(rootnode, src)
src, rt, mi = getsrc(Base.bodyfunction(m), (Bool, typeof(TSN.avoidzero), Int,))
tsn = TypedSyntaxNode(rootnode, src, mi)
sig, body = children(tsn)
isz = child(body, 2, 1, 1)
@test kind(isz) == K"call" && child(isz, 1).val == :iszero
Expand Down Expand Up @@ -520,8 +520,8 @@ include("test_module.jl")
@test_broken body.typ == Int

# Construction from MethodInstance
src, rt = TypedSyntax.getsrc(TSN.myoftype, (Float64, Int))
tsn = TypedSyntaxNode(src.parent)
src, rt, mi = TypedSyntax.getsrc(TSN.myoftype, (Float64, Int))
tsn = TypedSyntaxNode(mi)
sig, body = children(tsn)
node = child(body, 1)
@test node.typ === Type{Float64}
Expand Down Expand Up @@ -643,10 +643,9 @@ include("test_module.jl")
@test isa(tsnc, TypedSyntaxNode)

# issue 487
m = which(TSN.f487, (Int,))
src, rt = getsrc(TSN.f487, (Int,))
src, rt, mi = getsrc(TSN.f487, (Int,))
rt = Core.Const(1)
tsn, _ = TypedSyntax.tsn_and_mappings(m, src, rt)
tsn, _ = TypedSyntax.tsn_and_mappings(mi, src, rt)
@test_nowarn str = sprint(tsn; context=:color=>false) do io, obj
printstyled(io, obj; hide_type_stable=false)
end
Expand Down

0 comments on commit 9940f02

Please sign in to comment.