Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fetchmany to execute many *and* return rows #1175

Merged
merged 2 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,44 @@ async def fetchrow(
return None
return data[0]

async def fetchmany(
self, query, args, *, timeout: float=None, record_class=None
):
"""Run a query for each sequence of arguments in *args*
and return the results as a list of :class:`Record`.

:param query:
Query to execute.
:param args:
An iterable containing sequences of arguments for the query.
:param float timeout:
Optional timeout value in seconds.
:param type record_class:
If specified, the class to use for records returned by this method.
Must be a subclass of :class:`~asyncpg.Record`. If not specified,
a per-connection *record_class* is used.

:return list:
A list of :class:`~asyncpg.Record` instances. If specified, the
actual type of list elements would be *record_class*.

Example:

.. code-block:: pycon

>>> rows = await con.fetchmany('''
... INSERT INTO mytab (a, b) VALUES ($1, $2) RETURNING a;
... ''', [('x', 1), ('y', 2), ('z', 3)])
>>> rows
[<Record row=('x',)>, <Record row=('y',)>, <Record row=('z',)>]

.. versionadded:: 0.30.0
"""
self._check_open()
return await self._executemany(
query, args, timeout, return_rows=True, record_class=record_class
)

async def copy_from_table(self, table_name, *, output,
columns=None, schema_name=None, timeout=None,
format=None, oids=None, delimiter=None,
Expand Down Expand Up @@ -1896,17 +1934,27 @@ async def __execute(
)
return result, stmt

async def _executemany(self, query, args, timeout):
async def _executemany(
self,
query,
args,
timeout,
return_rows=False,
record_class=None,
):
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
state=stmt,
args=args,
portal_name='',
timeout=timeout,
return_rows=return_rows,
)
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
with self._time_and_log(query, args, timeout):
result, _ = await self._do_execute(query, executor, timeout)
result, _ = await self._do_execute(
query, executor, timeout, record_class=record_class
)
return result

async def _do_execute(
Expand Down
16 changes: 16 additions & 0 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,22 @@ async def fetchrow(self, query, *args, timeout=None, record_class=None):
record_class=record_class
)

async def fetchmany(self, query, args, *, timeout=None, record_class=None):
"""Run a query for each sequence of arguments in *args*
and return the results as a list of :class:`Record`.

Pool performs this operation using one of its connections. Other than
that, it behaves identically to
:meth:`Connection.fetchmany()
<asyncpg.connection.Connection.fetchmany>`.

.. versionadded:: 0.30.0
"""
async with self.acquire() as con:
return await con.fetchmany(
query, args, timeout=timeout, record_class=record_class
)

async def copy_from_table(
self,
table_name,
Expand Down
28 changes: 27 additions & 1 deletion asyncpg/prepared_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,27 @@ async def fetchrow(self, *args, timeout=None):
return None
return data[0]

@connresource.guarded
async def fetchmany(self, args, *, timeout=None):
"""Execute the statement and return a list of :class:`Record` objects.

:param args: Query arguments.
:param float timeout: Optional timeout value in seconds.

:return: A list of :class:`Record` instances.

.. versionadded:: 0.30.0
"""
return await self.__do_execute(
lambda protocol: protocol.bind_execute_many(
self._state,
args,
portal_name='',
timeout=timeout,
return_rows=True,
)
)

@connresource.guarded
async def executemany(self, args, *, timeout: float=None):
"""Execute the statement for each sequence of arguments in *args*.
Expand All @@ -222,7 +243,12 @@ async def executemany(self, args, *, timeout: float=None):
"""
return await self.__do_execute(
lambda protocol: protocol.bind_execute_many(
self._state, args, '', timeout))
self._state,
args,
portal_name='',
timeout=timeout,
return_rows=False,
))

async def __do_execute(self, executor):
protocol = self._connection._protocol
Expand Down
2 changes: 1 addition & 1 deletion asyncpg/protocol/coreproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ cdef class CoreProtocol:
cdef _bind_execute(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data)
object bind_data, bint return_rows)
cdef bint _bind_execute_many_more(self, bint first=*)
cdef _bind_execute_many_fail(self, object error, bint first=*)
cdef _bind(self, str portal_name, str stmt_name,
Expand Down
6 changes: 3 additions & 3 deletions asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1020,12 +1020,12 @@ cdef class CoreProtocol:
self._send_bind_message(portal_name, stmt_name, bind_data, limit)

cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data):
object bind_data, bint return_rows):
self._ensure_connected()
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)

self.result = None
self._discard_data = True
self.result = [] if return_rows else None
self._discard_data = not return_rows
self._execute_iter = bind_data
self._execute_portal_name = portal_name
self._execute_stmt_name = stmt_name
Expand Down
4 changes: 3 additions & 1 deletion asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ cdef class BaseProtocol(CoreProtocol):
args,
portal_name: str,
timeout,
return_rows: bool,
):
if self.cancel_waiter is not None:
await self.cancel_waiter
Expand All @@ -237,7 +238,8 @@ cdef class BaseProtocol(CoreProtocol):
more = self._bind_execute_many(
portal_name,
state.name,
arg_bufs) # network op
arg_bufs,
return_rows) # network op

self.last_query = state.query
self.statement = state
Expand Down
39 changes: 39 additions & 0 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,45 @@ async def test_executemany_basic(self):
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

async def test_executemany_returning(self):
result = await self.con.fetchmany('''
INSERT INTO exmany VALUES($1, $2) RETURNING a, b
''', [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])
result = await self.con.fetch('''
SELECT * FROM exmany
''')
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

# Empty set
await self.con.fetchmany('''
INSERT INTO exmany VALUES($1, $2) RETURNING a, b
''', ())
result = await self.con.fetch('''
SELECT * FROM exmany
''')
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

# Without "RETURNING"
result = await self.con.fetchmany('''
INSERT INTO exmany VALUES($1, $2)
''', [('e', 5), ('f', 6)])
self.assertEqual(result, [])
result = await self.con.fetch('''
SELECT * FROM exmany
''')
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6)
])

async def test_executemany_bad_input(self):
with self.assertRaisesRegex(
exceptions.DataError,
Expand Down
14 changes: 14 additions & 0 deletions tests/test_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,3 +611,17 @@ async def test_prepare_explicitly_named(self):
'prepared statement "foobar" already exists',
):
await self.con.prepare('select 1', name='foobar')

async def test_prepare_fetchmany(self):
tr = self.con.transaction()
await tr.start()
try:
await self.con.execute('CREATE TABLE fetchmany (a int, b text)')

stmt = await self.con.prepare(
'INSERT INTO fetchmany (a, b) VALUES ($1, $2) RETURNING a, b'
)
result = await stmt.fetchmany([(1, 'a'), (2, 'b'), (3, 'c')])
self.assertEqual(result, [(1, 'a'), (2, 'b'), (3, 'c')])
finally:
await tr.rollback()
Loading