Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add session.close(), remove internal _context #69

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 54 additions & 5 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ async def test_semaphore(client_cls, get_method, iter_method, mockserver):


@pytest.mark.asyncio
async def test_session(mockserver):
async def test_session_context_manager(mockserver):
client = AsyncZyteAPI(api_key="a", api_url=mockserver.urljoin("/"))
queries = [
{"url": "https://a.example", "httpResponseBody": True},
Expand All @@ -408,24 +408,73 @@ async def test_session(mockserver):
]
actual_results = []
async with client.session() as session:
assert session._context.connector.limit == client.n_conn
assert session._session.connector.limit == client.n_conn
actual_results.append(await session.get(queries[0]))
for future in session.iter(queries[1:]):
try:
result = await future
except Exception as e:
result = e
actual_results.append(result)
aiohttp_session = session._context
aiohttp_session = session._session
assert not aiohttp_session.closed
assert aiohttp_session.closed
assert session._context is None

with pytest.raises(RuntimeError):
await session.get(queries[0])

with pytest.raises(RuntimeError):
session.iter(queries[1:])
future = next(iter(session.iter(queries[1:])))
await future

assert len(actual_results) == len(expected_results)
for actual_result in actual_results:
if isinstance(actual_result, Exception):
assert Exception in expected_results
else:
assert actual_result in expected_results


@pytest.mark.asyncio
async def test_session_no_context_manager(mockserver):
client = AsyncZyteAPI(api_key="a", api_url=mockserver.urljoin("/"))
queries = [
{"url": "https://a.example", "httpResponseBody": True},
{"url": "https://exception.example", "httpResponseBody": True},
{"url": "https://b.example", "httpResponseBody": True},
]
expected_results = [
{
"url": "https://a.example",
"httpResponseBody": "PGh0bWw+PGJvZHk+SGVsbG88aDE+V29ybGQhPC9oMT48L2JvZHk+PC9odG1sPg==",
},
Exception,
{
"url": "https://b.example",
"httpResponseBody": "PGh0bWw+PGJvZHk+SGVsbG88aDE+V29ybGQhPC9oMT48L2JvZHk+PC9odG1sPg==",
},
]
actual_results = []
session = client.session()
assert session._session.connector.limit == client.n_conn
actual_results.append(await session.get(queries[0]))
for future in session.iter(queries[1:]):
try:
result = await future
except Exception as e:
result = e
actual_results.append(result)
aiohttp_session = session._session
assert not aiohttp_session.closed
await session.close()
assert aiohttp_session.closed

with pytest.raises(RuntimeError):
await session.get(queries[0])

with pytest.raises(RuntimeError):
future = next(iter(session.iter(queries[1:])))
await future

assert len(actual_results) == len(expected_results)
for actual_result in actual_results:
Expand Down
50 changes: 45 additions & 5 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_semaphore(mockserver):
assert client._async_client._semaphore.__aexit__.call_count == len(queries)


def test_session(mockserver):
def test_session_context_manager(mockserver):
client = ZyteAPI(api_key="a", api_url=mockserver.urljoin("/"))
queries = [
{"url": "https://a.example", "httpResponseBody": True},
Expand All @@ -87,20 +87,60 @@ def test_session(mockserver):
]
actual_results = []
with client.session() as session:
assert session._context.connector.limit == client._async_client.n_conn
assert session._session.connector.limit == client._async_client.n_conn
actual_results.append(session.get(queries[0]))
for result in session.iter(queries[1:]):
actual_results.append(result)
aiohttp_session = session._context
aiohttp_session = session._session
assert not aiohttp_session.closed
assert aiohttp_session.closed
assert session._context is None

with pytest.raises(RuntimeError):
session.get(queries[0])

assert isinstance(next(iter(session.iter(queries[1:]))), RuntimeError)

assert len(actual_results) == len(expected_results)
for actual_result in actual_results:
if isinstance(actual_result, Exception):
assert Exception in expected_results
else:
assert actual_result in expected_results


def test_session_no_context_manager(mockserver):
client = ZyteAPI(api_key="a", api_url=mockserver.urljoin("/"))
queries = [
{"url": "https://a.example", "httpResponseBody": True},
{"url": "https://exception.example", "httpResponseBody": True},
{"url": "https://b.example", "httpResponseBody": True},
]
expected_results = [
{
"url": "https://a.example",
"httpResponseBody": "PGh0bWw+PGJvZHk+SGVsbG88aDE+V29ybGQhPC9oMT48L2JvZHk+PC9odG1sPg==",
},
Exception,
{
"url": "https://b.example",
"httpResponseBody": "PGh0bWw+PGJvZHk+SGVsbG88aDE+V29ybGQhPC9oMT48L2JvZHk+PC9odG1sPg==",
},
]
actual_results = []
session = client.session()
assert session._session.connector.limit == client._async_client.n_conn
actual_results.append(session.get(queries[0]))
for result in session.iter(queries[1:]):
actual_results.append(result)
aiohttp_session = session._session
assert not aiohttp_session.closed
session.close()
assert aiohttp_session.closed

with pytest.raises(RuntimeError):
session.iter(queries[1:])
session.get(queries[0])

assert isinstance(next(iter(session.iter(queries[1:]))), RuntimeError)

assert len(actual_results) == len(expected_results)
for actual_result in actual_results:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def test_process_query_bytes():
_process_query({"url": b"https://example.com"})


def test_deprecated_create_session():
@pytest.mark.asyncio # https://github.com/aio-libs/aiohttp/pull/1468
async def test_deprecated_create_session():
from zyte_api.aio.client import create_session as _create_session

with pytest.warns(
Expand Down
34 changes: 19 additions & 15 deletions zyte_api/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,15 @@ class _AsyncSession:
def __init__(self, client, **session_kwargs):
self._client = client
self._session = create_session(client.n_conn, **session_kwargs)
self._context = None

async def __aenter__(self):
self._context = await self._session.__aenter__()
return self

async def __aexit__(self, *exc_info):
result = await self._context.__aexit__(*exc_info)
self._context = None
return result
await self._session.close()

def _check_context(self):
if self._context is None:
raise RuntimeError(
"Attempt to use session method on a session either not opened "
"or already closed."
)
async def close(self):
await self._session.close()

async def get(
self,
Expand All @@ -59,13 +51,12 @@ async def get(
handle_retries=True,
retrying: Optional[AsyncRetrying] = None,
):
self._check_context()
return await self._client.get(
query=query,
endpoint=endpoint,
handle_retries=handle_retries,
retrying=retrying,
session=self._context,
session=self._session,
)

def iter(
Expand All @@ -76,11 +67,10 @@ def iter(
handle_retries=True,
retrying: Optional[AsyncRetrying] = None,
) -> Iterator[Future]:
self._check_context()
return self._client.iter(
queries=queries,
endpoint=endpoint,
session=self._context,
session=self._session,
handle_retries=handle_retries,
retrying=retrying,
)
Expand Down Expand Up @@ -208,4 +198,18 @@ def _request(query):
return asyncio.as_completed([_request(query) for query in queries])

def session(self, **kwargs):
"""Asynchronous equivalent to :meth:`ZyteAPI.session`.

You do not need to use :meth:`~AsyncZyteAPI.session` as an async
context manager as long as you await ``close()`` on the object it
returns when you are done:

.. code-block:: python

session = client.session()
try:
...
finally:
await session.close()
"""
return _AsyncSession(client=self, **kwargs)
52 changes: 29 additions & 23 deletions zyte_api/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,24 @@ def _get_loop():
class _Session:
def __init__(self, client, **session_kwargs):
self._client = client
self._session = client._async_client.session(**session_kwargs)
self._context = None

def __enter__(self):
# https://github.com/aio-libs/aiohttp/pull/1468
async def create_session():
return client._async_client.session(**session_kwargs)._session

loop = _get_loop()
self._context = loop.run_until_complete(self._session.__aenter__())._context
self._session = loop.run_until_complete(create_session())

def __enter__(self):
return self

def __exit__(self, *exc_info):
loop = _get_loop()
result = loop.run_until_complete(self._context.__aexit__(*exc_info))
self._context = None
return result
loop.run_until_complete(self._session.close())

def _check_context(self):
if self._context is None:
raise RuntimeError(
"Attempt to use session method on a session either not opened "
"or already closed."
)
def close(self):
loop = _get_loop()
loop.run_until_complete(self._session.close())

def get(
self,
Expand All @@ -49,13 +47,12 @@ def get(
handle_retries=True,
retrying: Optional[AsyncRetrying] = None,
):
self._check_context()
return self._client.get(
query=query,
endpoint=endpoint,
handle_retries=handle_retries,
retrying=retrying,
session=self._context,
session=self._session,
)

def iter(
Expand All @@ -66,11 +63,10 @@ def iter(
handle_retries=True,
retrying: Optional[AsyncRetrying] = None,
) -> Generator[Union[dict, Exception], None, None]:
self._check_context()
return self._client.iter(
queries=queries,
endpoint=endpoint,
session=self._context,
session=self._session,
handle_retries=handle_retries,
retrying=retrying,
)
Expand Down Expand Up @@ -186,15 +182,14 @@ def iter(
yield exception

def session(self, **kwargs):
""":ref:`Context manager <context-managers>` to create a contextual
session.
""":ref:`Context manager <context-managers>` to create a session.

A contextual session is an object that has the same API as the client
object, except:
A session is an object that has the same API as the client object,
except:

- :meth:`get` and :meth:`iter` do not have a *session* parameter,
the contextual session creates an :class:`aiohttp.ClientSession`
object and passes it to :meth:`get` and :meth:`iter` automatically.
the session creates an :class:`aiohttp.ClientSession` object and
passes it to :meth:`get` and :meth:`iter` automatically.

- It does not have a :meth:`session` method.

Expand All @@ -205,5 +200,16 @@ def session(self, **kwargs):
The :class:`aiohttp.ClientSession` object is created with sane defaults
for Zyte API, but you can use *kwargs* to pass additional parameters to
:class:`aiohttp.ClientSession` and even override those sane defaults.

You do not need to use :meth:`session` as a context manager as long as
you call ``close()`` on the object it returns when you are done:

.. code-block:: python

session = client.session()
try:
...
finally:
session.close()
"""
return _Session(client=self, **kwargs)
Loading