Skip to content

Commit

Permalink
Allow const declarations on mutable fields (#43305)
Browse files Browse the repository at this point in the history
Mark some builtin types also, although Serialization relies upon being
able to mutilate the Method objects, so we do not yet mark those.

Replaces #11430

Co-authored-by: Matt Bauman <[email protected]>
  • Loading branch information
vtjnash and mbauman authored Dec 17, 2021
1 parent 08aa0ac commit 63f6294
Show file tree
Hide file tree
Showing 18 changed files with 294 additions and 165 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ New language features
* Support for Unicode 14.0.0 ([#43443]).
* `try`-blocks can now optionally have an `else`-block which is executed right after the main body only if
no errors were thrown. ([#42211])
* Mutable struct fields may now be annotated as `const` to prevent changing
them after construction, providing for greater clarity and optimization
ability of these objects ([#43305]).

Language changes
----------------
Expand Down
3 changes: 2 additions & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
if isa(typ, UnionAll)
typ = unwrap_unionall(typ)
end
# Could still end up here if we tried to setfield! and immutable, which would
# Could still end up here if we tried to setfield! on an immutable, which would
# error at runtime, but is not illegal to have in the IR.
ismutabletype(typ) || continue
typ = typ::DataType
Expand All @@ -871,6 +871,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
stmt = ir[SSAValue(def)]::Expr # == `setfield!` call
field = try_compute_fieldidx_stmt(ir, stmt, typ)
field === nothing && @goto skip
isconst(typ, field) && @goto skip # we discovered an attempt to mutate a const field, which must error
push!(fielddefuse[field].defs, def)
end
# Check that the defexpr has defined values for all the fields
Expand Down
90 changes: 18 additions & 72 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,7 @@ function find_tfunc(@nospecialize f)
end
end

const DATATYPE_NAME_FIELDINDEX = fieldindex(DataType, :name)
const DATATYPE_PARAMETERS_FIELDINDEX = fieldindex(DataType, :parameters)
const DATATYPE_TYPES_FIELDINDEX = fieldindex(DataType, :types)
const DATATYPE_SUPER_FIELDINDEX = fieldindex(DataType, :super)
const DATATYPE_INSTANCE_FIELDINDEX = fieldindex(DataType, :instance)
const DATATYPE_HASH_FIELDINDEX = fieldindex(DataType, :hash)

const TYPENAME_NAME_FIELDINDEX = fieldindex(Core.TypeName, :name)
const TYPENAME_MODULE_FIELDINDEX = fieldindex(Core.TypeName, :module)
const TYPENAME_NAMES_FIELDINDEX = fieldindex(Core.TypeName, :names)
const TYPENAME_WRAPPER_FIELDINDEX = fieldindex(Core.TypeName, :wrapper)
const TYPENAME_HASH_FIELDINDEX = fieldindex(Core.TypeName, :hash)
const TYPENAME_FLAGS_FIELDINDEX = fieldindex(Core.TypeName, :flags)

##########
# tfuncs #
Expand Down Expand Up @@ -305,7 +293,7 @@ function isdefined_tfunc(@nospecialize(arg1), @nospecialize(sym))
return Const(false)
elseif isa(arg1, Const)
arg1v = (arg1::Const).val
if !ismutable(arg1v) || isdefined(arg1v, idx) || (isa(arg1v, DataType) && is_dt_const_field(idx))
if !ismutable(arg1v) || isdefined(arg1v, idx) || isconst(typeof(arg1v), idx)
return Const(isdefined(arg1v, idx))
end
elseif !isvatuple(a1)
Expand Down Expand Up @@ -646,23 +634,6 @@ function subtype_tfunc(@nospecialize(a), @nospecialize(b))
end
add_tfunc(<:, 2, 2, subtype_tfunc, 10)

is_dt_const_field(fld::Int) = (
fld == DATATYPE_NAME_FIELDINDEX ||
fld == DATATYPE_PARAMETERS_FIELDINDEX ||
fld == DATATYPE_TYPES_FIELDINDEX ||
fld == DATATYPE_SUPER_FIELDINDEX ||
fld == DATATYPE_INSTANCE_FIELDINDEX ||
fld == DATATYPE_HASH_FIELDINDEX
)
function const_datatype_getfield_tfunc(@nospecialize(sv), fld::Int)
if fld == DATATYPE_INSTANCE_FIELDINDEX
return isdefined(sv, fld) ? Const(getfield(sv, fld)) : Union{}
elseif is_dt_const_field(fld) && isdefined(sv, fld)
return Const(getfield(sv, fld))
end
return nothing
end

function fieldcount_noerror(@nospecialize t)
if t isa UnionAll || t isa Union
t = argument_datatype(t)
Expand Down Expand Up @@ -801,41 +772,27 @@ function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
end
if isa(name, Const)
nv = name.val
if !(isa(nv,Symbol) || isa(nv,Int))
if isa(sv, Module)
if isa(nv, Symbol)
return abstract_eval_global(sv, nv)
end
return Bottom
end
if isa(sv, UnionAll)
if nv === :var || nv === 1
return Const(sv.var)
elseif nv === :body || nv === 2
return Const(sv.body)
end
elseif isa(sv, DataType)
idx = nv
if isa(idx, Symbol)
idx = fieldindex(DataType, idx, false)
end
if isa(idx, Int)
t = const_datatype_getfield_tfunc(sv, idx)
t === nothing || return t
end
elseif isa(sv, Core.TypeName)
fld = isa(nv, Symbol) ? fieldindex(Core.TypeName, nv, false) : nv
if (fld == TYPENAME_NAME_FIELDINDEX ||
fld == TYPENAME_MODULE_FIELDINDEX ||
fld == TYPENAME_WRAPPER_FIELDINDEX ||
fld == TYPENAME_HASH_FIELDINDEX ||
fld == TYPENAME_FLAGS_FIELDINDEX ||
(fld == TYPENAME_NAMES_FIELDINDEX && isdefined(sv, fld)))
return Const(getfield(sv, fld))
end
if isa(nv, Symbol)
nv = fieldindex(typeof(sv), nv, false)
end
if isa(sv, Module) && isa(nv, Symbol)
return abstract_eval_global(sv, nv)
if !isa(nv, Int)
return Bottom
end
if (isa(sv, SimpleVector) || !ismutable(sv)) && isdefined(sv, nv)
if isa(sv, DataType) && nv == DATATYPE_TYPES_FIELDINDEX && isdefined(sv, nv)
return Const(getfield(sv, nv))
end
if isconst(typeof(sv), nv)
if isdefined(sv, nv)
return Const(getfield(sv, nv))
end
return Union{}
end
end
s = typeof(sv)
elseif isa(s00, PartialStruct)
Expand All @@ -855,11 +812,11 @@ function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
return Any
end
s = s::DataType
if s <: Tuple && name Symbol
if s <: Tuple && !(Int <: widenconst(name))
return Bottom
end
if s <: Module
if name Int
if !(Symbol <: widenconst(name))
return Bottom
end
return Any
Expand Down Expand Up @@ -920,17 +877,6 @@ function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
if fld < 1 || fld > nf
return Bottom
end
if isconstType(s00)
sp = s00.parameters[1]
elseif isa(s00, Const)
sp = s00.val
else
sp = nothing
end
if isa(sp, DataType)
t = const_datatype_getfield_tfunc(sp, fld)
t !== nothing && return t
end
R = ftypes[fld]
if isempty(s.parameters)
return R
Expand Down
25 changes: 24 additions & 1 deletion base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,34 @@ parentmodule(t::UnionAll) = parentmodule(unwrap_unionall(t))
"""
isconst(m::Module, s::Symbol) -> Bool
Determine whether a global is declared `const` in a given `Module`.
Determine whether a global is declared `const` in a given module `m`.
"""
isconst(m::Module, s::Symbol) =
ccall(:jl_is_const, Cint, (Any, Any), m, s) != 0

"""
isconst(t::DataType, s::Union{Int,Symbol}) -> Bool
Determine whether a field `s` is declared `const` in a given type `t`.
"""
function isconst(@nospecialize(t::Type), s::Symbol)
t = unwrap_unionall(t)
isa(t, DataType) || return false
return isconst(t, fieldindex(t, s, false))
end
function isconst(@nospecialize(t::Type), s::Int)
t = unwrap_unionall(t)
# TODO: what to do for `Union`?
isa(t, DataType) || return false # uncertain
ismutabletype(t) || return true # immutable structs are always const
1 <= s <= length(t.name.names) || return true # OOB reads are "const" since they always throw
constfields = t.name.constfields
constfields === C_NULL && return false
s -= 1
return unsafe_load(Ptr{UInt32}(constfields), 1 + s÷32) & (1 << (s%32)) != 0
end


"""
@locals()
Expand Down
10 changes: 5 additions & 5 deletions doc/src/base/base.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,12 @@ Base.isstructtype
Base.nameof(::DataType)
Base.fieldnames
Base.fieldname
Core.fieldtype
Base.fieldtypes
Base.fieldcount
Base.hasfield
Core.nfields
Base.isconst
```

### Memory layout
Expand All @@ -190,9 +195,6 @@ Base.sizeof(::Type)
Base.isconcretetype
Base.isbits
Base.isbitstype
Core.fieldtype
Base.fieldtypes
Base.fieldcount
Base.fieldoffset
Base.datatype_alignment
Base.datatype_haspadding
Expand Down Expand Up @@ -418,8 +420,6 @@ Base.@__DIR__
Base.@__LINE__
Base.fullname
Base.names
Core.nfields
Base.isconst
Base.nameof(::Function)
Base.functionloc(::Any, ::Any)
Base.functionloc(::Method)
Expand Down
2 changes: 1 addition & 1 deletion src/ast.scm
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@
(or (symbol? e) (decl? e)))

(define (eventually-decl? e)
(or (decl? e) (and (pair? e) (eq? (car e) 'atomic) (symdecl? (cadr e)))))
(or (symbol? e) (and (pair? e) (memq (car e) '(|::| atomic const)) (eventually-decl? (cadr e)))))

(define (make-decl n t) `(|::| ,n ,t))

Expand Down
8 changes: 8 additions & 0 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,10 @@ static inline size_t get_checked_fieldindex(const char *name, jl_datatype_t *st,
JL_TYPECHKS(name, symbol, arg);
idx = jl_field_index(st, (jl_sym_t*)arg, 1);
}
if (mutabl && jl_field_isconst(st, idx)) {
jl_errorf("%s: const field .%s of type %s cannot be changed", name,
jl_symbol_name((jl_sym_t*)jl_svec_ref(jl_field_names(st), idx)), jl_symbol_name(st->name->name));
}
return idx;
}

Expand Down Expand Up @@ -1604,6 +1608,10 @@ static int equiv_type(jl_value_t *ta, jl_value_t *tb)
? dtb->name->atomicfields == NULL
: (dtb->name->atomicfields != NULL &&
memcmp(dta->name->atomicfields, dtb->name->atomicfields, (jl_svec_len(dta->name->names) + 31) / 32 * sizeof(uint32_t)) == 0)) &&
(dta->name->constfields == NULL
? dtb->name->constfields == NULL
: (dtb->name->constfields != NULL &&
memcmp(dta->name->constfields, dtb->name->constfields, (jl_svec_len(dta->name->names) + 31) / 32 * sizeof(uint32_t)) == 0)) &&
jl_egal((jl_value_t*)jl_field_names(dta), (jl_value_t*)jl_field_names(dtb)) &&
jl_nparams(dta) == jl_nparams(dtb)))
return 0;
Expand Down
16 changes: 4 additions & 12 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2249,10 +2249,10 @@ static jl_cgval_t emit_getfield_knownidx(jl_codectx_t &ctx, const jl_cgval_t &st
else {
ptindex = emit_struct_gep(ctx, cast<StructType>(lt), staddr, byte_offset + fsz);
}
return emit_unionload(ctx, addr, ptindex, jfty, fsz, al, tbaa, jt->name->mutabl, union_max, tbaa_unionselbyte);
return emit_unionload(ctx, addr, ptindex, jfty, fsz, al, tbaa, !jl_field_isconst(jt, idx), union_max, tbaa_unionselbyte);
}
assert(jl_is_concrete_type(jfty));
if (!jt->name->mutabl && !(maybe_null && (jfty == (jl_value_t*)jl_bool_type ||
if (jl_field_isconst(jt, idx) && !(maybe_null && (jfty == (jl_value_t*)jl_bool_type ||
((jl_datatype_t*)jfty)->layout->npointers))) {
// just compute the pointer and let user load it when necessary
return mark_julia_slot(addr, jfty, NULL, tbaa);
Expand Down Expand Up @@ -3283,21 +3283,13 @@ static void emit_write_multibarrier(jl_codectx_t &ctx, Value *parent, Value *agg
emit_write_barrier(ctx, parent, ptrs);
}


static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
jl_datatype_t *sty, const jl_cgval_t &strct, size_t idx0,
jl_cgval_t rhs, jl_cgval_t cmp,
bool checked, bool wb, AtomicOrdering Order, AtomicOrdering FailOrder,
bool wb, AtomicOrdering Order, AtomicOrdering FailOrder,
bool needlock, bool issetfield, bool isreplacefield, bool isswapfield, bool ismodifyfield,
const jl_cgval_t *modifyop, const std::string &fname)
{
if (!sty->name->mutabl && checked) {
std::string msg = fname + ": immutable struct of type "
+ std::string(jl_symbol_name(sty->name->name))
+ " cannot be changed";
emit_error(ctx, msg);
return jl_cgval_t();
}
assert(strct.ispointer());
size_t byte_offset = jl_field_offset(sty, idx0);
Value *addr = data_pointer(ctx, strct);
Expand Down Expand Up @@ -3574,7 +3566,7 @@ static jl_cgval_t emit_new_struct(jl_codectx_t &ctx, jl_value_t *ty, size_t narg
else
need_wb = false;
emit_typecheck(ctx, rhs, jl_svecref(sty->types, i), "new"); // n.b. ty argument must be concrete
emit_setfield(ctx, sty, strctinfo, i, rhs, jl_cgval_t(), false, need_wb, AtomicOrdering::NotAtomic, AtomicOrdering::NotAtomic, false, true, false, false, false, nullptr, "");
emit_setfield(ctx, sty, strctinfo, i, rhs, jl_cgval_t(), need_wb, AtomicOrdering::NotAtomic, AtomicOrdering::NotAtomic, false, true, false, false, false, nullptr, "");
}
return strctinfo;
}
Expand Down
41 changes: 27 additions & 14 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2506,6 +2506,7 @@ static bool emit_f_opfield(jl_codectx_t &ctx, jl_cgval_t *ret, jl_value_t *f,
bool isboxed = jl_field_isptr(uty, idx);
bool isatomic = jl_field_isatomic(uty, idx);
bool needlock = isatomic && !isboxed && jl_datatype_size(jl_field_type(uty, idx)) > MAX_ATOMIC_SIZE;
*ret = jl_cgval_t();
if (isatomic == (order == jl_memory_order_notatomic)) {
emit_atomic_error(ctx,
issetfield ?
Expand All @@ -2519,25 +2520,37 @@ static bool emit_f_opfield(jl_codectx_t &ctx, jl_cgval_t *ret, jl_value_t *f,
: "swapfield!: non-atomic field cannot be written atomically") :
(isatomic ? "modifyfield!: atomic field cannot be written non-atomically"
: "modifyfield!: non-atomic field cannot be written atomically"));
*ret = jl_cgval_t();
return true;
}
if (isatomic == (fail_order == jl_memory_order_notatomic)) {
else if (isatomic == (fail_order == jl_memory_order_notatomic)) {
emit_atomic_error(ctx,
(isatomic ? "replacefield!: atomic field cannot be accessed non-atomically"
: "replacefield!: non-atomic field cannot be accessed atomically"));
*ret = jl_cgval_t();
return true;
}
*ret = emit_setfield(ctx, uty, obj, idx, val, cmp, true, true,
(needlock || order <= jl_memory_order_notatomic)
? (isboxed ? AtomicOrdering::Unordered : AtomicOrdering::NotAtomic) // TODO: we should do this for anything with CountTrackedPointers(elty).count > 0
: get_llvm_atomic_order(order),
(needlock || fail_order <= jl_memory_order_notatomic)
? (isboxed ? AtomicOrdering::Unordered : AtomicOrdering::NotAtomic) // TODO: we should do this for anything with CountTrackedPointers(elty).count > 0
: get_llvm_atomic_order(fail_order),
needlock, issetfield, isreplacefield, isswapfield, ismodifyfield,
modifyop, fname);
else if (!uty->name->mutabl) {
std::string msg = fname + ": immutable struct of type "
+ std::string(jl_symbol_name(uty->name->name))
+ " cannot be changed";
emit_error(ctx, msg);
}
else if (jl_field_isconst(uty, idx)) {
std::string msg = fname + ": const field ."
+ std::string(jl_symbol_name((jl_sym_t*)jl_svec_ref(jl_field_names(uty), idx)))
+ " of type "
+ std::string(jl_symbol_name(uty->name->name))
+ " cannot be changed";
emit_error(ctx, msg);
}
else {
*ret = emit_setfield(ctx, uty, obj, idx, val, cmp, true,
(needlock || order <= jl_memory_order_notatomic)
? (isboxed ? AtomicOrdering::Unordered : AtomicOrdering::NotAtomic) // TODO: we should do this for anything with CountTrackedPointers(elty).count > 0
: get_llvm_atomic_order(order),
(needlock || fail_order <= jl_memory_order_notatomic)
? (isboxed ? AtomicOrdering::Unordered : AtomicOrdering::NotAtomic) // TODO: we should do this for anything with CountTrackedPointers(elty).count > 0
: get_llvm_atomic_order(fail_order),
needlock, issetfield, isreplacefield, isswapfield, ismodifyfield,
modifyop, fname);
}
return true;
}
}
Expand Down
Loading

0 comments on commit 63f6294

Please sign in to comment.