Skip to content

Commit

Permalink
Improve inference for ismutable with missing sparam
Browse files Browse the repository at this point in the history
We could already infer `ismutable(RefValue{T})` if we knew what
`T` was at inference time. However, the mutable does of course
not change depending on what `T` is, so fix that up by adding
an appropriate special case in `_getfield_tfunc`.
  • Loading branch information
Keno committed Sep 9, 2022
1 parent f1a0dd6 commit c889a0c
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 34 deletions.
93 changes: 60 additions & 33 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ function find_tfunc(@nospecialize f)
end

const DATATYPE_TYPES_FIELDINDEX = fieldindex(DataType, :types)
const DATATYPE_NAME_FIELDINDEX = fieldindex(DataType, :name)

##########
# tfuncs #
Expand Down Expand Up @@ -823,7 +824,11 @@ function getfield_nothrow(@nospecialize(s00), @nospecialize(name), boundscheck::
if isa(s, Union)
return getfield_nothrow(rewrap_unionall(s.a, s00), name, boundscheck) &&
getfield_nothrow(rewrap_unionall(s.b, s00), name, boundscheck)
elseif isa(s, DataType)
elseif isType(s)
sv = s.parameters[1]
s = s0 = typeof(sv)
end
if isa(s, DataType)
# Can't say anything about abstract types
isabstracttype(s) && return false
s.name.atomicfields == C_NULL || return false # TODO: currently we're only testing for ordering === :not_atomic
Expand Down Expand Up @@ -863,15 +868,40 @@ function getfield_tfunc(s00, name, order, boundscheck)
return getfield_tfunc(s00, name)
end
getfield_tfunc(@nospecialize(s00), @nospecialize(name)) = _getfield_tfunc(s00, name, false)

function _getfield_fieldindex(@nospecialize(s), name::Const)
nv = name.val
if isa(nv, Symbol)
nv = fieldindex(s, nv, false)
end
if isa(nv, Int)
return nv
end
return nothing
end

function _getfield_tfunc_const(@nospecialize(sv), name::Const, setfield::Bool)
if isa(name, Const)
nv = _getfield_fieldindex(typeof(sv), name)
nv === nothing && return Bottom
if isa(sv, DataType) && nv == DATATYPE_TYPES_FIELDINDEX && isdefined(sv, nv)
return Const(getfield(sv, nv))
end
if isconst(typeof(sv), nv)
if isdefined(sv, nv)
return Const(getfield(sv, nv))
end
return Union{}
end
end
return nothing
end

function _getfield_tfunc(@nospecialize(s00), @nospecialize(name), setfield::Bool)
if isa(s00, Conditional)
return Bottom # Bool has no fields
elseif isa(s00, Const) || isconstType(s00)
if !isa(s00, Const)
sv = s00.parameters[1]
else
sv = s00.val
end
elseif isa(s00, Const)
sv = s00.val
if isa(name, Const)
nv = name.val
if isa(sv, Module)
Expand All @@ -881,31 +911,15 @@ function _getfield_tfunc(@nospecialize(s00), @nospecialize(name), setfield::Bool
end
return Bottom
end
if isa(nv, Symbol)
nv = fieldindex(typeof(sv), nv, false)
end
if !isa(nv, Int)
return Bottom
end
if isa(sv, DataType) && nv == DATATYPE_TYPES_FIELDINDEX && isdefined(sv, nv)
return Const(getfield(sv, nv))
end
if isconst(typeof(sv), nv)
if isdefined(sv, nv)
return Const(getfield(sv, nv))
end
return Union{}
end
r = _getfield_tfunc_const(sv, name, setfield)
r !== nothing && return r
end
s = typeof(sv)
elseif isa(s00, PartialStruct)
s = widenconst(s00)
sty = unwrap_unionall(s)::DataType
if isa(name, Const)
nv = name.val
if isa(nv, Symbol)
nv = fieldindex(sty, nv, false)
end
nv = _getfield_fieldindex(sty, name)
if isa(nv, Int) && 1 <= nv <= length(s00.fields)
return unwrapva(s00.fields[nv])
end
Expand All @@ -917,6 +931,24 @@ function _getfield_tfunc(@nospecialize(s00), @nospecialize(name), setfield::Bool
return tmerge(_getfield_tfunc(rewrap_unionall(s.a, s00), name, setfield),
_getfield_tfunc(rewrap_unionall(s.b, s00), name, setfield))
end
if isType(s)
if isconstType(s)
sv = s00.parameters[1]
r = _getfield_tfunc_const(sv, name, setfield)
r !== nothing && return r
s = typeof(sv)
else
sv = s.parameters[1]
if isa(sv, DataType) && isa(name, Const) && (!isType(sv) && sv !== Core.TypeofBottom)
nv = _getfield_fieldindex(DataType, name)
if nv == DATATYPE_NAME_FIELDINDEX
# N.B. This doesn't work in general, because
return Const(sv.name)
end
s = DataType
end
end
end
isa(s, DataType) || return Any
isabstracttype(s) && return Any
if s <: Tuple && !hasintersect(widenconst(name), Int)
Expand Down Expand Up @@ -972,13 +1004,8 @@ function _getfield_tfunc(@nospecialize(s00), @nospecialize(name), setfield::Bool
end
return t
end
fld = name.val
if isa(fld, Symbol)
fld = fieldindex(s, fld, false)
end
if !isa(fld, Int)
return Bottom
end
fld = _getfield_fieldindex(s, name)
fld === nothing && return Bottom
if s <: Tuple && fld >= nf && isvarargtype(ftypes[nf])
return rewrap_unionall(unwrapva(ftypes[nf]), s00)
end
Expand Down
3 changes: 2 additions & 1 deletion test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1588,7 +1588,8 @@ g23024(TT::Tuple{DataType}) = f23024(TT[1], v23024)
@test g23024((UInt8,)) === 2

@test !Core.Compiler.isconstType(Type{typeof(Union{})}) # could be Core.TypeofBottom or Type{Union{}} at runtime
@test Base.return_types(supertype, (Type{typeof(Union{})},)) == Any[Any]
@test !isa(Core.Compiler.getfield_tfunc(Type{Core.TypeofBottom}, Core.Compiler.Const(:name)), Core.Compiler.Const)
@test Base.return_types(supertype, (Type{typeof(Union{})},)) == Any[DataType]

# issue #23685
struct Node23685{T}
Expand Down
4 changes: 4 additions & 0 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1496,3 +1496,7 @@ call_twice_sitofp(x::Int) = twice_sitofp(x, 2)
let src = code_typed1(call_twice_sitofp, (Int,))
@test count(iscall((src, Base.sitofp)), src.code) == 1
end

# Test getfield modeling of Type{Ref{_A}} where _A
@test Core.Compiler.getfield_tfunc(Type, Core.Compiler.Const(:parameters)) !== Union{}
@test fully_eliminated(Base.ismutable, Tuple{Base.RefValue})

0 comments on commit c889a0c

Please sign in to comment.