Skip to content

Commit

Permalink
facet: Allow more than one possibility on range facets
Browse files Browse the repository at this point in the history
  • Loading branch information
psaiz committed Dec 8, 2023
1 parent e9bcb74 commit ace51d5
Showing 1 changed file with 41 additions and 35 deletions.
76 changes: 41 additions & 35 deletions invenio_records_rest/facets.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,41 +79,47 @@ def range_filter(field, start_date_math=None, end_date_math=None, **kwargs):
"""

def inner(values):
if len(values) != 1 or values[0].count("--") != 1 or values[0] == "--":
raise RESTValidationError(
errors=[FieldError(field, "Invalid range format.")]
)

range_ends = values[0].split("--")
range_args = dict()

ineq_opers = [
{"strict": "gt", "nonstrict": "gte"},
{"strict": "lt", "nonstrict": "lte"},
]
date_maths = [start_date_math, end_date_math]

# Add the proper values to the dict
for range_end, strict, opers, date_math in zip(
range_ends, [">", "<"], ineq_opers, date_maths
):
if range_end != "":
# If first char is '>' for start or '<' for end
if range_end[0] == strict:
dict_key = opers["strict"]
range_end = range_end[1:]
else:
dict_key = opers["nonstrict"]

if date_math:
range_end = "{0}||{1}".format(range_end, date_math)

range_args[dict_key] = range_end

args = kwargs.copy()
args.update(range_args)

return dsl.query.Range(**{field: args})
queries = []
for value in values:
if value.count("--") != 1 or value == "--":
raise RESTValidationError(
errors=[FieldError(field, "Invalid range format.")]
)

range_ends = value.split("--")
range_args = dict()

ineq_opers = [
{"strict": "gt", "nonstrict": "gte"},
{"strict": "lt", "nonstrict": "lte"},
]
date_maths = [start_date_math, end_date_math]

# Add the proper values to the dict
for range_end, strict, opers, date_math in zip(
range_ends, [">", "<"], ineq_opers, date_maths
):
if range_end != "":
# If first char is '>' for start or '<' for end
if range_end[0] == strict:
dict_key = opers["strict"]
range_end = range_end[1:]
else:
dict_key = opers["nonstrict"]

if date_math:
range_end = "{0}||{1}".format(range_end, date_math)

range_args[dict_key] = range_end

args = kwargs.copy()
args.update(range_args)

queries.append(dsl.query.Range(**{field: args}))

if len(queries) > 1:
return dsl.Q("bool", should=queries)
return queries[0]

return inner

Expand Down

0 comments on commit ace51d5

Please sign in to comment.