Skip to content

Commit

Permalink
Stop unwrapping types while mapping (#585)
Browse files Browse the repository at this point in the history
This represents a small change of philosophy: formerly, when mapping types to SyntaxNodes, we would unwrap compiler annotations like `Core.Const`. Now, we use the type exactly as assigned by inference, and handle unwrapping during `show`.
  • Loading branch information
timholy authored Aug 8, 2024
1 parent 602c849 commit db6c69b
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 54 deletions.
31 changes: 16 additions & 15 deletions TypedSyntax/src/node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,12 @@ function map_signature!(sig::TypedSyntaxNode, slotnames::Vector{Symbol}, slottyp
kwdivider = 1
if havekws && slotnames[1] !== Symbol("#self#")
kwdivider = findfirst(1:length(slotnames)) do i
slotnames[i] == Symbol("") && unwrapinternal(slottypes[i]) <: Function # this should be the parent function as an argument
slotnames[i] == Symbol("") && isa(unwrapinternal(slottypes[i]), Function) # this should be the parent function as an argument
end
if kwdivider === nothing
kwdivider = 1
end
if length(slottypes) >= 2 && slotnames[2] == Symbol("") && (nt = unwrapinternal(slottypes[2])) <: NamedTuple
if length(slottypes) >= 2 && slotnames[2] == Symbol("") && (nt = unwrapinternal(slottypes[2]); isa(nt, Type)) && nt <: NamedTuple
# Match kwargs
argcontainer = children(last(children(sig)))
offset = length(children(sig)) - 1
Expand Down Expand Up @@ -244,7 +244,7 @@ function map_signature!(sig::TypedSyntaxNode, slotnames::Vector{Symbol}, slottyp
if kind(arg) == K"::" && length(children(arg)) == 2
arg = child(arg, 1)
end
arg.typ = unwrapinternal(slottypes[idx])
arg.typ = slottypes[idx]
end

# It's annoying to print the signature as `foo::typeof(foo)(a::Int)`
Expand Down Expand Up @@ -276,7 +276,7 @@ function striparg(arg)
end

function unwrapinternal(@nospecialize(T))
isa(T, Core.Const) && return Core.Typeof(T.val)
isa(T, Core.Const) && return T.val
isa(T, Core.PartialStruct) && return T.typ
return T
end
Expand All @@ -287,10 +287,10 @@ function gettyp(node2ssa, node, src)
ssavaluetypes = src.ssavaluetypes::Vector{Any}
if isa(stmt, Core.ReturnNode)
arg = stmt.val
isa(arg, SSAValue) && return unwrapinternal(ssavaluetypes[arg.id])
is_slot(arg) && return unwrapinternal((src.slottypes::Vector{Any})[arg.id])
isa(arg, SSAValue) && return ssavaluetypes[arg.id]
is_slot(arg) && return (src.slottypes::Vector{Any})[arg.id]
end
return unwrapinternal(ssavaluetypes[i])
return ssavaluetypes[i]
end

Base.copy(tsd::TypedSyntaxData) = TypedSyntaxData(tsd.source, tsd.typedsource, tsd.raw, tsd.position, tsd.val, tsd.typ, tsd.runtime)
Expand Down Expand Up @@ -608,7 +608,7 @@ function map_ssas_to_source(src::CodeInfo, mi::MethodInstance, rootnode::SyntaxN
argmapping = typeof(rootnode)[] # temporary storage
for (i, mapped, stmt) in zip(eachindex(mappings), mappings, src.code)
empty!(argmapping)
if is_slot(stmt) || isa(stmt, SSAValue)
if is_slot(stmt) || isa(stmt, SSAValue) || isa(stmt, GlobalRef)
append_targets_for_arg!(mapped, i, stmt)
elseif isa(stmt, Core.ReturnNode)
append_targets_for_line!(mapped, i, append_targets_for_arg!(argmapping, i, stmt.val))
Expand All @@ -626,16 +626,14 @@ function map_ssas_to_source(src::CodeInfo, mi::MethodInstance, rootnode::SyntaxN
append_targets_for_arg!(mapped, i, stmt)
filter_assignment_targets!(mapped, true) # match the RHS of assignments
if length(mapped) == 1
symtyps[only(mapped)] = unwrapinternal(
(is_slot(stmt) & have_slottypes) ? slottypes[(stmt::SlotType).id] :
symtyps[only(mapped)] = (is_slot(stmt) & have_slottypes) ? slottypes[(stmt::SlotType).id] :
isa(stmt, SSAValue) ? ssavaluetypes[stmt.id] : #=literal=#typeof(stmt)
)
end
# Now try to assign types to the LHS of the assignment
append_targets_for_arg!(argmapping, i, lhs)
filter_assignment_targets!(argmapping, false) # match the LHS of assignments
if length(argmapping) == 1
T = unwrapinternal(ssavaluetypes[i])
T = ssavaluetypes[i]
symtyps[only(argmapping)] = T
end
empty!(argmapping)
Expand Down Expand Up @@ -738,7 +736,7 @@ function map_ssas_to_source(src::CodeInfo, mi::MethodInstance, rootnode::SyntaxN
if isexpr(nextstmt, :call)
f = nextstmt.args[1]
if isa(f, GlobalRef) && f.mod == Base && f.name == :broadcasted
empty!(mapped)
# empty!(mapped)
break
elseif isa(f, GlobalRef) && f.mod == Base && f.name == :materialize && nextstmt.args[2] === SSAValue(i)
push!(mappings[inext], node)
Expand Down Expand Up @@ -793,14 +791,14 @@ function map_ssas_to_source(src::CodeInfo, mi::MethodInstance, rootnode::SyntaxN
haskey(symtyps, t) && continue
if skipped_parent(t) == node
is_prec_assignment(node) && t == child(node, 1) && continue
symtyps[t] = unwrapinternal(if j > 0
symtyps[t] = if j > 0
ssavaluetypes[j]
elseif have_slottypes
# We failed to find it as an SSAValue, it must have type assigned at function entry
slottypes[arg.id]
else
nothing
end)
end
break
end
end
Expand Down Expand Up @@ -904,6 +902,9 @@ function skipped_parent(node::SyntaxNode)
pnode === nothing && return node
ppnode = pnode.parent
if ppnode !== nothing && kind(pnode) KSet"... quote" # might need to add more things here
if kind(node) == K"Identifier" && kind(pnode) == K"quote" && kind(ppnode) == K"." && sourcetext(node) == "materialize"
return ppnode.parent
end
return ppnode
end
return pnode
Expand Down
45 changes: 41 additions & 4 deletions TypedSyntax/src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,22 +87,56 @@ end
function is_show_annotation(@nospecialize(T); type_annotations::Bool, hide_type_stable::Bool)
type_annotations || return false
if isa(T, Core.Const)
T = typeof(T.val)
isa(T.val, Module) && return false
T = Core.Typeof(T.val)
end
isa(T, Type) || return false
hide_type_stable || return true
return isa(T, Type) && is_type_unstable(T)
end

# Is the type equivalent to the source-text?
# We use `endswith` to handle module qualification
is_type_transparent(node, @nospecialize(T)) = endswith(replace(sprint(show, T), r"\s" => ""), replace(sourcetext(node), r"\s" => ""))

function is_callfunc(node::MaybeTypedSyntaxNode, @nospecialize(T))
thisnode = node
pnode = node.parent
while pnode !== nothing && kind(pnode) KSet"quote ." && pnode.parent !== nothing
thisnode = pnode
pnode = pnode.parent
end
if pnode !== nothing && kind(pnode) (K"call", K"curly") && ((is_infix_op_call(pnode) && is_operator(thisnode)) || thisnode === pnode.children[1])
if isa(T, Core.Const)
T = T.val
end
if isa(T, Type) || isa(T, Function)
T === Colon() && sourcetext(node) == ":" && return true
return is_type_transparent(node, T)
end
end
return false
end

function type_annotation_mode(node, @nospecialize(T); type_annotations::Bool, hide_type_stable::Bool)
kind(node) == K"return" && return false, "", "", ""
is_callfunc(node, T) && return false, "", "", ""
type_annotate = is_show_annotation(T; type_annotations, hide_type_stable)
pre = pre2 = post = ""
if type_annotate
if T isa DataType && T <: Type && isassigned(T.parameters, 1)
if replace(sourcetext(node), r"\s" => "") == replace(sprint(show, T.parameters[1]), r"\s" => "")
return false, pre, pre2, post
# Try stripping Core.Const and Type{T} wrappers to check if we need to avoid `String::Type{String}`
# or `String::Core.Const(String)` annotations
S = nothing
if isa(T, Core.Const)
val = T.val
if isa(val, DataType)
S = val
end
elseif isa(T, DataType) && T <: Type && isassigned(T.parameters, 1)
S = T.parameters[1]
end
if S !== nothing && is_type_transparent(node, S)
return false, pre, pre2, post
end
if kind(node) KSet":: where" || is_infix_op_call(node) || (is_prec_assignment(node) && kind(node) != K"=")
pre, post = "(", ")"
Expand All @@ -118,6 +152,9 @@ function show_annotation(io, @nospecialize(T), post, node, position; iswarn::Boo
inlay_hints = get(io, :inlay_hints, nothing)

print(io, post)
if isa(T, Core.Const) && isa(T.val, Type)
T = Type{T.val}
end
T_str = string(T)
if iswarn && is_type_unstable(T)
color = is_small_union_or_tunion(T) ? :yellow : :red
Expand Down
Loading

0 comments on commit db6c69b

Please sign in to comment.