Skip to content

Commit

Permalink
Add support for more range and multirange types
Browse files Browse the repository at this point in the history
  • Loading branch information
tlocke committed Jan 3, 2024
1 parent 796cd01 commit 8adb556
Showing 1 changed file with 95 additions and 59 deletions.
154 changes: 95 additions & 59 deletions pg8000/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,22 @@
DATE = 1082
DATE_ARRAY = 1182
DATEMULTIRANGE = 4535
DATEMULTIRANGE_ARRAY = 6155
DATERANGE = 3912
DATERANGE_ARRAY = 3913
FLOAT = 701
FLOAT_ARRAY = 1022
INET = 869
INET_ARRAY = 1041
INT2VECTOR = 22
INT4MULTIRANGE = 4451
INT4MULTIRANGE_ARRAY = 6150
INT4RANGE = 3904
INT4RANGE_ARRAY = 3905
INT8MULTIRANGE = 4536
INT8MULTIRANGE_ARRAY = 6157
INT8RANGE = 3926
INT8RANGE_ARRAY = 3927
INTEGER = 23
INTEGER_ARRAY = 1007
INTERVAL = 1186
Expand All @@ -67,7 +73,9 @@
NUMERIC = 1700
NUMERIC_ARRAY = 1231
NUMRANGE = 3906
NUMRANGE_ARRAY = 3907
NUMMULTIRANGE = 4532
NUMMULTIRANGE_ARRAY = 6151
NULLTYPE = -1
OID = 26
POINT = 600
Expand All @@ -87,9 +95,13 @@
TIMESTAMPTZ = 1184
TIMESTAMPTZ_ARRAY = 1185
TSMULTIRANGE = 4533
TSMULTIRANGE_ARRAY = 6152
TSRANGE = 3908
TSRANGE_ARRAY = 3909
TSTZMULTIRANGE = 4534
TSTZMULTIRANGE_ARRAY = 6153
TSTZRANGE = 3910
TSTZRANGE_ARRAY = 3911
UNKNOWN = 705
UUID_TYPE = 2950
UUID_ARRAY = 2951
Expand Down Expand Up @@ -291,6 +303,65 @@ def uuid_in(data):
return UUID(data)


def _range_in(elem_func):
def range_in(data):
if data == "empty":
return Range(is_empty=True)
else:
le, ue = [None if v == "" else elem_func(v) for v in data[1:-1].split(",")]
return Range(le, ue, bounds=f"{data[0]}{data[-1]}")

return range_in


daterange_in = _range_in(date_in)
int4range_in = _range_in(int)
int8range_in = _range_in(int)
numrange_in = _range_in(Decimal)


def ts_in(data):
return timestamp_in(data[1:-1])


def tstz_in(data):
return timestamptz_in(data[1:-1])


tsrange_in = _range_in(ts_in)
tstzrange_in = _range_in(tstz_in)


def _multirange_in(adapter):
def f(data):
in_range = False
result = []
val = []
for c in data:
if in_range:
val.append(c)
if c in "])":
value = "".join(val)
val.clear()
result.append(adapter(value))
in_range = False
elif c in "[(":
val.append(c)
in_range = True

return result

return f


datemultirange_in = _multirange_in(daterange_in)
int4multirange_in = _multirange_in(int4range_in)
int8multirange_in = _multirange_in(int8range_in)
nummultirange_in = _multirange_in(numrange_in)
tsmultirange_in = _multirange_in(tsrange_in)
tstzmultirange_in = _multirange_in(tstzrange_in)


class ParserState(Enum):
InString = 1
InEscape = 2
Expand Down Expand Up @@ -353,16 +424,28 @@ def f(data):
bytes_array_in = _array_in(bytes_in)
cidr_array_in = _array_in(cidr_in)
date_array_in = _array_in(date_in)
datemultirange_array_in = _array_in(datemultirange_in)
daterange_array_in = _array_in(daterange_in)
inet_array_in = _array_in(inet_in)
int_array_in = _array_in(int)
int4multirange_array_in = _array_in(int4multirange_in)
int4range_array_in = _array_in(int4range_in)
int8multirange_array_in = _array_in(int8multirange_in)
int8range_array_in = _array_in(int8range_in)
interval_array_in = _array_in(interval_in)
json_array_in = _array_in(json_in)
float_array_in = _array_in(float)
numeric_array_in = _array_in(numeric_in)
nummultirange_array_in = _array_in(nummultirange_in)
numrange_array_in = _array_in(numrange_in)
string_array_in = _array_in(string_in)
time_array_in = _array_in(time_in)
timestamp_array_in = _array_in(timestamp_in)
timestamptz_array_in = _array_in(timestamptz_in)
tsrange_array_in = _array_in(tsrange_in)
tsmultirange_array_in = _array_in(tsmultirange_in)
tstzmultirange_array_in = _array_in(tstzmultirange_in)
tstzrange_array_in = _array_in(tstzrange_in)
uuid_array_in = _array_in(uuid_in)


Expand Down Expand Up @@ -443,65 +526,6 @@ def composite_out(ar):
return f'({",".join(result)})'


def _range_in(elem_func):
def range_in(data):
if data == "empty":
return Range(is_empty=True)
else:
le, ue = [None if v == "" else elem_func(v) for v in data[1:-1].split(",")]
return Range(le, ue, bounds=f"{data[0]}{data[-1]}")

return range_in


daterange_in = _range_in(date_in)
int4range_in = _range_in(int)
int8range_in = _range_in(int)
numrange_in = _range_in(Decimal)


def ts_in(data):
return timestamp_in(data[1:-1])


def tstz_in(data):
return timestamptz_in(data[1:-1])


tsrange_in = _range_in(ts_in)
tstzrange_in = _range_in(tstz_in)


def _multirange_in(adapter):
def f(data):
in_range = False
result = []
val = []
for c in data:
if in_range:
val.append(c)
if c in "])":
value = "".join(val)
val.clear()
result.append(adapter(value))
in_range = False
elif c in "[(":
val.append(c)
in_range = True

return result

return f


datemultirange_in = _multirange_in(daterange_in)
int4multirange_in = _multirange_in(int4range_in)
int8multirange_in = _multirange_in(int8range_in)
nummultirange_in = _multirange_in(numrange_in)
tsmultirange_in = _multirange_in(tsrange_in)
tstzmultirange_in = _multirange_in(tstzrange_in)


def record_in(data):
state = ParserState.Out
results = []
Expand Down Expand Up @@ -605,15 +629,21 @@ def record_in(data):
DATE: date_in, # date
DATE_ARRAY: date_array_in, # date[]
DATEMULTIRANGE: datemultirange_in, # datemultirange
DATEMULTIRANGE_ARRAY: datemultirange_array_in, # datemultirange[]
DATERANGE: daterange_in, # daterange
DATERANGE_ARRAY: daterange_array_in, # daterange[]
FLOAT: float, # float8
FLOAT_ARRAY: float_array_in, # float8[]
INET: inet_in, # inet
INET_ARRAY: inet_array_in, # inet[]
INT4MULTIRANGE: int4multirange_in, # int4multirange
INT4MULTIRANGE_ARRAY: int4multirange_array_in, # int4multirange[]
INT4RANGE: int4range_in, # int4range
INT4RANGE_ARRAY: int4range_array_in, # int4range[]
INT8MULTIRANGE: int8multirange_in, # int8multirange
INT8MULTIRANGE_ARRAY: int8multirange_array_in, # int8multirange[]
INT8RANGE: int8range_in, # int8range
INT8RANGE_ARRAY: int8range_array_in, # int8range[]
INTEGER: int, # int4
INTEGER_ARRAY: int_array_in, # int4[]
JSON: json_in, # json
Expand All @@ -628,7 +658,9 @@ def record_in(data):
NUMERIC: numeric_in, # numeric
NUMERIC_ARRAY: numeric_array_in, # numeric[]
NUMRANGE: numrange_in, # numrange
NUMRANGE_ARRAY: numrange_array_in, # numrange[]
NUMMULTIRANGE: nummultirange_in, # nummultirange
NUMMULTIRANGE_ARRAY: nummultirange_array_in, # nummultirange[]
OID: int, # oid
POINT: point_in, # point
INTERVAL: interval_in, # interval
Expand All @@ -649,9 +681,13 @@ def record_in(data):
TIMESTAMPTZ: timestamptz_in, # timestamptz
TIMESTAMPTZ_ARRAY: timestamptz_array_in, # timestamptz
TSMULTIRANGE: tsmultirange_in, # tsmultirange
TSMULTIRANGE_ARRAY: tsmultirange_array_in, # tsmultirange[]
TSRANGE: tsrange_in, # tsrange
TSRANGE_ARRAY: tsrange_array_in, # tsrange[]
TSTZMULTIRANGE: tstzmultirange_in, # tstzmultirange
TSTZMULTIRANGE_ARRAY: tstzmultirange_array_in, # tstzmultirange[]
TSTZRANGE: tstzrange_in, # tstzrange
TSTZRANGE_ARRAY: tstzrange_array_in, # tstzrange[]
UNKNOWN: string_in, # unknown
UUID_ARRAY: uuid_array_in, # uuid[]
UUID_TYPE: uuid_in, # uuid
Expand Down

0 comments on commit 8adb556

Please sign in to comment.