Skip to content

Commit

Permalink
Support deleting methods during precompilation for stdlib excision (#…
Browse files Browse the repository at this point in the history
…51641)

The problem with the `delete_method` in `__init__` approach is that we
invalidate the method-table,
after we have performed all of the caching work. A package dependent on
`Random`, will still see
the stub method in Base and thus when we delete the stub, we may
invalidate useful work.

Instead we delete the methods when Random is being loaded, thus a
dependent package only ever sees
the method table with all the methods in Random, and non of the stubs
methods.

The only invalidation that thus may happen are calls to `rand` and
`randn` without first doing an `import Random`.
  • Loading branch information
vchuravy authored Oct 18, 2023
1 parent 1b4a194 commit 31ccfb6
Show file tree
Hide file tree
Showing 12 changed files with 95 additions and 23 deletions.
1 change: 1 addition & 0 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Tracking of newly-inferred CodeInstances during precompilation
const track_newly_inferred = RefValue{Bool}(false)
const newly_inferred = CodeInstance[]
const newly_deleted = Method[]

# build (and start inferring) the inference frame for the top-level MethodInstance
function typeinf(interp::AbstractInterpreter, result::InferenceResult, cache::Symbol)
Expand Down
1 change: 1 addition & 0 deletions base/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2297,6 +2297,7 @@ function include_package_for_output(pkg::PkgId, input::String, depot_path::Vecto
end

ccall(:jl_set_newly_inferred, Cvoid, (Any,), Core.Compiler.newly_inferred)
ccall(:jl_set_newly_deleted, Cvoid, (Any,), Core.Compiler.newly_deleted)
Core.Compiler.track_newly_inferred.x = true
try
Base.include(Base.__toplevel__, input)
Expand Down
3 changes: 2 additions & 1 deletion base/stubs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ function delete_stubs(mod)
if obj isa Function
ms = Base.methods(obj, mod)
for m in ms
Base.delete_method(m)
ccall(:jl_push_newly_deleted, Cvoid, (Any,), m)
ccall(:jl_method_table_disable_incremental, Cvoid, (Any, Any), Base.get_methodtable(m), m)
end
end
end
Expand Down
1 change: 1 addition & 0 deletions doc/src/devdocs/locks.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ The following is a leaf lock (level 2), and only acquires level 1 locks (safepoi
> * Module->lock
> * JLDebuginfoPlugin::PluginMutex
> * newly_inferred_mutex
> * newly_deleted_mutex
The following is a level 3 lock, which can only acquire level 1 or level 2 locks internally:

Expand Down
3 changes: 2 additions & 1 deletion src/clangsa/GCChecker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,8 @@ bool GCChecker::evalCall(const CallEvent &Call, CheckerContext &C) const {
} else if (name == "JL_GC_PUSH1" || name == "JL_GC_PUSH2" ||
name == "JL_GC_PUSH3" || name == "JL_GC_PUSH4" ||
name == "JL_GC_PUSH5" || name == "JL_GC_PUSH6" ||
name == "JL_GC_PUSH7" || name == "JL_GC_PUSH8") {
name == "JL_GC_PUSH7" || name == "JL_GC_PUSH8" ||
name == "JL_GC_PUSH9") {
ProgramStateRef State = C.getState();
// Transform slots to roots, transform values to rooted
unsigned NumArgs = CE->getNumArgs();
Expand Down
20 changes: 16 additions & 4 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -1858,9 +1858,9 @@ static jl_typemap_entry_t *do_typemap_search(jl_methtable_t *mt JL_PROPAGATES_RO
}
#endif

static void jl_method_table_invalidate(jl_methtable_t *mt, jl_typemap_entry_t *methodentry, size_t max_world)
static void jl_method_table_invalidate(jl_methtable_t *mt, jl_typemap_entry_t *methodentry, size_t max_world, int tracked)
{
if (jl_options.incremental && jl_generating_output())
if (!tracked && jl_options.incremental && jl_generating_output())
jl_error("Method deletion is not possible during Module precompile.");
jl_method_t *method = methodentry->func.method;
assert(!method->is_for_opaque_closure);
Expand Down Expand Up @@ -1917,10 +1917,22 @@ JL_DLLEXPORT void jl_method_table_disable(jl_methtable_t *mt, jl_method_t *metho
JL_LOCK(&mt->writelock);
// Narrow the world age on the method to make it uncallable
size_t world = jl_atomic_fetch_add(&jl_world_counter, 1);
jl_method_table_invalidate(mt, methodentry, world);
jl_method_table_invalidate(mt, methodentry, world, 0);
JL_UNLOCK(&mt->writelock);
}

JL_DLLEXPORT void jl_method_table_disable_incremental(jl_methtable_t *mt, jl_method_t *method)
{
jl_typemap_entry_t *methodentry = do_typemap_search(mt, method);
JL_LOCK(&mt->writelock);
// Narrow the world age on the method to make it uncallable
// size_t world = jl_atomic_load_acquire(&jl_world_counter);
size_t world = jl_atomic_fetch_add(&jl_world_counter, 1);
jl_method_table_invalidate(mt, methodentry, world, 1);
JL_UNLOCK(&mt->writelock);
}


static int jl_type_intersection2(jl_value_t *t1, jl_value_t *t2, jl_value_t **isect JL_REQUIRE_ROOTED_SLOT, jl_value_t **isect2 JL_REQUIRE_ROOTED_SLOT)
{
*isect2 = NULL;
Expand Down Expand Up @@ -2011,7 +2023,7 @@ JL_DLLEXPORT void jl_method_table_insert(jl_methtable_t *mt, jl_method_t *method
oldvalue = (jl_value_t*)replaced;
invalidated = 1;
method_overwrite(newentry, replaced->func.method);
jl_method_table_invalidate(mt, replaced, max_world);
jl_method_table_invalidate(mt, replaced, max_world, 0);
}
else {
jl_method_t *const *d;
Expand Down
2 changes: 2 additions & 0 deletions src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ static void jl_set_io_wait(int v)
extern jl_mutex_t jl_modules_mutex;
extern jl_mutex_t precomp_statement_out_lock;
extern jl_mutex_t newly_inferred_mutex;
extern jl_mutex_t newly_deleted_mutex;
extern jl_mutex_t global_roots_lock;

static void restore_fp_env(void)
Expand All @@ -726,6 +727,7 @@ static void init_global_mutexes(void) {
JL_MUTEX_INIT(&jl_modules_mutex, "jl_modules_mutex");
JL_MUTEX_INIT(&precomp_statement_out_lock, "precomp_statement_out_lock");
JL_MUTEX_INIT(&newly_inferred_mutex, "newly_inferred_mutex");
JL_MUTEX_INIT(&newly_deleted_mutex, "newly_deleted_mutex");
JL_MUTEX_INIT(&global_roots_lock, "global_roots_lock");
JL_MUTEX_INIT(&jl_codegen_lock, "jl_codegen_lock");
JL_MUTEX_INIT(&typecache_lock, "typecache_lock");
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@
XX(jl_method_instance_add_backedge) \
XX(jl_method_table_add_backedge) \
XX(jl_method_table_disable) \
XX(jl_method_table_disable_incremental) \
XX(jl_method_table_for) \
XX(jl_method_table_insert) \
XX(jl_methtable_lookup) \
Expand Down
7 changes: 7 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,7 @@ extern void JL_GC_PUSH4(void *, void *, void *, void *) JL_NOTSAFEPOINT;
extern void JL_GC_PUSH5(void *, void *, void *, void *, void *) JL_NOTSAFEPOINT;
extern void JL_GC_PUSH7(void *, void *, void *, void *, void *, void *, void *) JL_NOTSAFEPOINT;
extern void JL_GC_PUSH8(void *, void *, void *, void *, void *, void *, void *, void *) JL_NOTSAFEPOINT;
extern void JL_GC_PUSH9(void *, void *, void *, void *, void *, void *, void *, void *, void *) JL_NOTSAFEPOINT;
extern void _JL_GC_PUSHARGS(jl_value_t **, size_t) JL_NOTSAFEPOINT;
// This is necessary, because otherwise the analyzer considers this undefined
// behavior and terminates the exploration
Expand Down Expand Up @@ -974,6 +975,9 @@ extern void JL_GC_POP() JL_NOTSAFEPOINT;
#define JL_GC_PUSH8(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) \
void *__gc_stkf[] = {(void*)JL_GC_ENCODE_PUSH(8), jl_pgcstack, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8}; \
jl_pgcstack = (jl_gcframe_t*)__gc_stkf;
#define JL_GC_PUSH9(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9) \
void *__gc_stkf[] = {(void*)JL_GC_ENCODE_PUSH(9), jl_pgcstack, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9}; \
jl_pgcstack = (jl_gcframe_t*)__gc_stkf;


#define JL_GC_PUSHARGS(rts_var,n) \
Expand Down Expand Up @@ -1901,6 +1905,9 @@ JL_DLLEXPORT jl_value_t *jl_restore_incremental(const char *fname, jl_array_t *d

JL_DLLEXPORT void jl_set_newly_inferred(jl_value_t *newly_inferred);
JL_DLLEXPORT void jl_push_newly_inferred(jl_value_t *ci);
JL_DLLEXPORT void jl_method_table_disable_incremental(jl_methtable_t *mt, jl_method_t *m);
JL_DLLEXPORT void jl_set_newly_deleted(jl_value_t *newly_deleted);
JL_DLLEXPORT void jl_push_newly_deleted(jl_value_t *m);
JL_DLLEXPORT void jl_write_compiler_output(void);

// parsing
Expand Down
35 changes: 26 additions & 9 deletions src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -2393,7 +2393,7 @@ static void jl_prepare_serialization_data(jl_array_t *mod_array, jl_array_t *new
static void jl_save_system_image_to_stream(ios_t *f, jl_array_t *mod_array,
jl_array_t *worklist, jl_array_t *extext_methods,
jl_array_t *new_specializations, jl_array_t *method_roots_list,
jl_array_t *ext_targets, jl_array_t *edges) JL_GC_DISABLED
jl_array_t *ext_targets, jl_array_t *edges, jl_array_t *newly_deleted) JL_GC_DISABLED
{
htable_new(&field_replace, 0);
// strip metadata and IR when requested
Expand Down Expand Up @@ -2514,6 +2514,10 @@ static void jl_save_system_image_to_stream(ios_t *f, jl_array_t *mod_array,
jl_queue_for_serialization(&s, ext_targets);
jl_queue_for_serialization(&s, edges);
}
if (newly_deleted) {
jl_queue_for_serialization(&s, newly_deleted);
}

jl_serialize_reachable(&s);
// step 1.2: ensure all gvars are part of the sysimage too
record_gvars(&s, &gvars);
Expand Down Expand Up @@ -2647,6 +2651,7 @@ static void jl_save_system_image_to_stream(ios_t *f, jl_array_t *mod_array,
jl_write_value(&s, method_roots_list);
jl_write_value(&s, ext_targets);
jl_write_value(&s, edges);
jl_write_value(&s, newly_deleted);
}
write_uint32(f, jl_array_len(s.link_ids_gctags));
ios_write(f, (char*)jl_array_data(s.link_ids_gctags), jl_array_len(s.link_ids_gctags) * sizeof(uint32_t));
Expand Down Expand Up @@ -2725,11 +2730,11 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
}

jl_array_t *mod_array = NULL, *extext_methods = NULL, *new_specializations = NULL;
jl_array_t *method_roots_list = NULL, *ext_targets = NULL, *edges = NULL;
jl_array_t *method_roots_list = NULL, *ext_targets = NULL, *edges = NULL, *_newly_deleted = NULL;
int64_t checksumpos = 0;
int64_t checksumpos_ff = 0;
int64_t datastartpos = 0;
JL_GC_PUSH6(&mod_array, &extext_methods, &new_specializations, &method_roots_list, &ext_targets, &edges);
JL_GC_PUSH7(&mod_array, &extext_methods, &new_specializations, &method_roots_list, &ext_targets, &edges, &_newly_deleted);

if (worklist) {
mod_array = jl_get_loaded_modules(); // __toplevel__ modules loaded in this session (from Base.loaded_modules_array)
Expand Down Expand Up @@ -2776,7 +2781,10 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
}
if (_native_data != NULL)
native_functions = *_native_data;
jl_save_system_image_to_stream(ff, mod_array, worklist, extext_methods, new_specializations, method_roots_list, ext_targets, edges);
// Otherwise serialization will be confused.
if (newly_deleted)
_newly_deleted = jl_array_copy(newly_deleted);
jl_save_system_image_to_stream(ff, mod_array, worklist, extext_methods, new_specializations, method_roots_list, ext_targets, edges, _newly_deleted);
if (_native_data != NULL)
native_functions = NULL;
// make sure we don't run any Julia code concurrently before this point
Expand Down Expand Up @@ -2860,6 +2868,7 @@ static void jl_restore_system_image_from_stream_(ios_t *f, jl_image_t *image, jl
jl_array_t **extext_methods,
jl_array_t **new_specializations, jl_array_t **method_roots_list,
jl_array_t **ext_targets, jl_array_t **edges,
jl_array_t **newly_deleted,
char **base, arraylist_t *ccallable_list, pkgcachesizes *cachesizes) JL_GC_DISABLED
{
int en = jl_gc_enable(0);
Expand Down Expand Up @@ -2921,7 +2930,7 @@ static void jl_restore_system_image_from_stream_(ios_t *f, jl_image_t *image, jl
assert(!ios_eof(f));
s.s = f;
uintptr_t offset_restored = 0, offset_init_order = 0, offset_extext_methods = 0, offset_new_specializations = 0, offset_method_roots_list = 0;
uintptr_t offset_ext_targets = 0, offset_edges = 0;
uintptr_t offset_ext_targets = 0, offset_edges = 0, offset_newly_deleted = 0;
if (!s.incremental) {
size_t i;
for (i = 0; tags[i] != NULL; i++) {
Expand Down Expand Up @@ -2955,6 +2964,7 @@ static void jl_restore_system_image_from_stream_(ios_t *f, jl_image_t *image, jl
offset_method_roots_list = jl_read_offset(&s);
offset_ext_targets = jl_read_offset(&s);
offset_edges = jl_read_offset(&s);
offset_newly_deleted = jl_read_offset(&s);
}
s.buildid_depmods_idxs = depmod_to_imageidx(depmods);
size_t nlinks_gctags = read_uint32(f);
Expand Down Expand Up @@ -2988,6 +2998,7 @@ static void jl_restore_system_image_from_stream_(ios_t *f, jl_image_t *image, jl
*method_roots_list = (jl_array_t*)jl_delayed_reloc(&s, offset_method_roots_list);
*ext_targets = (jl_array_t*)jl_delayed_reloc(&s, offset_ext_targets);
*edges = (jl_array_t*)jl_delayed_reloc(&s, offset_edges);
*newly_deleted = (jl_array_t*)jl_delayed_reloc(&s, offset_newly_deleted);
if (!*new_specializations)
*new_specializations = jl_alloc_vec_any(0);
}
Expand Down Expand Up @@ -3175,6 +3186,11 @@ static void jl_restore_system_image_from_stream_(ios_t *f, jl_image_t *image, jl
assert(jl_is_datatype(obj));
jl_cache_type_((jl_datatype_t*)obj);
}

// Delete methods before inserting new ones.
if (newly_deleted)
jl_delete_methods(*newly_deleted);

// Perform fixups: things like updating world ages, inserting methods & specializations, etc.
size_t world = jl_atomic_load_acquire(&jl_world_counter);
for (size_t i = 0; i < s.uniquing_objs.len; i++) {
Expand Down Expand Up @@ -3401,11 +3417,11 @@ static jl_value_t *jl_restore_package_image_from_stream(void* pkgimage_handle, i
assert(datastartpos > 0 && datastartpos < dataendpos);
needs_permalloc = jl_options.permalloc_pkgimg || needs_permalloc;
jl_value_t *restored = NULL;
jl_array_t *init_order = NULL, *extext_methods = NULL, *new_specializations = NULL, *method_roots_list = NULL, *ext_targets = NULL, *edges = NULL;
jl_array_t *init_order = NULL, *extext_methods = NULL, *new_specializations = NULL, *method_roots_list = NULL, *ext_targets = NULL, *edges = NULL, *newly_deleted = NULL;
jl_svec_t *cachesizes_sv = NULL;
char *base;
arraylist_t ccallable_list;
JL_GC_PUSH8(&restored, &init_order, &extext_methods, &new_specializations, &method_roots_list, &ext_targets, &edges, &cachesizes_sv);
JL_GC_PUSH9(&restored, &init_order, &extext_methods, &new_specializations, &method_roots_list, &ext_targets, &edges, &newly_deleted, &cachesizes_sv);

{ // make a permanent in-memory copy of f (excluding the header)
ios_bufmode(f, bm_none);
Expand All @@ -3429,11 +3445,12 @@ static jl_value_t *jl_restore_package_image_from_stream(void* pkgimage_handle, i
ios_close(f);
ios_static_buffer(f, sysimg, len);
pkgcachesizes cachesizes;
jl_restore_system_image_from_stream_(f, image, depmods, checksum, (jl_array_t**)&restored, &init_order, &extext_methods, &new_specializations, &method_roots_list, &ext_targets, &edges, &base, &ccallable_list, &cachesizes);
jl_restore_system_image_from_stream_(f, image, depmods, checksum, (jl_array_t**)&restored, &init_order, &extext_methods, &new_specializations, &method_roots_list, &ext_targets, &edges, &newly_deleted, &base, &ccallable_list, &cachesizes);
JL_SIGATOMIC_END();

// Insert method extensions
jl_insert_methods(extext_methods);

// No special processing of `new_specializations` is required because recaching handled it
// Add roots to methods
jl_copy_roots(method_roots_list, jl_worklist_key((jl_array_t*)restored));
Expand Down Expand Up @@ -3469,7 +3486,7 @@ static jl_value_t *jl_restore_package_image_from_stream(void* pkgimage_handle, i
static void jl_restore_system_image_from_stream(ios_t *f, jl_image_t *image, uint32_t checksum)
{
JL_TIMING(LOAD_IMAGE, LOAD_Sysimg);
jl_restore_system_image_from_stream_(f, image, NULL, checksum | ((uint64_t)0xfdfcfbfa << 32), NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL);
jl_restore_system_image_from_stream_(f, image, NULL, checksum | ((uint64_t)0xfdfcfbfa << 32), NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL);
}

JL_DLLEXPORT jl_value_t *jl_restore_incremental_from_buf(void* pkgimage_handle, const char *buf, jl_image_t *image, size_t sz, jl_array_t *depmods, int completeinfo, const char *pkgname, bool needs_permalloc)
Expand Down
34 changes: 34 additions & 0 deletions src/staticdata_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,27 @@ JL_DLLEXPORT void jl_push_newly_inferred(jl_value_t* ci)
JL_UNLOCK(&newly_inferred_mutex);
}

static jl_array_t *newly_deleted JL_GLOBALLY_ROOTED /*FIXME*/;
// Mutex for newly_deleted
jl_mutex_t newly_deleted_mutex;

// Register array of newly-inferred MethodInstances
// This gets called as the first step of Base.include_package_for_output
JL_DLLEXPORT void jl_set_newly_deleted(jl_value_t* _newly_deleted)
{
assert(newly_deleted == NULL || jl_is_array(_newly_deleted));
newly_deleted = (jl_array_t*) _newly_deleted;
}

JL_DLLEXPORT void jl_push_newly_deleted(jl_value_t* m)
{
JL_LOCK(&newly_deleted_mutex);
size_t end = jl_array_len(newly_deleted);
jl_array_grow_end(newly_deleted, 1);
jl_arrayset(newly_deleted, m, end);
JL_UNLOCK(&newly_deleted_mutex);
}


// compute whether a type references something internal to worklist
// and thus could not have existed before deserialize
Expand Down Expand Up @@ -817,6 +838,19 @@ static void jl_insert_methods(jl_array_t *list)
}
}

static void jl_delete_methods(jl_array_t *list)
{
size_t i, l = jl_array_len(list);
for (i = 0; i < l; i++) {
jl_method_t *meth = (jl_method_t*)jl_array_ptr_ref(list, i);
assert(jl_is_method(meth));
assert(!meth->is_for_opaque_closure);
jl_methtable_t *mt = jl_method_get_table(meth);
assert((jl_value_t*)mt != jl_nothing);
jl_method_table_disable_incremental(mt, meth);
}
}

static void jl_copy_roots(jl_array_t *method_roots_list, uint64_t key)
{
size_t i, l = jl_array_len(method_roots_list);
Expand Down
10 changes: 2 additions & 8 deletions stdlib/Random/src/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,8 @@ export rand!, randn!,

## general definitions

module Stubs
function __init__()
# Remove the shim methods
if !Base.generating_output()
Base.Stubs.delete_stubs(Base.Stubs.Random)
end
end
end
# Remove the shim methods
Base.Stubs.delete_stubs(Base.Stubs.Random)

"""
AbstractRNG
Expand Down

0 comments on commit 31ccfb6

Please sign in to comment.