diff --git a/src/trio/_core/_ki.py b/src/trio/_core/_ki.py index bb881120a..33e363f72 100644 --- a/src/trio/_core/_ki.py +++ b/src/trio/_core/_ki.py @@ -9,6 +9,7 @@ import attrs from .._util import is_main_thread +from ._run_context import GLOBAL_RUN_CONTEXT if TYPE_CHECKING: import types @@ -170,6 +171,16 @@ def legacy_isasyncgenfunction( # NB: according to the signal.signal docs, 'frame' can be None on entry to # this function: def ki_protection_enabled(frame: types.FrameType | None) -> bool: + try: + task = GLOBAL_RUN_CONTEXT.task + except AttributeError: + task_ki_protected = False + task_frame = None + else: + task_ki_protected = task._ki_protected + task_frame = task.coro.cr_frame + del task + while frame is not None: try: v = _CODE_KI_PROTECTION_STATUS_WMAP[frame.f_code] @@ -179,6 +190,8 @@ def ki_protection_enabled(frame: types.FrameType | None) -> bool: return bool(v) if frame.f_code.co_name == "__del__": return True + if frame is task_frame: + return task_ki_protected frame = frame.f_back return True diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index 0c2c3477c..83415e665 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -7,7 +7,6 @@ import random import select import sys -import threading import warnings from collections import deque from contextlib import AbstractAsyncContextManager, contextmanager, suppress @@ -39,8 +38,9 @@ from ._entry_queue import EntryQueue, TrioToken from ._exceptions import Cancelled, RunFinishedError, TrioInternalError from ._instrumentation import Instruments -from ._ki import KIManager, disable_ki_protection, enable_ki_protection +from ._ki import KIManager, enable_ki_protection from ._parking_lot import GLOBAL_PARKING_LOT_BREAKER +from ._run_context import GLOBAL_RUN_CONTEXT as GLOBAL_RUN_CONTEXT from ._thread_cache import start_thread_soon from ._traps import ( Abort, @@ -1559,14 +1559,6 @@ def raise_cancel() -> NoReturn: ################################################################ -class RunContext(threading.local): - runner: Runner - task: Task - - -GLOBAL_RUN_CONTEXT: Final = RunContext() - - @attrs.frozen class RunStatistics: """An object containing run-loop-level debugging information. @@ -1670,22 +1662,6 @@ def in_main_thread() -> None: start_thread_soon(get_events, deliver) -@enable_ki_protection -def run_with_ki_protection_enabled(f: Callable[[T], RetT], v: T) -> RetT: - try: - return f(v) - finally: - del v # for the case where f is coro.throw() and v is a (Base)Exception - - -@disable_ki_protection -def run_with_ki_protection_disabled(f: Callable[[T], RetT], v: T) -> RetT: - try: - return f(v) - finally: - del v # for the case where f is coro.throw() and v is a (Base)Exception - - @attrs.define(eq=False) class Runner: clock: Clock @@ -2730,11 +2706,6 @@ def unrolled_run( next_send_fn = task._next_send_fn next_send = task._next_send - run_with = ( - run_with_ki_protection_enabled - if task._ki_protected - else run_with_ki_protection_disabled - ) task._next_send_fn = task._next_send = None final_outcome: Outcome[Any] | None = None try: @@ -2747,17 +2718,16 @@ def unrolled_run( # https://github.com/python/cpython/issues/108668 # So now we send in the Outcome object and unwrap it on the # other side. - msg = task.context.run(run_with, next_send_fn, next_send) + msg = task.context.run(next_send_fn, next_send) except StopIteration as stop_iteration: final_outcome = Value(stop_iteration.value) except BaseException as task_exc: # Store for later, removing uninteresting top frames: 1 # frame we always remove, because it's this function - # another is the run_with # catching it, and then in addition we remove however many # more Context.run adds. tb = task_exc.__traceback__ - for _ in range(2 + CONTEXT_RUN_TB_FRAMES): + for _ in range(1 + CONTEXT_RUN_TB_FRAMES): if tb is not None: # pragma: no branch tb = tb.tb_next final_outcome = Error(task_exc.with_traceback(tb)) diff --git a/src/trio/_core/_run_context.py b/src/trio/_core/_run_context.py new file mode 100644 index 000000000..085bff9a3 --- /dev/null +++ b/src/trio/_core/_run_context.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING, Final + +if TYPE_CHECKING: + from ._run import Runner, Task + + +class RunContext(threading.local): + runner: Runner + task: Task + + +GLOBAL_RUN_CONTEXT: Final = RunContext()