Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

don't use f_locals for foreign async generators #3112

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions newsfragments/3112.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Rework foreign async generator finalization to track async generator
ids rather than mutating ``ag_frame.f_locals``. This fixes an issue
with the previous implementation: locals' lifetimes will no longer be
extended by materialization in the ``ag_frame.f_locals`` dictionary that
the previous finalization dispatcher logic needed to access to do its work.
46 changes: 36 additions & 10 deletions src/trio/_core/_asyncgens.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
import warnings
import weakref
from typing import TYPE_CHECKING, NoReturn
from typing import TYPE_CHECKING, NoReturn, TypeVar

import attrs

Expand All @@ -16,14 +16,31 @@
ASYNCGEN_LOGGER = logging.getLogger("trio.async_generator_errors")

if TYPE_CHECKING:
from collections.abc import Callable
from types import AsyncGeneratorType

from typing_extensions import ParamSpec

_P = ParamSpec("_P")

_WEAK_ASYNC_GEN_SET = weakref.WeakSet[AsyncGeneratorType[object, NoReturn]]
_ASYNC_GEN_SET = set[AsyncGeneratorType[object, NoReturn]]
else:
_WEAK_ASYNC_GEN_SET = weakref.WeakSet
_ASYNC_GEN_SET = set

_R = TypeVar("_R")


@_core.disable_ki_protection
def _call_without_ki_protection(
f: Callable[_P, _R],
/,
*args: _P.args,
**kwargs: _P.kwargs,
) -> _R:
return f(*args, **kwargs)


@attrs.define(eq=False)
class AsyncGenerators:
Expand All @@ -35,6 +52,11 @@ class AsyncGenerators:
# regular set so we don't have to deal with GC firing at
# unexpected times.
alive: _WEAK_ASYNC_GEN_SET | _ASYNC_GEN_SET = attrs.Factory(_WEAK_ASYNC_GEN_SET)
# The ids of foreign async generators are added to this set when first
# iterated. Usually it is not safe to refer to ids like this, but because
# we're using a finalizer we can ensure ids in this set do not outlive
# their async generator.
foreign: set[int] = attrs.Factory(set)

# This collects async generators that get garbage collected during
# the one-tick window between the system nursery closing and the
Expand All @@ -51,10 +73,10 @@ def firstiter(agen: AsyncGeneratorType[object, NoReturn]) -> None:
# An async generator first iterated outside of a Trio
# task doesn't belong to Trio. Probably we're in guest
# mode and the async generator belongs to our host.
# The locals dictionary is the only good place to
# A strong set of ids is one of the only good places to
# remember this fact, at least until
# https://bugs.python.org/issue40916 is implemented.
agen.ag_frame.f_locals["@trio_foreign_asyncgen"] = True
# https://github.com/python/cpython/issues/85093 is implemented.
self.foreign.add(id(agen))
if self.prev_hooks.firstiter is not None:
self.prev_hooks.firstiter(agen)

Expand All @@ -76,13 +98,16 @@ def finalize_in_trio_context(
# have hit it.
self.trailing_needs_finalize.add(agen)

@_core.enable_ki_protection
def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None:
graingert marked this conversation as resolved.
Show resolved Hide resolved
agen_name = name_asyncgen(agen)
try:
is_ours = not agen.ag_frame.f_locals.get("@trio_foreign_asyncgen")
except AttributeError: # pragma: no cover
self.foreign.remove(id(agen))
except KeyError:
is_ours = True
else:
is_ours = False

agen_name = name_asyncgen(agen)
graingert marked this conversation as resolved.
Show resolved Hide resolved
if is_ours:
runner.entry_queue.run_sync_soon(
finalize_in_trio_context,
Expand All @@ -105,8 +130,9 @@ def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None:
)
else:
# Not ours -> forward to the host loop's async generator finalizer
if self.prev_hooks.finalizer is not None:
self.prev_hooks.finalizer(agen)
finalizer = self.prev_hooks.finalizer
if finalizer is not None:
_call_without_ki_protection(finalizer, agen)
else:
# Host has no finalizer. Reimplement the default
# Python behavior with no hooks installed: throw in
Expand All @@ -116,7 +142,7 @@ def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None:
try:
# If the next thing is a yield, this will raise RuntimeError
# which we allow to propagate
closer.send(None)
_call_without_ki_protection(closer.send, None)
except StopIteration:
pass
else:
Expand Down
64 changes: 53 additions & 11 deletions src/trio/_core/_tests/test_guest_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import asyncio
import contextlib
import contextvars
import queue
import signal
import socket
Expand All @@ -11,6 +10,7 @@
import time
import traceback
import warnings
import weakref
from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence
from functools import partial
from math import inf
Expand All @@ -22,6 +22,7 @@
)

import pytest
import sniffio
from outcome import Outcome

import trio
Expand Down Expand Up @@ -234,7 +235,8 @@ async def trio_main(in_host: InHost) -> str:


def test_guest_mode_sniffio_integration() -> None:
from sniffio import current_async_library, thread_local as sniffio_library
current_async_library = sniffio.current_async_library
sniffio_library = sniffio.thread_local

async def trio_main(in_host: InHost) -> str:
async def synchronize() -> None:
Expand Down Expand Up @@ -458,9 +460,9 @@ def aiotrio_run(

async def aio_main() -> T:
nonlocal run_sync_soon_not_threadsafe
trio_done_fut = loop.create_future()
trio_done_fut: asyncio.Future[Outcome[T]] = loop.create_future()

def trio_done_callback(main_outcome: Outcome[object]) -> None:
def trio_done_callback(main_outcome: Outcome[T]) -> None:
print(f"trio_fn finished: {main_outcome!r}")
trio_done_fut.set_result(main_outcome)

Expand All @@ -479,9 +481,11 @@ def trio_done_callback(main_outcome: Outcome[object]) -> None:
strict_exception_groups=strict_exception_groups,
)

return (await trio_done_fut).unwrap() # type: ignore[no-any-return]
return (await trio_done_fut).unwrap()

try:
# can't use asyncio.run because that fails on Windows (3.8, x64, with
# Komodia LSP) and segfaults on Windows (3.9, x64, with Komodia LSP)
return loop.run_until_complete(aio_main())
finally:
loop.close()
Expand Down Expand Up @@ -655,8 +659,6 @@ async def trio_main(in_host: InHost) -> None:

@restore_unraisablehook()
def test_guest_mode_asyncgens() -> None:
import sniffio

record = set()

async def agen(label: str) -> AsyncGenerator[int, None]:
Expand All @@ -683,9 +685,49 @@ async def trio_main() -> None:

gc_collect_harder()

# Ensure we don't pollute the thread-level context if run under
# an asyncio without contextvars support (3.6)
context = contextvars.copy_context()
context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True)
aiotrio_run(trio_main, host_uses_signal_set_wakeup_fd=True)

assert record == {("asyncio", "asyncio"), ("trio", "trio")}


@restore_unraisablehook()
def test_guest_mode_asyncgens_garbage_collection() -> None:
A5rocks marked this conversation as resolved.
Show resolved Hide resolved
record: set[tuple[str, str, bool]] = set()

async def agen(label: str) -> AsyncGenerator[int, None]:
class A:
pass

a = A()
a_wr = weakref.ref(a)
assert sniffio.current_async_library() == label
try:
yield 1
finally:
library = sniffio.current_async_library()
with contextlib.suppress(trio.Cancelled):
await sys.modules[library].sleep(0)

del a
if sys.implementation.name == "pypy":
gc_collect_harder()
A5rocks marked this conversation as resolved.
Show resolved Hide resolved

record.add((label, library, a_wr() is None))

async def iterate_in_aio() -> None:
await agen("asyncio").asend(None)

async def trio_main() -> None:
task = asyncio.ensure_future(iterate_in_aio())
done_evt = trio.Event()
task.add_done_callback(lambda _: done_evt.set())
with trio.fail_after(1):
await done_evt.wait()

await agen("trio").asend(None)

gc_collect_harder()

aiotrio_run(trio_main, host_uses_signal_set_wakeup_fd=True)

assert record == {("asyncio", "asyncio", True), ("trio", "trio", True)}
Loading