From a34dc966cf28bbb6dbcf7daee55f63f4bf4d2774 Mon Sep 17 00:00:00 2001 From: Pablo Saiz Date: Thu, 7 Dec 2023 18:25:47 +0100 Subject: [PATCH] facet: Allow more than one possibility on range facets --- invenio_records_rest/facets.py | 90 ++++++++++++++++++++-------------- 1 file changed, 52 insertions(+), 38 deletions(-) diff --git a/invenio_records_rest/facets.py b/invenio_records_rest/facets.py index 3102996..fbd775a 100644 --- a/invenio_records_rest/facets.py +++ b/invenio_records_rest/facets.py @@ -29,14 +29,22 @@ def inner(values): queries = [] for value in values: subvalues = value.split("::") - if len(subvalues)>1: - queries.append(dsl.Q("bool", must=[dsl.Q("term", **{field: subvalues[0]}), dsl.Q("term", **{subfield: subvalues[1]})])) + if len(subvalues) > 1: + queries.append( + dsl.Q( + "bool", + must=[ + dsl.Q("term", **{field: subvalues[0]}), + dsl.Q("term", **{subfield: subvalues[1]}), + ], + ) + ) else: top_level.append(value) if len(top_level): queries.append(dsl.Q("terms", **{field: top_level})) - if len(queries)>1: + if len(queries) > 1: return dsl.Q("bool", should=queries) return queries[0] @@ -67,41 +75,47 @@ def range_filter(field, start_date_math=None, end_date_math=None, **kwargs): """ def inner(values): - if len(values) != 1 or values[0].count("--") != 1 or values[0] == "--": - raise RESTValidationError( - errors=[FieldError(field, "Invalid range format.")] - ) - - range_ends = values[0].split("--") - range_args = dict() - - ineq_opers = [ - {"strict": "gt", "nonstrict": "gte"}, - {"strict": "lt", "nonstrict": "lte"}, - ] - date_maths = [start_date_math, end_date_math] - - # Add the proper values to the dict - for range_end, strict, opers, date_math in zip( - range_ends, [">", "<"], ineq_opers, date_maths - ): - if range_end != "": - # If first char is '>' for start or '<' for end - if range_end[0] == strict: - dict_key = opers["strict"] - range_end = range_end[1:] - else: - dict_key = opers["nonstrict"] - - if date_math: - range_end = "{0}||{1}".format(range_end, date_math) - - range_args[dict_key] = range_end - - args = kwargs.copy() - args.update(range_args) - - return dsl.query.Range(**{field: args}) + queries = [] + for value in values: + if value.count("--") != 1 or value == "--": + raise RESTValidationError( + errors=[FieldError(field, "Invalid range format.")] + ) + + range_ends = value.split("--") + range_args = dict() + + ineq_opers = [ + {"strict": "gt", "nonstrict": "gte"}, + {"strict": "lt", "nonstrict": "lte"}, + ] + date_maths = [start_date_math, end_date_math] + + # Add the proper values to the dict + for range_end, strict, opers, date_math in zip( + range_ends, [">", "<"], ineq_opers, date_maths + ): + if range_end != "": + # If first char is '>' for start or '<' for end + if range_end[0] == strict: + dict_key = opers["strict"] + range_end = range_end[1:] + else: + dict_key = opers["nonstrict"] + + if date_math: + range_end = "{0}||{1}".format(range_end, date_math) + + range_args[dict_key] = range_end + + args = kwargs.copy() + args.update(range_args) + + queries.append(dsl.query.Range(**{field: args})) + + if len(queries) > 1: + return dsl.Q("bool", should=queries) + return queries[0] return inner