Skip to content

Commit

Permalink
Add support for tuple-format custom codecs on composite types
Browse files Browse the repository at this point in the history
It is now possible to `set_type_codec('mycomposite', ... format='tuple')`,
which is useful for types that are represented by a composite type in
Postgres, but are an integral type in Python, e.g. `complex`.

Fixes: #1060
  • Loading branch information
elprans committed Aug 15, 2023
1 parent 511aeb2 commit 0e72323
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 56 deletions.
34 changes: 29 additions & 5 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,9 @@ async def set_type_codec(self, typename, *,
| ``time with | (``microseconds``, |
| time zone`` | ``time zone offset in seconds``) |
+-----------------+---------------------------------------------+
| any composite | Composite value elements |
| type | |
+-----------------+---------------------------------------------+
:param encoder:
Callable accepting a Python object as a single argument and
Expand Down Expand Up @@ -1208,6 +1211,10 @@ async def set_type_codec(self, typename, *,
The ``binary`` keyword argument was removed in favor of
``format``.
.. versionchanged:: 0.29.0
Custom codecs for composite types are now supported with
``format='tuple'``.
.. note::
It is recommended to use the ``'binary'`` or ``'tuple'`` *format*
Expand All @@ -1218,11 +1225,28 @@ async def set_type_codec(self, typename, *,
codecs.
"""
self._check_open()
settings = self._protocol.get_settings()
typeinfo = await self._introspect_type(typename, schema)
if not introspection.is_scalar_type(typeinfo):
full_typeinfos = []
if introspection.is_scalar_type(typeinfo):
kind = 'scalar'
elif introspection.is_composite_type(typeinfo):
if format != 'tuple':
raise exceptions.UnsupportedClientFeatureError(
'only tuple-format codecs can be used on composite types',
hint="Use `set_type_codec(..., format='tuple')` and "
"pass/interpret data as a Python tuple. See an "
"example at https://magicstack.github.io/asyncpg/"
"current/usage.html#example-decoding-complex-types",
)
kind = 'composite'
full_typeinfos, _ = await self._introspect_types(
(typeinfo['oid'],), 10)
else:
raise exceptions.InterfaceError(
'cannot use custom codec on non-scalar type {}.{}'.format(
schema, typename))
f'cannot use custom codec on type {schema}.{typename}: '
f'it is neither a scalar type nor a composite type'
)
if introspection.is_domain_type(typeinfo):
raise exceptions.UnsupportedClientFeatureError(
'custom codecs on domain types are not supported',
Expand All @@ -1234,8 +1258,8 @@ async def set_type_codec(self, typename, *,
)

oid = typeinfo['oid']
self._protocol.get_settings().add_python_codec(
oid, typename, schema, 'scalar',
settings.add_python_codec(
oid, typename, schema, full_typeinfos, kind,
encoder, decoder, format)

# Statement cache is no longer valid due to codec changes.
Expand Down
4 changes: 4 additions & 0 deletions asyncpg/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,7 @@ def is_scalar_type(typeinfo) -> bool:

def is_domain_type(typeinfo) -> bool:
return typeinfo['kind'] == b'd'


def is_composite_type(typeinfo) -> bool:
return typeinfo['kind'] == b'c'
3 changes: 3 additions & 0 deletions asyncpg/protocol/codecs/base.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ cdef class Codec:

encode_func c_encoder
decode_func c_decoder
Codec base_codec

object py_encoder
object py_decoder
Expand All @@ -79,6 +80,7 @@ cdef class Codec:
CodecType type, ServerDataFormat format,
ClientExchangeFormat xformat,
encode_func c_encoder, decode_func c_decoder,
Codec base_codec,
object py_encoder, object py_decoder,
Codec element_codec, tuple element_type_oids,
object element_names, list element_codecs,
Expand Down Expand Up @@ -169,6 +171,7 @@ cdef class Codec:
object decoder,
encode_func c_encoder,
decode_func c_decoder,
Codec base_codec,
ServerDataFormat format,
ClientExchangeFormat xformat)

Expand Down
90 changes: 62 additions & 28 deletions asyncpg/protocol/codecs/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,25 @@ cdef class Codec:
self.oid = oid
self.type = CODEC_UNDEFINED

cdef init(self, str name, str schema, str kind,
CodecType type, ServerDataFormat format,
ClientExchangeFormat xformat,
encode_func c_encoder, decode_func c_decoder,
object py_encoder, object py_decoder,
Codec element_codec, tuple element_type_oids,
object element_names, list element_codecs,
Py_UCS4 element_delimiter):
cdef init(
self,
str name,
str schema,
str kind,
CodecType type,
ServerDataFormat format,
ClientExchangeFormat xformat,
encode_func c_encoder,
decode_func c_decoder,
Codec base_codec,
object py_encoder,
object py_decoder,
Codec element_codec,
tuple element_type_oids,
object element_names,
list element_codecs,
Py_UCS4 element_delimiter,
):

self.name = name
self.schema = schema
Expand All @@ -40,6 +51,7 @@ cdef class Codec:
self.xformat = xformat
self.c_encoder = c_encoder
self.c_decoder = c_decoder
self.base_codec = base_codec
self.py_encoder = py_encoder
self.py_decoder = py_decoder
self.element_codec = element_codec
Expand All @@ -48,6 +60,12 @@ cdef class Codec:
self.element_delimiter = element_delimiter
self.element_names = element_names

if base_codec is not None:
if c_encoder != NULL or c_decoder != NULL:
raise exceptions.InternalClientError(
'base_codec is mutually exclusive with c_encoder/c_decoder'
)

if element_names is not None:
self.record_desc = record.ApgRecordDesc_New(
element_names, tuple(element_names))
Expand Down Expand Up @@ -98,7 +116,7 @@ cdef class Codec:
codec = Codec(self.oid)
codec.init(self.name, self.schema, self.kind,
self.type, self.format, self.xformat,
self.c_encoder, self.c_decoder,
self.c_encoder, self.c_decoder, self.base_codec,
self.py_encoder, self.py_decoder,
self.element_codec,
self.element_type_oids, self.element_names,
Expand Down Expand Up @@ -196,7 +214,10 @@ cdef class Codec:
raise exceptions.InternalClientError(
'unexpected data format: {}'.format(self.format))
elif self.xformat == PG_XFORMAT_TUPLE:
self.c_encoder(settings, buf, data)
if self.base_codec is not None:
self.base_codec.encode(settings, buf, data)
else:
self.c_encoder(settings, buf, data)
else:
raise exceptions.InternalClientError(
'unexpected exchange format: {}'.format(self.xformat))
Expand Down Expand Up @@ -295,7 +316,10 @@ cdef class Codec:
raise exceptions.InternalClientError(
'unexpected data format: {}'.format(self.format))
elif self.xformat == PG_XFORMAT_TUPLE:
data = self.c_decoder(settings, buf)
if self.base_codec is not None:
data = self.base_codec.decode(settings, buf)
else:
data = self.c_decoder(settings, buf)
else:
raise exceptions.InternalClientError(
'unexpected exchange format: {}'.format(self.xformat))
Expand Down Expand Up @@ -367,8 +391,8 @@ cdef class Codec:
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'array', CODEC_ARRAY, element_codec.format,
PG_XFORMAT_OBJECT, NULL, NULL, None, None, element_codec,
None, None, None, element_delimiter)
PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
element_codec, None, None, None, element_delimiter)
return codec

@staticmethod
Expand All @@ -379,8 +403,8 @@ cdef class Codec:
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'range', CODEC_RANGE, element_codec.format,
PG_XFORMAT_OBJECT, NULL, NULL, None, None, element_codec,
None, None, None, 0)
PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
element_codec, None, None, None, 0)
return codec

@staticmethod
Expand All @@ -391,7 +415,7 @@ cdef class Codec:
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'multirange', CODEC_MULTIRANGE,
element_codec.format, PG_XFORMAT_OBJECT, NULL, NULL,
element_codec.format, PG_XFORMAT_OBJECT, NULL, NULL, None,
None, None, element_codec, None, None, None, 0)
return codec

Expand All @@ -407,7 +431,7 @@ cdef class Codec:
codec = Codec(oid)
codec.init(name, schema, 'composite', CODEC_COMPOSITE,
format, PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
element_type_oids, element_names, element_codecs, 0)
None, element_type_oids, element_names, element_codecs, 0)
return codec

@staticmethod
Expand All @@ -419,12 +443,13 @@ cdef class Codec:
object decoder,
encode_func c_encoder,
decode_func c_decoder,
Codec base_codec,
ServerDataFormat format,
ClientExchangeFormat xformat):
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, kind, CODEC_PY, format, xformat,
c_encoder, c_decoder, encoder, decoder,
c_encoder, c_decoder, base_codec, encoder, decoder,
None, None, None, None, 0)
return codec

Expand Down Expand Up @@ -596,34 +621,43 @@ cdef class DataCodecConfig:
self.declare_fallback_codec(oid, name, schema)

def add_python_codec(self, typeoid, typename, typeschema, typekind,
encoder, decoder, format, xformat):
typeinfos, encoder, decoder, format, xformat):
cdef:
Codec core_codec
Codec core_codec = None
encode_func c_encoder = NULL
decode_func c_decoder = NULL
Codec base_codec = None
uint32_t oid = pylong_as_oid(typeoid)
bint codec_set = False

# Clear all previous overrides (this also clears type cache).
self.remove_python_codec(typeoid, typename, typeschema)

if typeinfos:
self.add_types(typeinfos)

if format == PG_FORMAT_ANY:
formats = (PG_FORMAT_TEXT, PG_FORMAT_BINARY)
else:
formats = (format,)

for fmt in formats:
if xformat == PG_XFORMAT_TUPLE:
core_codec = get_core_codec(oid, fmt, xformat)
if core_codec is None:
continue
c_encoder = core_codec.c_encoder
c_decoder = core_codec.c_decoder
if typekind == "scalar":
core_codec = get_core_codec(oid, fmt, xformat)
if core_codec is None:
continue
c_encoder = core_codec.c_encoder
c_decoder = core_codec.c_decoder
elif typekind == "composite":
base_codec = self.get_codec(oid, fmt)
if base_codec is None:
continue

self._custom_type_codecs[typeoid, fmt] = \
Codec.new_python_codec(oid, typename, typeschema, typekind,
encoder, decoder, c_encoder, c_decoder,
fmt, xformat)
base_codec, fmt, xformat)
codec_set = True

if not codec_set:
Expand Down Expand Up @@ -829,7 +863,7 @@ cdef register_core_codec(uint32_t oid,

codec = Codec(oid)
codec.init(name, 'pg_catalog', kind, CODEC_C, format, xformat,
encode, decode, None, None, None, None, None, None, 0)
encode, decode, None, None, None, None, None, None, None, 0)
cpython.Py_INCREF(codec) # immortalize

if format == PG_FORMAT_BINARY:
Expand All @@ -853,7 +887,7 @@ cdef register_extra_codec(str name,

codec = Codec(INVALIDOID)
codec.init(name, None, kind, CODEC_C, format, PG_XFORMAT_OBJECT,
encode, decode, None, None, None, None, None, None, 0)
encode, decode, None, None, None, None, None, None, None, 0)
EXTRA_CODECS[name, format] = codec


Expand Down
2 changes: 1 addition & 1 deletion asyncpg/protocol/settings.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ cdef class ConnectionSettings(pgproto.CodecContext):
cpdef get_text_codec(self)
cpdef inline register_data_types(self, types)
cpdef inline add_python_codec(
self, typeoid, typename, typeschema, typekind, encoder,
self, typeoid, typename, typeschema, typeinfos, typekind, encoder,
decoder, format)
cpdef inline remove_python_codec(
self, typeoid, typename, typeschema)
Expand Down
6 changes: 4 additions & 2 deletions asyncpg/protocol/settings.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ cdef class ConnectionSettings(pgproto.CodecContext):
self._data_codecs.add_types(types)

cpdef inline add_python_codec(self, typeoid, typename, typeschema,
typekind, encoder, decoder, format):
typeinfos, typekind, encoder, decoder,
format):
cdef:
ServerDataFormat _format
ClientExchangeFormat xformat
Expand All @@ -57,7 +58,8 @@ cdef class ConnectionSettings(pgproto.CodecContext):
))

self._data_codecs.add_python_codec(typeoid, typename, typeschema,
typekind, encoder, decoder,
typekind, typeinfos,
encoder, decoder,
_format, xformat)

cpdef inline remove_python_codec(self, typeoid, typename, typeschema):
Expand Down
43 changes: 41 additions & 2 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,46 @@ JSON values using the :mod:`json <python:json>` module.
finally:
await conn.close()
asyncio.get_event_loop().run_until_complete(main())
asyncio.run(main())
Example: complex types
~~~~~~~~~~~~~~~~~~~~~~

The example below shows how to configure asyncpg to encode and decode
Python :class:`complex <python:complex>` values to a custom composite
type in PostgreSQL.

.. code-block:: python
import asyncio
import asyncpg
async def main():
conn = await asyncpg.connect()
try:
await conn.execute(
'''
CREATE TYPE mycomplex AS (
r float,
i float
);'''
)
await conn.set_type_codec(
'complex',
encoder=lambda x: (x.real, x.imag),
decoder=lambda t: complex(t[0], t[1]),
format='tuple',
)
res = await conn.fetchval('SELECT $1::complex', (1+2j))
finally:
await conn.close()
asyncio.run(main())
Example: automatic conversion of PostGIS types
Expand Down Expand Up @@ -274,7 +313,7 @@ will work.
finally:
await conn.close()
asyncio.get_event_loop().run_until_complete(main())
asyncio.run(main())
Example: decoding numeric columns as floats
Expand Down
Loading

0 comments on commit 0e72323

Please sign in to comment.