Skip to content

Commit

Permalink
[wasm] Improve jiterpreter trace entry point selection heuristic (#82604
Browse files Browse the repository at this point in the history
)

This PR adjusts the jiterpreter's heuristic that decides where it's best to put entry points:
* Adds a requirement that entry points be at least a certain distance apart, since in some cases we can end up with trace entry points right next to each other, which isn't very useful and adds overhead. (Backwards branch targets are exempted from this so loops will still JIT properly).
* If we fail to create a trace exactly located at a backwards branch target, continue trying at blocks afterward. This should help in the rare case where the body of a loop begins with an unsupported instruction.
* When considering how long a trace actually is, we treat conditional aborts (like calls and throws) separately from ignored and supported instructions, so they don't count towards the overall size of the trace. These instructions aren't actually doing any useful work and if executed the trace will exit, so it's better not to consider them when deciding whether a trace is worth compiling.
This PR also manually inlines trace entry logic.
  • Loading branch information
kg authored Mar 4, 2023
1 parent f6d564e commit 507acb6
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 44 deletions.
33 changes: 2 additions & 31 deletions src/mono/mono/mini/interp/interp.c
Original file line number Diff line number Diff line change
Expand Up @@ -3798,31 +3798,6 @@ max_d (double lhs, double rhs)
return fmax (lhs, rhs);
}

#ifdef HOST_BROWSER
MONO_ALWAYS_INLINE static ptrdiff_t
mono_interp_tier_enter_jiterpreter (
JiterpreterThunk thunk, InterpFrame *frame, unsigned char *locals, ThreadContext *context,
const guint16 *ip
)
{
// g_assert(thunk);
ptrdiff_t offset = thunk(frame, locals);
/*
* Verify that the offset returned by the thunk is not total garbage
* FIXME: These constants might actually be too small since a method
* could have massive amounts of IL - maybe we should disable the jiterpreter
* for methods that big
*/
// g_assertf((offset >= -0xFFFFF) && (offset <= 0xFFFFF), "thunk returned an obviously invalid offset: %i", offset);
#ifdef ENABLE_EXPERIMENT_TIERED
if (offset < 0) {
mini_tiered_inc (frame->imethod->method, &frame->imethod->tiered_counter, 0);
}
#endif
return offset;
}
#endif // HOST_BROWSER

/*
* If CLAUSE_ARGS is non-null, start executing from it.
* The ERROR argument is used to avoid declaring an error object for every interp frame, its not used
Expand Down Expand Up @@ -7780,9 +7755,7 @@ MINT_IN_CASE(MINT_BRTRUE_I8_SP) ZEROP_SP(gint64, !=); MINT_IN_BREAK;
// now execute the trace
// this isn't important for performance, but it makes it easier to use the
// jiterpreter early in automated tests where code only runs once
offset = mono_interp_tier_enter_jiterpreter (
prepare_result, frame, locals, context, ip
);
offset = prepare_result(frame, locals);
ip = (guint16*) (((guint8*)ip) + offset);
break;
}
Expand All @@ -7795,9 +7768,7 @@ MINT_IN_CASE(MINT_BRTRUE_I8_SP) ZEROP_SP(gint64, !=); MINT_IN_BREAK;

MINT_IN_CASE(MINT_TIER_ENTER_JITERPRETER) {
JiterpreterThunk thunk = (void*)READ32(ip + 1);
ptrdiff_t offset = mono_interp_tier_enter_jiterpreter (
thunk, frame, locals, context, ip
);
ptrdiff_t offset = thunk(frame, locals);
ip = (guint16*) (((guint8*)ip) + offset);
MINT_IN_BREAK;
}
Expand Down
40 changes: 35 additions & 5 deletions src/mono/mono/mini/interp/jiterpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ mono_jiterp_cas_i64 (volatile int64_t *addr, int64_t *newVal, int64_t *expected,
#define TRACE_IGNORE -1
#define TRACE_CONTINUE 0
#define TRACE_ABORT 1
#define TRACE_CONDITIONAL_ABORT 2

/*
* This function provides an approximate answer for "will this instruction cause the jiterpreter
Expand Down Expand Up @@ -703,7 +704,7 @@ jiterp_should_abort_trace (InterpInst *ins, gboolean *inside_branch_block)
// Detect backwards branches
if (ins->info.target_bb->il_offset <= ins->il_offset) {
if (*inside_branch_block)
return TRACE_CONTINUE;
return TRACE_CONDITIONAL_ABORT;
else
return mono_opt_jiterpreter_backward_branches_enabled ? TRACE_CONTINUE : TRACE_ABORT;
}
Expand All @@ -714,7 +715,7 @@ jiterp_should_abort_trace (InterpInst *ins, gboolean *inside_branch_block)
case MINT_MONO_RETHROW:
case MINT_THROW:
if (*inside_branch_block)
return TRACE_CONTINUE;
return TRACE_CONDITIONAL_ABORT;

return TRACE_ABORT;

Expand Down Expand Up @@ -755,13 +756,13 @@ jiterp_should_abort_trace (InterpInst *ins, gboolean *inside_branch_block)
(opcode <= MINT_CALLI_NAT_FAST)
// (opcode <= MINT_JIT_CALL2)
)
return *inside_branch_block ? TRACE_CONTINUE : TRACE_ABORT;
return *inside_branch_block ? TRACE_CONDITIONAL_ABORT : TRACE_ABORT;
else if (
// returns
(opcode >= MINT_RET) &&
(opcode <= MINT_RET_U2)
)
return *inside_branch_block ? TRACE_CONTINUE : TRACE_ABORT;
return *inside_branch_block ? TRACE_CONDITIONAL_ABORT : TRACE_ABORT;
else if (
(opcode >= MINT_LDC_I4_M1) &&
(opcode <= MINT_LDC_R8)
Expand Down Expand Up @@ -834,6 +835,10 @@ should_generate_trace_here (InterpBasicBlock *bb) {
case TRACE_ABORT:
jiterpreter_abort_counts[ins->opcode]++;
return current_trace_length >= mono_opt_jiterpreter_minimum_trace_length;
case TRACE_CONDITIONAL_ABORT:
// FIXME: Stop traces that contain these early on, as long as we are relatively certain
// that these instructions will be hit (i.e. they are not unlikely branches)
break;
case TRACE_IGNORE:
break;
default:
Expand Down Expand Up @@ -925,6 +930,9 @@ jiterp_insert_entry_points (void *_imethod, void *_td)
if (!mono_opt_jiterpreter_traces_enabled)
return;

// Start with a high instruction counter so the distance check will pass
int instruction_count = mono_opt_jiterpreter_minimum_distance_between_traces;

for (InterpBasicBlock *bb = td->entry_bb; bb != NULL; bb = bb->next_bb) {
// Enter trace at top of functions
gboolean is_backwards_branch = FALSE,
Expand All @@ -941,7 +949,16 @@ jiterp_insert_entry_points (void *_imethod, void *_td)
// multiple times and waste some work. At present this is unavoidable because
// control flow means we can end up with two traces covering different subsets
// of the same method in order to handle loops and resuming
gboolean should_generate = enabled && should_generate_trace_here(bb);
gboolean should_generate = enabled &&
// Only insert a trace if the heuristic says this location will likely produce a long
// enough one to be worth it
should_generate_trace_here(bb) &&
// And don't insert another trace if we inserted one too recently, unless this
// is a backwards branch target
(
(instruction_count >= mono_opt_jiterpreter_minimum_distance_between_traces) ||
is_backwards_branch
);

if (mono_opt_jiterpreter_call_resume_enabled && bb->contains_call_instruction)
enter_at_next = TRUE;
Expand All @@ -957,12 +974,25 @@ jiterp_insert_entry_points (void *_imethod, void *_td)
InterpInst *ins = mono_jiterp_insert_ins (td, NULL, MINT_TIER_PREPARE_JITERPRETER);
memcpy(ins->data, &trace_index, sizeof (trace_index));

// Clear the instruction counter
instruction_count = 0;

// Note that we only clear enter_at_next here, after generating a trace.
// This means that the flag will stay set intentionally if we keep failing
// to generate traces, perhaps due to a string of small basic blocks
// or multiple call instructions.
enter_at_next = bb->contains_call_instruction;
} else if (is_backwards_branch && enabled && !should_generate) {
// We failed to start a trace at a backwards branch target, but that might just mean
// that the loop body starts with one or two unsupported opcodes, so it may be
// worthwhile to try again later
enter_at_next = TRUE;
}

// Increase the instruction counter. If we inserted an entry point at the top of this bb,
// the new instruction counter will be the number of instructions in the block, so if
// it's big enough we'll be able to insert another entry point right away.
instruction_count += bb->in_count;
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/mono/mono/utils/options-def.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ DEFINE_BOOL(jiterpreter_backward_branches_enabled, "jiterpreter-backward-branche
DEFINE_BOOL(jiterpreter_direct_jit_call, "jiterpreter-direct-jit-calls", TRUE, "Bypass gsharedvt wrappers when compiling JIT call wrappers")
// any trace that doesn't have at least this many meaningful (non-nop) opcodes in it will be rejected
DEFINE_INT(jiterpreter_minimum_trace_length, "jiterpreter-minimum-trace-length", 10, "Reject traces shorter than this number of meaningful opcodes")
// ensure that we don't create trace entry points too close together
DEFINE_INT(jiterpreter_minimum_distance_between_traces, "jiterpreter-minimum-distance-between-traces", 6, "Don't insert entry points closer together than this")
// once a trace entry point is inserted, we only actually JIT code for it once it's been hit this many times
DEFINE_INT(jiterpreter_minimum_trace_hit_count, "jiterpreter-minimum-trace-hit-count", 5000, "JIT trace entry points once they are hit this many times")
// After a do_jit_call call site is hit this many times, we will queue it to be jitted
Expand Down
21 changes: 13 additions & 8 deletions src/mono/wasm/runtime/jiterpreter-trace-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ export function generate_wasm_body (
// because a backward branch might target a point in the middle of the trace
(isFirstInstruction && backwardBranchTable),
needsFallthroughEipUpdate = needsEipCheck && !isFirstInstruction;
let isDeadOpcode = false,
let isLowValueOpcode = false,
skipDregInvalidation = false;

// We record the offset of each backward branch we encounter, so that later branch
Expand Down Expand Up @@ -317,7 +317,7 @@ export function generate_wasm_body (
}

case MintOpcode.MINT_TIER_ENTER_JITERPRETER:
isDeadOpcode = true;
isLowValueOpcode = true;
// If we hit an enter opcode and we're not currently in a branch block
// or the enter opcode is the first opcode in a branch block, this likely
// indicates that we've reached a loop body that was already jitted before
Expand Down Expand Up @@ -363,7 +363,7 @@ export function generate_wasm_body (
case MintOpcode.MINT_SDB_BREAKPOINT:
case MintOpcode.MINT_SDB_INTR_LOC:
case MintOpcode.MINT_SDB_SEQ_POINT:
isDeadOpcode = true;
isLowValueOpcode = true;
break;

case MintOpcode.MINT_SAFEPOINT:
Expand Down Expand Up @@ -792,6 +792,7 @@ export function generate_wasm_body (
// to abort the entire trace if we have branch support enabled - the call
// might be infrequently hit and as a result it's worth it to keep going.
append_bailout(builder, ip, BailoutReason.Call);
isLowValueOpcode = true;
} else {
// We're in a block that executes unconditionally, and no branches have been
// executed before now so the trace will always need to bail out into the
Expand All @@ -815,6 +816,7 @@ export function generate_wasm_body (
? BailoutReason.CallDelegate
: BailoutReason.Call
);
isLowValueOpcode = true;
} else {
ip = abort;
}
Expand All @@ -828,6 +830,7 @@ export function generate_wasm_body (
// Otherwise, it may be in a branch that is unlikely to execute
if (builder.branchTargets.size > 0) {
append_bailout(builder, ip, BailoutReason.Throw);
isLowValueOpcode = true;
} else {
ip = abort;
}
Expand Down Expand Up @@ -1031,9 +1034,10 @@ export function generate_wasm_body (
(opcode <= MintOpcode.MINT_RET_I8_IMM)
)
) {
if ((builder.branchTargets.size > 0) || trapTraceErrors || builder.options.countBailouts)
if ((builder.branchTargets.size > 0) || trapTraceErrors || builder.options.countBailouts) {
append_bailout(builder, ip, BailoutReason.Return);
else
isLowValueOpcode = true;
} else
ip = abort;
} else if (
(opcode >= MintOpcode.MINT_LDC_I4_M1) &&
Expand Down Expand Up @@ -1102,9 +1106,10 @@ export function generate_wasm_body (
// types can be handled by emit_branch or emit_relop_branch,
// to only perform a conditional bailout
// complex safepoint branches, just generate a bailout
if (builder.branchTargets.size > 0)
if (builder.branchTargets.size > 0) {
append_bailout(builder, ip, BailoutReason.ComplexBranch);
else
isLowValueOpcode = true;
} else
ip = abort;
} else {
ip = abort;
Expand Down Expand Up @@ -1147,7 +1152,7 @@ export function generate_wasm_body (
builder.traceBuf.push(stmtText);
}

if (!isDeadOpcode)
if (!isLowValueOpcode)
result++;

ip += <any>(info[1] * 2);
Expand Down

0 comments on commit 507acb6

Please sign in to comment.