Skip to content

Commit

Permalink
Add missing context kwarg to _sentry_task_factory (#2267)
Browse files Browse the repository at this point in the history
* Add missing context kwargs to _sentry_task_factory

* Forward context to Task

* Update _sentry_task_factory type comment

* Added type annotations and unit tests

* Suppress linter error

* Fix import error in old Python versions

* Fix again linter error

* Fixed all mypy errors for real

* Fix tests for Python 3.7

* Add pytest.mark.forked to prevent threading test failure

---------

Co-authored-by: Daniel Szoke <[email protected]>
Co-authored-by: Daniel Szoke <[email protected]>
  • Loading branch information
3 people authored Aug 28, 2023
1 parent 6f49e75 commit 838368c
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 6 deletions.
11 changes: 6 additions & 5 deletions sentry_sdk/integrations/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

if TYPE_CHECKING:
from typing import Any
from collections.abc import Coroutine

from sentry_sdk._types import ExcInfo

Expand All @@ -37,8 +38,8 @@ def patch_asyncio():
loop = asyncio.get_running_loop()
orig_task_factory = loop.get_task_factory()

def _sentry_task_factory(loop, coro):
# type: (Any, Any) -> Any
def _sentry_task_factory(loop, coro, **kwargs):
# type: (asyncio.AbstractEventLoop, Coroutine[Any, Any, Any], Any) -> asyncio.Future[Any]

async def _coro_creating_hub_and_span():
# type: () -> Any
Expand All @@ -56,7 +57,7 @@ async def _coro_creating_hub_and_span():

# Trying to use user set task factory (if there is one)
if orig_task_factory:
return orig_task_factory(loop, _coro_creating_hub_and_span())
return orig_task_factory(loop, _coro_creating_hub_and_span(), **kwargs)

# The default task factory in `asyncio` does not have its own function
# but is just a couple of lines in `asyncio.base_events.create_task()`
Expand All @@ -65,13 +66,13 @@ async def _coro_creating_hub_and_span():
# WARNING:
# If the default behavior of the task creation in asyncio changes,
# this will break!
task = Task(_coro_creating_hub_and_span(), loop=loop)
task = Task(_coro_creating_hub_and_span(), loop=loop, **kwargs)
if task._source_traceback: # type: ignore
del task._source_traceback[-1] # type: ignore

return task

loop.set_task_factory(_sentry_task_factory)
loop.set_task_factory(_sentry_task_factory) # type: ignore
except RuntimeError:
# When there is no running loop, we have nothing to patch.
pass
Expand Down
200 changes: 199 additions & 1 deletion tests/integrations/asyncio/test_asyncio_py3.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
import asyncio
import inspect
import sys

import pytest

import sentry_sdk
from sentry_sdk.consts import OP
from sentry_sdk.integrations.asyncio import AsyncioIntegration
from sentry_sdk.integrations.asyncio import AsyncioIntegration, patch_asyncio

try:
from unittest.mock import MagicMock, patch
except ImportError:
from mock import MagicMock, patch

try:
from contextvars import Context, ContextVar
except ImportError:
pass # All tests will be skipped with incompatible versions


minimum_python_37 = pytest.mark.skipif(
sys.version_info < (3, 7), reason="Asyncio tests need Python >= 3.7"
)


minimum_python_311 = pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Asyncio task context parameter was introduced in Python 3.11",
)


async def foo():
await asyncio.sleep(0.01)

Expand All @@ -33,6 +50,17 @@ def event_loop(request):
loop.close()


def get_sentry_task_factory(mock_get_running_loop):
"""
Patches (mocked) asyncio and gets the sentry_task_factory.
"""
mock_loop = mock_get_running_loop.return_value
patch_asyncio()
patched_factory = mock_loop.set_task_factory.call_args[0][0]

return patched_factory


@minimum_python_37
@pytest.mark.asyncio
async def test_create_task(
Expand Down Expand Up @@ -170,3 +198,173 @@ async def add(a, b):

result = await asyncio.create_task(add(1, 2))
assert result == 3, result


@minimum_python_311
@pytest.mark.asyncio
async def test_task_with_context(sentry_init):
"""
Integration test to ensure working context parameter in Python 3.11+
"""
sentry_init(
integrations=[
AsyncioIntegration(),
],
)

var = ContextVar("var")
var.set("original value")

async def change_value():
var.set("changed value")

async def retrieve_value():
return var.get()

# Create a context and run both tasks within the context
ctx = Context()
async with asyncio.TaskGroup() as tg:
tg.create_task(change_value(), context=ctx)
retrieve_task = tg.create_task(retrieve_value(), context=ctx)

assert retrieve_task.result() == "changed value"


@minimum_python_37
@patch("asyncio.get_running_loop")
def test_patch_asyncio(mock_get_running_loop):
"""
Test that the patch_asyncio function will patch the task factory.
"""
mock_loop = mock_get_running_loop.return_value

patch_asyncio()

assert mock_loop.set_task_factory.called

set_task_factory_args, _ = mock_loop.set_task_factory.call_args
assert len(set_task_factory_args) == 1

sentry_task_factory, *_ = set_task_factory_args
assert callable(sentry_task_factory)


@minimum_python_37
@pytest.mark.forked
@patch("asyncio.get_running_loop")
@patch("sentry_sdk.integrations.asyncio.Task")
def test_sentry_task_factory_no_factory(MockTask, mock_get_running_loop): # noqa: N803
mock_loop = mock_get_running_loop.return_value
mock_coro = MagicMock()

# Set the original task factory to None
mock_loop.get_task_factory.return_value = None

# Retieve sentry task factory (since it is an inner function within patch_asyncio)
sentry_task_factory = get_sentry_task_factory(mock_get_running_loop)

# The call we are testing
ret_val = sentry_task_factory(mock_loop, mock_coro)

assert MockTask.called
assert ret_val == MockTask.return_value

task_args, task_kwargs = MockTask.call_args
assert len(task_args) == 1

coro_param, *_ = task_args
assert inspect.iscoroutine(coro_param)

assert "loop" in task_kwargs
assert task_kwargs["loop"] == mock_loop


@minimum_python_37
@pytest.mark.forked
@patch("asyncio.get_running_loop")
def test_sentry_task_factory_with_factory(mock_get_running_loop):
mock_loop = mock_get_running_loop.return_value
mock_coro = MagicMock()

# The original task factory will be mocked out here, let's retrieve the value for later
orig_task_factory = mock_loop.get_task_factory.return_value

# Retieve sentry task factory (since it is an inner function within patch_asyncio)
sentry_task_factory = get_sentry_task_factory(mock_get_running_loop)

# The call we are testing
ret_val = sentry_task_factory(mock_loop, mock_coro)

assert orig_task_factory.called
assert ret_val == orig_task_factory.return_value

task_factory_args, _ = orig_task_factory.call_args
assert len(task_factory_args) == 2

loop_arg, coro_arg = task_factory_args
assert loop_arg == mock_loop
assert inspect.iscoroutine(coro_arg)


@minimum_python_311
@patch("asyncio.get_running_loop")
@patch("sentry_sdk.integrations.asyncio.Task")
def test_sentry_task_factory_context_no_factory(
MockTask, mock_get_running_loop # noqa: N803
):
mock_loop = mock_get_running_loop.return_value
mock_coro = MagicMock()
mock_context = MagicMock()

# Set the original task factory to None
mock_loop.get_task_factory.return_value = None

# Retieve sentry task factory (since it is an inner function within patch_asyncio)
sentry_task_factory = get_sentry_task_factory(mock_get_running_loop)

# The call we are testing
ret_val = sentry_task_factory(mock_loop, mock_coro, context=mock_context)

assert MockTask.called
assert ret_val == MockTask.return_value

task_args, task_kwargs = MockTask.call_args
assert len(task_args) == 1

coro_param, *_ = task_args
assert inspect.iscoroutine(coro_param)

assert "loop" in task_kwargs
assert task_kwargs["loop"] == mock_loop
assert "context" in task_kwargs
assert task_kwargs["context"] == mock_context


@minimum_python_311
@patch("asyncio.get_running_loop")
def test_sentry_task_factory_context_with_factory(mock_get_running_loop):
mock_loop = mock_get_running_loop.return_value
mock_coro = MagicMock()
mock_context = MagicMock()

# The original task factory will be mocked out here, let's retrieve the value for later
orig_task_factory = mock_loop.get_task_factory.return_value

# Retieve sentry task factory (since it is an inner function within patch_asyncio)
sentry_task_factory = get_sentry_task_factory(mock_get_running_loop)

# The call we are testing
ret_val = sentry_task_factory(mock_loop, mock_coro, context=mock_context)

assert orig_task_factory.called
assert ret_val == orig_task_factory.return_value

task_factory_args, task_factory_kwargs = orig_task_factory.call_args
assert len(task_factory_args) == 2

loop_arg, coro_arg = task_factory_args
assert loop_arg == mock_loop
assert inspect.iscoroutine(coro_arg)

assert "context" in task_factory_kwargs
assert task_factory_kwargs["context"] == mock_context

0 comments on commit 838368c

Please sign in to comment.