From 10b73fcfeb8e4778ba4ba35a1497004e13f7f67d Mon Sep 17 00:00:00 2001 From: Vitaly Burovoy Date: Thu, 25 May 2017 14:03:11 +0000 Subject: [PATCH 1/4] Fix typo PooledConnectionProxy=>PoolConnectionProxy --- asyncpg/connection.py | 2 +- tests/test_pool.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 47f5656f..c556a20c 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -918,7 +918,7 @@ def _notify(self, pid, channel, payload): else: # `_proxy` is not None when the connection is a member # of a connection pool. Which means that the user is working - # with a PooledConnectionProxy instance, and expects to see it + # with a `PoolConnectionProxy` instance, and expects to see it # (and not the actual Connection) in their event callbacks. con_ref = self._proxy diff --git a/tests/test_pool.py b/tests/test_pool.py index 278e8eee..d9412768 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -131,7 +131,7 @@ async def test_pool_07(self): cons = set() async def setup(con): - if con._con not in cons: # `con` is `PooledConnectionProxy`. + if con._con not in cons: # `con` is `PoolConnectionProxy`. raise RuntimeError('init was not called before setup') async def init(con): @@ -141,7 +141,7 @@ async def init(con): async def user(pool): async with pool.acquire() as con: - if con._con not in cons: # `con` is `PooledConnectionProxy`. + if con._con not in cons: # `con` is `PoolConnectionProxy`. raise RuntimeError('init was not called') async with self.create_pool(database='postgres', From b6cf4da346c547053ad707f9600d79aaa820f016 Mon Sep 17 00:00:00 2001 From: Vitaly Burovoy Date: Wed, 28 Dec 2016 06:11:45 +0000 Subject: [PATCH 2/4] New severity "V" field as "severity_en" attribute It is identical to the regular severity "S", but it is never localized. Introduced in Postgres-9.6: https://www.postgresql.org/docs/9.6/static/protocol-error-fields.html --- asyncpg/exceptions/_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/asyncpg/exceptions/_base.py b/asyncpg/exceptions/_base.py index 72e9ec73..9a059d7b 100644 --- a/asyncpg/exceptions/_base.py +++ b/asyncpg/exceptions/_base.py @@ -21,6 +21,7 @@ class PostgresMessageMeta(type): _message_map = {} _field_map = { 'S': 'severity', + 'V': 'severity_en', 'C': 'sqlstate', 'M': 'message', 'D': 'detail', From 30db879401a87c9b52663779ce3b063fcde96cb2 Mon Sep 17 00:00:00 2001 From: Vitaly Burovoy Date: Sat, 27 May 2017 14:35:12 +0000 Subject: [PATCH 3/4] Implement "as_dict" method for PostgresMessage Sometimes it is useful to get all non-None members, usually for logging. Do it in a base class to avoid copy-pasting in users' code. --- asyncpg/exceptions/_base.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/asyncpg/exceptions/_base.py b/asyncpg/exceptions/_base.py index 9a059d7b..aa51dd02 100644 --- a/asyncpg/exceptions/_base.py +++ b/asyncpg/exceptions/_base.py @@ -127,6 +127,15 @@ def new(cls, fields, query=None): return e + def as_dict(self): + message = {} + for f in type(self)._field_map.values(): + val = getattr(self, f) + if val is not None: + message[f] = val + + return message + class PostgresError(PostgresMessage, Exception): """Base class for all Postgres errors.""" From 844804a2db6b2655122783a3f113fb16b2ada1a3 Mon Sep 17 00:00:00 2001 From: Vitaly Burovoy Date: Wed, 28 Dec 2016 06:27:03 +0000 Subject: [PATCH 4/4] Add API for receiving asynchronous notices Notice message types are: WARNING, NOTICE, DEBUG, INFO, or LOG. https://www.postgresql.org/docs/current/static/protocol-error-fields.html Issue #144. --- asyncpg/connection.py | 45 +++++++++++++++- asyncpg/exceptions/_base.py | 9 +++- asyncpg/protocol/coreproto.pxd | 1 + asyncpg/protocol/coreproto.pyx | 12 ++++- asyncpg/protocol/protocol.pyx | 8 +++ tests/test_listeners.py | 99 ++++++++++++++++++++++++++++++++++ 6 files changed, 170 insertions(+), 4 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index c556a20c..9cafedec 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -41,7 +41,7 @@ class Connection(metaclass=ConnectionMeta): '_stmt_cache', '_stmts_to_close', '_listeners', '_server_version', '_server_caps', '_intro_query', '_reset_query', '_proxy', '_stmt_exclusive_section', - '_config', '_params', '_addr') + '_config', '_params', '_addr', '_notice_callbacks') def __init__(self, protocol, transport, loop, addr: (str, int) or str, @@ -69,6 +69,7 @@ def __init__(self, protocol, transport, loop, self._stmts_to_close = set() self._listeners = {} + self._notice_callbacks = set() settings = self._protocol.get_settings() ver_string = settings.server_version @@ -126,6 +127,26 @@ async def remove_listener(self, channel, callback): del self._listeners[channel] await self.fetch('UNLISTEN {}'.format(channel)) + def add_notice_callback(self, callback): + """Add a callback for Postgres notices (NOTICE, DEBUG, LOG etc.). + + It will be called when asyncronous NoticeResponse is received + from the connection. Possible message types are: WARNING, NOTICE, DEBUG, + INFO, or LOG. + + :param callable callback: + A callable receiving the following arguments: + **connection**: a Connection the callback is registered with; + **message**: the `exceptions.PostgresNotice` message. + """ + if self.is_closed(): + raise exceptions.InterfaceError('connection is closed') + self._notice_callbacks.add(callback) + + def remove_notice_callback(self, callback): + """Remove a callback for notices.""" + self._notice_callbacks.discard(callback) + def get_server_pid(self): """Return the PID of the Postgres server the connection is bound to.""" return self._protocol.get_server_pid() @@ -821,6 +842,7 @@ async def close(self): self._listeners = {} self._aborted = True await self._protocol.close() + self._notice_callbacks = set() def terminate(self): """Terminate the connection without waiting for pending data.""" @@ -828,6 +850,7 @@ def terminate(self): self._listeners = {} self._aborted = True self._protocol.abort() + self._notice_callbacks = set() async def reset(self): self._check_open() @@ -909,6 +932,26 @@ async def cancel(): self._loop.create_task(cancel()) + def _notice(self, message): + if self._proxy is None: + con_ref = self + else: + # See the comment in the `_notify` below. + con_ref = self._proxy + + for cb in self._notice_callbacks: + self._loop.call_soon(self._call_notice_cb, cb, con_ref, message) + + def _call_notice_cb(self, cb, con_ref, message): + try: + cb(con_ref, message) + except Exception as ex: + self._loop.call_exception_handler({ + 'message': 'Unhandled exception in asyncpg notice message ' + 'callback {!r}'.format(cb), + 'exception': ex + }) + def _notify(self, pid, channel, payload): if channel not in self._listeners: return diff --git a/asyncpg/exceptions/_base.py b/asyncpg/exceptions/_base.py index aa51dd02..41518a53 100644 --- a/asyncpg/exceptions/_base.py +++ b/asyncpg/exceptions/_base.py @@ -9,7 +9,7 @@ __all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError', - 'InterfaceError') + 'InterfaceError', 'PostgresNotice') def _is_asyncpg_class(cls): @@ -151,3 +151,10 @@ class UnknownPostgresError(FatalPostgresError): class InterfaceError(Exception): """An error caused by improper use of asyncpg API.""" + + +class PostgresNotice(PostgresMessage): + sqlstate = '00000' + + def __init__(self, message): + self.args = [message] diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index c3b18f3d..60efa591 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -170,5 +170,6 @@ cdef class CoreProtocol: cdef _on_result(self) cdef _on_notification(self, pid, channel, payload) + cdef _on_notice(self, parsed) cdef _set_server_parameter(self, name, val) cdef _on_connection_lost(self, exc) diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index e8ae79a0..bfd37783 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -56,6 +56,10 @@ cdef class CoreProtocol: # NotificationResponse self._parse_msg_notification() continue + elif mtype == b'N': + # 'N' - NoticeResponse + self._on_notice(self._parse_msg_error_response(False)) + continue if state == PROTOCOL_AUTH: self._process__auth(mtype) @@ -302,10 +306,9 @@ cdef class CoreProtocol: self._push_result() cdef _process__simple_query(self, char mtype): - if mtype in {b'D', b'I', b'N', b'T'}: + if mtype in {b'D', b'I', b'T'}: # 'D' - DataRow # 'I' - EmptyQueryResponse - # 'N' - NoticeResponse # 'T' - RowDescription self.buffer.consume_message() @@ -614,6 +617,8 @@ cdef class CoreProtocol: if is_error: self.result_type = RESULT_FAILED self.result = parsed + else: + return parsed cdef _push_result(self): try: @@ -910,6 +915,9 @@ cdef class CoreProtocol: cdef _on_result(self): pass + cdef _on_notice(self, parsed): + pass + cdef _on_notification(self, pid, channel, payload): pass diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index 3279c7a6..cc8c7bba 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -731,6 +731,14 @@ cdef class BaseProtocol(CoreProtocol): self.last_query = None self.return_extra = False + cdef _on_notice(self, parsed): + # Check it here to avoid unnecessary object creation. + if self.connection._notice_callbacks: + message = apg_exc_base.PostgresMessage.new( + parsed, query=self.last_query) + + self.connection._notice(message) + cdef _on_notification(self, pid, channel, payload): self.connection._notify(pid, channel, payload) diff --git a/tests/test_listeners.py b/tests/test_listeners.py index 5df8861a..d43ab363 100644 --- a/tests/test_listeners.py +++ b/tests/test_listeners.py @@ -8,6 +8,7 @@ import asyncio from asyncpg import _testbase as tb +from asyncpg.exceptions import PostgresNotice, PostgresWarning class TestListeners(tb.ClusterTestCase): @@ -74,3 +75,101 @@ def listener1(*args): self.assertEqual( await q1.get(), (con1, con2.get_server_pid(), 'ipc', 'hello')) + + +class TestNotices(tb.ConnectedTestCase): + async def test_notify_01(self): + q1 = asyncio.Queue(loop=self.loop) + + def notice_callb(con, message): + # data in the message depend on PG's version, hide some values + if message.server_source_line is not None : + message.server_source_line = '***' + + q1.put_nowait((con, type(message), message.as_dict())) + + con = self.con + con.add_notice_callback(notice_callb) + await con.execute( + "DO $$ BEGIN RAISE NOTICE 'catch me!'; END; $$ LANGUAGE plpgsql" + ) + await con.execute( + "DO $$ BEGIN RAISE WARNING 'catch me!'; END; $$ LANGUAGE plpgsql" + ) + + expect_msg = { + 'context': 'PL/pgSQL function inline_code_block line 1 at RAISE', + 'message': 'catch me!', + 'server_source_filename': 'pl_exec.c', + 'server_source_function': 'exec_stmt_raise', + 'server_source_line': '***'} + + expect_msg_notice = expect_msg.copy() + expect_msg_notice.update({ + 'severity': 'NOTICE', + 'severity_en': 'NOTICE', + 'sqlstate': '00000', + }) + + expect_msg_warn = expect_msg.copy() + expect_msg_warn.update({ + 'severity': 'WARNING', + 'severity_en': 'WARNING', + 'sqlstate': '01000', + }) + + if con.get_server_version() < (9, 6): + del expect_msg_notice['context'] + del expect_msg_notice['severity_en'] + del expect_msg_warn['context'] + del expect_msg_warn['severity_en'] + + self.assertEqual( + await q1.get(), + (con, PostgresNotice, expect_msg_notice)) + + self.assertEqual( + await q1.get(), + (con, PostgresWarning, expect_msg_warn)) + + con.remove_notice_callback(notice_callb) + await con.execute( + "DO $$ BEGIN RAISE NOTICE '/dev/null!'; END; $$ LANGUAGE plpgsql" + ) + + self.assertTrue(q1.empty()) + + + async def test_notify_sequence(self): + q1 = asyncio.Queue(loop=self.loop) + + cur_id = None + + def notice_callb(con, message): + q1.put_nowait((con, cur_id, message.message)) + + con = self.con + await con.execute( + "CREATE FUNCTION _test(i INT) RETURNS int LANGUAGE plpgsql AS $$" + " BEGIN" + " RAISE NOTICE '1_%', i;" + " PERFORM pg_sleep(0.1);" + " RAISE NOTICE '2_%', i;" + " RETURN i;" + " END" + "$$" + ) + con.add_notice_callback(notice_callb) + for cur_id in range(10): + await con.execute("SELECT _test($1)", cur_id) + + for cur_id in range(10): + self.assertEqual( + q1.get_nowait(), + (con, cur_id, '1_%s' % cur_id)) + self.assertEqual( + q1.get_nowait(), + (con, cur_id, '2_%s' % cur_id)) + + con.remove_notice_callback(notice_callb) + self.assertTrue(q1.empty())