Skip to content

Commit

Permalink
facet: Allow more than one possibility on range facets
Browse files Browse the repository at this point in the history
  • Loading branch information
psaiz committed Dec 7, 2023
1 parent 945a870 commit a34dc96
Showing 1 changed file with 52 additions and 38 deletions.
90 changes: 52 additions & 38 deletions invenio_records_rest/facets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,22 @@ def inner(values):
queries = []
for value in values:
subvalues = value.split("::")
if len(subvalues)>1:
queries.append(dsl.Q("bool", must=[dsl.Q("term", **{field: subvalues[0]}), dsl.Q("term", **{subfield: subvalues[1]})]))
if len(subvalues) > 1:
queries.append(
dsl.Q(
"bool",
must=[
dsl.Q("term", **{field: subvalues[0]}),
dsl.Q("term", **{subfield: subvalues[1]}),
],
)
)
else:
top_level.append(value)
if len(top_level):
queries.append(dsl.Q("terms", **{field: top_level}))

if len(queries)>1:
if len(queries) > 1:
return dsl.Q("bool", should=queries)
return queries[0]

Expand Down Expand Up @@ -67,41 +75,47 @@ def range_filter(field, start_date_math=None, end_date_math=None, **kwargs):
"""

def inner(values):
if len(values) != 1 or values[0].count("--") != 1 or values[0] == "--":
raise RESTValidationError(
errors=[FieldError(field, "Invalid range format.")]
)

range_ends = values[0].split("--")
range_args = dict()

ineq_opers = [
{"strict": "gt", "nonstrict": "gte"},
{"strict": "lt", "nonstrict": "lte"},
]
date_maths = [start_date_math, end_date_math]

# Add the proper values to the dict
for range_end, strict, opers, date_math in zip(
range_ends, [">", "<"], ineq_opers, date_maths
):
if range_end != "":
# If first char is '>' for start or '<' for end
if range_end[0] == strict:
dict_key = opers["strict"]
range_end = range_end[1:]
else:
dict_key = opers["nonstrict"]

if date_math:
range_end = "{0}||{1}".format(range_end, date_math)

range_args[dict_key] = range_end

args = kwargs.copy()
args.update(range_args)

return dsl.query.Range(**{field: args})
queries = []
for value in values:
if value.count("--") != 1 or value == "--":
raise RESTValidationError(
errors=[FieldError(field, "Invalid range format.")]
)

range_ends = value.split("--")
range_args = dict()

ineq_opers = [
{"strict": "gt", "nonstrict": "gte"},
{"strict": "lt", "nonstrict": "lte"},
]
date_maths = [start_date_math, end_date_math]

# Add the proper values to the dict
for range_end, strict, opers, date_math in zip(
range_ends, [">", "<"], ineq_opers, date_maths
):
if range_end != "":
# If first char is '>' for start or '<' for end
if range_end[0] == strict:
dict_key = opers["strict"]
range_end = range_end[1:]
else:
dict_key = opers["nonstrict"]

if date_math:
range_end = "{0}||{1}".format(range_end, date_math)

range_args[dict_key] = range_end

args = kwargs.copy()
args.update(range_args)

queries.append(dsl.query.Range(**{field: args}))

if len(queries) > 1:
return dsl.Q("bool", should=queries)
return queries[0]

return inner

Expand Down

0 comments on commit a34dc96

Please sign in to comment.