Skip to content

Commit

Permalink
Allow for :foreigncall to transition to GC safe automatically (#49933)
Browse files Browse the repository at this point in the history
This has been bouncing around as a idea for a while.
One of the challenges around time-to-safepoint has been Julia code
that is calling libraries.

Since foreign code will not include safepoints we see increased latency
when one thread is running a foreign-call and another wants to trigger
GC.

The open design question here is:
- Do we expose this as an option the user must "opt-in", e.g. by using a
  keyword arg to `@ccall` or a specific calling-convetion.
- Or do we turn this on for all ccall, except for Julia runtime calls.

There is relativly little code outside the Julia runtime that needs to
be "GC unsafe",
exception are programs that directly use the Julia C-API. Incidentially
`jl_adopt_thread`
and `@cfunction`/`@ccallable` do the right thing and transition to "GC
unsafe", regardless
of what state the thread currently is in.

I still need to figure out how to reliably detect Julia runtime calls,
but I think we can
switch all other calls to "GC safe". We should also consider
optimizations that mark large
regions of code without Julia runtime interactions as "GC safe" in
particular numeric
for-loops.

Closes #57057

---------

Co-authored-by: Gabriel Baraldi <[email protected]>
(cherry picked from commit 85458a0)
  • Loading branch information
vchuravy authored and KristofferC committed Feb 15, 2025
1 parent 423cb56 commit 5da257d
Show file tree
Hide file tree
Showing 16 changed files with 160 additions and 44 deletions.
2 changes: 1 addition & 1 deletion Compiler/src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3409,7 +3409,7 @@ function abstract_eval_foreigncall(interp::AbstractInterpreter, e::Expr, sstate:
abstract_eval_value(interp, x, sstate, sv)
end
cconv = e.args[5]
if isa(cconv, QuoteNode) && (v = cconv.value; isa(v, Tuple{Symbol, UInt16}))
if isa(cconv, QuoteNode) && (v = cconv.value; isa(v, Tuple{Symbol, UInt16, Bool}))
override = decode_effects_override(v[2])
effects = override_effects(effects, override)
end
Expand Down
2 changes: 1 addition & 1 deletion Compiler/src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const VALID_EXPR_HEADS = IdDict{Symbol,UnitRange{Int}}(
:meta => 0:typemax(Int),
:global => 1:1,
:globaldecl => 1:2,
:foreigncall => 5:typemax(Int), # name, RT, AT, nreq, (cconv, effects), args..., roots...
:foreigncall => 5:typemax(Int), # name, RT, AT, nreq, (cconv, effects, gc_safe), args..., roots...
:cfunction => 5:5,
:isdefined => 1:2,
:code_coverage_effect => 0:0,
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ New language features
* Support for Unicode 16 ([#56925]).
* `Threads.@spawn` now takes a `:samepool` argument to specify the same threadpool as the caller.
`Threads.@spawn :samepool foo()` which is shorthand for `Threads.@spawn Threads.threadpool() foo()` ([#57109]).
* The `@ccall` macro can now take a `gc_safe` argument, that if set to true allows the runtime to run garbage collection concurrently to the `ccall`

Language changes
----------------
Expand Down
50 changes: 43 additions & 7 deletions base/c.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,31 @@ The above input outputs this:
(:printf, :Cvoid, [:Cstring, :Cuint], ["%d", :value])
"""
function ccall_macro_parse(expr::Expr)
function ccall_macro_parse(exprs)
gc_safe = false
expr = nothing
if exprs isa Expr
expr = exprs
elseif length(exprs) == 1
expr = exprs[1]
elseif length(exprs) == 2
gc_expr = exprs[1]
expr = exprs[2]
if gc_expr.head == :(=) && gc_expr.args[1] == :gc_safe
if gc_expr.args[2] == true
gc_safe = true
elseif gc_expr.args[2] == false
gc_safe = false
else
throw(ArgumentError("gc_safe must be true or false"))
end
else
throw(ArgumentError("@ccall option must be `gc_safe=true` or `gc_safe=false`"))
end
else
throw(ArgumentError("@ccall needs a function signature with a return type"))
end

# setup and check for errors
if !isexpr(expr, :(::))
throw(ArgumentError("@ccall needs a function signature with a return type"))
Expand Down Expand Up @@ -328,12 +352,11 @@ function ccall_macro_parse(expr::Expr)
pusharg!(a)
end
end

return func, rettype, types, args, nreq
return func, rettype, types, args, gc_safe, nreq
end


function ccall_macro_lower(convention, func, rettype, types, args, nreq)
function ccall_macro_lower(convention, func, rettype, types, args, gc_safe, nreq)
statements = []

# if interpolation was used, ensure the value is a function pointer at runtime.
Expand All @@ -351,9 +374,15 @@ function ccall_macro_lower(convention, func, rettype, types, args, nreq)
else
func = esc(func)
end
cconv = nothing
if convention isa Tuple
cconv = Expr(:cconv, (convention..., gc_safe), nreq)
else
cconv = Expr(:cconv, (convention, UInt16(0), gc_safe), nreq)
end

return Expr(:block, statements...,
Expr(:call, :ccall, func, Expr(:cconv, convention, nreq), esc(rettype),
Expr(:call, :ccall, func, cconv, esc(rettype),
Expr(:tuple, map(esc, types)...), map(esc, args)...))
end

Expand Down Expand Up @@ -404,9 +433,16 @@ Example using an external library:
The string literal could also be used directly before the function
name, if desired `"libglib-2.0".g_uri_escape_string(...`
It's possible to declare the ccall as `gc_safe` by using the `gc_safe = true` option:
@ccall gc_safe=true strlen(s::Cstring)::Csize_t
This allows the garbage collector to run concurrently with the ccall, which can be useful whenever
the `ccall` may block outside of julia.
WARNING: This option should be used with caution, as it can lead to undefined behavior if the ccall
calls back into the julia runtime. (`@cfunction`/`@ccallables` are safe however)
"""
macro ccall(expr)
return ccall_macro_lower(:ccall, ccall_macro_parse(expr)...)
macro ccall(exprs...)
return ccall_macro_lower((:ccall), ccall_macro_parse(exprs)...)
end

macro ccall_effects(effects::UInt16, expr)
Expand Down
2 changes: 1 addition & 1 deletion base/meta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ function _partially_inline!(@nospecialize(x), slot_replacements::Vector{Any},
elseif i == 4
@assert isa(x.args[4], Int)
elseif i == 5
@assert isa((x.args[5]::QuoteNode).value, Union{Symbol, Tuple{Symbol, UInt8}})
@assert isa((x.args[5]::QuoteNode).value, Union{Symbol, Tuple{Symbol, UInt16, Bool}})
else
x.args[i] = _partially_inline!(x.args[i], slot_replacements,
type_signature, static_param_values,
Expand Down
2 changes: 1 addition & 1 deletion base/strings/string.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ end
# but the macro is not available at this time in bootstrap, so we write it manually.
const _string_n_override = 0x04ee
@eval _string_n(n::Integer) = $(Expr(:foreigncall, QuoteNode(:jl_alloc_string), Ref{String},
:(Core.svec(Csize_t)), 1, QuoteNode((:ccall, _string_n_override)), :(convert(Csize_t, n))))
:(Core.svec(Csize_t)), 1, QuoteNode((:ccall, _string_n_override, false)), :(convert(Csize_t, n))))

"""
String(s::AbstractString)
Expand Down
4 changes: 2 additions & 2 deletions doc/src/devdocs/ast.md
Original file line number Diff line number Diff line change
Expand Up @@ -498,9 +498,9 @@ These symbols appear in the `head` field of [`Expr`](@ref)s in lowered form.

The number of required arguments for a varargs function definition.

* `args[5]::QuoteNode{<:Union{Symbol,Tuple{Symbol,UInt16}}`: calling convention
* `args[5]::QuoteNode{<:Union{Symbol,Tuple{Symbol,UInt16}, Tuple{Symbol,UInt16,Bool}}`: calling convention

The calling convention for the call, optionally with effects.
The calling convention for the call, optionally with effects, and `gc_safe` (safe to execute concurrently to GC.).

* `args[6:5+length(args[3])]` : arguments

Expand Down
21 changes: 15 additions & 6 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,7 @@ class function_sig_t {
AttributeList attributes; // vector of function call site attributes
Type *lrt; // input parameter of the llvm return type (from julia_struct_to_llvm)
bool retboxed; // input parameter indicating whether lrt is jl_value_t*
bool gc_safe; // input parameter indicating whether the call is safe to execute concurrently to GC
Type *prt; // out parameter of the llvm return type for the function signature
int sret; // out parameter for indicating whether return value has been moved to the first argument position
std::string err_msg;
Expand All @@ -1146,8 +1147,8 @@ class function_sig_t {
size_t nreqargs; // number of required arguments in ccall function definition
jl_codegen_params_t *ctx;

function_sig_t(const char *fname, Type *lrt, jl_value_t *rt, bool retboxed, jl_svec_t *at, jl_unionall_t *unionall_env, size_t nreqargs, CallingConv::ID cc, bool llvmcall, jl_codegen_params_t *ctx)
: lrt(lrt), retboxed(retboxed),
function_sig_t(const char *fname, Type *lrt, jl_value_t *rt, bool retboxed, bool gc_safe, jl_svec_t *at, jl_unionall_t *unionall_env, size_t nreqargs, CallingConv::ID cc, bool llvmcall, jl_codegen_params_t *ctx)
: lrt(lrt), retboxed(retboxed), gc_safe(gc_safe),
prt(NULL), sret(0), cc(cc), llvmcall(llvmcall),
at(at), rt(rt), unionall_env(unionall_env),
nccallargs(jl_svec_len(at)), nreqargs(nreqargs),
Expand Down Expand Up @@ -1295,6 +1296,7 @@ std::string generate_func_sig(const char *fname)
RetAttrs = RetAttrs.addAttribute(LLVMCtx, Attribute::NonNull);
if (rt == jl_bottom_type)
FnAttrs = FnAttrs.addAttribute(LLVMCtx, Attribute::NoReturn);

assert(attributes.isEmpty());
attributes = AttributeList::get(LLVMCtx, FnAttrs, RetAttrs, paramattrs);
return "";
Expand Down Expand Up @@ -1412,7 +1414,7 @@ static const std::string verify_ccall_sig(jl_value_t *&rt, jl_value_t *at,

const int fc_args_start = 6;

// Expr(:foreigncall, pointer, rettype, (argtypes...), nreq, [cconv | (cconv, effects)], args..., roots...)
// Expr(:foreigncall, pointer, rettype, (argtypes...), nreq, gc_safe, [cconv | (cconv, effects)], args..., roots...)
static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
{
JL_NARGSV(ccall, 5);
Expand All @@ -1424,11 +1426,13 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
assert(jl_is_quotenode(args[5]));
jl_value_t *jlcc = jl_quotenode_value(args[5]);
jl_sym_t *cc_sym = NULL;
bool gc_safe = false;
if (jl_is_symbol(jlcc)) {
cc_sym = (jl_sym_t*)jlcc;
}
else if (jl_is_tuple(jlcc)) {
cc_sym = (jl_sym_t*)jl_get_nth_field_noalloc(jlcc, 0);
gc_safe = jl_unbox_bool(jl_get_nth_field_checked(jlcc, 2));
}
assert(jl_is_symbol(cc_sym));
native_sym_arg_t symarg = {};
Expand Down Expand Up @@ -1547,7 +1551,7 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
}
if (rt != args[2] && rt != (jl_value_t*)jl_any_type)
jl_temporary_root(ctx, rt);
function_sig_t sig("ccall", lrt, rt, retboxed,
function_sig_t sig("ccall", lrt, rt, retboxed, gc_safe,
(jl_svec_t*)at, unionall, nreqargs,
cc, llvmcall, &ctx.emission_context);
for (size_t i = 0; i < nccallargs; i++) {
Expand Down Expand Up @@ -2158,11 +2162,16 @@ jl_cgval_t function_sig_t::emit_a_ccall(
}
}

OperandBundleDef OpBundle("jl_roots", gc_uses);
// Potentially we could drop `jl_roots(gc_uses)` in the presence of `gc-transition(gc_uses)`
SmallVector<OperandBundleDef, 2> bundles;
if (!gc_uses.empty())
bundles.push_back(OperandBundleDef("jl_roots", gc_uses));
if (gc_safe)
bundles.push_back(OperandBundleDef("gc-transition", ArrayRef<Value*> {}));
// the actual call
CallInst *ret = ctx.builder.CreateCall(functype, llvmf,
argvals,
ArrayRef<OperandBundleDef>(&OpBundle, gc_uses.empty() ? 0 : 1));
bundles);
((CallInst*)ret)->setAttributes(attributes);

if (cc != CallingConv::C)
Expand Down
4 changes: 2 additions & 2 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8041,7 +8041,7 @@ static jl_cgval_t emit_cfunction(jl_codectx_t &ctx, jl_value_t *output_type, con
if (rt != declrt && rt != (jl_value_t*)jl_any_type)
jl_temporary_root(ctx, rt);

function_sig_t sig("cfunction", lrt, rt, retboxed, argt, unionall_env, false, CallingConv::C, false, &ctx.emission_context);
function_sig_t sig("cfunction", lrt, rt, retboxed, false, argt, unionall_env, false, CallingConv::C, false, &ctx.emission_context);
assert(sig.fargt.size() + sig.sret == sig.fargt_sig.size());
if (!sig.err_msg.empty()) {
emit_error(ctx, sig.err_msg);
Expand Down Expand Up @@ -8181,7 +8181,7 @@ const char *jl_generate_ccallable(Module *llvmmod, void *sysimg_handle, jl_value
}
jl_value_t *err;
{ // scope block for sig
function_sig_t sig("cfunction", lcrt, crt, toboxed,
function_sig_t sig("cfunction", lcrt, crt, toboxed, false,
argtypes, NULL, false, CallingConv::C, false, &params);
if (sig.err_msg.empty()) {
if (sysimg_handle) {
Expand Down
14 changes: 5 additions & 9 deletions src/llvm-codegen-shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,21 +244,17 @@ static inline llvm::Value *emit_gc_state_set(llvm::IRBuilder<> &builder, llvm::T
unsigned offset = offsetof(jl_tls_states_t, gc_state);
Value *gc_state = builder.CreateConstInBoundsGEP1_32(T_int8, ptls, offset, "gc_state");
if (old_state == nullptr) {
old_state = builder.CreateLoad(T_int8, gc_state);
old_state = builder.CreateLoad(T_int8, gc_state, "old_state");
cast<LoadInst>(old_state)->setOrdering(AtomicOrdering::Monotonic);
}
builder.CreateAlignedStore(state, gc_state, Align(sizeof(void*)))->setOrdering(AtomicOrdering::Release);
if (auto *C = dyn_cast<ConstantInt>(old_state))
if (C->isZero())
return old_state;
if (auto *C = dyn_cast<ConstantInt>(state))
if (!C->isZero())
return old_state;
if (auto *C2 = dyn_cast<ConstantInt>(state))
if (C->getZExtValue() == C2->getZExtValue())
return old_state;
BasicBlock *passBB = BasicBlock::Create(builder.getContext(), "safepoint", builder.GetInsertBlock()->getParent());
BasicBlock *exitBB = BasicBlock::Create(builder.getContext(), "after_safepoint", builder.GetInsertBlock()->getParent());
Constant *zero8 = ConstantInt::get(T_int8, 0);
builder.CreateCondBr(builder.CreateOr(builder.CreateICmpEQ(old_state, zero8), // if (!old_state || !state)
builder.CreateICmpEQ(state, zero8)),
builder.CreateCondBr(builder.CreateICmpEQ(old_state, state, "is_new_state"), // Safepoint whenever we change the GC state
passBB, exitBB);
builder.SetInsertPoint(passBB);
MDNode *tbaa = get_tbaa_const(builder.getContext());
Expand Down
49 changes: 40 additions & 9 deletions src/llvm-late-gc-lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2181,16 +2181,47 @@ bool LateLowerGCFrame::CleanupIR(Function &F, State *S, bool *CFGModified) {
NewCall->copyMetadata(*CI);
CI->replaceAllUsesWith(NewCall);
UpdatePtrNumbering(CI, NewCall, S);
} else if (CI->arg_size() == CI->getNumOperands()) {
/* No operand bundle to lower */
++it;
continue;
} else {
CallInst *NewCall = CallInst::Create(CI, None, CI);
NewCall->takeName(CI);
NewCall->copyMetadata(*CI);
CI->replaceAllUsesWith(NewCall);
UpdatePtrNumbering(CI, NewCall, S);
SmallVector<OperandBundleDef,2> bundles;
CI->getOperandBundlesAsDefs(bundles);
bool gc_transition = false;
for (auto &bundle: bundles)
if (bundle.getTag() == "gc-transition")
gc_transition = true;

// In theory LLVM wants us to lower this using RewriteStatepointsForGC
if (gc_transition) {
// Insert the operations to switch to gc_safe if necessary.
IRBuilder<> builder(CI);
Value *pgcstack = getOrAddPGCstack(F);
assert(pgcstack);
// We dont use emit_state_set here because safepoints are unconditional for any code that reaches this
// We are basically guaranteed to go from gc_unsafe to gc_safe and back, and both transitions need a safepoint
// We also can't add any BBs here, so just avoiding the branches is good
Value *ptls = get_current_ptls_from_task(builder, get_current_task_from_pgcstack(builder, pgcstack), tbaa_gcframe);
unsigned offset = offsetof(jl_tls_states_t, gc_state);
Value *gc_state = builder.CreateConstInBoundsGEP1_32(Type::getInt8Ty(builder.getContext()), ptls, offset, "gc_state");
LoadInst *last_gc_state = builder.CreateAlignedLoad(Type::getInt8Ty(builder.getContext()), gc_state, Align(sizeof(void*)));
last_gc_state->setOrdering(AtomicOrdering::Monotonic);
builder.CreateAlignedStore(builder.getInt8(JL_GC_STATE_SAFE), gc_state, Align(sizeof(void*)))->setOrdering(AtomicOrdering::Release);
MDNode *tbaa = get_tbaa_const(builder.getContext());
emit_gc_safepoint(builder, T_size, ptls, tbaa, false);
builder.SetInsertPoint(CI->getNextNode());
builder.CreateAlignedStore(last_gc_state, gc_state, Align(sizeof(void*)))->setOrdering(AtomicOrdering::Release);
emit_gc_safepoint(builder, T_size, ptls, tbaa, false);
}
if (CI->arg_size() == CI->getNumOperands()) {
/* No operand bundle to lower */
++it;
continue;
} else {
// remove operand bundle
CallInst *NewCall = CallInst::Create(CI, None, CI);
NewCall->takeName(CI);
NewCall->copyMetadata(*CI);
CI->replaceAllUsesWith(NewCall);
UpdatePtrNumbering(CI, NewCall, S);
}
}
if (!CI->use_empty()) {
CI->replaceAllUsesWith(UndefValue::get(CI->getType()));
Expand Down
21 changes: 21 additions & 0 deletions src/llvm-pass-helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,27 @@ llvm::CallInst *JuliaPassContext::getPGCstack(llvm::Function &F) const
return nullptr;
}

llvm::CallInst *JuliaPassContext::getOrAddPGCstack(llvm::Function &F)
{
if (pgcstack_getter || adoptthread_func)
for (auto &I : F.getEntryBlock()) {
if (CallInst *callInst = dyn_cast<CallInst>(&I)) {
Value *callee = callInst->getCalledOperand();
if ((pgcstack_getter && callee == pgcstack_getter) ||
(adoptthread_func && callee == adoptthread_func)) {
return callInst;
}
}
}
IRBuilder<> builder(&F.getEntryBlock().front());
if (pgcstack_getter)
return builder.CreateCall(pgcstack_getter);
auto FT = FunctionType::get(PointerType::get(F.getContext(), 0), false);
auto F2 = Function::Create(FT, Function::ExternalLinkage, "julia.get_pgcstack", F.getParent());
pgcstack_getter = F2;
return builder.CreateCall( F2);
}

llvm::Function *JuliaPassContext::getOrNull(
const jl_intrinsics::IntrinsicDescription &desc) const
{
Expand Down
5 changes: 4 additions & 1 deletion src/llvm-pass-helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ struct JuliaPassContext {
// point of the given function, if there exists such a call.
// Otherwise, `nullptr` is returned.
llvm::CallInst *getPGCstack(llvm::Function &F) const;

// Gets a call to the `julia.get_pgcstack' intrinsic in the entry
// point of the given function, if there exists such a call.
// Otherwise, creates a new call to the intrinsic
llvm::CallInst *getOrAddPGCstack(llvm::Function &F);
// Gets the intrinsic or well-known function that conforms to
// the given description if it exists in the module. If not,
// `nullptr` is returned.
Expand Down
7 changes: 6 additions & 1 deletion src/llvm-ptls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,13 @@ void LowerPTLS::fix_pgcstack_use(CallInst *pgcstack, Function *pgcstack_getter,
last_gc_state->addIncoming(prior, fastTerm->getParent());
for (auto &BB : *pgcstack->getParent()->getParent()) {
if (isa<ReturnInst>(BB.getTerminator())) {
// Don't use emit_gc_safe_leave here, as that introduces a new BB while iterating BBs
builder.SetInsertPoint(BB.getTerminator());
emit_gc_unsafe_leave(builder, T_size, get_current_ptls_from_task(builder, get_current_task_from_pgcstack(builder, phi), tbaa), last_gc_state, true);
Value *ptls = get_current_ptls_from_task(builder, get_current_task_from_pgcstack(builder, phi), tbaa_gcframe);
unsigned offset = offsetof(jl_tls_states_t, gc_state);
Value *gc_state = builder.CreateConstInBoundsGEP1_32(Type::getInt8Ty(builder.getContext()), ptls, offset, "gc_state");
builder.CreateAlignedStore(last_gc_state, gc_state, Align(sizeof(void*)))->setOrdering(AtomicOrdering::Release);
emit_gc_safepoint(builder, T_size, ptls, tbaa, true);
}
}
}
Expand Down
Loading

0 comments on commit 5da257d

Please sign in to comment.