diff --git a/CHANGES.txt b/CHANGES.txt index 764e343a..69ed2c58 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,5 +1,10 @@ CHANGES ------- +X.X.X (XXXX-XX-XX) +^^^^^^^^^^^^^^^^^^^ + +* Add `async with` support to cursor context + 0.15.0 (2018-08-14) ^^^^^^^^^^^^^^^^^^^ diff --git a/aiopg/cursor.py b/aiopg/cursor.py index edfe1be7..82a1e256 100644 --- a/aiopg/cursor.py +++ b/aiopg/cursor.py @@ -50,9 +50,18 @@ def description(self): def close(self): """Close the cursor now.""" - if not self.closed: + if self.closed: + return + + try: self._impl.close() self._conn.cursor_closed(self) + except psycopg2.ProgrammingError: + # seen instances where the cursor fails to close: + # https://github.com/aio-libs/aiopg/issues/364 + # We close it here so we don't return a bad connection to the pool + self._conn.cursor_closed(self) + raise @property def closed(self): diff --git a/aiopg/pool.py b/aiopg/pool.py index d2a4892c..e8e1c60d 100644 --- a/aiopg/pool.py +++ b/aiopg/pool.py @@ -262,15 +262,12 @@ def release(self, conn): fut = ensure_future(self._wakeup(), loop=self._loop) return fut - @asyncio.coroutine - def cursor(self, name=None, cursor_factory=None, - scrollable=None, withhold=False, *, timeout=None): - """XXX""" - conn = yield from self.acquire() - cur = yield from conn.cursor(name=name, cursor_factory=cursor_factory, - scrollable=scrollable, withhold=withhold, - timeout=timeout) - return _PoolCursorContextManager(self, conn, cur) + def cursor(self, name=None, cursor_factory=None, scrollable=None, + withhold=False, *, timeout=None): + cursor_kwargs = dict(name=name, cursor_factory=cursor_factory, + scrollable=scrollable, withhold=withhold, + timeout=timeout) + return _PoolCursorContextManager(self, cursor_kwargs) def __enter__(self): raise RuntimeError( diff --git a/aiopg/utils.py b/aiopg/utils.py index 31d8e010..f8859539 100644 --- a/aiopg/utils.py +++ b/aiopg/utils.py @@ -1,6 +1,5 @@ import asyncio import sys -import psycopg2 PY_35 = sys.version_info >= (3, 5) PY_352 = sys.version_info >= (3, 5, 2) @@ -191,6 +190,10 @@ def __init__(self, pool, conn): self._pool = pool self._conn = conn + @property + def conn(self): + return self._conn + def __enter__(self): assert self._conn return self._conn @@ -221,24 +224,27 @@ def __aexit__(self, exc_type, exc_val, exc_tb): class _PoolCursorContextManager: """Context manager. - This enables the following idiom for acquiring and releasing a + This enables the following idioms for acquiring and releasing a cursor around a block: with (yield from pool.cursor()) as cur: yield from cur.execute("SELECT 1") + async with pool.cursor() as cur: + yield from cur.execute("SELECT 1") + while failing loudly when accidentally using: with pool: """ - __slots__ = ('_pool', '_conn', '_cur') + __slots__ = ('_pool', '_cursor_kwargs', '_cur') - def __init__(self, pool, conn, cur): + def __init__(self, pool, cursor_kwargs=None): self._pool = pool - self._conn = conn - self._cur = cur + self._cursor_kwargs = cursor_kwargs + self._cur = None def __enter__(self): return self._cur @@ -246,20 +252,68 @@ def __enter__(self): def __exit__(self, *args): try: self._cur.close() - except psycopg2.ProgrammingError: - # seen instances where the cursor fails to close: - # https://github.com/aio-libs/aiopg/issues/364 - # We close it here so we don't return a bad connection to the pool - self._conn.close() - raise finally: try: - self._pool.release(self._conn) + self._pool.__exit__(*args) finally: - self._pool = None - self._conn = None self._cur = None + @asyncio.coroutine + def _init_cursor(self, with_aenter): + assert not self._cur + + if with_aenter: + conn = None + else: + conn = yield from self._pool.acquire() + + # self._pool now morphs into a _PoolConnectionContextManager + self._pool = _PoolConnectionContextManager(self._pool, conn) + + if with_aenter: + # this will create the connection + yield from self._pool.__aenter__() + self._cur = yield from self._pool.conn.cursor( + **self._cursor_kwargs) + + return self._cur + else: + self._cur = yield from self._pool.conn.cursor( + **self._cursor_kwargs) + return self + + @asyncio.coroutine + def __iter__(self): + # This will get hit if you use "yield from pool.cursor()" + result = yield from self._init_cursor(False) + return result + + def __await__(self): + # This will get hit directly if you "await pool.cursor()" + # this is using a trick similar to the one here: + # https://magicstack.github.io/asyncpg/current/_modules/asyncpg/pool.html + # however since `self._init()` is an "asyncio.coroutine" we can't use + # just return self._init().__await__() as that returns a generator + # without an "__await__" attribute and we can't return a coroutine from + # here + value = yield from self._init_cursor(False) + return value + + if PY_35: + @asyncio.coroutine + def __aenter__(self): + value = yield from self._init_cursor(True) + return value + + @asyncio.coroutine + def __aexit__(self, exc_type, exc_val, exc_tb): + try: + yield from self._cur.__aexit__(exc_type, exc_val, exc_tb) + self._cur = None + finally: + yield from self._pool.__aexit__(exc_type, exc_val, exc_tb) + self._pool = None + if not PY_35: try: diff --git a/tests/pep492/test_async_await.py b/tests/pep492/test_async_await.py index 6174cb0b..57b4589b 100644 --- a/tests/pep492/test_async_await.py +++ b/tests/pep492/test_async_await.py @@ -6,7 +6,6 @@ from aiopg.sa import SAConnection -@asyncio.coroutine async def test_cursor_await(make_connection): conn = await make_connection() @@ -17,7 +16,6 @@ async def test_cursor_await(make_connection): cursor.close() -@asyncio.coroutine async def test_connect_context_manager(loop, pg_params): async with aiopg.connect(loop=loop, **pg_params) as conn: cursor = await conn.cursor() @@ -28,7 +26,26 @@ async def test_connect_context_manager(loop, pg_params): assert conn.closed -@asyncio.coroutine +async def test_pool_cursor_context_manager(loop, pg_params): + async with aiopg.create_pool(loop=loop, **pg_params) as pool: + async with pool.cursor() as cursor: + await cursor.execute('SELECT 42') + resp = await cursor.fetchone() + assert resp == (42, ) + assert cursor.closed + assert pool.closed + + +async def test_pool_cursor_await_context_manager(loop, pg_params): + async with aiopg.create_pool(loop=loop, **pg_params) as pool: + with (await pool.cursor()) as cursor: + await cursor.execute('SELECT 42') + resp = await cursor.fetchone() + assert resp == (42, ) + assert cursor.closed + assert pool.closed + + async def test_connection_context_manager(make_connection): conn = await make_connection() assert not conn.closed @@ -41,7 +58,6 @@ async def test_connection_context_manager(make_connection): assert conn.closed -@asyncio.coroutine async def test_cursor_create_with_context_manager(make_connection): conn = await make_connection() @@ -54,7 +70,6 @@ async def test_cursor_create_with_context_manager(make_connection): assert cursor.closed -@asyncio.coroutine async def test_two_cursor_create_with_context_manager(make_connection): conn = await make_connection() @@ -63,7 +78,6 @@ async def test_two_cursor_create_with_context_manager(make_connection): assert not cursor2.closed -@asyncio.coroutine async def test_pool_context_manager_timeout(pg_params, loop): async with aiopg.create_pool(loop=loop, **pg_params, minsize=1, maxsize=1) as pool: @@ -79,7 +93,7 @@ async def test_pool_context_manager_timeout(pg_params, loop): fut.cancel() cursor_ctx = await pool.cursor() with cursor_ctx as cursor: - resp = await cursor.execute('SELECT 42;') + await cursor.execute('SELECT 42;') resp = await cursor.fetchone() assert resp == (42, ) @@ -87,7 +101,6 @@ async def test_pool_context_manager_timeout(pg_params, loop): assert pool.closed -@asyncio.coroutine async def test_cursor_with_context_manager(make_connection): conn = await make_connection() cursor = await conn.cursor() @@ -100,7 +113,6 @@ async def test_cursor_with_context_manager(make_connection): assert cursor.closed -@asyncio.coroutine async def test_cursor_lightweight(make_connection): conn = await make_connection() cursor = await conn.cursor() @@ -112,7 +124,6 @@ async def test_cursor_lightweight(make_connection): assert cursor.closed -@asyncio.coroutine async def test_pool_context_manager(pg_params, loop): pool = await aiopg.create_pool(loop=loop, **pg_params) @@ -127,7 +138,6 @@ async def test_pool_context_manager(pg_params, loop): assert pool.closed -@asyncio.coroutine async def test_create_pool_context_manager(pg_params, loop): async with aiopg.create_pool(loop=loop, **pg_params) as pool: async with pool.acquire() as conn: @@ -141,7 +151,6 @@ async def test_create_pool_context_manager(pg_params, loop): assert pool.closed -@asyncio.coroutine async def test_cursor_aiter(make_connection): result = [] conn = await make_connection() @@ -156,7 +165,6 @@ async def test_cursor_aiter(make_connection): assert conn.closed -@asyncio.coroutine async def test_engine_context_manager(pg_params, loop): engine = await aiopg.sa.create_engine(loop=loop, **pg_params) async with engine: @@ -166,7 +174,6 @@ async def test_engine_context_manager(pg_params, loop): assert engine.closed -@asyncio.coroutine async def test_create_engine_context_manager(pg_params, loop): async with aiopg.sa.create_engine(loop=loop, **pg_params) as engine: async with engine.acquire() as conn: @@ -174,7 +181,6 @@ async def test_create_engine_context_manager(pg_params, loop): assert engine.closed -@asyncio.coroutine async def test_result_proxy_aiter(pg_params, loop): sql = 'SELECT generate_series(1, 5);' result = [] @@ -188,7 +194,6 @@ async def test_result_proxy_aiter(pg_params, loop): assert conn.closed -@asyncio.coroutine async def test_transaction_context_manager(pg_params, loop): sql = 'SELECT generate_series(1, 5);' result = [] @@ -215,7 +220,6 @@ async def test_transaction_context_manager(pg_params, loop): assert conn.closed -@asyncio.coroutine async def test_transaction_context_manager_error(pg_params, loop): async with aiopg.sa.create_engine(loop=loop, **pg_params) as engine: async with engine.acquire() as conn: @@ -228,7 +232,6 @@ async def test_transaction_context_manager_error(pg_params, loop): assert conn.closed -@asyncio.coroutine async def test_transaction_context_manager_commit_once(pg_params, loop): async with aiopg.sa.create_engine(loop=loop, **pg_params) as engine: async with engine.acquire() as conn: @@ -248,7 +251,6 @@ async def test_transaction_context_manager_commit_once(pg_params, loop): assert conn.closed -@asyncio.coroutine async def test_transaction_context_manager_nested_commit(pg_params, loop): sql = 'SELECT generate_series(1, 5);' result = [] @@ -278,7 +280,6 @@ async def test_transaction_context_manager_nested_commit(pg_params, loop): assert conn.closed -@asyncio.coroutine async def test_sa_connection_execute(pg_params, loop): sql = 'SELECT generate_series(1, 5);' result = [] diff --git a/tests/test_connection.py b/tests/test_connection.py index 4d07c3c6..f1e0464d 100755 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -9,6 +9,7 @@ import time import sys +from psycopg2.extensions import parse_dsn from aiopg.connection import Connection, TIMEOUT from aiopg.cursor import Cursor from aiopg.utils import ensure_future @@ -147,12 +148,13 @@ def test_set_session(connect): @asyncio.coroutine def test_dsn(connect, pg_params): conn = yield from connect() - pg_params['password'] = 'x' * len(pg_params['password']) - assert 'dbname' in conn.dsn - assert 'user' in conn.dsn - assert 'password' in conn.dsn - assert 'host' in conn.dsn - assert 'port' in conn.dsn + + pg_params = pg_params.copy() + pg_params['password'] = 'xxx' + pg_params['dbname'] = pg_params.pop('database') + pg_params['port'] = str(pg_params['port']) + + assert parse_dsn(conn.dsn) == pg_params @asyncio.coroutine @@ -211,7 +213,8 @@ def test_isolation_level(connect): assert psycopg2.extensions.ISOLATION_LEVEL_DEFAULT == conn.isolation_level with pytest.raises(psycopg2.ProgrammingError): - yield from conn.set_isolation_level(1) + yield from conn.set_isolation_level( + psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED) assert psycopg2.extensions.ISOLATION_LEVEL_DEFAULT == conn.isolation_level @@ -499,7 +502,7 @@ def test_connect_to_unsupported_port(unused_port, loop, pg_params): pg_params['port'] = port with pytest.raises(psycopg2.OperationalError): - yield from aiopg.connect(loop=loop, **pg_params) + yield from aiopg.connect(loop=loop, timeout=3, **pg_params) @asyncio.coroutine @@ -689,6 +692,7 @@ def test_connection_on_server_restart(connect, pg_server, docker): yield from cur.execute('SELECT 1') ret = yield from cur.fetchone() assert (1,) == ret + docker.restart(container=pg_server['Id']) with pytest.raises(psycopg2.OperationalError):