diff --git a/home/forms/search.py b/home/forms/search.py index 70449642..90a8e3b8 100644 --- a/home/forms/search.py +++ b/home/forms/search.py @@ -1,28 +1,22 @@ from copy import deepcopy from urllib.parse import urlencode -from data_platform_catalogue.search_types import ResultType +from data_platform_catalogue.search_types import DomainOption, ResultType from django import forms -from ..models.domain_model import Domain, DomainModel -from ..service.search_facet_fetcher import SearchFacetFetcher +from ..models.domain_model import Domain +from ..service.domain_fetcher import DomainFetcher from ..service.search_tag_fetcher import SearchTagFetcher def get_domain_choices() -> list[Domain]: - """Make API call to obtain domain choices""" + """Make Domains API call to obtain domain choices""" choices = [ Domain("", "All domains"), ] - facets = SearchFacetFetcher().fetch() - choices.extend(DomainModel(facets).top_level_domains) - return choices - - -def get_subdomain_choices() -> list[Domain]: - choices = [Domain("", "All subdomains")] - facets = SearchFacetFetcher().fetch() - choices.extend(DomainModel(facets).all_subdomains()) + list_domain_options: list[DomainOption] = DomainFetcher().fetch() + domains: list[Domain] = [Domain(d.urn, d.name) for d in list_domain_options] + choices.extend(domains) return choices @@ -53,27 +47,6 @@ def get_tags(): return tags -class SelectWithOptionAttribute(forms.Select): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.domain_model = None - - def create_option( - self, name, urn, label, selected, index, subindex=None, attrs=None - ): - option = super().create_option( - name, urn, label, selected, index, subindex, attrs - ) - - facets = SearchFacetFetcher().fetch() - self.domain_model = self.domain_model or DomainModel(facets) - - if urn: - option["attrs"]["data-parent"] = self.domain_model.get_parent_urn(urn) - - return option - - class SearchForm(forms.Form): """Django form to represent search page inputs""" @@ -97,13 +70,6 @@ class SearchForm(forms.Form): } ), ) - subdomain = forms.ChoiceField( - choices=get_subdomain_choices, - required=False, - widget=SelectWithOptionAttribute( - attrs={"form": "searchform", "class": "govuk-select"} - ), - ) where_to_access = forms.MultipleChoiceField( choices=get_where_to_access_choices, required=False, @@ -171,6 +137,4 @@ def encode_without_filter(self, filter_name: str, filter_value: str): value.remove(filter_value) elif isinstance(value, str) and filter_value == value: query_params.pop(filter_name) - if filter_name == "domain": - query_params.pop("subdomain") return f"?{urlencode(query_params, doseq=True)}" diff --git a/home/models/domain_model.py b/home/models/domain_model.py index c0b96dbd..b06cdb3f 100644 --- a/home/models/domain_model.py +++ b/home/models/domain_model.py @@ -1,7 +1,7 @@ import logging from typing import NamedTuple -from data_platform_catalogue.search_types import SearchFacets +from data_platform_catalogue.search_types import DomainOption logger = logging.getLogger(__name__) @@ -12,40 +12,14 @@ class Domain(NamedTuple): class DomainModel: - """ - Store information about domains and subdomains - """ - - def __init__(self, search_facets: SearchFacets): + def __init__(self, domains: list[DomainOption]): self.labels = {} - self.top_level_domains = [ - Domain(option.value, option.label) - for option in search_facets.options("domains") - ] - self.top_level_domains.sort(key=lambda d: d.label) - + self.top_level_domains = [Domain(domain.urn, domain.name) for domain in domains] logger.info(f"{self.top_level_domains=}") - self.subdomains = {} - for urn, label in self.top_level_domains: self.labels[urn] = label - def all_subdomains(self) -> list[Domain]: # -> list[Any] - """ - A flat list of all subdomains - """ - subdomains = [] - for domain_choices in self.subdomains.values(): - subdomains.extend(domain_choices) - return subdomains - - def get_parent_urn(self, child_subdomain_urn) -> str | None: - for domain, subdomains in self.subdomains.items(): - for subdomain in subdomains: - if child_subdomain_urn == subdomain.urn: - return domain - def get_label(self, urn): return self.labels.get(urn, urn) diff --git a/home/service/list_domain_fetcher.py b/home/service/domain_fetcher.py similarity index 80% rename from home/service/list_domain_fetcher.py rename to home/service/domain_fetcher.py index 10d4a6e9..f15e3e3c 100644 --- a/home/service/list_domain_fetcher.py +++ b/home/service/domain_fetcher.py @@ -1,12 +1,12 @@ -from data_platform_catalogue.search_types import ListDomainOption +from data_platform_catalogue.search_types import DomainOption from django.core.cache import cache from .base import GenericService -class ListDomainFetcher(GenericService): +class DomainFetcher(GenericService): """ - ListDomainFetcher implementation to fetch domains with the total number of + DomainFetcher implementation to fetch domains with the total number of associated entities from the backend. """ @@ -16,7 +16,7 @@ def __init__(self, filter_zero_entities: bool = True): self.cache_timeout_seconds = 300 self.filter_zero_entities = filter_zero_entities - def fetch(self) -> list[ListDomainOption]: + def fetch(self) -> list[DomainOption]: """ Fetch a static list of options that is independent of the search query and any applied filters. Values are cached for 5 seconds to avoid diff --git a/home/service/search.py b/home/service/search.py index ef197517..fc69df93 100644 --- a/home/service/search.py +++ b/home/service/search.py @@ -3,6 +3,7 @@ from typing import Any from data_platform_catalogue.search_types import ( + DomainOption, MultiSelectFilter, ResultType, SearchResponse, @@ -16,32 +17,13 @@ from home.models.domain_model import DomainModel from .base import GenericService -from .search_facet_fetcher import SearchFacetFetcher - - -def domains_with_their_subdomains( - domain: str, subdomain: str, domain_model: DomainModel -) -> list[str]: - """ - Users can search by domain, and optionally by subdomain. - When subdomain is passed, then we can filter on that directly. - - However, when we filter by domain alone, assets tagged to subdomains - are not automatically included, so we need to include all possible - subdomains in the filter. - """ - if subdomain: - return [subdomain] - - subdomains = domain_model.subdomains.get(domain, []) - subdomains = [subdomain[0] for subdomain in subdomains] - return [domain, *subdomains] if not domain == "" else [] +from .domain_fetcher import DomainFetcher class SearchService(GenericService): def __init__(self, form: SearchForm, page: str, items_per_page: int = 20): - facets = SearchFacetFetcher().fetch() - self.domain_model = DomainModel(facets) + domains: list[DomainOption] = DomainFetcher().fetch() + self.domain_model = DomainModel(domains) self.stemmer = PorterStemmer() self.form = form if self.form.is_bound: @@ -79,18 +61,14 @@ def _get_search_results(self, page: str, items_per_page: int) -> SearchResponse: query = form_data.get("query", "").replace("_", " ") sort = form_data.get("sort", "relevance") domain = form_data.get("domain", "") - subdomain = form_data.get("subdomain", "") tags = form_data.get("tags", "") - domains_and_subdomains = domains_with_their_subdomains( - domain, subdomain, self.domain_model - ) where_to_access = self._build_custom_property_filter( "dc_where_to_access_dataset=", form_data.get("where_to_access", []) ) entity_types = self._build_entity_types(form_data.get("entity_types", [])) filter_value = [] - if domains_and_subdomains: - filter_value.append(MultiSelectFilter("domains", domains_and_subdomains)) + if domain: + filter_value.append(MultiSelectFilter("domains", [domain])) if where_to_access: filter_value.append(MultiSelectFilter("customProperties", where_to_access)) if tags: @@ -167,9 +145,8 @@ def _generate_domain_clear_href( self, ) -> dict[str, str]: domain = self.form.cleaned_data.get("domain", "") - subdomain = self.form.cleaned_data.get("subdomain", "") - label = self.domain_model.get_label(subdomain or domain) + label = self.domain_model.get_label(domain) return { label: ( diff --git a/home/views.py b/home/views.py index 864bf9e6..903f7d2c 100644 --- a/home/views.py +++ b/home/views.py @@ -1,5 +1,5 @@ from data_platform_catalogue.client.exceptions import EntityDoesNotExist -from data_platform_catalogue.search_types import ListDomainOption +from data_platform_catalogue.search_types import DomainOption from django.http import Http404, HttpResponseBadRequest from django.shortcuts import render @@ -9,8 +9,8 @@ DatabaseDetailsService, DatasetDetailsService, ) +from home.service.domain_fetcher import DomainFetcher from home.service.glossary import GlossaryService -from home.service.list_domain_fetcher import ListDomainFetcher from home.service.metadata_specification import MetadataSpecificationService from home.service.search import SearchService @@ -19,7 +19,7 @@ def home_view(request): """ Displys only domains that have entities tagged for display in the catalog. """ - domains: list[ListDomainOption] = ListDomainFetcher().fetch() + domains: list[DomainOption] = DomainFetcher().fetch() context = {"domains": domains, "h1_value": "Home"} return render(request, "home.html", context) diff --git a/lib/datahub-client/data_platform_catalogue/client/datahub_client.py b/lib/datahub-client/data_platform_catalogue/client/datahub_client.py index 7b762ded..42233d8a 100644 --- a/lib/datahub-client/data_platform_catalogue/client/datahub_client.py +++ b/lib/datahub-client/data_platform_catalogue/client/datahub_client.py @@ -3,30 +3,6 @@ from importlib.resources import files from typing import Sequence -from datahub.configuration.common import ConfigurationError -from datahub.emitter import mce_builder -from datahub.emitter.mcp import MetadataChangeProposalWrapper -from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph -from datahub.ingestion.source.common.subtypes import ( - DatasetContainerSubTypes, - DatasetSubTypes, -) -from datahub.metadata import schema_classes -from datahub.metadata.com.linkedin.pegasus2avro.common import DataPlatformInstance -from datahub.metadata.schema_classes import ( - ChangeTypeClass, - ContainerClass, - ContainerPropertiesClass, - DatasetPropertiesClass, - DomainPropertiesClass, - DomainsClass, - OtherSchemaClass, - SchemaFieldClass, - SchemaFieldDataTypeClass, - SchemaMetadataClass, - SubTypesClass, -) - from data_platform_catalogue.client.exceptions import ( AspectDoesNotExist, ConnectivityError, @@ -57,13 +33,36 @@ Table, ) from data_platform_catalogue.search_types import ( - ListDomainOption, + DomainOption, MultiSelectFilter, ResultType, SearchFacets, SearchResponse, SortOption, ) +from datahub.configuration.common import ConfigurationError +from datahub.emitter import mce_builder +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph +from datahub.ingestion.source.common.subtypes import ( + DatasetContainerSubTypes, + DatasetSubTypes, +) +from datahub.metadata import schema_classes +from datahub.metadata.com.linkedin.pegasus2avro.common import DataPlatformInstance +from datahub.metadata.schema_classes import ( + ChangeTypeClass, + ContainerClass, + ContainerPropertiesClass, + DatasetPropertiesClass, + DomainPropertiesClass, + DomainsClass, + OtherSchemaClass, + SchemaFieldClass, + SchemaFieldDataTypeClass, + SchemaMetadataClass, + SubTypesClass, +) logger = logging.getLogger(__name__) @@ -230,9 +229,9 @@ def list_domains( MultiSelectFilter("tags", ["urn:li:tag:dc_display_in_catalogue"]) ], count: int = 1000, - ) -> list[ListDomainOption]: + ) -> list[DomainOption]: """ - Returns a list of ListDomainOption objects + Returns a list of DomainOption objects """ return self.search_client.list_domains( query=query, filters=filters, count=count diff --git a/lib/datahub-client/data_platform_catalogue/client/search.py b/lib/datahub-client/data_platform_catalogue/client/search.py index 7117e1e3..e27967dc 100644 --- a/lib/datahub-client/data_platform_catalogue/client/search.py +++ b/lib/datahub-client/data_platform_catalogue/client/search.py @@ -3,9 +3,6 @@ from importlib.resources import files from typing import Any, Sequence -from datahub.configuration.common import GraphError # pylint: disable=E0611 -from datahub.ingestion.graph.client import DataHubGraph # pylint: disable=E0611 - from data_platform_catalogue.client.exceptions import CatalogueError from data_platform_catalogue.client.graphql_helpers import ( parse_created_and_modified, @@ -19,8 +16,8 @@ ) from data_platform_catalogue.entities import EntityRef from data_platform_catalogue.search_types import ( + DomainOption, FacetOption, - ListDomainOption, MultiSelectFilter, ResultType, SearchFacets, @@ -28,6 +25,8 @@ SearchResult, SortOption, ) +from datahub.configuration.common import GraphError # pylint: disable=E0611 +from datahub.ingestion.graph.client import DataHubGraph # pylint: disable=E0611 logger = logging.getLogger(__name__) @@ -193,7 +192,7 @@ def list_domains( MultiSelectFilter("tags", ["urn:li:tag:dc_display_in_catalogue"]) ], count: int = 1000, - ) -> list[ListDomainOption]: + ) -> list[DomainOption]: """ Returns domains that can be used to filter the search results. """ @@ -255,8 +254,8 @@ def _map_filters(self, filters: Sequence[MultiSelectFilter]): def _parse_list_domains( self, list_domains_result: list[dict[str, Any]] - ) -> list[ListDomainOption]: - list_domain_options: list[ListDomainOption] = [] + ) -> list[DomainOption]: + list_domain_options: list[DomainOption] = [] for domain in list_domains_result: urn = domain.get("urn", "") @@ -265,7 +264,7 @@ def _parse_list_domains( entities = domain.get("entities", {}) total = entities.get("total", 0) - list_domain_options.append(ListDomainOption(urn, name, total)) + list_domain_options.append(DomainOption(urn, name, total)) return list_domain_options def _parse_result( diff --git a/lib/datahub-client/data_platform_catalogue/search_types.py b/lib/datahub-client/data_platform_catalogue/search_types.py index b09a1628..16a6b704 100644 --- a/lib/datahub-client/data_platform_catalogue/search_types.py +++ b/lib/datahub-client/data_platform_catalogue/search_types.py @@ -55,7 +55,7 @@ class FacetOption: @dataclass -class ListDomainOption: +class DomainOption: """ A representation of a domain and the number of associated entities represented by total. diff --git a/lib/datahub-client/tests/test_integration_with_datahub_server.py b/lib/datahub-client/tests/test_integration_with_datahub_server.py index 65342866..60b94362 100644 --- a/lib/datahub-client/tests/test_integration_with_datahub_server.py +++ b/lib/datahub-client/tests/test_integration_with_datahub_server.py @@ -12,7 +12,6 @@ from datetime import datetime, timezone import pytest - from data_platform_catalogue.client.datahub_client import DataHubCatalogueClient from data_platform_catalogue.entities import ( AccessInformation, @@ -32,7 +31,7 @@ UsageRestrictions, ) from data_platform_catalogue.search_types import ( - ListDomainOption, + DomainOption, MultiSelectFilter, ResultType, ) @@ -48,7 +47,7 @@ def test_list_domains(): response = client.list_domains() assert len(response) > 0 domain = response[0] - assert isinstance(domain, ListDomainOption) + assert isinstance(domain, DomainOption) assert "urn:li:domain" in domain.urn diff --git a/tests/conftest.py b/tests/conftest.py index 77490b3e..b70e0489 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,8 +26,8 @@ UsageRestrictions, ) from data_platform_catalogue.search_types import ( + DomainOption, FacetOption, - ListDomainOption, ResultType, SearchFacets, SearchResponse, @@ -40,7 +40,7 @@ from home.forms.search import SearchForm from home.models.domain_model import DomainModel from home.service.details import DatabaseDetailsService -from home.service.list_domain_fetcher import ListDomainFetcher +from home.service.domain_fetcher import DomainFetcher from home.service.search import SearchService from home.service.search_facet_fetcher import SearchFacetFetcher from home.service.search_tag_fetcher import SearchTagFetcher @@ -294,22 +294,22 @@ def mock_catalogue(request, example_database): mock_list_domains_response( mock_catalogue, domains=[ - ListDomainOption( + DomainOption( urn="urn:li:domain:prisons", name="Prisons", total=fake.random_int(min=1, max=100), ), - ListDomainOption( + DomainOption( urn="urn:li:domain:courts", name="Courts", total=fake.random_int(min=1, max=100), ), - ListDomainOption( + DomainOption( urn="urn:li:domain:finance", name="Finance", total=fake.random_int(min=1, max=100), ), - ListDomainOption( + DomainOption( urn="urn:li:domain:hq", name="HQ", total=0, @@ -443,7 +443,7 @@ def search_facets(): @pytest.fixture def list_domains(filter_zero_entities): - return ListDomainFetcher(filter_zero_entities).fetch() + return DomainFetcher(filter_zero_entities).fetch() @pytest.fixture @@ -452,8 +452,11 @@ def search_tags(): @pytest.fixture -def valid_domain(search_facets): - return DomainModel(search_facets).top_level_domains[0] +def valid_domain(): + domains = DomainFetcher().fetch() + return DomainModel( + domains, + ).top_level_domains[0] @pytest.fixture diff --git a/tests/home/service/test_search.py b/tests/home/service/test_search.py index 83545492..96cfebc3 100644 --- a/tests/home/service/test_search.py +++ b/tests/home/service/test_search.py @@ -6,7 +6,7 @@ import pytest from home.forms.search import SearchForm -from home.service.search import SearchService, domains_with_their_subdomains +from home.service.search import SearchService dev_env = True if os.environ.get("ENV") == "dev" else False run_for_dev = pytest.mark.skipif( @@ -48,7 +48,6 @@ def test_get_context_remove_filter_hrefs(self, search_context, valid_domain): "analytical_platform": ( "?query=test&" f"domain={quote(valid_domain.urn)}&" - "subdomain=&" "entity_types=TABLE&" "sort=ascending&" "clear_filter=False&" @@ -61,7 +60,6 @@ def test_get_context_remove_filter_hrefs(self, search_context, valid_domain): "Table": ( "?query=test&" f"domain={quote(valid_domain.urn)}&" - "subdomain=&" "where_to_access=analytical_platform&" "sort=ascending&" "clear_filter=False&" @@ -74,7 +72,6 @@ def test_get_context_remove_filter_hrefs(self, search_context, valid_domain): "tag-1": ( "?query=test&" f"domain={quote(valid_domain.urn)}&" - "subdomain=&" "where_to_access=analytical_platform&" "entity_types=TABLE&" "sort=ascending&" @@ -157,28 +154,3 @@ def test_highlight_results_with_query(self, search_service): search_service.results.page_results != search_service.highlighted_results.page_results ) - - -@run_for_dev -@pytest.mark.parametrize( - "domain, subdomain, expected_subdomains", - [ - ("does-not-exist", "", []), - ( - "urn:li:domain:HMPPS", - "", - [ - "urn:li:domain:HMPPS", - "urn:li:domain:2feb789b-44d3-4412-b998-1f26819fabf9", - "urn:li:domain:abe153c1-416b-4abb-be7f-6accf2abb10a", - ], - ), - ( - "urn:li:domain:HMPPS", - "urn:li:domain:2feb789b-44d3-4412-b998-1f26819fabf9", - ["urn:li:domain:2feb789b-44d3-4412-b998-1f26819fabf9"], - ), - ], -) -def test_domain_expansion(domain, subdomain, expected_subdomains): - assert domains_with_their_subdomains(domain, subdomain) == expected_subdomains