Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deprecation warnings related to asyncio #364

Merged
merged 6 commits into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions tests/ext/aiobotocore/test_aiobotocore.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@


@pytest.fixture(scope='function')
def recorder(loop):
def recorder(event_loop):
"""
Clean up before and after each test run
"""
xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop))
xray_recorder.configure(
service='test', sampling=False, context=AsyncContext(loop=event_loop)
)
xray_recorder.clear_trace_entities()
yield xray_recorder
xray_recorder.clear_trace_entities()


async def test_describe_table(loop, recorder):
async def test_describe_table(event_loop, recorder):
segment = recorder.begin_segment('name')

req_id = '1234'
Expand All @@ -45,7 +47,7 @@ async def test_describe_table(loop, recorder):
assert aws_meta['operation'] == 'DescribeTable'


async def test_s3_parameter_capture(loop, recorder):
async def test_s3_parameter_capture(event_loop, recorder):
segment = recorder.begin_segment('name')

bucket_name = 'mybucket'
Expand All @@ -70,7 +72,7 @@ async def test_s3_parameter_capture(loop, recorder):
assert aws_meta['operation'] == 'GetObject'


async def test_list_parameter_counting(loop, recorder):
async def test_list_parameter_counting(event_loop, recorder):
"""
Test special parameters that have shape of list are recorded
as count based on `para_whitelist.json`
Expand Down Expand Up @@ -103,7 +105,7 @@ async def test_list_parameter_counting(loop, recorder):
assert aws_meta['queue_name_prefix'] == queue_name_prefix


async def test_map_parameter_grouping(loop, recorder):
async def test_map_parameter_grouping(event_loop, recorder):
"""
Test special parameters that have shape of map are recorded
as a list of keys based on `para_whitelist.json`
Expand Down Expand Up @@ -131,9 +133,10 @@ async def test_map_parameter_grouping(loop, recorder):
assert sorted(aws_meta['table_names']) == ['table1', 'table2']


async def test_context_missing_not_swallow_return(loop, recorder):
async def test_context_missing_not_swallow_return(event_loop, recorder):
xray_recorder.configure(service='test', sampling=False,
context=AsyncContext(loop=loop), context_missing='LOG_ERROR')
context=AsyncContext(loop=event_loop),
context_missing='LOG_ERROR')

response = {'ResponseMetadata': {'RequestId': '1234', 'HTTPStatusCode': 403}}

Expand All @@ -146,9 +149,10 @@ async def test_context_missing_not_swallow_return(loop, recorder):
assert actual_resp == response


async def test_context_missing_not_suppress_exception(loop, recorder):
async def test_context_missing_not_suppress_exception(event_loop, recorder):
xray_recorder.configure(service='test', sampling=False,
context=AsyncContext(loop=loop), context_missing='LOG_ERROR')
context=AsyncContext(loop=event_loop),
context_missing='LOG_ERROR')

session = get_session()
async with session.create_client('dynamodb', region_name='eu-west-2') as client:
Expand Down
66 changes: 35 additions & 31 deletions tests/ext/aiohttp/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Expects pytest-aiohttp
"""
import asyncio
import sys
from aws_xray_sdk import global_sdk_config
try:
from unittest.mock import patch
Expand Down Expand Up @@ -84,7 +85,10 @@ async def handle_delay(self, request: web.Request) -> web.Response:
"""
Handle /delay request
"""
await asyncio.sleep(0.3, loop=self._loop)
if sys.version_info >= (3, 8):
await asyncio.sleep(0.3)
else:
await asyncio.sleep(0.3, loop=self._loop)
return web.Response(text="ok")

def get_app(self) -> web.Application:
Expand Down Expand Up @@ -120,15 +124,15 @@ def recorder(loop):
patcher.stop()


async def test_ok(test_client, loop, recorder):
async def test_ok(aiohttp_client, loop, recorder):
"""
Test a normal response

:param test_client: AioHttp test client fixture
:param aiohttp_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await test_client(ServerTest.app(loop=loop))
client = await aiohttp_client(ServerTest.app(loop=loop))

resp = await client.get('/')
assert resp.status == 200
Expand All @@ -144,15 +148,15 @@ async def test_ok(test_client, loop, recorder):
assert response['status'] == 200


async def test_ok_x_forwarded_for(test_client, loop, recorder):
async def test_ok_x_forwarded_for(aiohttp_client, loop, recorder):
"""
Test a normal response with x_forwarded_for headers

:param test_client: AioHttp test client fixture
:param aiohttp_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await test_client(ServerTest.app(loop=loop))
client = await aiohttp_client(ServerTest.app(loop=loop))

resp = await client.get('/', headers={'X-Forwarded-For': 'foo'})
assert resp.status == 200
Expand All @@ -162,15 +166,15 @@ async def test_ok_x_forwarded_for(test_client, loop, recorder):
assert segment.http['request']['x_forwarded_for']


async def test_ok_content_length(test_client, loop, recorder):
async def test_ok_content_length(aiohttp_client, loop, recorder):
"""
Test a normal response with content length as response header

:param test_client: AioHttp test client fixture
:param aiohttp_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await test_client(ServerTest.app(loop=loop))
client = await aiohttp_client(ServerTest.app(loop=loop))

resp = await client.get('/?content_length=100')
assert resp.status == 200
Expand All @@ -179,15 +183,15 @@ async def test_ok_content_length(test_client, loop, recorder):
assert segment.http['response']['content_length'] == 100


async def test_error(test_client, loop, recorder):
async def test_error(aiohttp_client, loop, recorder):
"""
Test a 4XX response

:param test_client: AioHttp test client fixture
:param aiohttp_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await test_client(ServerTest.app(loop=loop))
client = await aiohttp_client(ServerTest.app(loop=loop))

resp = await client.get('/error')
assert resp.status == 404
Expand All @@ -204,15 +208,15 @@ async def test_error(test_client, loop, recorder):
assert response['status'] == 404


async def test_exception(test_client, loop, recorder):
async def test_exception(aiohttp_client, loop, recorder):
"""
Test handling an exception

:param test_client: AioHttp test client fixture
:param aiohttp_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await test_client(ServerTest.app(loop=loop))
client = await aiohttp_client(ServerTest.app(loop=loop))

with pytest.raises(Exception):
await client.get('/exception')
Expand All @@ -231,15 +235,15 @@ async def test_exception(test_client, loop, recorder):
assert exception.type == 'CancelledError'


async def test_unhauthorized(test_client, loop, recorder):
async def test_unhauthorized(aiohttp_client, loop, recorder):
"""
Test a 401 response

:param test_client: AioHttp test client fixture
:param aiohttp_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await test_client(ServerTest.app(loop=loop))
client = await aiohttp_client(ServerTest.app(loop=loop))

resp = await client.get('/unauthorized')
assert resp.status == 401
Expand All @@ -256,8 +260,8 @@ async def test_unhauthorized(test_client, loop, recorder):
assert response['status'] == 401


async def test_response_trace_header(test_client, loop, recorder):
client = await test_client(ServerTest.app(loop=loop))
async def test_response_trace_header(aiohttp_client, loop, recorder):
client = await aiohttp_client(ServerTest.app(loop=loop))
resp = await client.get('/')
xray_header = resp.headers[http.XRAY_HEADER]
segment = recorder.emitter.pop()
Expand All @@ -266,42 +270,42 @@ async def test_response_trace_header(test_client, loop, recorder):
assert expected in xray_header


async def test_concurrent(test_client, loop, recorder):
async def test_concurrent(aiohttp_client, loop, recorder):
"""
Test multiple concurrent requests

:param test_client: AioHttp test client fixture
:param aiohttp_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
client = await test_client(ServerTest.app(loop=loop))
client = await aiohttp_client(ServerTest.app(loop=loop))

recorder.emitter = CustomStubbedEmitter()

async def get_delay():
resp = await client.get('/delay')
assert resp.status == 200

await asyncio.wait([get_delay(), get_delay(), get_delay(),
get_delay(), get_delay(), get_delay(),
get_delay(), get_delay(), get_delay()],
loop=loop)
if sys.version_info >= (3, 8):
await asyncio.wait([loop.create_task(get_delay()) for i in range(9)])
else:
await asyncio.wait([loop.create_task(get_delay()) for i in range(9)], loop=loop)

# Ensure all ID's are different
ids = [item.id for item in recorder.emitter.local]
assert len(ids) == len(set(ids))


async def test_disabled_sdk(test_client, loop, recorder):
async def test_disabled_sdk(aiohttp_client, loop, recorder):
"""
Test a normal response when the SDK is disabled.

:param test_client: AioHttp test client fixture
:param aiohttp_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
global_sdk_config.set_sdk_enabled(False)
client = await test_client(ServerTest.app(loop=loop))
client = await aiohttp_client(ServerTest.app(loop=loop))

resp = await client.get('/')
assert resp.status == 200
Expand Down
24 changes: 18 additions & 6 deletions tests/test_async_local_storage.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
import random
import sys

from aws_xray_sdk.core.async_context import TaskLocalStorage


def test_localstorage_isolation(loop):
local_storage = TaskLocalStorage(loop=loop)
def test_localstorage_isolation(event_loop):
local_storage = TaskLocalStorage(loop=event_loop)

async def _test():
"""
Expand All @@ -19,7 +20,10 @@ async def _test():
random_int = random.random()
local_storage.randint = random_int

await asyncio.sleep(0.0, loop=loop)
if sys.version_info >= (3, 8):
await asyncio.sleep(0.0)
else:
await asyncio.sleep(0.0, loop=event_loop)

current_random_int = local_storage.randint
assert random_int == current_random_int
Expand All @@ -29,9 +33,17 @@ async def _test():
return False

# Run loads of concurrent tasks
results = loop.run_until_complete(
asyncio.wait([_test() for _ in range(0, 100)], loop=loop)
)
if sys.version_info >= (3, 8):
results = event_loop.run_until_complete(
asyncio.wait([event_loop.create_task(_test()) for _ in range(0, 100)])
)
else:
results = event_loop.run_until_complete(
asyncio.wait(
[event_loop.create_task(_test()) for _ in range(0, 100)],
loop=event_loop,
)
)
results = [item.result() for item in results[0]]

# Double check all is good
Expand Down
18 changes: 12 additions & 6 deletions tests/test_async_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ async def async_method():
await async_method2()


async def test_capture(loop):
xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop))
async def test_capture(event_loop):
xray_recorder.configure(
service='test', sampling=False, context=AsyncContext(loop=event_loop)
)

segment = xray_recorder.begin_segment('name')

Expand All @@ -44,8 +46,10 @@ async def test_capture(loop):
assert platform.python_implementation() == service.get('runtime')
assert platform.python_version() == service.get('runtime_version')

async def test_concurrent_calls(loop):
xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop))
async def test_concurrent_calls(event_loop):
xray_recorder.configure(
service='test', sampling=False, context=AsyncContext(loop=event_loop)
)
async with xray_recorder.in_segment_async('segment') as segment:
global counter
counter = 0
Expand All @@ -67,8 +71,10 @@ async def assert_task():
assert subseg_parent_id == segment.id


async def test_async_context_managers(loop):
xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop))
async def test_async_context_managers(event_loop):
xray_recorder.configure(
service='test', sampling=False, context=AsyncContext(loop=event_loop)
)

async with xray_recorder.in_segment_async('segment') as segment:
async with xray_recorder.capture_async('aio_capture') as subsegment:
Expand Down
12 changes: 4 additions & 8 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,13 @@ deps =
py34: typing >= 3.7.4.3

; Python 3.5+ only deps
; for some reason pytest-aiohttp is required for "core" tests
; TODO: find and replace by more direct dependency
py{35,36,37,38,39}: pytest-aiohttp
py{35,36,37,38,39}: pytest-asyncio

ext-aiobotocore: aiobotocore >= 0.10.0
ext-aiobotocore: pytest-aiohttp
ext-aiobotocore: pytest-asyncio

ext-aiohttp: aiohttp >= 3.0.0
; Breaking change where the `test_client` fixture was renamed.
; Also, the stable version is only supported for Python 3.7+
ext-aiohttp: pytest-aiohttp < 1.0.0
ext-aiohttp: aiohttp >= 3.3.0
ext-aiohttp: pytest-aiohttp

ext-httpx: httpx >= 0.20
ext-httpx: pytest-asyncio >= 0.19
Expand Down