Skip to content

Commit

Permalink
improve async search client handling (#17319)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich authored Dec 19, 2024
1 parent ff661ea commit 6019481
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import logging
from enum import auto
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from azure.search.documents import SearchClient
from azure.search.documents.aio import SearchClient as AsyncSearchClient
Expand Down Expand Up @@ -523,12 +523,17 @@ async def _avalidate_index(self, index_name: Optional[str]) -> None:

def __init__(
self,
search_or_index_client: Any,
search_or_index_client: Union[
SearchClient, SearchIndexClient, AsyncSearchClient, AsyncSearchIndexClient
],
id_field_key: str,
chunk_field_key: str,
embedding_field_key: str,
metadata_string_field_key: str,
doc_id_field_key: str,
async_search_or_index_client: Optional[
Union[AsyncSearchClient, AsyncSearchIndexClient]
] = None,
filterable_metadata_field_keys: Optional[
Union[
List[str],
Expand Down Expand Up @@ -617,13 +622,6 @@ def __init__(
self._user_agent = (
f"{base_user_agent} {user_agent}" if user_agent else base_user_agent
)

self._index_client: SearchIndexClient = cast(SearchIndexClient, None)
self._async_index_client: AsyncSearchIndexClient = cast(
AsyncSearchIndexClient, None
)
self._search_client: SearchClient = cast(SearchClient, None)
self._async_search_client: AsyncSearchClient = cast(AsyncSearchClient, None)
self._embedding_dimensionality = embedding_dimensionality
self._index_name = index_name

Expand All @@ -639,11 +637,22 @@ def __init__(
self._language_analyzer = language_analyzer
self._compression_type = compression_type.lower()

# Validate search_or_index_client
# Initialize clients to None
self._index_client = None
self._async_index_client = None
self._search_client = None
self._async_search_client = None

if search_or_index_client and async_search_or_index_client is None:
logger.warning(
"async_search_or_index_client is None. Depending on the client type passed "
"in, sync or async functions may not work."
)

# Validate sync search_or_index_client
if search_or_index_client is not None:
if isinstance(search_or_index_client, SearchIndexClient):
# If SearchIndexClient is supplied so must index_name
self._index_client = cast(SearchIndexClient, search_or_index_client)
self._index_client = search_or_index_client
self._index_client._client._config.user_agent_policy.add_user_agent(
self._user_agent
)
Expand All @@ -660,18 +669,32 @@ def __init__(
self._user_agent
)

elif isinstance(search_or_index_client, AsyncSearchIndexClient):
# If SearchIndexClient is supplied so must index_name
self._async_index_client = cast(
AsyncSearchIndexClient, search_or_index_client
elif isinstance(search_or_index_client, SearchClient):
self._search_client = search_or_index_client
self._search_client._client._config.user_agent_policy.add_user_agent(
self._user_agent
)
# Validate index_name
if index_name:
raise ValueError(
"index_name cannot be supplied if search_or_index_client "
"is of type azure.search.documents.SearchClient"
)

# Validate async search_or_index_client -- if not provided, assume the search_or_index_client could be async
async_search_or_index_client = (
async_search_or_index_client or search_or_index_client
)
if async_search_or_index_client is not None:
if isinstance(async_search_or_index_client, AsyncSearchIndexClient):
self._async_index_client = async_search_or_index_client
self._async_index_client._client._config.user_agent_policy.add_user_agent(
self._user_agent
)

if not index_name:
raise ValueError(
"index_name must be supplied if search_or_index_client is of "
"index_name must be supplied if async_search_or_index_client is of "
"type azure.search.documents.aio.SearchIndexClient"
)

Expand All @@ -682,58 +705,40 @@ def __init__(
self._user_agent
)

elif isinstance(search_or_index_client, SearchClient):
self._search_client = cast(SearchClient, search_or_index_client)
self._search_client._client._config.user_agent_policy.add_user_agent(
self._user_agent
)
# Validate index_name
if index_name:
raise ValueError(
"index_name cannot be supplied if search_or_index_client "
"is of type azure.search.documents.SearchClient"
)

elif isinstance(search_or_index_client, AsyncSearchClient):
self._async_search_client = cast(
AsyncSearchClient, search_or_index_client
)
elif isinstance(async_search_or_index_client, AsyncSearchClient):
self._async_search_client = async_search_or_index_client
self._async_search_client._client._config.user_agent_policy.add_user_agent(
self._user_agent
)

# Validate index_name
if index_name:
raise ValueError(
"index_name cannot be supplied if search_or_index_client "
"is of type azure.search.documents.SearchClient"
"index_name cannot be supplied if async_search_or_index_client "
"is of type azure.search.documents.aio.SearchClient"
)

if isinstance(search_or_index_client, AsyncSearchIndexClient):
if not self._async_index_client and not self._async_search_client:
raise ValueError(
"search_or_index_client must be of type "
"azure.search.documents.SearchIndexClient or "
"azure.search.documents.SearchClient"
)

if isinstance(search_or_index_client, SearchIndexClient):
if not self._index_client and not self._search_client:
raise ValueError(
"search_or_index_client must be of type "
"azure.search.documents.SearchIndexClient or "
"azure.search.documents.SearchClient"
)
else:
raise ValueError("search_or_index_client not specified")
# Validate that at least one client was provided
if not any(
[
self._search_client,
self._async_search_client,
self._index_client,
self._async_index_client,
]
):
raise ValueError(
"Either search_or_index_client or async_search_or_index_client must be provided"
)

# Validate index management requirements
if index_management == IndexManagement.CREATE_IF_NOT_EXISTS and not (
self._index_client or self._async_index_client
):
raise ValueError(
"index_management has value of IndexManagement.CREATE_IF_NOT_EXISTS "
"but search_or_index_client is not of type "
"azure.search.documents.SearchIndexClient or azure.search.documents.aio.SearchIndexClient "
"but neither search_or_index_client nor async_search_or_index_client is of type "
"azure.search.documents.SearchIndexClient or azure.search.documents.aio.SearchIndexClient"
)

self._index_management = index_management
Expand Down Expand Up @@ -1161,20 +1166,36 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
odata_filter = self._create_odata_filter(query.filters)
azure_query_result_search: AzureQueryResultSearchBase = (
AzureQueryResultSearchDefault(
query, self._field_mapping, odata_filter, self._search_client
query,
self._field_mapping,
odata_filter,
self._search_client,
self._async_search_client,
)
)
if query.mode == VectorStoreQueryMode.SPARSE:
azure_query_result_search = AzureQueryResultSearchSparse(
query, self._field_mapping, odata_filter, self._search_client
query,
self._field_mapping,
odata_filter,
self._search_client,
self._async_search_client,
)
elif query.mode == VectorStoreQueryMode.HYBRID:
azure_query_result_search = AzureQueryResultSearchHybrid(
query, self._field_mapping, odata_filter, self._search_client
query,
self._field_mapping,
odata_filter,
self._search_client,
self._async_search_client,
)
elif query.mode == VectorStoreQueryMode.SEMANTIC_HYBRID:
azure_query_result_search = AzureQueryResultSearchSemanticHybrid(
query, self._field_mapping, odata_filter, self._search_client
query,
self._field_mapping,
odata_filter,
self._search_client,
self._async_search_client,
)
return azure_query_result_search.search()

Expand All @@ -1193,20 +1214,36 @@ async def aquery(

azure_query_result_search: AzureQueryResultSearchBase = (
AzureQueryResultSearchDefault(
query, self._field_mapping, odata_filter, self._async_search_client
query,
self._field_mapping,
odata_filter,
self._search_client,
self._async_search_client,
)
)
if query.mode == VectorStoreQueryMode.SPARSE:
azure_query_result_search = AzureQueryResultSearchSparse(
query, self._field_mapping, odata_filter, self._async_search_client
query,
self._field_mapping,
odata_filter,
self._search_client,
self._async_search_client,
)
elif query.mode == VectorStoreQueryMode.HYBRID:
azure_query_result_search = AzureQueryResultSearchHybrid(
query, self._field_mapping, odata_filter, self._async_search_client
query,
self._field_mapping,
odata_filter,
self._search_client,
self._async_search_client,
)
elif query.mode == VectorStoreQueryMode.SEMANTIC_HYBRID:
azure_query_result_search = AzureQueryResultSearchSemanticHybrid(
query, self._field_mapping, odata_filter, self._async_search_client
query,
self._field_mapping,
odata_filter,
self._search_client,
self._async_search_client,
)
return await azure_query_result_search.asearch()

Expand Down Expand Up @@ -1339,12 +1376,14 @@ def __init__(
query: VectorStoreQuery,
field_mapping: Dict[str, str],
odata_filter: Optional[str],
search_client: Any,
search_client: SearchClient,
async_search_client: AsyncSearchClient,
) -> None:
self._query = query
self._field_mapping = field_mapping
self._odata_filter = odata_filter
self._search_client = search_client
self._async_search_client = async_search_client

@property
def _select_fields(self) -> List[str]:
Expand Down Expand Up @@ -1417,7 +1456,7 @@ def _create_query_result(
async def _acreate_query_result(
self, search_query: str, vectors: Optional[List[Any]]
) -> VectorStoreQueryResult:
results = await self._search_client.search(
results = await self._async_search_client.search(
search_text=search_query,
vector_queries=vectors,
top=self._query.similarity_top_k,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-vector-stores-azureaisearch"
readme = "README.md"
version = "0.3.0"
version = "0.3.1"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
Expand Down

0 comments on commit 6019481

Please sign in to comment.