Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update search form domain filters #597

Merged
merged 8 commits into from
Jul 29, 2024
54 changes: 9 additions & 45 deletions home/forms/search.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
from copy import deepcopy
from urllib.parse import urlencode

from data_platform_catalogue.search_types import ResultType
from data_platform_catalogue.search_types import ListDomainOption, ResultType
from django import forms

from ..models.domain_model import Domain, DomainModel
from ..service.search_facet_fetcher import SearchFacetFetcher
from ..models.domain_model import Domain
from ..service.list_domain_fetcher import ListDomainFetcher
from ..service.search_tag_fetcher import SearchTagFetcher


def get_domain_choices() -> list[Domain]:
"""Make API call to obtain domain choices"""
def get_list_domain_choices() -> list[Domain]:
"""Make ListDomains API call to obtain domain choices"""
choices = [
Domain("", "All domains"),
]
facets = SearchFacetFetcher().fetch()
choices.extend(DomainModel(facets).top_level_domains)
return choices


def get_subdomain_choices() -> list[Domain]:
choices = [Domain("", "All subdomains")]
facets = SearchFacetFetcher().fetch()
choices.extend(DomainModel(facets).all_subdomains())
list_domain_options: list[ListDomainOption] = ListDomainFetcher().fetch()
domains: list[Domain] = [Domain(d.urn, d.name) for d in list_domain_options]
choices.extend(domains)
return choices


Expand Down Expand Up @@ -53,27 +47,6 @@ def get_tags():
return tags


class SelectWithOptionAttribute(forms.Select):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.domain_model = None

def create_option(
self, name, urn, label, selected, index, subindex=None, attrs=None
):
option = super().create_option(
name, urn, label, selected, index, subindex, attrs
)

facets = SearchFacetFetcher().fetch()
self.domain_model = self.domain_model or DomainModel(facets)

if urn:
option["attrs"]["data-parent"] = self.domain_model.get_parent_urn(urn)

return option


class SearchForm(forms.Form):
"""Django form to represent search page inputs"""

Expand All @@ -86,7 +59,7 @@ class SearchForm(forms.Form):
),
)
domain = forms.ChoiceField(
choices=get_domain_choices,
choices=get_list_domain_choices,
required=False,
widget=forms.Select(
attrs={
Expand All @@ -97,13 +70,6 @@ class SearchForm(forms.Form):
}
),
)
subdomain = forms.ChoiceField(
choices=get_subdomain_choices,
required=False,
widget=SelectWithOptionAttribute(
attrs={"form": "searchform", "class": "govuk-select"}
),
)
where_to_access = forms.MultipleChoiceField(
choices=get_where_to_access_choices,
required=False,
Expand Down Expand Up @@ -171,6 +137,4 @@ def encode_without_filter(self, filter_name: str, filter_value: str):
value.remove(filter_value)
elif isinstance(value, str) and filter_value == value:
query_params.pop(filter_name)
if filter_name == "domain":
query_params.pop("subdomain")
return f"?{urlencode(query_params, doseq=True)}"
34 changes: 4 additions & 30 deletions home/models/domain_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import NamedTuple

from data_platform_catalogue.search_types import SearchFacets
from data_platform_catalogue.search_types import ListDomainOption

logger = logging.getLogger(__name__)

Expand All @@ -11,41 +11,15 @@ class Domain(NamedTuple):
label: str


class DomainModel:
"""
Store information about domains and subdomains
"""

def __init__(self, search_facets: SearchFacets):
class ListDomainModel:
def __init__(self, domains: list[ListDomainOption]):
self.labels = {}

self.top_level_domains = [
Domain(option.value, option.label)
for option in search_facets.options("domains")
]
self.top_level_domains.sort(key=lambda d: d.label)

self.top_level_domains = [Domain(domain.urn, domain.name) for domain in domains]
logger.info(f"{self.top_level_domains=}")

self.subdomains = {}

for urn, label in self.top_level_domains:
self.labels[urn] = label

def all_subdomains(self) -> list[Domain]: # -> list[Any]
"""
A flat list of all subdomains
"""
subdomains = []
for domain_choices in self.subdomains.values():
subdomains.extend(domain_choices)
return subdomains

def get_parent_urn(self, child_subdomain_urn) -> str | None:
for domain, subdomains in self.subdomains.items():
for subdomain in subdomains:
if child_subdomain_urn == subdomain.urn:
return domain

def get_label(self, urn):
return self.labels.get(urn, urn)
39 changes: 8 additions & 31 deletions home/service/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any

from data_platform_catalogue.search_types import (
ListDomainOption,
MultiSelectFilter,
ResultType,
SearchResponse,
Expand All @@ -13,35 +14,16 @@
from nltk.stem import PorterStemmer

from home.forms.search import SearchForm
from home.models.domain_model import DomainModel
from home.models.domain_model import ListDomainModel

from .base import GenericService
from .search_facet_fetcher import SearchFacetFetcher


def domains_with_their_subdomains(
domain: str, subdomain: str, domain_model: DomainModel
) -> list[str]:
"""
Users can search by domain, and optionally by subdomain.
When subdomain is passed, then we can filter on that directly.

However, when we filter by domain alone, assets tagged to subdomains
are not automatically included, so we need to include all possible
subdomains in the filter.
"""
if subdomain:
return [subdomain]

subdomains = domain_model.subdomains.get(domain, [])
subdomains = [subdomain[0] for subdomain in subdomains]
return [domain, *subdomains] if not domain == "" else []
from .list_domain_fetcher import ListDomainFetcher


class SearchService(GenericService):
def __init__(self, form: SearchForm, page: str, items_per_page: int = 20):
facets = SearchFacetFetcher().fetch()
self.domain_model = DomainModel(facets)
domains: list[ListDomainOption] = ListDomainFetcher().fetch()
self.domain_model = ListDomainModel(domains)
self.stemmer = PorterStemmer()
self.form = form
if self.form.is_bound:
Expand Down Expand Up @@ -79,18 +61,14 @@ def _get_search_results(self, page: str, items_per_page: int) -> SearchResponse:
query = form_data.get("query", "").replace("_", " ")
sort = form_data.get("sort", "relevance")
domain = form_data.get("domain", "")
subdomain = form_data.get("subdomain", "")
tags = form_data.get("tags", "")
domains_and_subdomains = domains_with_their_subdomains(
domain, subdomain, self.domain_model
)
where_to_access = self._build_custom_property_filter(
"dc_where_to_access_dataset=", form_data.get("where_to_access", [])
)
entity_types = self._build_entity_types(form_data.get("entity_types", []))
filter_value = []
if domains_and_subdomains:
filter_value.append(MultiSelectFilter("domains", domains_and_subdomains))
if domain:
filter_value.append(MultiSelectFilter("domains", [domain]))
if where_to_access:
filter_value.append(MultiSelectFilter("customProperties", where_to_access))
if tags:
Expand Down Expand Up @@ -167,9 +145,8 @@ def _generate_domain_clear_href(
self,
) -> dict[str, str]:
domain = self.form.cleaned_data.get("domain", "")
subdomain = self.form.cleaned_data.get("subdomain", "")

label = self.domain_model.get_label(subdomain or domain)
label = self.domain_model.get_label(domain)

return {
label: (
Expand Down
9 changes: 6 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from faker import Faker

from home.forms.search import SearchForm
from home.models.domain_model import DomainModel
from home.models.domain_model import ListDomainModel
from home.service.details import DatabaseDetailsService
from home.service.list_domain_fetcher import ListDomainFetcher
from home.service.search import SearchService
Expand Down Expand Up @@ -452,8 +452,11 @@ def search_tags():


@pytest.fixture
def valid_domain(search_facets):
return DomainModel(search_facets).top_level_domains[0]
def valid_domain():
domains = ListDomainFetcher().fetch()
return ListDomainModel(
domains,
).top_level_domains[0]


@pytest.fixture
Expand Down
30 changes: 1 addition & 29 deletions tests/home/service/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from home.forms.search import SearchForm
from home.service.search import SearchService, domains_with_their_subdomains
from home.service.search import SearchService

dev_env = True if os.environ.get("ENV") == "dev" else False
run_for_dev = pytest.mark.skipif(
Expand Down Expand Up @@ -48,7 +48,6 @@ def test_get_context_remove_filter_hrefs(self, search_context, valid_domain):
"analytical_platform": (
"?query=test&"
f"domain={quote(valid_domain.urn)}&"
"subdomain=&"
"entity_types=TABLE&"
"sort=ascending&"
"clear_filter=False&"
Expand All @@ -61,7 +60,6 @@ def test_get_context_remove_filter_hrefs(self, search_context, valid_domain):
"Table": (
"?query=test&"
f"domain={quote(valid_domain.urn)}&"
"subdomain=&"
"where_to_access=analytical_platform&"
"sort=ascending&"
"clear_filter=False&"
Expand All @@ -74,7 +72,6 @@ def test_get_context_remove_filter_hrefs(self, search_context, valid_domain):
"tag-1": (
"?query=test&"
f"domain={quote(valid_domain.urn)}&"
"subdomain=&"
"where_to_access=analytical_platform&"
"entity_types=TABLE&"
"sort=ascending&"
Expand Down Expand Up @@ -157,28 +154,3 @@ def test_highlight_results_with_query(self, search_service):
search_service.results.page_results
!= search_service.highlighted_results.page_results
)


@run_for_dev
@pytest.mark.parametrize(
"domain, subdomain, expected_subdomains",
[
("does-not-exist", "", []),
(
"urn:li:domain:HMPPS",
"",
[
"urn:li:domain:HMPPS",
"urn:li:domain:2feb789b-44d3-4412-b998-1f26819fabf9",
"urn:li:domain:abe153c1-416b-4abb-be7f-6accf2abb10a",
],
),
(
"urn:li:domain:HMPPS",
"urn:li:domain:2feb789b-44d3-4412-b998-1f26819fabf9",
["urn:li:domain:2feb789b-44d3-4412-b998-1f26819fabf9"],
),
],
)
def test_domain_expansion(domain, subdomain, expected_subdomains):
assert domains_with_their_subdomains(domain, subdomain) == expected_subdomains
Loading