diff --git a/src/aerospike_vector_search/admin.py b/src/aerospike_vector_search/admin.py index 89ca5b0b..c84637f1 100644 --- a/src/aerospike_vector_search/admin.py +++ b/src/aerospike_vector_search/admin.py @@ -9,6 +9,7 @@ from .internal import channel_provider from .shared.admin_helpers import BaseClient from .shared.conversions import fromIndexStatusResponse +from .types import IndexDefinition, Role logger = logging.getLogger(__name__) @@ -19,7 +20,7 @@ class Client(BaseClient): This client is designed to conduct Aerospike Vector Search administrative operation such as creating indexes, querying index information, and dropping indexes. - :param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all of the nodes in the cluster. + :param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all the nodes in the cluster. :type seeds: Union[types.HostPort, tuple[types.HostPort, ...]] :param listener_name: An external (NATed) address and port combination that differs from the actual address and port where AVS is listening. Clients can access AVS on a node using the advertised listener address and port. Defaults to None. @@ -86,7 +87,7 @@ def index_create( name: str, vector_field: str, dimensions: int, - vector_distance_metric: Optional[types.VectorDistanceMetric] = ( + vector_distance_metric: types.VectorDistanceMetric = ( types.VectorDistanceMetric.SQUARED_EUCLIDEAN ), sets: Optional[str] = None, @@ -113,7 +114,7 @@ def index_create( :param vector_distance_metric: The distance metric used to compare when performing a vector search. Defaults to :class:`VectorDistanceMetric.SQUARED_EUCLIDEAN`. - :type vector_distance_metric: Optional[types.VectorDistanceMetric] + :type vector_distance_metric: types.VectorDistanceMetric :param sets: The set used for the index. Defaults to None. :type sets: Optional[str] @@ -123,7 +124,7 @@ def index_create( specified for :class:`types.HnswParams` will be used. :type index_params: Optional[types.HnswParams] - :param index_labels: Meta data associated with the index. Defaults to None. + :param index_labels: Metadata associated with the index. Defaults to None. :type index_labels: Optional[dict[str, str]] :param index_storage: Namespace and set where index overhead (non-vector data) is stored. @@ -269,7 +270,7 @@ def index_drop( def index_list( self, timeout: Optional[int] = None, apply_defaults: Optional[bool] = True - ) -> list[dict]: + ) -> list[IndexDefinition]: """ List all indices. @@ -308,7 +309,7 @@ def index_get( name: str, timeout: Optional[int] = None, apply_defaults: Optional[bool] = True, - ) -> dict[str, Union[int, str]]: + ) -> IndexDefinition: """ Retrieve the information related with an index. @@ -630,7 +631,7 @@ def revoke_roles( logger.error("Failed to revoke roles with error: %s", e) raise types.AVSServerError(rpc_error=e) - def list_roles(self, timeout: Optional[int] = None) -> list[dict]: + def list_roles(self, timeout: Optional[int] = None) -> list[Role]: """ List roles available on the AVS server. @@ -663,8 +664,8 @@ def _wait_for_index_creation( *, namespace: str, name: str, - timeout: Optional[int] = sys.maxsize, - wait_interval: Optional[int] = 0.1, + timeout: int = sys.maxsize, + wait_interval: float = 0.1, ) -> None: """ Wait for the index to be created. @@ -697,8 +698,8 @@ def _wait_for_index_deletion( *, namespace: str, name: str, - timeout: Optional[int] = sys.maxsize, - wait_interval: Optional[int] = 0.1, + timeout: int = sys.maxsize, + wait_interval: float = 0.1, ) -> None: """ Wait for the index to be deleted. diff --git a/src/aerospike_vector_search/aio/admin.py b/src/aerospike_vector_search/aio/admin.py index e4e369cc..eae2a817 100644 --- a/src/aerospike_vector_search/aio/admin.py +++ b/src/aerospike_vector_search/aio/admin.py @@ -1,16 +1,15 @@ import asyncio import logging import sys - from typing import Optional, Union import grpc from .internal import channel_provider from .. import types -from ..shared.conversions import fromIndexStatusResponse from ..shared.admin_helpers import BaseClient - +from ..shared.conversions import fromIndexStatusResponse +from ..types import Role, IndexDefinition logger = logging.getLogger(__name__) @@ -21,7 +20,7 @@ class Client(BaseClient): This client is designed to conduct Aerospike Vector Search administrative operation such as creating indexes, querying index information, and dropping indexes. - :param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all of the nodes in the cluster. + :param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all the nodes in the cluster. :type seeds: Union[types.HostPort, tuple[types.HostPort, ...]] :param listener_name: An external (NATed) address and port combination that differs from the actual address and port where AVS is listening. Clients can access AVS on a node using the advertised listener address and port. Defaults to None. @@ -88,7 +87,7 @@ async def index_create( name: str, vector_field: str, dimensions: int, - vector_distance_metric: Optional[types.VectorDistanceMetric] = ( + vector_distance_metric: types.VectorDistanceMetric = ( types.VectorDistanceMetric.SQUARED_EUCLIDEAN ), sets: Optional[str] = None, @@ -115,7 +114,7 @@ async def index_create( :param vector_distance_metric: The distance metric used to compare when performing a vector search. Defaults to :class:`VectorDistanceMetric.SQUARED_EUCLIDEAN`. - :type vector_distance_metric: Optional[types.VectorDistanceMetric] + :type vector_distance_metric: types.VectorDistanceMetric :param sets: The set used for the index. Defaults to None. :type sets: Optional[str] @@ -125,7 +124,7 @@ async def index_create( specified for :class:`types.HnswParams` will be used. :type index_params: Optional[types.HnswParams] - :param index_labels: Meta data associated with the index. Defaults to None. + :param index_labels: Metadata associated with the index. Defaults to None. :type index_labels: Optional[dict[str, str]] :param index_storage: Namespace and set where index overhead (non-vector data) is stored. @@ -278,7 +277,7 @@ async def index_drop( async def index_list( self, timeout: Optional[int] = None, apply_defaults: Optional[bool] = True - ) -> list[dict]: + ) -> list[IndexDefinition]: """ List all indices. @@ -320,7 +319,7 @@ async def index_get( name: str, timeout: Optional[int] = None, apply_defaults: Optional[bool] = True, - ) -> dict[str, Union[int, str]]: + ) -> IndexDefinition: """ Retrieve the information related with an index. @@ -594,14 +593,14 @@ async def list_users(self, timeout: Optional[int] = None) -> list[types.User]: async def grant_roles( self, *, username: str, roles: list[str], timeout: Optional[int] = None - ) -> int: + ) : """ Grant roles to existing AVS Users. :param username: Username of the user which will receive the roles. :type username: str - :param roles: Roles the specified user will recieved. + :param roles: Roles the specified user will receive. :type roles: list[str] :param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError `. Defaults to None. @@ -630,7 +629,7 @@ async def grant_roles( async def revoke_roles( self, *, username: str, roles: list[str], timeout: Optional[int] = None - ) -> int: + ) : """ Revoke roles from existing AVS Users. @@ -664,7 +663,7 @@ async def revoke_roles( logger.error("Failed to revoke roles with error: %s", e) raise types.AVSServerError(rpc_error=e) - async def list_roles(self, timeout: Optional[int] = None) -> None: + async def list_roles(self, timeout: Optional[int] = None) -> list[Role]: """ list roles of existing AVS Users. @@ -699,8 +698,8 @@ async def _wait_for_index_creation( *, namespace: str, name: str, - timeout: Optional[int] = sys.maxsize, - wait_interval: Optional[int] = 0.1, + timeout: int = sys.maxsize, + wait_interval: float = 0.1, ) -> None: """ Wait for the index to be created. @@ -717,7 +716,7 @@ async def _wait_for_index_creation( index_creation_request, credentials=self._channel_provider.get_token(), ) - logger.debug("Index created succesfully") + logger.debug("Index created successfully") # Index has been created return except grpc.RpcError as e: @@ -734,8 +733,8 @@ async def _wait_for_index_deletion( *, namespace: str, name: str, - timeout: Optional[int] = sys.maxsize, - wait_interval: Optional[int] = 0.1, + timeout: int = sys.maxsize, + wait_interval: float = 0.1, ) -> None: """ Wait for the index to be deleted. @@ -759,7 +758,7 @@ async def _wait_for_index_deletion( await asyncio.sleep(wait_interval) except grpc.RpcError as e: if e.code() == grpc.StatusCode.NOT_FOUND: - logger.debug("Index deleted succesfully") + logger.debug("Index deleted successfully") # Index has been created return else: diff --git a/src/aerospike_vector_search/aio/client.py b/src/aerospike_vector_search/aio/client.py index 4bcb87cc..154f3293 100644 --- a/src/aerospike_vector_search/aio/client.py +++ b/src/aerospike_vector_search/aio/client.py @@ -385,7 +385,7 @@ async def exists( exists_request, credentials=self._channel_provider.get_token(), **kwargs ) except grpc.RpcError as e: - logger.error("Failed to verfiy vector existence with error: %s", e) + logger.error("Failed to verify vector existence with error: %s", e) raise types.AVSServerError(rpc_error=e) return self._respond_exists(response) @@ -679,9 +679,9 @@ async def wait_for_index_completion( *, namespace: str, name: str, - timeout: Optional[int] = sys.maxsize, - wait_interval: Optional[int] = 12, - validation_threshold: Optional[int] = 2, + timeout: int = sys.maxsize, + wait_interval: int = 12, + validation_threshold: int = 2, ) -> None: """ Wait for the index to have no pending index update operations. @@ -714,7 +714,7 @@ async def wait_for_index_completion( # Wait interval between polling ( index_stub, - wait_interval, + wait_interval_float, start_time, unmerged_record_initialized, validation_count, @@ -740,7 +740,7 @@ async def wait_for_index_completion( validation_count += 1 else: validation_count = 0 - await asyncio.sleep(wait_interval) + await asyncio.sleep(wait_interval_float) async def close(self): """ diff --git a/src/aerospike_vector_search/aio/internal/channel_provider.py b/src/aerospike_vector_search/aio/internal/channel_provider.py index 8797f877..7f32e41a 100644 --- a/src/aerospike_vector_search/aio/internal/channel_provider.py +++ b/src/aerospike_vector_search/aio/internal/channel_provider.py @@ -1,13 +1,10 @@ import re import asyncio import logging -import jwt -from jwt.exceptions import InvalidTokenError from typing import Optional, Union import google.protobuf.empty_pb2 import grpc -import random from ... import types from ...shared.proto_generated import vector_db_pb2 @@ -18,7 +15,7 @@ logger = logging.getLogger(__name__) -TEND_INTERVAL = 1 +TEND_INTERVAL : int = 1 class ChannelProvider(base_channel_provider.BaseChannelProvider): @@ -72,7 +69,7 @@ def __init__( self._tend_exception: Exception = None async def _is_ready(self): - # Wait 1 round of cluster tending, auth token initialization, and server client compatiblity verfication + # Wait 1 round of cluster tending, auth token initialization, and server client compatibility verification await self._ready.wait() # This propogates any fatal/unexpected errors from client initialization/tending to the client. @@ -92,7 +89,7 @@ async def _tend(self): self._ready.set() - except Exception as e: + except Exception as e: # Set all event to prevent hanging if initial tend fails with error self._tend_ended.set() self._ready.set() @@ -185,7 +182,7 @@ async def _tend_token(self): try: if not self._token: return - elif self._token != True: + elif not self._token: await asyncio.sleep((self._ttl * self._ttl_threshold)) await self._update_token_and_ttl() diff --git a/src/aerospike_vector_search/client.py b/src/aerospike_vector_search/client.py index 99848042..aac80736 100644 --- a/src/aerospike_vector_search/client.py +++ b/src/aerospike_vector_search/client.py @@ -5,6 +5,7 @@ import warnings import grpc +import numpy as np from . import types from .internal import channel_provider @@ -199,7 +200,7 @@ def upsert( self, *, namespace: str, - key: Union[int, str, bytes, bytearray], + key: Union[int, str, bytes, bytearray, np.generic, np.ndarray], record_data: dict[str, Any], set_name: Optional[str] = None, ignore_mem_queue_full: Optional[bool] = False, @@ -369,7 +370,7 @@ def exists( exists_request, credentials=self._channel_provider.get_token(), **kwargs ) except grpc.RpcError as e: - logger.error("Failed to verfiy vector existence with error: %s", e) + logger.error("Failed to verify vector existence with error: %s", e) raise types.AVSServerError(rpc_error=e) return self._respond_exists(response) @@ -661,9 +662,9 @@ def wait_for_index_completion( *, namespace: str, name: str, - timeout: Optional[int] = sys.maxsize, - wait_interval: Optional[int] = 12, - validation_threshold: Optional[int] = 2, + timeout: int = sys.maxsize, + wait_interval: int = 12, + validation_threshold: int = 2, ) -> None: """ Wait for the index to have no pending index update operations. @@ -693,7 +694,7 @@ def wait_for_index_completion( """ ( index_stub, - wait_interval, + wait_interval_float, start_time, unmerged_record_initialized, validation_count, @@ -719,7 +720,7 @@ def wait_for_index_completion( validation_count += 1 else: validation_count = 0 - time.sleep(wait_interval) + time.sleep(wait_interval_float) def close(self): """ diff --git a/src/aerospike_vector_search/internal/channel_provider.py b/src/aerospike_vector_search/internal/channel_provider.py index 3c863c5a..23760689 100644 --- a/src/aerospike_vector_search/internal/channel_provider.py +++ b/src/aerospike_vector_search/internal/channel_provider.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -TEND_INTERVAL = 1 +TEND_INTERVAL: int = 1 class ChannelProvider(base_channel_provider.BaseChannelProvider): @@ -109,7 +109,7 @@ def _call_get_cluster_id(self, stub): "While tending, failed to get cluster id with error: " + str(e) ) - def _call_get_cluster_endpoints(self, stub): + def _call_get_cluster_endpoints(self, stub) -> vector_db_pb2.ServerEndpointList: try: return ( stub.GetClusterEndpoints( @@ -214,5 +214,5 @@ def close(self): channelEndpoints.channel.close() with self._auth_tending_lock: - if self._auth_timer != None: + if self._auth_timer is not None: self._auth_timer.cancel() diff --git a/src/aerospike_vector_search/shared/admin_helpers.py b/src/aerospike_vector_search/shared/admin_helpers.py index ca49c750..48f6424a 100644 --- a/src/aerospike_vector_search/shared/admin_helpers.py +++ b/src/aerospike_vector_search/shared/admin_helpers.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Any, Optional, Tuple, Dict +from typing import Any, Optional, Tuple, List import time import google.protobuf.empty_pb2 @@ -12,32 +12,30 @@ from .proto_generated import types_pb2, user_admin_pb2, index_pb2 from .. import types from . import conversions -from ..types import AVSClientError - -logger = logging.getLogger(__name__) +from ..types import AVSClientError, IndexDefinition, HostPort empty = google.protobuf.empty_pb2.Empty() class BaseClient(object): - def _prepare_seeds(self, seeds) -> None: + def _prepare_seeds(self, seeds) -> tuple[HostPort, ...]: return helpers._prepare_seeds(seeds) def _prepare_index_create( - self, - namespace, - name, - vector_field, - dimensions, - vector_distance_metric, - sets, - index_params, - index_labels, - index_storage, - timeout, - logger, - ) -> None: + self, + namespace: str, + name: str, + vector_field: str, + dimensions: int, + vector_distance_metric: types.VectorDistanceMetric, + sets: Optional[str], + index_params: Optional[types.HnswParams], + index_labels: Optional[dict[str, str]], + index_storage: Optional[types.IndexStorage], + timeout: Optional[int], + logger: logging.Logger + ) -> Tuple[index_pb2_grpc.IndexServiceStub, index_pb2.IndexCreateRequest, dict[str, Any]] : logger.debug( "Creating index: namespace=%s, name=%s, vector_field=%s, dimensions=%d, vector_distance_metric=%s, " @@ -62,16 +60,16 @@ def _prepare_index_create( sets = None if index_params != None: index_params = index_params._to_pb2() - if index_storage != None: + if index_storage is not None: index_storage = index_storage._to_pb2() - id = self._get_index_id(namespace, name) + index_id = self._get_index_id(namespace, name) vector_distance_metric = vector_distance_metric.value index_stub = self._get_index_stub() index_definition = types_pb2.IndexDefinition( - id=id, + id=index_id, vectorDistanceMetric=vector_distance_metric, setFilter=sets, hnswParams=index_params, @@ -87,7 +85,7 @@ def _prepare_index_update( self, namespace: str, name: str, - index_labels: Optional[Dict[str, str]], + index_labels: Optional[dict[str, str]], hnsw_update_params: Optional[types.HnswIndexUpdate], timeout: Optional[int], logger: logging.Logger @@ -126,8 +124,7 @@ def _prepare_index_update( return (index_stub, index_update_request, kwargs) - - def _prepare_index_drop(self, namespace, name, timeout, logger) -> None: + def _prepare_index_drop(self, namespace: str, name: str, timeout: Optional[int], logger: logging.Logger) -> tuple[index_pb2_grpc.IndexServiceStub, index_pb2.IndexDropRequest, dict[str, Any]]: logger.debug( "Dropping index: namespace=%s, name=%s, timeout=%s", @@ -145,7 +142,7 @@ def _prepare_index_drop(self, namespace, name, timeout, logger) -> None: index_drop_request = index_pb2.IndexDropRequest(indexId=index_id) return (index_stub, index_drop_request, kwargs) - def _prepare_index_list(self, timeout, logger, apply_defaults) -> None: + def _prepare_index_list(self, timeout: Optional[int], logger: logging.Logger, apply_defaults: Optional[bool]) -> tuple[index_pb2_grpc.IndexServiceStub, index_pb2.IndexListRequest, dict[str, Any]]: logger.debug( "Getting index list: timeout=%s, apply_defaults=%s", @@ -158,12 +155,12 @@ def _prepare_index_list(self, timeout, logger, apply_defaults) -> None: kwargs["timeout"] = timeout index_stub = self._get_index_stub() - index_list_request = index_pb2.IndexListRequest(applyDefaults=apply_defaults) + index_list_request: index_pb2.IndexListRequest = index_pb2.IndexListRequest(applyDefaults=apply_defaults) return (index_stub, index_list_request, kwargs) def _prepare_index_get( - self, namespace, name, timeout, logger, apply_defaults - ) -> None: + self, namespace: str, name: str, timeout: Optional[int], logger: logging.Logger, apply_defaults: Optional[bool] + ) -> tuple[index_pb2_grpc.IndexServiceStub, index_pb2.IndexGetRequest, dict[str, Any]]: logger.debug( "Getting index information: namespace=%s, name=%s, timeout=%s, apply_defaults=%s", @@ -184,7 +181,8 @@ def _prepare_index_get( ) return (index_stub, index_get_request, kwargs) - def _prepare_index_get_status(self, namespace, name, timeout, logger) -> None: + def _prepare_index_get_status(self, namespace: str, name: str, timeout: Optional[int], logger: logging.Logger) -> tuple[ + index_pb2_grpc.IndexServiceStub, index_pb2.IndexStatusRequest, dict[str, Any]]: logger.debug( "Getting index status: namespace=%s, name=%s, timeout=%s", @@ -202,7 +200,8 @@ def _prepare_index_get_status(self, namespace, name, timeout, logger) -> None: index_get_status_request = index_pb2.IndexStatusRequest(indexId=index_id) return (index_stub, index_get_status_request, kwargs) - def _prepare_add_user(self, username, password, roles, timeout, logger) -> None: + def _prepare_add_user(self, username: str, password: str, roles: list[str], timeout: Optional[int], logger: logging.Logger) -> tuple[ + user_admin_pb2_grpc.UserAdminServiceStub, user_admin_pb2.AddUserRequest, dict[str, Any]]: logger.debug( "Getting index status: username=%s, password=%s, roles=%s, timeout=%s", username, @@ -223,7 +222,8 @@ def _prepare_add_user(self, username, password, roles, timeout, logger) -> None: return (user_admin_stub, add_user_request, kwargs) - def _prepare_update_credentials(self, username, password, timeout, logger) -> None: + def _prepare_update_credentials(self, username: str, password: str, timeout: Optional[int], logger: logging.Logger) -> tuple[ + user_admin_pb2_grpc.UserAdminServiceStub, user_admin_pb2.UpdateCredentialsRequest, dict[str, Any]]: logger.debug( "Getting index status: username=%s, password=%s, timeout=%s", username, @@ -243,7 +243,7 @@ def _prepare_update_credentials(self, username, password, timeout, logger) -> No return (user_admin_stub, update_user_request, kwargs) - def _prepare_drop_user(self, username, timeout, logger) -> None: + def _prepare_drop_user(self, username: str, timeout: Optional[int], logger: logging.Logger) -> tuple[user_admin_pb2_grpc.UserAdminServiceStub, user_admin_pb2.DropUserRequest, dict[str, Any]]: logger.debug("Getting index status: username=%s, timeout=%s", username, timeout) kwargs = {} @@ -255,7 +255,7 @@ def _prepare_drop_user(self, username, timeout, logger) -> None: return (user_admin_stub, drop_user_request, kwargs) - def _prepare_get_user(self, username, timeout, logger) -> None: + def _prepare_get_user(self, username: str, timeout : Optional[int], logger: logging.Logger) -> tuple[user_admin_pb2_grpc.UserAdminServiceStub, user_admin_pb2.GetUserRequest, dict[str, Any]]: logger.debug("Getting index status: username=%s, timeout=%s", username, timeout) kwargs = {} @@ -267,7 +267,7 @@ def _prepare_get_user(self, username, timeout, logger) -> None: return (user_admin_stub, get_user_request, kwargs) - def _prepare_list_users(self, timeout, logger) -> None: + def _prepare_list_users(self, timeout : Optional[int], logger: logging.Logger) -> tuple[user_admin_pb2_grpc.UserAdminServiceStub, Any, dict[str, Any]]: logger.debug("Getting index status") kwargs = {} @@ -279,7 +279,8 @@ def _prepare_list_users(self, timeout, logger) -> None: return (user_admin_stub, list_users_request, kwargs) - def _prepare_grant_roles(self, username, roles, timeout, logger) -> None: + def _prepare_grant_roles(self, username: str, roles: list[str], timeout: Optional[int], logger: logging.Logger) -> tuple[ + user_admin_pb2_grpc.UserAdminServiceStub, user_admin_pb2.GrantRolesRequest, dict[str, Any]]: logger.debug( "Getting index status: username=%s, roles=%s, timeout=%s", username, @@ -298,7 +299,8 @@ def _prepare_grant_roles(self, username, roles, timeout, logger) -> None: return (user_admin_stub, grant_roles_request, kwargs) - def _prepare_revoke_roles(self, username, roles, timeout, logger) -> None: + def _prepare_revoke_roles(self, username: str, roles: list[str], timeout: Optional[int], logger) -> tuple[ + user_admin_pb2_grpc.UserAdminServiceStub, user_admin_pb2.RevokeRolesRequest, dict[str, Any]]: logger.debug( "Getting index status: username=%s, roles=%s, timeout=%s", username, @@ -317,7 +319,7 @@ def _prepare_revoke_roles(self, username, roles, timeout, logger) -> None: return (user_admin_stub, revoke_roles_request, kwargs) - def _prepare_list_roles(self, timeout, logger) -> None: + def _prepare_list_roles(self, timeout: Optional[int], logger: logging.Logger) -> tuple[user_admin_pb2_grpc.UserAdminServiceStub, Any, dict[str, Any]]: logger.debug("Getting index status: timeout=%s", timeout) kwargs = {} @@ -329,52 +331,51 @@ def _prepare_list_roles(self, timeout, logger) -> None: return (user_admin_stub, list_roles_request, kwargs) - def _respond_index_list(self, response) -> None: + def _respond_index_list(self, response) -> List[types.IndexDefinition]: response_list = [] for index in response.indices: response_list.append(conversions.fromIndexDefintion(index)) return response_list - def _respond_index_get(self, response) -> None: - + def _respond_index_get(self, response: types_pb2.IndexDefinition) -> IndexDefinition: return conversions.fromIndexDefintion(response) - def _respond_get_user(self, response) -> None: - + def _respond_get_user(self, response) -> types.User: return types.User(username=response.username, roles=list(response.roles)) - def _respond_list_users(self, response) -> None: + def _respond_list_users(self, response) -> List[types.User]: user_list = [] for user in response.users: user_list.append(types.User(username=user.username, roles=list(user.roles))) return user_list - def _respond_list_roles(self, response) -> None: + def _respond_list_roles(self, response) -> List[types.Role]: role_list = [] for role in response.roles: role_list.append(types.Role(id=role.id)) return role_list - def _get_index_stub(self): + def _get_index_stub(self) -> index_pb2_grpc.IndexServiceStub: return index_pb2_grpc.IndexServiceStub(self._channel_provider.get_channel()) - def _get_user_admin_stub(self): + def _get_user_admin_stub(self) -> user_admin_pb2_grpc.UserAdminServiceStub: return user_admin_pb2_grpc.UserAdminServiceStub( self._channel_provider.get_channel() ) - def _get_index_id(self, namespace, name): + def _get_index_id(self, namespace, name) -> types_pb2.IndexId: return types_pb2.IndexId(namespace=namespace, name=name) - def _get_add_user_request(self, namespace, name): + def _get_add_user_request(self, namespace, name) -> user_admin_pb2.AddUserRequest : return user_admin_pb2.AddUserRequest(namespace=namespace, name=name) - def _prepare_wait_for_index_waiting(self, namespace, name, wait_interval): + def _prepare_wait_for_index_waiting(self, namespace, name, wait_interval) -> ( + Tuple)[index_pb2_grpc.IndexServiceStub, int, float, bool, int, index_pb2.IndexGetRequest]: return helpers._prepare_wait_for_index_waiting( self, namespace, name, wait_interval ) - def _check_timeout(self, start_time, timeout): + def _check_timeout(self, start_time: float, timeout: int): if start_time + timeout < time.monotonic(): raise AVSClientError(message="timed-out waiting for index creation") diff --git a/src/aerospike_vector_search/shared/base_channel_provider.py b/src/aerospike_vector_search/shared/base_channel_provider.py index 10bee5d3..e6f556a5 100644 --- a/src/aerospike_vector_search/shared/base_channel_provider.py +++ b/src/aerospike_vector_search/shared/base_channel_provider.py @@ -1,21 +1,16 @@ import logging import random -import time +from logging import Logger +from typing import Optional, Union, Tuple -from typing import Optional, Union - -import json -import jwt -import re import grpc +import jwt - -from .. import types from . import helpers - -from .proto_generated import vector_db_pb2, auth_pb2 from .proto_generated import auth_pb2_grpc +from .proto_generated import vector_db_pb2, auth_pb2, types_pb2 from .proto_generated import vector_db_pb2_grpc +from .. import types logger = logging.getLogger(__name__) @@ -58,7 +53,7 @@ def __init__( self.ssl_target_name_override = ssl_target_name_override - self._credentials = helpers._get_credentials(username, password) + self._credentials : Optional[types_pb2.Credentials] = helpers._get_credentials(username, password) if self._credentials: self._token = True else: @@ -84,7 +79,7 @@ def __init__( def get_token(self) -> grpc.access_token_call_credentials: return self._token - def _prepare_about(self) -> grpc.Channel: + def _prepare_about(self) -> Tuple[vector_db_pb2_grpc.AboutServiceStub, vector_db_pb2.AboutRequest]: stub = vector_db_pb2_grpc.AboutServiceStub(self.get_channel()) about_request = vector_db_pb2.AboutRequest() return (stub, about_request) @@ -128,7 +123,7 @@ def add_new_channel_to_node_channels(self, node, newEndpoints): new_channel = self._create_channel_from_server_endpoint_list(newEndpoints) self._node_channels[node] = ChannelAndEndpoints(new_channel, newEndpoints) - def init_tend_cluster(self) -> None: + def init_tend_cluster(self) -> tuple[list[ChannelAndEndpoints], bool]: end_tend_cluster = False if self._is_loadbalancer or self._closed: @@ -141,7 +136,7 @@ def init_tend_cluster(self) -> None: return (channels, end_tend_cluster) - def check_cluster_id(self, new_cluster_id) -> None: + def check_cluster_id(self, new_cluster_id) -> bool: if new_cluster_id == self._cluster_id: return False @@ -149,7 +144,7 @@ def check_cluster_id(self, new_cluster_id) -> None: return True - def update_temp_endpoints(self, endpoints, temp_endpoints): + def update_temp_endpoints(self, endpoints, temp_endpoints) -> dict[int, vector_db_pb2.ServerEndpointList]: if len(endpoints) > len(temp_endpoints): return endpoints else: @@ -169,10 +164,10 @@ def check_for_new_endpoints(self, node, newEndpoints): return (channel_endpoints, add_new_channel) - def _get_ttl(self, payload): + def _get_ttl(self, payload) -> int: return payload["exp"] - payload["iat"] - def _prepare_authenticate(self, credentials, logger): + def _prepare_authenticate(self, credentials: Optional[types_pb2.Credentials], logger: Logger): logger.debug("Refreshing auth token") auth_stub = self._get_auth_stub() @@ -180,13 +175,13 @@ def _prepare_authenticate(self, credentials, logger): return (auth_stub, auth_request) - def _get_auth_stub(self): + def _get_auth_stub(self) -> auth_pb2_grpc.AuthServiceStub: return auth_pb2_grpc.AuthServiceStub(self.get_channel()) - def _get_authenticate_request(self, credentials): + def _get_authenticate_request(self, credentials) -> auth_pb2.AuthRequest: return auth_pb2.AuthRequest(credentials=credentials) - def _respond_authenticate(self, token): + def _respond_authenticate(self, token) -> None: payload = jwt.decode( token, "", algorithms=["RS256"], options={"verify_signature": False} ) @@ -195,7 +190,7 @@ def _respond_authenticate(self, token): self._token = grpc.access_token_call_credentials(token) - def verify_compatible_server(self) -> bool: + def verify_compatible_server(self): def parse_version(v: str): return tuple(str(part) if part.isdigit() else part for part in v.split(".")) @@ -204,7 +199,7 @@ def parse_version(v: str): ): self._tend_ended.set() raise types.AVSClientError( - message="This AVS Client version is only compatbile with AVS Servers above the following version number: " + message="This AVS Client version is only compatible with AVS Servers above the following version number: " + self.minimum_required_version ) else: @@ -246,7 +241,7 @@ def _gather_temp_endpoints(self, new_cluster_ids, update_endpoints_stubs): responses.append(response) return responses - def _assign_temporary_endpoints(self, cluster_endpoints_list): + def _assign_temporary_endpoints(self, cluster_endpoints_list) -> dict[int, vector_db_pb2.ServerEndpointList]: # TODO: Worry about thread safety temp_endpoints: dict[int, vector_db_pb2.ServerEndpointList] = {} for endpoints in cluster_endpoints_list: diff --git a/src/aerospike_vector_search/shared/client_helpers.py b/src/aerospike_vector_search/shared/client_helpers.py index 1f825f04..ae0468c8 100644 --- a/src/aerospike_vector_search/shared/client_helpers.py +++ b/src/aerospike_vector_search/shared/client_helpers.py @@ -1,9 +1,10 @@ -from typing import Any, Optional, Union +from logging import Logger +from typing import Any, Optional, Union, Tuple, List import time import numpy as np from . import conversions -from .proto_generated import transact_pb2 +from .proto_generated import transact_pb2, index_pb2, index_pb2_grpc from .proto_generated import transact_pb2_grpc from .. import types from .proto_generated import types_pb2 @@ -13,20 +14,20 @@ class BaseClient(object): - def _prepare_seeds(self, seeds) -> None: + def _prepare_seeds(self, seeds) -> Tuple[types.HostPort, ...]: return helpers._prepare_seeds(seeds) def _prepare_put( self, - namespace, - key, - record_data, - set_name, - write_type, - ignore_mem_queue_full, - timeout, - logger, - ) -> None: + namespace: str, + key: Union[int, str, bytes, bytearray, np.generic, np.ndarray], + record_data: dict[str, Any], + set_name: Optional[str], + write_type: transact_pb2.WriteType, + ignore_mem_queue_full: Optional[bool], + timeout: Optional[int], + logger: Logger, + ) -> tuple[transact_pb2_grpc.TransactServiceStub, transact_pb2.PutRequest, dict[str, Any]]: logger.debug( "Putting record: namespace=%s, key=%s, record_data:%s, set_name:%s, ignore_mem_queue_full %s, timeout:%s", @@ -70,14 +71,14 @@ def _prepare_put( def _prepare_insert( self, - namespace, - key, - record_data, - set_name, - ignore_mem_queue_full, - timeout, - logger, - ) -> None: + namespace: str, + key: Union[int, str, bytes, bytearray, np.generic, np.ndarray], + record_data: dict[str, Any], + set_name: Optional[str], + ignore_mem_queue_full: Optional[bool], + timeout: Optional[int], + logger: Logger, + ) -> tuple[transact_pb2_grpc.TransactServiceStub, transact_pb2.PutRequest, dict[str, Any]]: return self._prepare_put( namespace, key, @@ -91,14 +92,14 @@ def _prepare_insert( def _prepare_update( self, - namespace, - key, - record_data, - set_name, - ignore_mem_queue_full, - timeout, - logger, - ) -> None: + namespace: str, + key: Union[int, str, bytes, bytearray, np.generic, np.ndarray], + record_data: dict[str, Any], + set_name: Optional[str], + ignore_mem_queue_full: Optional[bool], + timeout: Optional[int], + logger: Logger, + ) -> tuple[transact_pb2_grpc.TransactServiceStub, transact_pb2.PutRequest, dict[str, Any]]: return self._prepare_put( namespace, key, @@ -110,16 +111,17 @@ def _prepare_update( logger, ) + def _prepare_upsert( self, - namespace, - key, - record_data, - set_name, - ignore_mem_queue_full, - timeout, - logger, - ) -> None: + namespace: str, + key: Union[int, str, bytes, bytearray, np.generic, np.ndarray], + record_data: dict[str, Any], + set_name: Optional[str], + ignore_mem_queue_full: Optional[bool], + timeout: Optional[int], + logger: Logger, + ) -> tuple[transact_pb2_grpc.TransactServiceStub, transact_pb2.PutRequest, dict[str, Any]]: return self._prepare_put( namespace, key, @@ -133,7 +135,7 @@ def _prepare_upsert( def _prepare_get( self, namespace, key, include_fields, exclude_fields, set_name, timeout, logger - ) -> None: + ) -> tuple[transact_pb2_grpc.TransactServiceStub, types_pb2.Key, transact_pb2.GetRequest, dict[str, Any]]: logger.debug( "Getting record: namespace=%s, key=%s, include_fields:%s, exclude_fields:%s, set_name:%s, timeout:%s", @@ -157,7 +159,8 @@ def _prepare_get( return (transact_stub, key, get_request, kwargs) - def _prepare_exists(self, namespace, key, set_name, timeout, logger) -> None: + def _prepare_exists(self, namespace, key, set_name, timeout, logger) -> tuple[ + transact_pb2_grpc.TransactServiceStub, transact_pb2.ExistsRequest, dict[str, Any]]: logger.debug( "Getting record existence: namespace=%s, key=%s, set_name:%s, timeout:%s", @@ -178,7 +181,8 @@ def _prepare_exists(self, namespace, key, set_name, timeout, logger) -> None: return (transact_stub, exists_request, kwargs) - def _prepare_delete(self, namespace, key, set_name, timeout, logger) -> None: + def _prepare_delete(self, namespace, key, set_name, timeout, logger) -> tuple[ + transact_pb2_grpc.TransactServiceStub, transact_pb2.DeleteRequest, dict[str, Any]]: logger.debug( "Deleting record: namespace=%s, key=%s, set_name=%s, timeout:%s", @@ -200,8 +204,8 @@ def _prepare_delete(self, namespace, key, set_name, timeout, logger) -> None: return (transact_stub, delete_request, kwargs) def _prepare_is_indexed( - self, namespace, key, index_name, index_namespace, set_name, timeout, logger - ) -> None: + self, namespace: str, key: Union[int, str, bytes, bytearray, np.generic, np.ndarray], index_name: str, index_namespace: Optional[str], set_name: Optional[str], timeout: Optional[int],logger: Logger + ) -> tuple[transact_pb2_grpc.TransactServiceStub, transact_pb2.IsIndexedRequest, dict[str, Any]]: kwargs = {} if timeout is not None: @@ -229,16 +233,16 @@ def _prepare_is_indexed( def _prepare_vector_search( self, - namespace, - index_name, - query, - limit, - search_params, - include_fields, - exclude_fields, - timeout, - logger, - ) -> None: + namespace: str, + index_name: str, + query: Union[List[Union[bool, float]], np.ndarray], + limit: int, + search_params: Optional[types.HnswSearchParams], + include_fields: Optional[List[str]], + exclude_fields: Optional[List[str]], + timeout: Optional[int], + logger: Logger, + ) -> tuple[transact_pb2_grpc.TransactServiceStub, Any, dict[str, Any]]: kwargs = {} if timeout is not None: @@ -280,24 +284,24 @@ def _prepare_vector_search( return (transact_stub, vector_search_request, kwargs) - def _get_transact_stub(self): + def _get_transact_stub(self) -> transact_pb2_grpc.TransactServiceStub: return transact_pb2_grpc.TransactServiceStub( self._channel_provider.get_channel() ) - def _respond_get(self, response, key) -> None: + def _respond_get(self, response, key) -> types.RecordWithKey: return types.RecordWithKey( key=conversions.fromVectorDbKey(key), fields=conversions.fromVectorDbRecord(response), ) - def _respond_exists(self, response) -> None: + def _respond_exists(self, response) -> bool: return response.value - def _respond_is_indexed(self, response) -> None: + def _respond_is_indexed(self, response) -> bool: return response.value - def _respond_neighbor(self, response) -> None: + def _respond_neighbor(self, response) -> types.Neighbor: return conversions.fromVectorDbNeighbor(response) def _get_projection_spec( @@ -305,7 +309,7 @@ def _get_projection_spec( *, include_fields: Optional[list] = None, exclude_fields: Optional[list] = None, - ): + ) -> transact_pb2.ProjectionSpec: # include all fields by default if include_fields is None: include = transact_pb2.ProjectionFilter( @@ -330,8 +334,8 @@ def _get_projection_spec( return projection_spec def _get_key( - self, namespace: str, set: str, key: Union[int, str, bytes, bytearray] - ): + self, namespace: str, set: Optional[str], key: Union[int, str, bytes, bytearray, np.generic, np.ndarray] + ) -> types_pb2.Key: if isinstance(key, np.ndarray): key = key.tobytes() @@ -349,17 +353,18 @@ def _get_key( raise Exception("Invalid key type" + str(type(key))) return key - def _prepare_wait_for_index_waiting(self, namespace, name, wait_interval): + def _prepare_wait_for_index_waiting(self, namespace: str, name: str, wait_interval: int) -> ( + Tuple)[index_pb2_grpc.IndexServiceStub, float, float, bool, int, index_pb2.IndexGetRequest]: return helpers._prepare_wait_for_index_waiting( self, namespace, name, wait_interval ) - def _check_timeout(self, start_time, timeout): + def _check_timeout(self, start_time: float, timeout: int): if start_time + timeout < time.monotonic(): raise AVSClientError(message="timed-out waiting for index creation") def _check_completion_condition( - self, start_time, timeout, index_status, unmerged_record_initialized + self, start_time: float, timeout:int , index_status, unmerged_record_initialized ): self._check_timeout(start_time, timeout) diff --git a/src/aerospike_vector_search/shared/conversions.py b/src/aerospike_vector_search/shared/conversions.py index 87da1941..5a74f1f7 100644 --- a/src/aerospike_vector_search/shared/conversions.py +++ b/src/aerospike_vector_search/shared/conversions.py @@ -48,7 +48,7 @@ def toVectorDbValue(value: Any) -> types_pb2.Value: raise Exception("Invalid type " + str(type(value))) -def toMapKey(value): +def toMapKey(value) -> types_pb2.MapKey: if isinstance(value, str): return types_pb2.MapKey(stringValue=value) elif isinstance(value, int): @@ -83,78 +83,78 @@ def fromVectorDbRecord(record: types_pb2.Record) -> dict[str, Any]: return fields -def fromVectorDbNeighbor(input: types_pb2.Neighbor) -> types.Neighbor: +def fromVectorDbNeighbor(input_vectordb_neighbor: types_pb2.Neighbor) -> types.Neighbor: return types.Neighbor( - key=fromVectorDbKey(input.key), - fields=fromVectorDbRecord(input.record), - distance=input.distance, + key=fromVectorDbKey(input_vectordb_neighbor.key), + fields=fromVectorDbRecord(input_vectordb_neighbor.record), + distance=input_vectordb_neighbor.distance, ) -def fromIndexDefintion(input: types_pb2.IndexDefinition) -> types.IndexDefinition: +def fromIndexDefintion(input_data: types_pb2.IndexDefinition) -> types.IndexDefinition: return types.IndexDefinition( id=types.IndexId( - namespace=input.id.namespace, - name=input.id.name, + namespace=input_data.id.namespace, + name=input_data.id.name, ), - dimensions=input.dimensions, - vector_distance_metric=input.vectorDistanceMetric, - field=input.field, - sets=input.setFilter, + dimensions=input_data.dimensions, + vector_distance_metric=input_data.vectorDistanceMetric, + field=input_data.field, + sets=input_data.setFilter, hnsw_params=types.HnswParams( - m=input.hnswParams.m, - ef_construction=input.hnswParams.efConstruction, - ef=input.hnswParams.ef, + m=input_data.hnswParams.m, + ef_construction=input_data.hnswParams.efConstruction, + ef=input_data.hnswParams.ef, batching_params=types.HnswBatchingParams( - max_index_records=input.hnswParams.batchingParams.maxIndexRecords, - index_interval=input.hnswParams.batchingParams.indexInterval, - max_reindex_records = input.hnswParams.batchingParams.maxReindexRecords, - reindex_interval = input.hnswParams.batchingParams.reindexInterval + max_index_records=input_data.hnswParams.batchingParams.maxIndexRecords, + index_interval=input_data.hnswParams.batchingParams.indexInterval, + max_reindex_records = input_data.hnswParams.batchingParams.maxReindexRecords, + reindex_interval = input_data.hnswParams.batchingParams.reindexInterval ), - max_mem_queue_size=input.hnswParams.maxMemQueueSize, + max_mem_queue_size=input_data.hnswParams.maxMemQueueSize, index_caching_params=types.HnswCachingParams( - max_entries=input.hnswParams.indexCachingParams.maxEntries, - expiry=input.hnswParams.indexCachingParams.expiry, + max_entries=input_data.hnswParams.indexCachingParams.maxEntries, + expiry=input_data.hnswParams.indexCachingParams.expiry, ), healer_params=types.HnswHealerParams( - max_scan_rate_per_node=input.hnswParams.healerParams.maxScanRatePerNode, - max_scan_page_size=input.hnswParams.healerParams.maxScanPageSize, - re_index_percent=input.hnswParams.healerParams.reindexPercent, - schedule=input.hnswParams.healerParams.schedule, - parallelism=input.hnswParams.healerParams.parallelism, + max_scan_rate_per_node=input_data.hnswParams.healerParams.maxScanRatePerNode, + max_scan_page_size=input_data.hnswParams.healerParams.maxScanPageSize, + re_index_percent=input_data.hnswParams.healerParams.reindexPercent, + schedule=input_data.hnswParams.healerParams.schedule, + parallelism=input_data.hnswParams.healerParams.parallelism, ), merge_params=types.HnswIndexMergeParams( - index_parallelism=input.hnswParams.mergeParams.indexParallelism, - reindex_parallelism=input.hnswParams.mergeParams.reIndexParallelism, + index_parallelism=input_data.hnswParams.mergeParams.indexParallelism, + reindex_parallelism=input_data.hnswParams.mergeParams.reIndexParallelism, ), ), - index_labels=input.labels, + index_labels=input_data.labels, storage=types.IndexStorage( - namespace=input.storage.namespace, set_name=input.storage.set + namespace=input_data.storage.namespace, set_name=input_data.storage.set ), ) -def fromVectorDbValue(input: types_pb2.Value) -> Any: - if input.HasField("stringValue"): - return input.stringValue - elif input.HasField("intValue"): - return input.intValue - elif input.HasField("longValue"): - return input.longValue - elif input.HasField("bytesValue"): - return input.bytesValue - elif input.HasField("mapValue"): - dict = {} - for entry in input.mapValue.entries: +def fromVectorDbValue(input_vector: types_pb2.Value) -> Any: + if input_vector.HasField("stringValue"): + return input_vector.stringValue + elif input_vector.HasField("intValue"): + return input_vector.intValue + elif input_vector.HasField("longValue"): + return input_vector.longValue + elif input_vector.HasField("bytesValue"): + return input_vector.bytesValue + elif input_vector.HasField("mapValue"): + data = {} + for entry in input_vector.mapValue.entries: k = fromVectorDbValue(entry.key) v = fromVectorDbValue(entry.value) - dict[k] = v - return dict - elif input.HasField("listValue"): - return [fromVectorDbValue(v) for v in input.listValue.entries] - elif input.HasField("vectorValue"): - vector = input.vectorValue + data[k] = v + return data + elif input_vector.HasField("listValue"): + return [fromVectorDbValue(v) for v in input_vector.listValue.entries] + elif input_vector.HasField("vectorValue"): + vector = input_vector.vectorValue if vector.HasField("floatData"): return [v for v in vector.floatData.value] if vector.HasField("boolData"): diff --git a/src/aerospike_vector_search/shared/helpers.py b/src/aerospike_vector_search/shared/helpers.py index 0604af07..22c791e6 100644 --- a/src/aerospike_vector_search/shared/helpers.py +++ b/src/aerospike_vector_search/shared/helpers.py @@ -2,9 +2,10 @@ from .. import types from .proto_generated import types_pb2, index_pb2 from .proto_generated import index_pb2_grpc +from typing import Union, Tuple, Optional -def _prepare_seeds(seeds) -> None: +def _prepare_seeds(seeds: Union[types.HostPort, Tuple[types.HostPort, ...]]) -> Tuple[types.HostPort, ...]: if not seeds: raise types.AVSClientError(message="at least one seed host needed") @@ -15,7 +16,8 @@ def _prepare_seeds(seeds) -> None: return seeds -def _prepare_wait_for_index_waiting(client, namespace, name, wait_interval): +def _prepare_wait_for_index_waiting(client, namespace: str, name: str, wait_interval: Optional[int]) -> ( + Tuple)[index_pb2_grpc.IndexServiceStub, int, float, bool, int, index_pb2.IndexGetRequest]: unmerged_record_initialized = False start_time = time.monotonic() @@ -34,7 +36,7 @@ def _prepare_wait_for_index_waiting(client, namespace, name, wait_interval): ) -def _get_credentials(username, password): +def _get_credentials(username: str, password: str) -> Optional[types_pb2.Credentials]: if not username: return None return types_pb2.Credentials( diff --git a/src/aerospike_vector_search/types.py b/src/aerospike_vector_search/types.py index 386ec529..988e0033 100644 --- a/src/aerospike_vector_search/types.py +++ b/src/aerospike_vector_search/types.py @@ -7,7 +7,7 @@ class HostPort(object): """ represents host, port and TLS usage information. - Used primarily when intializing client. + Used primarily when initializing client. :param host: The host address. :type host: str @@ -50,13 +50,13 @@ def __repr__(self) -> str: f"key={self.key})" ) - def __str__(self): + def __str__(self) -> str: """ Returns a string representation of the key. """ return f"Key: namespace='{self.namespace}', set='{self.set}', key={self.key}" - def __eq__(self, other): + def __eq__(self, other) -> bool: if not isinstance(other, Key): return NotImplemented @@ -76,7 +76,7 @@ class RecordWithKey(object): :param key: (Key): The key of the record. :type key: Key - :param fields: : The fields associated with the record. + :param fields: The fields associated with the record. :type fields: dict[str, Any] """ @@ -84,7 +84,7 @@ def __init__(self, *, key: Key, fields: dict[str, Any]) -> None: self.key = key self.fields = fields - def __str__(self): + def __str__(self) -> str: """ Returns a string representation of the record, including a key and fields. """ @@ -142,7 +142,7 @@ def __repr__(self) -> str: f"distance={self.distance})" ) - def __str__(self): + def __str__(self) -> str: """ Returns a string representation of the neighboring record. """ @@ -424,7 +424,7 @@ class HnswCachingParams(object): :param max_entries: maximum number of entries to cache. Default is the global cache config, which is configured in the AVS Server. :type max_entries: Optional[int] - :param expiry: Cache entries will expire after this time in millseconds has expired after the entry was add to the cache. + :param expiry: Cache entries will expire after this time in milliseconds has expired after the entry was add to the cache. Default is the global cache config, which is configured in the AVS Server. :type expiry: Optional[int] @@ -547,13 +547,13 @@ def __init__( m: Optional[int] = None, ef_construction: Optional[int] = None, ef: Optional[int] = None, - batching_params: Optional[HnswBatchingParams] = HnswBatchingParams(), + batching_params: HnswBatchingParams = HnswBatchingParams(), max_mem_queue_size: Optional[int] = None, - index_caching_params: Optional[HnswCachingParams] = HnswCachingParams(), - healer_params: Optional[HnswHealerParams] = HnswHealerParams(), - merge_params: Optional[HnswIndexMergeParams] = HnswIndexMergeParams(), - enable_vector_integrity_check : Optional[bool] = True, - record_caching_params : Optional[HnswCachingParams] = HnswCachingParams() + index_caching_params: HnswCachingParams = HnswCachingParams(), + healer_params: HnswHealerParams = HnswHealerParams(), + merge_params: HnswIndexMergeParams = HnswIndexMergeParams(), + enable_vector_integrity_check : bool = True, + record_caching_params : HnswCachingParams = HnswCachingParams() ) -> None: self.m = m self.ef_construction = ef_construction @@ -660,7 +660,7 @@ def __init__(self, *, ef: Optional[int] = None) -> None: self.ef = ef - def _to_pb2(self): + def _to_pb2(self) -> types_pb2.HnswSearchParams: params = types_pb2.HnswSearchParams() params.ef = self.ef return params @@ -900,7 +900,7 @@ class IndexDefinition(object): :param storage: Index storage details. :type storage: Optional[IndexStorage] default None - :param index_labels: Meta data associated with the index. Defaults to None. + :param index_labels: Metadata associated with the index. Defaults to None. :type index_labels: Optional[dict[str, str]] """ @@ -984,13 +984,13 @@ class AVSServerError(AVSError): def __init__(self, *, rpc_error) -> None: self.rpc_error = rpc_error - def __str__(self): + def __str__(self) -> str: return f"AVSServerError(rpc_error={self.rpc_error})" class AVSClientError(AVSError): """ - Custom exception raised for errors related to AVS client-side failures.. + Custom exception raised for errors related to AVS client-side failures. :param message: error messaging raised by the AVS Client. Defaults to None. :type set_name: str @@ -1000,7 +1000,7 @@ class AVSClientError(AVSError): def __init__(self, *, message) -> None: self.message = message - def __str__(self): + def __str__(self) -> str: return f"AVSClientError(message={self.message})"