Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference correctness: fields and globals can revert to undef #53750

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2583,26 +2583,19 @@ function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtyp
t = Bool
effects = EFFECTS_TOTAL
exct = Union{}
isa(sym, Symbol) && (sym = GlobalRef(frame_module(sv), sym))
if isa(sym, SlotNumber) && vtypes !== nothing
vtyp = vtypes[slot_id(sym)]
if vtyp.typ === Bottom
t = Const(false) # never assigned previously
elseif !vtyp.undef
t = Const(true) # definitely assigned previously
end
elseif isa(sym, Symbol)
if isdefined(frame_module(sv), sym)
t = Const(true)
elseif InferenceParams(interp).assume_bindings_static
t = Const(false)
else
effects = Effects(EFFECTS_TOTAL; consistent=ALWAYS_FALSE)
end
elseif isa(sym, GlobalRef)
if isdefined(sym.mod, sym.name)
if InferenceParams(interp).assume_bindings_static
t = Const(isdefined_globalref(sym))
elseif isdefinedconst_globalref(sym)
t = Const(true)
elseif InferenceParams(interp).assume_bindings_static
t = Const(false)
else
effects = Effects(EFFECTS_TOTAL; consistent=ALWAYS_FALSE)
end
Expand Down Expand Up @@ -2791,9 +2784,10 @@ function override_effects(effects::Effects, override::EffectsOverride)
end

isdefined_globalref(g::GlobalRef) = !iszero(ccall(:jl_globalref_boundp, Cint, (Any,), g))
isdefinedconst_globalref(g::GlobalRef) = isconst(g) && isdefined_globalref(g)

function abstract_eval_globalref_type(g::GlobalRef)
if isdefined_globalref(g) && isconst(g)
if isdefinedconst_globalref(g)
return Const(ccall(:jl_get_globalref_value, Any, (Any,), g))
end
ty = ccall(:jl_get_binding_type, Any, (Any, Any), g.mod, g.name)
Expand All @@ -2812,11 +2806,15 @@ function abstract_eval_globalref(interp::AbstractInterpreter, g::GlobalRef, sv::
if is_mutation_free_argtype(rt)
inaccessiblememonly = ALWAYS_TRUE
end
elseif isdefined_globalref(g)
nothrow = true
elseif InferenceParams(interp).assume_bindings_static
consistent = inaccessiblememonly = ALWAYS_TRUE
rt = Union{}
if isdefined_globalref(g)
nothrow = true
else
rt = Union{}
end
elseif isdefinedconst_globalref(g)
nothrow = true
end
return RTEffects(rt, nothrow ? Union{} : UndefVarError, Effects(EFFECTS_TOTAL; consistent, nothrow, inaccessiblememonly))
end
Expand Down
3 changes: 1 addition & 2 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,7 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
isa(stmt, GotoNode) && return (true, false, true)
isa(stmt, GotoIfNot) && return (true, false, ⊑(𝕃ₒ, argextype(stmt.cond, src), Bool))
if isa(stmt, GlobalRef)
nothrow = isdefined(stmt.mod, stmt.name)
consistent = nothrow && isconst(stmt.mod, stmt.name)
nothrow = consistent = isdefinedconst_globalref(stmt)
return (consistent, nothrow, nothrow)
elseif isa(stmt, Expr)
(; head, args) = stmt
Expand Down
85 changes: 32 additions & 53 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -394,20 +394,13 @@ end
return isdefined_tfunc(𝕃, arg1, sym)
end
@nospecs function isdefined_tfunc(𝕃::AbstractLattice, arg1, sym)
if isa(arg1, Const)
arg1t = typeof(arg1.val)
else
arg1t = widenconst(arg1)
end
if isType(arg1t)
return Bool
end
arg1t = arg1 isa Const ? typeof(arg1.val) : isconstType(arg1) ? typeof(arg1.parameters[1]) : widenconst(arg1)
a1 = unwrap_unionall(arg1t)
if isa(a1, DataType) && !isabstracttype(a1)
if a1 === Module
hasintersect(widenconst(sym), Symbol) || return Bottom
if isa(sym, Const) && isa(sym.val, Symbol) && isa(arg1, Const) &&
isdefined(arg1.val::Module, sym.val::Symbol)
isdefinedconst_globalref(GlobalRef(arg1.val::Module, sym.val::Symbol))
return Const(true)
end
elseif isa(sym, Const)
Expand All @@ -433,9 +426,8 @@ end
elseif idx <= 0 || (!isvatuple(a1) && idx > fieldcount(a1))
return Const(false)
elseif isa(arg1, Const)
arg1v = (arg1::Const).val
if !ismutable(arg1v) || isdefined(arg1v, idx) || isconst(typeof(arg1v), idx)
return Const(isdefined(arg1v, idx))
if !ismutabletype(a1) || isconst(a1, idx)
return Const(isdefined(arg1.val, idx))
end
elseif !isvatuple(a1)
fieldT = fieldtype(a1, idx)
Expand Down Expand Up @@ -987,7 +979,7 @@ end
# If we have s00 being a const, we can potentially refine our type-based analysis above
if isa(s00, Const) || isconstType(s00)
if !isa(s00, Const)
sv = s00.parameters[1]
sv = (s00::DataType).parameters[1]
else
sv = s00.val
end
Expand All @@ -997,15 +989,16 @@ end
isa(sv, Module) && return false
isa(nval, Int) || return false
end
return isdefined(sv, nval)
return isdefined_tfunc(𝕃, s00, name) === Const(true)
end
boundscheck && return false
# If bounds checking is disabled and all fields are assigned,
# we may assume that we don't throw
isa(sv, Module) && return false
name ⊑ Int || name ⊑ Symbol || return false
for i = 1:fieldcount(typeof(sv))
isdefined(sv, i) || return false
typeof(sv).name.n_uninitialized == 0 && return true
for i = (datatype_min_ninitialized(typeof(sv)) + 1):nfields(sv)
isdefined_tfunc(𝕃, s00, Const(i)) === Const(true) || return false
end
return true
end
Expand Down Expand Up @@ -1244,27 +1237,22 @@ end
return rewrap_unionall(R, s00)
end

@nospecs function getfield_notundefined(typ0, name)
if isa(typ0, Const) && isa(name, Const)
typv = typ0.val
namev = name.val
isa(typv, Module) && return true
if isa(namev, Symbol) || isa(namev, Int)
# Fields are not allowed to transition from defined to undefined, so
# even if the field is not const, all we need to check here is that
# it is defined here.
return isdefined(typv, namev)
end
@nospecs function getfield_notuninit(typ0, name)
if isa(typ0, Const)
# If the object is Const, then we know exactly the bit patterns that
# must be returned by getfield if not an error
return true
end
typ0 = widenconst(typ0)
typ = unwrap_unionall(typ0)
if isa(typ, Union)
return getfield_notundefined(rewrap_unionall(typ.a, typ0), name) &&
getfield_notundefined(rewrap_unionall(typ.b, typ0), name)
return getfield_notuninit(rewrap_unionall(typ.a, typ0), name) &&
getfield_notuninit(rewrap_unionall(typ.b, typ0), name)
end
isa(typ, DataType) || return false
if typ.name === Tuple.name || typ.name === _NAMEDTUPLE_NAME
# tuples and named tuples can't be instantiated with undefined fields,
isabstracttype(typ) && !isconstType(typ) && return false # cannot say anything about abstract types
if typ.name.n_uninitialized == 0
# Types such as tuples and named tuples that can't be instantiated with undefined fields,
# so we don't need to be conservative here
return true
end
Expand Down Expand Up @@ -2436,25 +2424,16 @@ function isdefined_effects(𝕃::AbstractLattice, argtypes::Vector{Any})
# consistent if the first arg is immutable
na = length(argtypes)
2 ≤ na ≤ 3 || return EFFECTS_THROWS
obj, sym = argtypes
wobj = unwrapva(obj)
wobj, sym = argtypes
wobj = unwrapva(wobj)
sym = unwrapva(sym)
consistent = CONSISTENT_IF_INACCESSIBLEMEMONLY
if is_immutable_argtype(wobj)
consistent = ALWAYS_TRUE
else
# Bindings/fields are not allowed to transition from defined to undefined, so even
# if the object is not immutable, we can prove `:consistent`-cy if it is defined:
if isa(wobj, Const) && isa(sym, Const)
objval = wobj.val
symval = sym.val
if isa(objval, Module)
if isa(symval, Symbol) && isdefined(objval, symval)
consistent = ALWAYS_TRUE
end
elseif (isa(symval, Symbol) || isa(symval, Int)) && isdefined(objval, symval)
consistent = ALWAYS_TRUE
end
end
elseif isdefined_tfunc(𝕃, wobj, sym) isa Const
# Some bindings/fields are not allowed to transition from defined to undefined or the reverse, so even
# if the object is not immutable, we can prove `:consistent`-cy of this:
consistent = ALWAYS_TRUE
end
nothrow = isdefined_nothrow(𝕃, argtypes)
if hasintersect(widenconst(wobj), Module)
Expand Down Expand Up @@ -2483,11 +2462,11 @@ function getfield_effects(𝕃::AbstractLattice, argtypes::Vector{Any}, @nospeci
# taint `:consistent` if accessing `isbitstype`-type object field that may be initialized
# with undefined value: note that we don't need to taint `:consistent` if accessing
# uninitialized non-`isbitstype` field since it will simply throw `UndefRefError`
# NOTE `getfield_notundefined` conservatively checks if this field is never initialized
# NOTE `getfield_notuninit` conservatively checks if this field is never initialized
# with undefined value to avoid tainting `:consistent` too aggressively
# TODO this should probably taint `:noub`, however, it would hinder concrete eval for
# `REPLInterpreter` that can ignore `:consistent-cy`, causing worse completions
if !(length(argtypes) ≥ 2 && getfield_notundefined(obj, argtypes[2]))
if !(length(argtypes) ≥ 2 && getfield_notuninit(obj, argtypes[2]))
consistent = ALWAYS_FALSE
end
noub = ALWAYS_TRUE
Expand Down Expand Up @@ -3123,7 +3102,7 @@ end
if M isa Const && s isa Const
M, s = M.val, s.val
if M isa Module && s isa Symbol
return isdefined(M, s)
return isdefinedconst_globalref(GlobalRef(M, s))
end
end
return false
Expand Down Expand Up @@ -3196,9 +3175,9 @@ end
end

function global_assignment_nothrow(M::Module, s::Symbol, @nospecialize(newty))
if isdefined(M, s) && !isconst(M, s)
if !isconst(M, s)
ty = ccall(:jl_get_binding_type, Any, (Any, Any), M, s)
return ty === nothing || newty ty
return ty isa Type && widenconst(newty) <: ty
end
return false
end
Expand All @@ -3214,7 +3193,7 @@ end
end
@nospecs function get_binding_type_tfunc(𝕃::AbstractLattice, M, s)
if get_binding_type_effect_free(M, s)
return Const(Core.get_binding_type((M::Const).val, (s::Const).val))
return Const(Core.get_binding_type((M::Const).val::Module, (s::Const).val::Symbol))
end
return Type
end
Expand Down
22 changes: 8 additions & 14 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3186,22 +3186,16 @@ static jl_cgval_t emit_globalref(jl_codectx_t &ctx, jl_module_t *mod, jl_sym_t *
if (bp == NULL)
return jl_cgval_t();
bp = julia_binding_pvalue(ctx, bp);
jl_value_t *ty = nullptr;
if (bnd) {
jl_value_t *v = jl_atomic_load_acquire(&bnd->value); // acquire value for ty
if (v != NULL) {
if (bnd->constp)
return mark_julia_const(ctx, v);
LoadInst *v = ctx.builder.CreateAlignedLoad(ctx.types().T_prjlvalue, bp, Align(sizeof(void*)));
setName(ctx.emission_context, v, jl_symbol_name(name));
v->setOrdering(order);
jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_binding);
ai.decorateInst(v);
jl_value_t *ty = jl_atomic_load_relaxed(&bnd->ty);
return mark_julia_type(ctx, v, true, ty);
}
if (v != NULL && bnd->constp)
return mark_julia_const(ctx, v);
ty = jl_atomic_load_relaxed(&bnd->ty);
}
// todo: use type info to avoid undef check
return emit_checked_var(ctx, bp, name, (jl_value_t*)mod, false, ctx.tbaa().tbaa_binding);
if (ty == nullptr)
ty = (jl_value_t*)jl_any_type;
return update_julia_type(ctx, emit_checked_var(ctx, bp, name, (jl_value_t*)mod, false, ctx.tbaa().tbaa_binding), ty);
}

static jl_cgval_t emit_globalop(jl_codectx_t &ctx, jl_module_t *mod, jl_sym_t *sym, jl_cgval_t rval, const jl_cgval_t &cmp,
Expand Down Expand Up @@ -5459,7 +5453,7 @@ static jl_cgval_t emit_isdefined(jl_codectx_t &ctx, jl_value_t *sym)
}
jl_binding_t *bnd = jl_get_binding(modu, name);
if (bnd) {
if (jl_atomic_load_relaxed(&bnd->value) != NULL)
if (jl_atomic_load_acquire(&bnd->value) != NULL && bnd->constp)
return mark_julia_const(ctx, jl_true);
Value *bp = julia_binding_gv(ctx, bnd);
bp = julia_binding_pvalue(ctx, bp);
Expand Down
64 changes: 32 additions & 32 deletions test/compiler/effects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,52 +249,52 @@ struct SyntacticallyDefined{T}
x::T
end

import Core.Compiler: Const, getfield_notundefined
import Core.Compiler: Const, getfield_notuninit
for T = (Base.RefValue, Maybe) # both mutable and immutable
for name = (Const(1), Const(:x))
@test getfield_notundefined(T{String}, name)
@test getfield_notundefined(T{Integer}, name)
@test getfield_notundefined(T{Union{String,Integer}}, name)
@test getfield_notundefined(Union{T{String},T{Integer}}, name)
@test !getfield_notundefined(T{Int}, name)
@test !getfield_notundefined(T{<:Integer}, name)
@test !getfield_notundefined(T{Union{Int32,Int64}}, name)
@test !getfield_notundefined(T, name)
@test getfield_notuninit(T{String}, name)
@test getfield_notuninit(T{Integer}, name)
@test getfield_notuninit(T{Union{String,Integer}}, name)
@test getfield_notuninit(Union{T{String},T{Integer}}, name)
@test !getfield_notuninit(T{Int}, name)
@test !getfield_notuninit(T{<:Integer}, name)
@test !getfield_notuninit(T{Union{Int32,Int64}}, name)
@test !getfield_notuninit(T, name)
end
# throw doesn't account for undefined behavior
for name = (Const(0), Const(2), Const(1.0), Const(:y), Const("x"),
Float64, String, Nothing)
@test getfield_notundefined(T{String}, name)
@test getfield_notundefined(T{Int}, name)
@test getfield_notundefined(T{Integer}, name)
@test getfield_notundefined(T{<:Integer}, name)
@test getfield_notundefined(T{Union{Int32,Int64}}, name)
@test getfield_notundefined(T, name)
@test getfield_notuninit(T{String}, name)
@test getfield_notuninit(T{Int}, name)
@test getfield_notuninit(T{Integer}, name)
@test getfield_notuninit(T{<:Integer}, name)
@test getfield_notuninit(T{Union{Int32,Int64}}, name)
@test getfield_notuninit(T, name)
end
# should not be too conservative when field isn't known very well but object information is accurate
@test getfield_notundefined(T{String}, Int)
@test getfield_notundefined(T{String}, Symbol)
@test getfield_notundefined(T{Integer}, Int)
@test getfield_notundefined(T{Integer}, Symbol)
@test !getfield_notundefined(T{Int}, Int)
@test !getfield_notundefined(T{Int}, Symbol)
@test !getfield_notundefined(T{<:Integer}, Int)
@test !getfield_notundefined(T{<:Integer}, Symbol)
@test getfield_notuninit(T{String}, Int)
@test getfield_notuninit(T{String}, Symbol)
@test getfield_notuninit(T{Integer}, Int)
@test getfield_notuninit(T{Integer}, Symbol)
@test !getfield_notuninit(T{Int}, Int)
@test !getfield_notuninit(T{Int}, Symbol)
@test !getfield_notuninit(T{<:Integer}, Int)
@test !getfield_notuninit(T{<:Integer}, Symbol)
end
# should be conservative when object information isn't accurate
@test !getfield_notundefined(Any, Const(1))
@test !getfield_notundefined(Any, Const(:x))
@test !getfield_notuninit(Any, Const(1))
@test !getfield_notuninit(Any, Const(:x))
# tuples and namedtuples should be okay if not given accurate information
for TupleType = Any[Tuple{Int,Int,Int}, Tuple{Int,Vararg{Int}}, Tuple{Any}, Tuple,
NamedTuple{(:a, :b), Tuple{Int,Int}}, NamedTuple{(:x,),Tuple{Any}}, NamedTuple],
FieldType = Any[Int, Symbol, Any]
@test getfield_notundefined(TupleType, FieldType)
@test getfield_notuninit(TupleType, FieldType)
end
# skip analysis on fields that are known to be defined syntactically
@test Core.Compiler.getfield_notundefined(SyntacticallyDefined{Float64}, Symbol)
@test Core.Compiler.getfield_notundefined(Const(Main), Const(:var))
@test Core.Compiler.getfield_notundefined(Const(Main), Const(42))
# high-level tests for `getfield_notundefined`
@test Core.Compiler.getfield_notuninit(SyntacticallyDefined{Float64}, Symbol)
@test Core.Compiler.getfield_notuninit(Const(Main), Const(:var))
@test Core.Compiler.getfield_notuninit(Const(Main), Const(42))
# high-level tests for `getfield_notuninit`
@test Base.infer_effects() do
Maybe{Int}()
end |> !Core.Compiler.is_consistent
Expand Down Expand Up @@ -904,7 +904,7 @@ end |> Core.Compiler.is_foldable_nothrow
@test Base.infer_effects(Tuple{WrapperOneField{Float64}, Symbol}) do w, s
getfield(w, s)
end |> Core.Compiler.is_foldable
@test Core.Compiler.getfield_notundefined(WrapperOneField{Float64}, Symbol)
@test Core.Compiler.getfield_notuninit(WrapperOneField{Float64}, Symbol)
@test Base.infer_effects(Tuple{WrapperOneField{Symbol}, Symbol}) do w, s
getfield(w, s)
end |> Core.Compiler.is_foldable
Expand Down Expand Up @@ -996,7 +996,7 @@ end
let effects = Base.infer_effects() do
isdefined(defined_ref, :x)
end
@test Core.Compiler.is_consistent(effects)
@test !Core.Compiler.is_consistent(effects)
@test Core.Compiler.is_nothrow(effects)
end
let effects = Base.infer_effects() do
Expand Down
Loading