Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VEC-399: Add typehints #63

Merged
merged 11 commits into from
Nov 14, 2024
23 changes: 12 additions & 11 deletions src/aerospike_vector_search/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .internal import channel_provider
from .shared.admin_helpers import BaseClient
from .shared.conversions import fromIndexStatusResponse
from .types import IndexDefinition, Role

logger = logging.getLogger(__name__)

Expand All @@ -19,7 +20,7 @@ class Client(BaseClient):

This client is designed to conduct Aerospike Vector Search administrative operation such as creating indexes, querying index information, and dropping indexes.

:param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all of the nodes in the cluster.
:param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all the nodes in the cluster.
:type seeds: Union[types.HostPort, tuple[types.HostPort, ...]]

:param listener_name: An external (NATed) address and port combination that differs from the actual address and port where AVS is listening. Clients can access AVS on a node using the advertised listener address and port. Defaults to None.
Expand Down Expand Up @@ -86,7 +87,7 @@ def index_create(
name: str,
vector_field: str,
dimensions: int,
vector_distance_metric: Optional[types.VectorDistanceMetric] = (
vector_distance_metric: types.VectorDistanceMetric = (
types.VectorDistanceMetric.SQUARED_EUCLIDEAN
),
sets: Optional[str] = None,
Expand All @@ -113,7 +114,7 @@ def index_create(
:param vector_distance_metric:
The distance metric used to compare when performing a vector search.
Defaults to :class:`VectorDistanceMetric.SQUARED_EUCLIDEAN`.
:type vector_distance_metric: Optional[types.VectorDistanceMetric]
:type vector_distance_metric: types.VectorDistanceMetric

:param sets: The set used for the index. Defaults to None.
:type sets: Optional[str]
Expand All @@ -123,7 +124,7 @@ def index_create(
specified for :class:`types.HnswParams` will be used.
:type index_params: Optional[types.HnswParams]

:param index_labels: Meta data associated with the index. Defaults to None.
:param index_labels: Metadata associated with the index. Defaults to None.
:type index_labels: Optional[dict[str, str]]

:param index_storage: Namespace and set where index overhead (non-vector data) is stored.
Expand Down Expand Up @@ -269,7 +270,7 @@ def index_drop(

def index_list(
self, timeout: Optional[int] = None, apply_defaults: Optional[bool] = True
) -> list[dict]:
) -> list[IndexDefinition]:
"""
List all indices.

Expand Down Expand Up @@ -308,7 +309,7 @@ def index_get(
name: str,
timeout: Optional[int] = None,
apply_defaults: Optional[bool] = True,
) -> dict[str, Union[int, str]]:
) -> IndexDefinition:
"""
Retrieve the information related with an index.

Expand Down Expand Up @@ -630,7 +631,7 @@ def revoke_roles(
logger.error("Failed to revoke roles with error: %s", e)
raise types.AVSServerError(rpc_error=e)

def list_roles(self, timeout: Optional[int] = None) -> list[dict]:
def list_roles(self, timeout: Optional[int] = None) -> list[Role]:
"""
List roles available on the AVS server.

Expand Down Expand Up @@ -663,8 +664,8 @@ def _wait_for_index_creation(
*,
namespace: str,
name: str,
timeout: Optional[int] = sys.maxsize,
wait_interval: Optional[int] = 0.1,
timeout: int = sys.maxsize,
wait_interval: float = 0.1,
) -> None:
"""
Wait for the index to be created.
Expand Down Expand Up @@ -697,8 +698,8 @@ def _wait_for_index_deletion(
*,
namespace: str,
name: str,
timeout: Optional[int] = sys.maxsize,
wait_interval: Optional[int] = 0.1,
timeout: int = sys.maxsize,
wait_interval: float = 0.1,
) -> None:
"""
Wait for the index to be deleted.
Expand Down
37 changes: 18 additions & 19 deletions src/aerospike_vector_search/aio/admin.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import asyncio
import logging
import sys

from typing import Optional, Union

import grpc

from .internal import channel_provider
from .. import types
from ..shared.conversions import fromIndexStatusResponse
from ..shared.admin_helpers import BaseClient

from ..shared.conversions import fromIndexStatusResponse
from ..types import Role, IndexDefinition

logger = logging.getLogger(__name__)

Expand All @@ -21,7 +20,7 @@ class Client(BaseClient):

This client is designed to conduct Aerospike Vector Search administrative operation such as creating indexes, querying index information, and dropping indexes.

:param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all of the nodes in the cluster.
:param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all the nodes in the cluster.
:type seeds: Union[types.HostPort, tuple[types.HostPort, ...]]

:param listener_name: An external (NATed) address and port combination that differs from the actual address and port where AVS is listening. Clients can access AVS on a node using the advertised listener address and port. Defaults to None.
Expand Down Expand Up @@ -88,7 +87,7 @@ async def index_create(
name: str,
vector_field: str,
dimensions: int,
vector_distance_metric: Optional[types.VectorDistanceMetric] = (
vector_distance_metric: types.VectorDistanceMetric = (
types.VectorDistanceMetric.SQUARED_EUCLIDEAN
),
sets: Optional[str] = None,
Expand All @@ -115,7 +114,7 @@ async def index_create(
:param vector_distance_metric:
The distance metric used to compare when performing a vector search.
Defaults to :class:`VectorDistanceMetric.SQUARED_EUCLIDEAN`.
:type vector_distance_metric: Optional[types.VectorDistanceMetric]
:type vector_distance_metric: types.VectorDistanceMetric

:param sets: The set used for the index. Defaults to None.
:type sets: Optional[str]
Expand All @@ -125,7 +124,7 @@ async def index_create(
specified for :class:`types.HnswParams` will be used.
:type index_params: Optional[types.HnswParams]

:param index_labels: Meta data associated with the index. Defaults to None.
:param index_labels: Metadata associated with the index. Defaults to None.
:type index_labels: Optional[dict[str, str]]

:param index_storage: Namespace and set where index overhead (non-vector data) is stored.
Expand Down Expand Up @@ -278,7 +277,7 @@ async def index_drop(

async def index_list(
self, timeout: Optional[int] = None, apply_defaults: Optional[bool] = True
) -> list[dict]:
) -> list[IndexDefinition]:
"""
List all indices.

Expand Down Expand Up @@ -320,7 +319,7 @@ async def index_get(
name: str,
timeout: Optional[int] = None,
apply_defaults: Optional[bool] = True,
) -> dict[str, Union[int, str]]:
) -> IndexDefinition:
"""
Retrieve the information related with an index.

Expand Down Expand Up @@ -594,14 +593,14 @@ async def list_users(self, timeout: Optional[int] = None) -> list[types.User]:

async def grant_roles(
self, *, username: str, roles: list[str], timeout: Optional[int] = None
) -> int:
) :
"""
Grant roles to existing AVS Users.

:param username: Username of the user which will receive the roles.
:type username: str

:param roles: Roles the specified user will recieved.
:param roles: Roles the specified user will receive.
:type roles: list[str]

:param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError <aerospike_vector_search.types.AVSServerError>`. Defaults to None.
Expand Down Expand Up @@ -630,7 +629,7 @@ async def grant_roles(

async def revoke_roles(
self, *, username: str, roles: list[str], timeout: Optional[int] = None
) -> int:
) :
"""
Revoke roles from existing AVS Users.

Expand Down Expand Up @@ -664,7 +663,7 @@ async def revoke_roles(
logger.error("Failed to revoke roles with error: %s", e)
raise types.AVSServerError(rpc_error=e)

async def list_roles(self, timeout: Optional[int] = None) -> None:
async def list_roles(self, timeout: Optional[int] = None) -> list[Role]:
"""
list roles of existing AVS Users.

Expand Down Expand Up @@ -699,8 +698,8 @@ async def _wait_for_index_creation(
*,
namespace: str,
name: str,
timeout: Optional[int] = sys.maxsize,
wait_interval: Optional[int] = 0.1,
timeout: int = sys.maxsize,
wait_interval: float = 0.1,
) -> None:
"""
Wait for the index to be created.
Expand All @@ -717,7 +716,7 @@ async def _wait_for_index_creation(
index_creation_request,
credentials=self._channel_provider.get_token(),
)
logger.debug("Index created succesfully")
logger.debug("Index created successfully")
# Index has been created
return
except grpc.RpcError as e:
Expand All @@ -734,8 +733,8 @@ async def _wait_for_index_deletion(
*,
namespace: str,
name: str,
timeout: Optional[int] = sys.maxsize,
wait_interval: Optional[int] = 0.1,
timeout: int = sys.maxsize,
wait_interval: float = 0.1,
) -> None:
"""
Wait for the index to be deleted.
Expand All @@ -759,7 +758,7 @@ async def _wait_for_index_deletion(
await asyncio.sleep(wait_interval)
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.NOT_FOUND:
logger.debug("Index deleted succesfully")
logger.debug("Index deleted successfully")
# Index has been created
return
else:
Expand Down
12 changes: 6 additions & 6 deletions src/aerospike_vector_search/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ async def exists(
exists_request, credentials=self._channel_provider.get_token(), **kwargs
)
except grpc.RpcError as e:
logger.error("Failed to verfiy vector existence with error: %s", e)
logger.error("Failed to verify vector existence with error: %s", e)
raise types.AVSServerError(rpc_error=e)

return self._respond_exists(response)
Expand Down Expand Up @@ -679,9 +679,9 @@ async def wait_for_index_completion(
*,
namespace: str,
name: str,
timeout: Optional[int] = sys.maxsize,
wait_interval: Optional[int] = 12,
validation_threshold: Optional[int] = 2,
timeout: int = sys.maxsize,
wait_interval: int = 12,
validation_threshold: int = 2,
) -> None:
"""
Wait for the index to have no pending index update operations.
Expand Down Expand Up @@ -714,7 +714,7 @@ async def wait_for_index_completion(
# Wait interval between polling
(
index_stub,
wait_interval,
wait_interval_float,
start_time,
unmerged_record_initialized,
validation_count,
Expand All @@ -740,7 +740,7 @@ async def wait_for_index_completion(
validation_count += 1
else:
validation_count = 0
await asyncio.sleep(wait_interval)
await asyncio.sleep(wait_interval_float)

async def close(self):
"""
Expand Down
11 changes: 4 additions & 7 deletions src/aerospike_vector_search/aio/internal/channel_provider.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import re
import asyncio
import logging
import jwt
from jwt.exceptions import InvalidTokenError
from typing import Optional, Union

import google.protobuf.empty_pb2
import grpc
import random

from ... import types
from ...shared.proto_generated import vector_db_pb2
Expand All @@ -18,7 +15,7 @@

logger = logging.getLogger(__name__)

TEND_INTERVAL = 1
TEND_INTERVAL : int = 1


class ChannelProvider(base_channel_provider.BaseChannelProvider):
Expand Down Expand Up @@ -72,7 +69,7 @@ def __init__(
self._tend_exception: Exception = None

async def _is_ready(self):
# Wait 1 round of cluster tending, auth token initialization, and server client compatiblity verfication
# Wait 1 round of cluster tending, auth token initialization, and server client compatibility verification
await self._ready.wait()

# This propogates any fatal/unexpected errors from client initialization/tending to the client.
Expand All @@ -92,7 +89,7 @@ async def _tend(self):

self._ready.set()

except Exception as e:
except Exception as e:
# Set all event to prevent hanging if initial tend fails with error
self._tend_ended.set()
self._ready.set()
Expand Down Expand Up @@ -185,7 +182,7 @@ async def _tend_token(self):
try:
if not self._token:
return
elif self._token != True:
elif not self._token:
await asyncio.sleep((self._ttl * self._ttl_threshold))

await self._update_token_and_ttl()
Expand Down
15 changes: 8 additions & 7 deletions src/aerospike_vector_search/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings

import grpc
import numpy as np

from . import types
from .internal import channel_provider
Expand Down Expand Up @@ -199,7 +200,7 @@ def upsert(
self,
*,
namespace: str,
key: Union[int, str, bytes, bytearray],
key: Union[int, str, bytes, bytearray, np.generic, np.ndarray],
record_data: dict[str, Any],
set_name: Optional[str] = None,
ignore_mem_queue_full: Optional[bool] = False,
Expand Down Expand Up @@ -369,7 +370,7 @@ def exists(
exists_request, credentials=self._channel_provider.get_token(), **kwargs
)
except grpc.RpcError as e:
logger.error("Failed to verfiy vector existence with error: %s", e)
logger.error("Failed to verify vector existence with error: %s", e)
raise types.AVSServerError(rpc_error=e)

return self._respond_exists(response)
Expand Down Expand Up @@ -661,9 +662,9 @@ def wait_for_index_completion(
*,
namespace: str,
name: str,
timeout: Optional[int] = sys.maxsize,
wait_interval: Optional[int] = 12,
validation_threshold: Optional[int] = 2,
timeout: int = sys.maxsize,
wait_interval: int = 12,
validation_threshold: int = 2,
) -> None:
"""
Wait for the index to have no pending index update operations.
Expand Down Expand Up @@ -693,7 +694,7 @@ def wait_for_index_completion(
"""
(
index_stub,
wait_interval,
wait_interval_float,
start_time,
unmerged_record_initialized,
validation_count,
Expand All @@ -719,7 +720,7 @@ def wait_for_index_completion(
validation_count += 1
else:
validation_count = 0
time.sleep(wait_interval)
time.sleep(wait_interval_float)

def close(self):
"""
Expand Down
Loading
Loading