Skip to content

Commit

Permalink
Merge pull request #57 from aio-libs/even-more-async-support
Browse files Browse the repository at this point in the history
Even more async support
  • Loading branch information
jettify committed Jan 18, 2016
2 parents ac83511 + e6c4fd5 commit 52149b6
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 15 deletions.
13 changes: 11 additions & 2 deletions aiomysql/sa/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .result import create_result_proxy
from .transaction import (RootTransaction, Transaction,
NestedTransaction, TwoPhaseTransaction)
from ..utils import _TransactionContextManager, _SAConnectionContextManager


class SAConnection:
Expand All @@ -23,7 +24,6 @@ def __init__(self, connection, engine):
self._engine = engine
self._dialect = engine.dialect

@asyncio.coroutine
def execute(self, query, *multiparams, **params):
"""Executes a SQL query with optional parameters.
Expand Down Expand Up @@ -61,6 +61,11 @@ def execute(self, query, *multiparams, **params):
execution.
"""
coro = self._execute(query, *multiparams, **params)
return _SAConnectionContextManager(coro)

@asyncio.coroutine
def _execute(self, query, *multiparams, **params):
cursor = yield from self._connection.cursor()
dp = _distill_params(multiparams, params)
if len(dp) > 1:
Expand Down Expand Up @@ -124,7 +129,6 @@ def closed(self):
def connection(self):
return self._connection

@asyncio.coroutine
def begin(self):
"""Begin a transaction and return a transaction handle.
Expand Down Expand Up @@ -152,6 +156,11 @@ def begin(self):
.begin_twophase - use a two phase/XA transaction
"""
coro = self._begin()
return _TransactionContextManager(coro)

@asyncio.coroutine
def _begin(self):
if self._transaction is None:
self._transaction = RootTransaction(self)
yield from self._begin_impl()
Expand Down
4 changes: 3 additions & 1 deletion aiomysql/sa/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ def _prepare(self):
cursor = self._cursor
if cursor.description is not None:
self._metadata = ResultMetaData(self, cursor.description)
callback = lambda wr: asyncio.Task(cursor.close(), loop=loop)

def callback(wr):
asyncio.Task(cursor.close(), loop=loop)
self._weak = weakref.ref(self, callback)
else:
self._metadata = None
Expand Down
21 changes: 12 additions & 9 deletions aiomysql/sa/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio

from . import exc
from ..utils import PY_35


class Transaction(object):
Expand Down Expand Up @@ -86,16 +87,18 @@ def commit(self):
def _do_commit(self):
pass

@asyncio.coroutine
def __aenter__(self):
return self
if PY_35: # pragma: no branch
@asyncio.coroutine
def __aenter__(self):
return self

@asyncio.coroutine
def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
yield from self.commit()
else:
yield from self.rollback()
@asyncio.coroutine
def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type:
yield from self.rollback()
else:
if self._is_active:
yield from self.commit()


class RootTransaction(Transaction):
Expand Down
23 changes: 23 additions & 0 deletions aiomysql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,29 @@ def __aexit__(self, exc_type, exc, tb):
self._obj = None


class _SAConnectionContextManager(_ContextManager):

if PY_35: # pragma: no branch
@asyncio.coroutine
def __aiter__(self):
result = yield from self._coro
return result


class _TransactionContextManager(_ContextManager):

if PY_35: # pragma: no branch

@asyncio.coroutine
def __aexit__(self, exc_type, exc, tb):
if exc_type:
yield from self._obj.rollback()
else:
if self._obj.is_active:
yield from self._obj.commit()
self._obj = None


class _PoolAcquireContextManager(_ContextManager):

__slots__ = ('_coro', '_conn', '_pool')
Expand Down
68 changes: 65 additions & 3 deletions tests/pep492/test_async_with.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ async def go():
async with conn:
await self._prepare(conn.connection)
ret = []
async for i in (await conn.execute(tbl.select())):
async for i in conn.execute(tbl.select()):
ret.append(i)
assert [(1, 'a'), (2, 'b'), (3, 'c')] == ret
assert conn.closed
Expand Down Expand Up @@ -249,7 +249,7 @@ async def go():
await self._prepare(conn.connection)

ret = []
async for i in (await conn.execute(tbl.select())):
async for i in conn.execute(tbl.select()):
ret.append(i)
assert [(1, 'a'), (2, 'b'), (3, 'c')] == ret

Expand All @@ -264,8 +264,70 @@ async def go():
await self._prepare(conn.connection)

ret = []
async for i in (await conn.execute(tbl.select())):
async for i in conn.execute(tbl.select()):
ret.append(i)
assert [(1, 'a'), (2, 'b'), (3, 'c')] == ret

self.loop.run_until_complete(go())

def test_transaction_context_manager(self):
async def go():
kw = self._conn_kw()
async with sa.create_engine(**kw) as engine:
async with engine.acquire() as conn:
await self._prepare(conn.connection)
async with conn.begin() as tr:
async with conn.execute(tbl.select()) as cursor:
ret = []
async for i in conn.execute(tbl.select()):
ret.append(i)
assert [(1, 'a'), (2, 'b'), (3, 'c')] == ret
assert cursor.closed
assert not tr.is_active

tr2 = await conn.begin()
async with tr2:
assert tr2.is_active
async with conn.execute('SELECT 1;') as cursor:
rec = await cursor.scalar()
assert rec == 1
cursor.close()
assert not tr2.is_active

assert conn.closed
self.loop.run_until_complete(go())

def test_transaction_context_manager_error(self):
async def go():
kw = self._conn_kw()
async with sa.create_engine(**kw) as engine:
async with engine.acquire() as conn:
with pytest.raises(RuntimeError) as ctx:
async with conn.begin() as tr:
assert tr.is_active
raise RuntimeError('boom')
assert str(ctx.value) == 'boom'
assert not tr.is_active
assert conn.closed
self.loop.run_until_complete(go())

def test_transaction_context_manager_commit_once(self):
async def go():
kw = self._conn_kw()
async with sa.create_engine(**kw) as engine:
async with engine.acquire() as conn:
async with conn.begin() as tr:
# check that in context manager we do not execute
# commit for second time. Two commits in row causes
# InvalidRequestError exception
await tr.commit()
assert not tr.is_active

tr2 = await conn.begin()
async with tr2:
assert tr2.is_active
# check for double commit one more time
await tr2.commit()
assert not tr2.is_active
assert conn.closed
self.loop.run_until_complete(go())

0 comments on commit 52149b6

Please sign in to comment.