Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Added type hinting to OpenSearch clients #27946

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import uuid
import warnings
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple

import numpy as np
from langchain_core.documents import Document
Expand All @@ -23,57 +23,18 @@
PAINLESS_SCRIPTING_SEARCH = "painless_scripting"
MATCH_ALL_QUERY = {"match_all": {}} # type: Dict

if TYPE_CHECKING:
from opensearchpy import AsyncOpenSearch, OpenSearch

def _import_opensearch() -> Any:
"""Import OpenSearch if available, otherwise raise error."""

def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> OpenSearch:
"""Get OpenSearch client from the opensearch_url, otherwise raise error."""
try:
from opensearchpy import OpenSearch
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
return OpenSearch


def _import_async_opensearch() -> Any:
"""Import AsyncOpenSearch if available, otherwise raise error."""
try:
from opensearchpy import AsyncOpenSearch
except ImportError:
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
return AsyncOpenSearch


def _import_bulk() -> Any:
"""Import bulk if available, otherwise raise error."""
try:
from opensearchpy.helpers import bulk
client = OpenSearch(opensearch_url, **kwargs)
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
return bulk


def _import_async_bulk() -> Any:
"""Import async_bulk if available, otherwise raise error."""
try:
from opensearchpy.helpers import async_bulk
except ImportError:
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
return async_bulk


def _import_not_found_error() -> Any:
"""Import not found error if available, otherwise raise error."""
try:
from opensearchpy.exceptions import NotFoundError
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
return NotFoundError


def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
"""Get OpenSearch client from the opensearch_url, otherwise raise error."""
try:
opensearch = _import_opensearch()
client = opensearch(opensearch_url, **kwargs)
except ValueError as e:
raise ImportError(
f"OpenSearch client string provided is not in proper format. "
Expand All @@ -82,11 +43,14 @@ def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
return client


def _get_async_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
def _get_async_opensearch_client(opensearch_url: str, **kwargs: Any) -> AsyncOpenSearch:
"""Get AsyncOpenSearch client from the opensearch_url, otherwise raise error."""
try:
async_opensearch = _import_async_opensearch()
client = async_opensearch(opensearch_url, **kwargs)
from opensearchpy import AsyncOpenSearch

client = AsyncOpenSearch(opensearch_url, **kwargs)
except ImportError:
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
except ValueError as e:
raise ImportError(
f"AsyncOpenSearch client string provided is not in proper format. "
Expand Down Expand Up @@ -127,7 +91,7 @@ def _is_aoss_enabled(http_auth: Any) -> bool:


def _bulk_ingest_embeddings(
client: Any,
client: OpenSearch,
index_name: str,
embeddings: List[List[float]],
texts: Iterable[str],
Expand All @@ -142,16 +106,19 @@ def _bulk_ingest_embeddings(
"""Bulk Ingest Embeddings into given index."""
if not mapping:
mapping = dict()
try:
from opensearchpy.exceptions import NotFoundError
from opensearchpy.helpers import bulk
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)

bulk = _import_bulk()
not_found_error = _import_not_found_error()
requests = []
return_ids = []
mapping = mapping

try:
client.indices.get(index=index_name)
except not_found_error:
except NotFoundError:
client.indices.create(index=index_name, body=mapping)

for i, text in enumerate(texts):
Expand All @@ -177,7 +144,7 @@ def _bulk_ingest_embeddings(


async def _abulk_ingest_embeddings(
client: Any,
client: AsyncOpenSearch,
index_name: str,
embeddings: List[List[float]],
texts: Iterable[str],
Expand All @@ -193,14 +160,18 @@ async def _abulk_ingest_embeddings(
if not mapping:
mapping = dict()

async_bulk = _import_async_bulk()
not_found_error = _import_not_found_error()
try:
from opensearchpy.exceptions import NotFoundError
from opensearchpy.helpers import async_bulk
except ImportError:
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)

requests = []
return_ids = []

try:
await client.indices.get(index=index_name)
except not_found_error:
except NotFoundError:
await client.indices.create(index=index_name, body=mapping)

for i, text in enumerate(texts):
Expand Down Expand Up @@ -230,7 +201,7 @@ async def _abulk_ingest_embeddings(
def _default_scripting_text_mapping(
dim: int,
vector_field: str = "vector_field",
) -> Dict:
) -> Dict[str, Any]:
"""For Painless Scripting or Script Scoring,the default mapping to create index."""
return {
"mappings": {
Expand All @@ -249,7 +220,7 @@ def _default_text_mapping(
ef_construction: int = 512,
m: int = 16,
vector_field: str = "vector_field",
) -> Dict:
) -> Dict[str, Any]:
"""For Approximate k-NN Search, this is the default mapping to create index."""
return {
"settings": {"index": {"knn": True, "knn.algo_param.ef_search": ef_search}},
Expand All @@ -275,7 +246,7 @@ def _default_approximate_search_query(
k: int = 4,
vector_field: str = "vector_field",
score_threshold: Optional[float] = 0.0,
) -> Dict:
) -> Dict[str, Any]:
"""For Approximate k-NN Search, this is the default query."""
return {
"size": k,
Expand All @@ -291,7 +262,7 @@ def _approximate_search_query_with_boolean_filter(
vector_field: str = "vector_field",
subquery_clause: str = "must",
score_threshold: Optional[float] = 0.0,
) -> Dict:
) -> Dict[str, Any]:
"""For Approximate k-NN Search, with Boolean Filter."""
return {
"size": k,
Expand All @@ -313,7 +284,7 @@ def _approximate_search_query_with_efficient_filter(
k: int = 4,
vector_field: str = "vector_field",
score_threshold: Optional[float] = 0.0,
) -> Dict:
) -> Dict[str, Any]:
"""For Approximate k-NN Search, with Efficient Filter for Lucene and
Faiss Engines."""
search_query = _default_approximate_search_query(
Expand All @@ -330,7 +301,7 @@ def _default_script_query(
pre_filter: Optional[Dict] = None,
vector_field: str = "vector_field",
score_threshold: Optional[float] = 0.0,
) -> Dict:
) -> Dict[str, Any]:
"""For Script Scoring Search, this is the default query."""

if not pre_filter:
Expand Down Expand Up @@ -376,7 +347,7 @@ def _default_painless_scripting_query(
pre_filter: Optional[Dict] = None,
vector_field: str = "vector_field",
score_threshold: Optional[float] = 0.0,
) -> Dict:
) -> Dict[str, Any]:
"""For Painless Scripting Search, this is the default query."""

if not pre_filter:
Expand Down Expand Up @@ -692,7 +663,10 @@ def delete(
refresh_indices: Whether to refresh the index
after deleting documents. Defaults to True.
"""
bulk = _import_bulk()
try:
from opensearchpy.helpers import bulk
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)

body = []

Expand Down
Loading