Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API for receiving asynchronous notices #147

Merged
merged 4 commits into from
Jul 5, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Connection(metaclass=ConnectionMeta):
'_stmt_cache', '_stmts_to_close', '_listeners',
'_server_version', '_server_caps', '_intro_query',
'_reset_query', '_proxy', '_stmt_exclusive_section',
'_config', '_params', '_addr')
'_config', '_params', '_addr', '_notice_callbacks')

def __init__(self, protocol, transport, loop,
addr: (str, int) or str,
Expand Down Expand Up @@ -69,6 +69,7 @@ def __init__(self, protocol, transport, loop,
self._stmts_to_close = set()

self._listeners = {}
self._notice_callbacks = set()

settings = self._protocol.get_settings()
ver_string = settings.server_version
Expand Down Expand Up @@ -126,6 +127,26 @@ async def remove_listener(self, channel, callback):
del self._listeners[channel]
await self.fetch('UNLISTEN {}'.format(channel))

def add_notice_callback(self, callback):
"""Add a callback for Postgres notices (NOTICE, DEBUG, LOG etc.).

It will be called when asyncronous NoticeResponse is received
from the connection. Possible message types are: WARNING, NOTICE, DEBUG,
INFO, or LOG.

:param callable callback:
A callable receiving the following arguments:
**connection**: a Connection the callback is registered with;
**message**: the `exceptions.PostgresNotice` message.
"""
if self.is_closed():
raise exceptions.InterfaceError('connection is closed')
self._notice_callbacks.add(callback)

def remove_notice_callback(self, callback):
"""Remove a callback for notices."""
self._notice_callbacks.discard(callback)

def get_server_pid(self):
"""Return the PID of the Postgres server the connection is bound to."""
return self._protocol.get_server_pid()
Expand Down Expand Up @@ -821,13 +842,15 @@ async def close(self):
self._listeners = {}
self._aborted = True
await self._protocol.close()
self._notice_callbacks = set()

def terminate(self):
"""Terminate the connection without waiting for pending data."""
self._mark_stmts_as_closed()
self._listeners = {}
self._aborted = True
self._protocol.abort()
self._notice_callbacks = set()

async def reset(self):
self._check_open()
Expand Down Expand Up @@ -909,6 +932,26 @@ async def cancel():

self._loop.create_task(cancel())

def _notice(self, message):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rename this to _schedule_notice_callbacks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disagree. They are called immediately, there is no schedule process.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Well, you should use loop.call_soon anyways. We never want to call callback right away.

if self._proxy is None:
con_ref = self
else:
# See the comment in the `_notify` below.
con_ref = self._proxy

for cb in self._notice_callbacks:
self._loop.call_soon(self._call_notice_cb, cb, con_ref, message)

def _call_notice_cb(self, cb, con_ref, message):
try:
cb(con_ref, message)
except Exception as ex:
self._loop.call_exception_handler({
'message': 'Unhandled exception in asyncpg notice message '
'callback {!r}'.format(cb),
'exception': ex
})

def _notify(self, pid, channel, payload):
if channel not in self._listeners:
return
Expand All @@ -918,7 +961,7 @@ def _notify(self, pid, channel, payload):
else:
# `_proxy` is not None when the connection is a member
# of a connection pool. Which means that the user is working
# with a PooledConnectionProxy instance, and expects to see it
# with a `PoolConnectionProxy` instance, and expects to see it
# (and not the actual Connection) in their event callbacks.
con_ref = self._proxy

Expand Down
19 changes: 18 additions & 1 deletion asyncpg/exceptions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
'InterfaceError')
'InterfaceError', 'PostgresNotice')


def _is_asyncpg_class(cls):
Expand All @@ -21,6 +21,7 @@ class PostgresMessageMeta(type):
_message_map = {}
_field_map = {
'S': 'severity',
'V': 'severity_en',
'C': 'sqlstate',
'M': 'message',
'D': 'detail',
Expand Down Expand Up @@ -126,6 +127,15 @@ def new(cls, fields, query=None):

return e

def as_dict(self):
message = {}
for f in type(self)._field_map.values():
val = getattr(self, f)
if val is not None:
message[f] = val

return message


class PostgresError(PostgresMessage, Exception):
"""Base class for all Postgres errors."""
Expand All @@ -141,3 +151,10 @@ class UnknownPostgresError(FatalPostgresError):

class InterfaceError(Exception):
"""An error caused by improper use of asyncpg API."""


class PostgresNotice(PostgresMessage):
sqlstate = '00000'

def __init__(self, message):
self.args = [message]
1 change: 1 addition & 0 deletions asyncpg/protocol/coreproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -170,5 +170,6 @@ cdef class CoreProtocol:

cdef _on_result(self)
cdef _on_notification(self, pid, channel, payload)
cdef _on_notice(self, parsed)
cdef _set_server_parameter(self, name, val)
cdef _on_connection_lost(self, exc)
12 changes: 10 additions & 2 deletions asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ cdef class CoreProtocol:
# NotificationResponse
self._parse_msg_notification()
continue
elif mtype == b'N':
# 'N' - NoticeResponse
self._on_notice(self._parse_msg_error_response(False))
continue

if state == PROTOCOL_AUTH:
self._process__auth(mtype)
Expand Down Expand Up @@ -302,10 +306,9 @@ cdef class CoreProtocol:
self._push_result()

cdef _process__simple_query(self, char mtype):
if mtype in {b'D', b'I', b'N', b'T'}:
if mtype in {b'D', b'I', b'T'}:
# 'D' - DataRow
# 'I' - EmptyQueryResponse
# 'N' - NoticeResponse
# 'T' - RowDescription
self.buffer.consume_message()

Expand Down Expand Up @@ -614,6 +617,8 @@ cdef class CoreProtocol:
if is_error:
self.result_type = RESULT_FAILED
self.result = parsed
else:
return parsed

cdef _push_result(self):
try:
Expand Down Expand Up @@ -910,6 +915,9 @@ cdef class CoreProtocol:
cdef _on_result(self):
pass

cdef _on_notice(self, parsed):
pass

cdef _on_notification(self, pid, channel, payload):
pass

Expand Down
8 changes: 8 additions & 0 deletions asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,14 @@ cdef class BaseProtocol(CoreProtocol):
self.last_query = None
self.return_extra = False

cdef _on_notice(self, parsed):
# Check it here to avoid unnecessary object creation.
if self.connection._notice_callbacks:
message = apg_exc_base.PostgresMessage.new(
parsed, query=self.last_query)

self.connection._notice(message)

cdef _on_notification(self, pid, channel, payload):
self.connection._notify(pid, channel, payload)

Expand Down
99 changes: 99 additions & 0 deletions tests/test_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import asyncio

from asyncpg import _testbase as tb
from asyncpg.exceptions import PostgresNotice, PostgresWarning


class TestListeners(tb.ClusterTestCase):
Expand Down Expand Up @@ -74,3 +75,101 @@ def listener1(*args):
self.assertEqual(
await q1.get(),
(con1, con2.get_server_pid(), 'ipc', 'hello'))


class TestNotices(tb.ConnectedTestCase):
async def test_notify_01(self):
q1 = asyncio.Queue(loop=self.loop)

def notice_callb(con, message):
# data in the message depend on PG's version, hide some values
if message.server_source_line is not None :
message.server_source_line = '***'

q1.put_nowait((con, type(message), message.as_dict()))

con = self.con
con.add_notice_callback(notice_callb)
await con.execute(
"DO $$ BEGIN RAISE NOTICE 'catch me!'; END; $$ LANGUAGE plpgsql"
)
await con.execute(
"DO $$ BEGIN RAISE WARNING 'catch me!'; END; $$ LANGUAGE plpgsql"
)

expect_msg = {
'context': 'PL/pgSQL function inline_code_block line 1 at RAISE',
'message': 'catch me!',
'server_source_filename': 'pl_exec.c',
'server_source_function': 'exec_stmt_raise',
'server_source_line': '***'}

expect_msg_notice = expect_msg.copy()
expect_msg_notice.update({
'severity': 'NOTICE',
'severity_en': 'NOTICE',
'sqlstate': '00000',
})

expect_msg_warn = expect_msg.copy()
expect_msg_warn.update({
'severity': 'WARNING',
'severity_en': 'WARNING',
'sqlstate': '01000',
})

if con.get_server_version() < (9, 6):
del expect_msg_notice['context']
del expect_msg_notice['severity_en']
del expect_msg_warn['context']
del expect_msg_warn['severity_en']

self.assertEqual(
await q1.get(),
(con, PostgresNotice, expect_msg_notice))

self.assertEqual(
await q1.get(),
(con, PostgresWarning, expect_msg_warn))

con.remove_notice_callback(notice_callb)
await con.execute(
"DO $$ BEGIN RAISE NOTICE '/dev/null!'; END; $$ LANGUAGE plpgsql"
)

self.assertTrue(q1.empty())


async def test_notify_sequence(self):
q1 = asyncio.Queue(loop=self.loop)

cur_id = None

def notice_callb(con, message):
q1.put_nowait((con, cur_id, message.message))

con = self.con
await con.execute(
"CREATE FUNCTION _test(i INT) RETURNS int LANGUAGE plpgsql AS $$"
" BEGIN"
" RAISE NOTICE '1_%', i;"
" PERFORM pg_sleep(0.1);"
" RAISE NOTICE '2_%', i;"
" RETURN i;"
" END"
"$$"
)
con.add_notice_callback(notice_callb)
for cur_id in range(10):
await con.execute("SELECT _test($1)", cur_id)

for cur_id in range(10):
self.assertEqual(
q1.get_nowait(),
(con, cur_id, '1_%s' % cur_id))
self.assertEqual(
q1.get_nowait(),
(con, cur_id, '2_%s' % cur_id))

con.remove_notice_callback(notice_callb)
self.assertTrue(q1.empty())
4 changes: 2 additions & 2 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ async def test_pool_07(self):
cons = set()

async def setup(con):
if con._con not in cons: # `con` is `PooledConnectionProxy`.
if con._con not in cons: # `con` is `PoolConnectionProxy`.
raise RuntimeError('init was not called before setup')

async def init(con):
Expand All @@ -141,7 +141,7 @@ async def init(con):

async def user(pool):
async with pool.acquire() as con:
if con._con not in cons: # `con` is `PooledConnectionProxy`.
if con._con not in cons: # `con` is `PoolConnectionProxy`.
raise RuntimeError('init was not called')

async with self.create_pool(database='postgres',
Expand Down