diff --git a/invenio_records_rest/config.py b/invenio_records_rest/config.py index 1a00fee..fcb61d4 100644 --- a/invenio_records_rest/config.py +++ b/invenio_records_rest/config.py @@ -344,6 +344,9 @@ def deleted_pid_error_handler(error): } """ +RECORDS_REST_FACETS_POST_FILTERS_PROPAGATE = False +"""Define if the post_filters facets in one category should be applied as filters to all the other categories""" + RECORDS_REST_DEFAULT_CREATE_PERMISSION_FACTORY = deny_all """Default create permission factory: reject any request.""" diff --git a/invenio_records_rest/facets.py b/invenio_records_rest/facets.py index 3102996..d8cbe22 100644 --- a/invenio_records_rest/facets.py +++ b/invenio_records_rest/facets.py @@ -22,21 +22,33 @@ def nested_filter(field, subfield): - """Create a nested filter. Similar to the https://github.com/inveniosoftware/invenio-records-resources/blob/master/invenio_records_resources/services/records/facets/facets.py#L94""" + """Create a nested filter. + + Similar to the example from + https://github.com/inveniosoftware/invenio-records-resources/blob/master/invenio_records_resources/services/records/facets/facets.py#L94 + """ def inner(values): top_level = [] queries = [] for value in values: subvalues = value.split("::") - if len(subvalues)>1: - queries.append(dsl.Q("bool", must=[dsl.Q("term", **{field: subvalues[0]}), dsl.Q("term", **{subfield: subvalues[1]})])) + if len(subvalues) > 1: + queries.append( + dsl.Q( + "bool", + must=[ + dsl.Q("term", **{field: subvalues[0]}), + dsl.Q("term", **{subfield: subvalues[1]}), + ], + ) + ) else: top_level.append(value) if len(top_level): queries.append(dsl.Q("terms", **{field: top_level})) - if len(queries)>1: + if len(queries) > 1: return dsl.Q("bool", should=queries) return queries[0] @@ -139,10 +151,48 @@ def _query_filter(search, urlkwargs, definitions): return (search, urlkwargs) -def _aggregations(search, definitions): +def remove_filter_from_list(facet_filters, facet_names): + """Remove the specified filters from the list. + + This is used to remove one category from the filters. Check the example defined on line 204-215. + The reasoning behind is that a post_filter on a category should be applied to all the aggregations except the + aggregation on that particular category. + """ + new_facet_filters = facet_filters.copy() + for name in facet_names: + new_facet_filters.pop(name) + return new_facet_filters + + +def _aggregations(search, definitions, updated_filters={}, urlkwargs=None): """Add aggregations to query.""" if definitions: for name, agg in definitions.items(): + if name in updated_filters: + # Read the example introduced in lines 204-215. + # Imagine that our initial query and aggregation looks like + # {"post_filter": { "term": {"brand":"ferrari"}}, + # "aggs": {"brand": {"term": {"field":"brand"}}, + # "color": {"term": {"field": "color"}}}} + # The goal is that the previous query will be transformed into something like + # {"post_filter": { "term": {"brand":"ferrari"}}, + # "aggs": {"brand": {"term": {"field":"brand"}}, + # "color": {"filter": {"bool":{ "must": [{"term": {"brand": "ferrari"}}]}}, + # "aggs": {"filtered": {"term": {"field": "color"}}}}}}} + + facet_filters, _ = _create_filter_dsl( + urlkwargs, remove_filter_from_list(updated_filters, [name]) + ) + agg = { + "filter": { + "bool": { + "must": [ + facet_filter.to_dict() for facet_filter in facet_filters + ] + } + }, + "aggs": {"filtered": agg}, + } search.aggs[name] = agg if not callable(agg) else agg() return search @@ -168,17 +218,35 @@ def default_facets_factory(search, index): selected_facets = make_comma_list_a_list(request.args.getlist("facets", None)) all_aggs = facets.get("aggs", {}) + # This parameter is a bit tricky. Let's go first with an example to see the goal. + # Imagine a website that sells cars, where the cars can be filtered by two categories: brand and color. + # There is a facet with the brand: bmw, mercedes, ferrari,... and another with the color: white, blue, red, ... + # Imagine that a user looks for ferrari: + # * If the query is done as a standard filter, it will also affect the aggregations, so the aggregations will + # return only brand: ferrari, color:red. + # * If the query is done as a post_filter, with RECORDS_REST_FACETS_POST_FILTERS_PROPAGATE=False, the + # restrictions will not be applied to the categories, so brand: bmw, mercedes, ferrari,... and + # color: white, blue, red, ... + # * If the query is done as a post_filter, with RECORDS_REST_FACETS_POST_FILTERS_PROPAGATE=True, the + # restrictions will be applied to the other categories, so brand: bmw, mercedes, ferrari,... (since the + # filter on ferrari is not applied on the brand, and color: red (since all the ferraris are red). + if current_app.config["RECORDS_REST_FACETS_POST_FILTERS_PROPAGATE"]: + updated_filters = facets.get("post_filters", {}) + else: + updated_filters = {} + # If no facets were requested, assume default behaviour - Take all. - if not selected_facets: - search = _aggregations(search, all_aggs) - # otherwise, check if there are facets to chose - elif selected_facets and all_aggs: - aggs = {} - # Go through all available facets and check if they were requested. - for facet_name, facet_body in all_aggs.items(): - if facet_name in selected_facets: - aggs.update({facet_name: facet_body}) - search = _aggregations(search, aggs) + if all_aggs: + if not selected_facets: + search = _aggregations(search, all_aggs, updated_filters, urlkwargs) + # otherwise, check if there are facets to chose + else: + aggs = {} + # Go through all available facets and check if they were requested. + for facet_name, facet_body in all_aggs.items(): + if facet_name in selected_facets: + aggs.update({facet_name: facet_body}) + search = _aggregations(search, aggs, updated_filters, urlkwargs) # Query filter search, urlkwargs = _query_filter(search, urlkwargs, facets.get("filters", {}))