diff --git a/invenio_records_rest/facets.py b/invenio_records_rest/facets.py index d8cbe22..c0e6667 100644 --- a/invenio_records_rest/facets.py +++ b/invenio_records_rest/facets.py @@ -79,41 +79,47 @@ def range_filter(field, start_date_math=None, end_date_math=None, **kwargs): """ def inner(values): - if len(values) != 1 or values[0].count("--") != 1 or values[0] == "--": - raise RESTValidationError( - errors=[FieldError(field, "Invalid range format.")] - ) - - range_ends = values[0].split("--") - range_args = dict() - - ineq_opers = [ - {"strict": "gt", "nonstrict": "gte"}, - {"strict": "lt", "nonstrict": "lte"}, - ] - date_maths = [start_date_math, end_date_math] - - # Add the proper values to the dict - for range_end, strict, opers, date_math in zip( - range_ends, [">", "<"], ineq_opers, date_maths - ): - if range_end != "": - # If first char is '>' for start or '<' for end - if range_end[0] == strict: - dict_key = opers["strict"] - range_end = range_end[1:] - else: - dict_key = opers["nonstrict"] - - if date_math: - range_end = "{0}||{1}".format(range_end, date_math) - - range_args[dict_key] = range_end - - args = kwargs.copy() - args.update(range_args) - - return dsl.query.Range(**{field: args}) + queries = [] + for value in values: + if value.count("--") != 1 or value == "--": + raise RESTValidationError( + errors=[FieldError(field, "Invalid range format.")] + ) + + range_ends = value.split("--") + range_args = dict() + + ineq_opers = [ + {"strict": "gt", "nonstrict": "gte"}, + {"strict": "lt", "nonstrict": "lte"}, + ] + date_maths = [start_date_math, end_date_math] + + # Add the proper values to the dict + for range_end, strict, opers, date_math in zip( + range_ends, [">", "<"], ineq_opers, date_maths + ): + if range_end != "": + # If first char is '>' for start or '<' for end + if range_end[0] == strict: + dict_key = opers["strict"] + range_end = range_end[1:] + else: + dict_key = opers["nonstrict"] + + if date_math: + range_end = "{0}||{1}".format(range_end, date_math) + + range_args[dict_key] = range_end + + args = kwargs.copy() + args.update(range_args) + + queries.append(dsl.query.Range(**{field: args})) + + if len(queries) > 1: + return dsl.Q("bool", should=queries) + return queries[0] return inner