Skip to content

Commit

Permalink
Query entities: a parameter to find your entities (#2869)
Browse files Browse the repository at this point in the history
* Query entities: a parameter to find your entities

* Dedup Entity type model and use RelationNodeType instead
  • Loading branch information
jotare authored Feb 12, 2025
1 parent 42c4a36 commit 754cfa6
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 9 deletions.
1 change: 1 addition & 0 deletions nucliadb/src/nucliadb/search/search/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ async def query_parser_from_find_request(
kbid=kbid,
features=item.features,
query=item.query,
query_entities=item.query_entities,
label_filters=item.filters,
keyword_filters=item.keyword_filters,
faceted=None,
Expand Down
19 changes: 18 additions & 1 deletion nucliadb/src/nucliadb/search/search/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Any, Awaitable, Optional, Union

from nucliadb.common import datamanagers
from nucliadb.common.models_utils.from_proto import RelationNodeTypeMap
from nucliadb.search import logger
from nucliadb.search.predict import SendToPredictError
from nucliadb.search.search.filters import (
Expand All @@ -49,6 +50,7 @@
from nucliadb_models.metadata import ResourceProcessingStatus
from nucliadb_models.search import (
Filter,
KnowledgeGraphEntity,
MaxTokens,
MinScore,
SearchOptions,
Expand Down Expand Up @@ -94,6 +96,7 @@ def __init__(
keyword_filters: Union[list[str], list[Filter]],
top_k: int,
min_score: MinScore,
query_entities: Optional[list[KnowledgeGraphEntity]] = None,
faceted: Optional[list[str]] = None,
sort: Optional[SortOptions] = None,
range_creation_start: Optional[datetime] = None,
Expand All @@ -120,6 +123,7 @@ def __init__(
self.kbid = kbid
self.features = features
self.query = query
self.query_entities = query_entities
self.hidden = hidden
if self.hidden is not None:
if self.hidden:
Expand Down Expand Up @@ -231,6 +235,7 @@ async def parse(self) -> tuple[nodereader_pb2.SearchRequest, bool, list[str]]:
self.parse_document_search(request)
self.parse_paragraph_search(request)
incomplete = await self.parse_vector_search(request)
# BUG: autofilters are not used to filter, but we say we do
autofilters = await self.parse_relation_search(request)
await self.parse_synonyms(request)
await self.parse_min_score(request, incomplete)
Expand Down Expand Up @@ -372,8 +377,20 @@ async def parse_vector_search(self, request: nodereader_pb2.SearchRequest) -> bo

async def parse_relation_search(self, request: nodereader_pb2.SearchRequest) -> list[str]:
autofilters = []
# BUG: autofiler should autofilter, not enable relation search
if self.has_relations_search or self.autofilter:
detected_entities = await self.fetcher.get_detected_entities()
if self.query_entities:
detected_entities = []
for entity in self.query_entities:
relation_node = utils_pb2.RelationNode()
relation_node.value = entity.name
if entity.type is not None:
relation_node.ntype = RelationNodeTypeMap[entity.type]
if entity.subtype is not None:
relation_node.subtype = entity.subtype
detected_entities.append(relation_node)
else:
detected_entities = await self.fetcher.get_detected_entities()
meta_cache = await self.fetcher.get_entities_meta_cache()
detected_entities = expand_entities(meta_cache, detected_entities)
if self.has_relations_search:
Expand Down
4 changes: 2 additions & 2 deletions nucliadb/tests/ndbfixtures/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#
from os.path import dirname
from typing import AsyncIterable, AsyncIterator, Iterator
from unittest.mock import Mock, patch
from unittest.mock import AsyncMock, Mock, patch

import pytest
from pytest_mock import MockerFixture
Expand Down Expand Up @@ -92,7 +92,7 @@ async def local_files():

@pytest.fixture(scope="function")
def predict_mock() -> Mock: # type: ignore
mock = Mock()
mock = AsyncMock()
with global_utility(Utility.PREDICT, mock):
yield mock

Expand Down
37 changes: 37 additions & 0 deletions nucliadb/tests/nucliadb/integration/search/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from httpx import AsyncClient
from nats.aio.client import Client
from nats.js import JetStreamContext
from pytest_mock import MockerFixture

from nucliadb.common.cluster.settings import settings as cluster_settings
from nucliadb.common.maindb.utils import get_driver
Expand Down Expand Up @@ -805,6 +806,42 @@ async def test_search_automatic_relations(
)


@pytest.mark.deploy_modes("standalone")
async def test_search_user_relations(
nucliadb_reader: AsyncClient,
nucliadb_writer: AsyncClient,
nucliadb_ingest_grpc: WriterStub,
standalone_knowledgebox: str,
predict_mock: AsyncMock,
mocker: MockerFixture,
):
kbid = standalone_knowledgebox

from nucliadb.search.search import find

spy = mocker.spy(find, "node_query")
with patch.object(predict_mock, "detect_entities", AsyncMock(return_value=[])):
resp = await nucliadb_reader.post(
f"/kb/{kbid}/find",
json={
"query": "What relates Newton and Becquer?",
"query_entities": [
{"name": "Newton"},
{"name": "Becquer", "type": "entity", "subtype": "person"},
],
"features": ["relations"],
},
)
assert resp.status_code == 200

assert spy.call_count == 1
request = spy.call_args.args[2]
assert len(request.relation_subgraph.entry_points) == 2
assert request.relation_subgraph.entry_points[0].value == "Newton"
assert request.relation_subgraph.entry_points[1].value == "Becquer"
assert request.relation_subgraph.entry_points[1].subtype == "person"


async def get_audit_messages(sub):
msg = await sub.fetch(1)
auditreq = AuditRequest()
Expand Down
18 changes: 12 additions & 6 deletions nucliadb_models/src/nucliadb_models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

# Bw/c import to avoid breaking users
# noqa isort: skip
from nucliadb_models.metadata import RelationType, ResourceProcessingStatus
from nucliadb_models.metadata import RelationNodeType, RelationType, ResourceProcessingStatus
from nucliadb_models.resource import ExtractedDataTypeName, Resource
from nucliadb_models.security import RequestSecurity
from nucliadb_models.utils import DateTime
Expand Down Expand Up @@ -228,11 +228,8 @@ class RelationDirection(str, Enum):
OUT = "out"


class EntityType(str, Enum):
ENTITY = "entity"
LABEL = "label"
RESOURCE = "resource"
USER = "user"
# Bw/c we use to have this model duplicated
EntityType = RelationNodeType


class DirectionalRelation(BaseModel):
Expand Down Expand Up @@ -1684,7 +1681,16 @@ class SummarizedResponse(BaseModel):
)


class KnowledgeGraphEntity(BaseModel):
name: str
type: Optional[RelationNodeType] = None
subtype: Optional[str] = None


class FindRequest(BaseSearchRequest):
query_entities: SkipJsonSchema[Optional[list[KnowledgeGraphEntity]]] = Field(
default=None, title="Query entities", description="Entities to use in a knowledge graph search"
)
features: list[SearchOptions] = SearchParamDefaults.search_features.to_pydantic_field(
default=[
SearchOptions.KEYWORD,
Expand Down

0 comments on commit 754cfa6

Please sign in to comment.