Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for async contextvars #153

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions pytest_asyncio/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import functools
import inspect
import socket
from contextvars import Context, copy_context
from asyncio import coroutines

import pytest
try:
Expand Down Expand Up @@ -102,6 +104,7 @@ async def async_finalizer():
return asyncio.get_event_loop().run_until_complete(setup())

fixturedef.func = wrapper

elif inspect.iscoroutinefunction(fixturedef.func):
coro = fixturedef.func

Expand Down Expand Up @@ -154,9 +157,12 @@ def inner(**kwargs):


def pytest_runtest_setup(item):
if 'asyncio' in item.keywords and 'event_loop' not in item.fixturenames:
# inject an event loop fixture for all async tests
item.fixturenames.append('event_loop')
if 'asyncio' in item.keywords:
if 'event_loop' not in item.fixturenames:
# inject an event loop fixture for all async tests
item.fixturenames.append('event_loop')
if 'context' not in item.fixturenames:
item.fixturenames.append('context')
if item.get_closest_marker("asyncio") is not None \
and not getattr(item.obj, 'hypothesis', False) \
and getattr(item.obj, 'is_hypothesis_test', False):
Expand All @@ -165,6 +171,40 @@ def pytest_runtest_setup(item):
'only works with Hypothesis 3.64.0 or later.' % item
)

class Task(asyncio.tasks._PyTask):
def __init__(self, coro, *, loop=None, name=None, context=None):
asyncio.futures._PyFuture.__init__(self, loop=loop)
if self._source_traceback:
del self._source_traceback[-1]
if not coroutines.iscoroutine(coro):
# raise after Future.__init__(), attrs are required for __del__
# prevent logging for pending task in __del__
self._log_destroy_pending = False
raise TypeError(f"a coroutine was expected, got {coro!r}")

if name is None:
self._name = f'Task-{asyncio.tasks._task_name_counter()}'
else:
self._name = str(name)

self._must_cancel = False
self._fut_waiter = None
self._coro = coro
self._context = context if context is not None else copy_context()

self._loop.call_soon(self.__step, context=self._context)
asyncio._register_task(self)


@pytest.fixture
def context(event_loop, request):
"""Create an empty context for the async test case and it's async fixtures."""
context = Context()
def taskfactory(loop, coro):
return Task(coro, loop=loop, context=context)
event_loop.set_task_factory(taskfactory)
return context


@pytest.fixture
def event_loop(request):
Expand Down
21 changes: 21 additions & 0 deletions tests/test_contextvars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Quick'n'dirty unit tests for provided fixtures and markers."""
import asyncio
import pytest

import pytest_asyncio.plugin

from contextvars import ContextVar


ctxvar = ContextVar('ctxvar')


@pytest.fixture
async def set_some_context(context):
ctxvar.set('quarantine is fun')


@pytest.mark.asyncio
async def test_test(set_some_context):
# print ("Context in test:", list(context.items()))
assert ctxvar.get() == 'quarantine is fun'