Skip to content

Commit

Permalink
VEC-399: Add typehints (#63)
Browse files Browse the repository at this point in the history
Add typehints
  • Loading branch information
rahul-aerospike authored Nov 14, 2024
1 parent 0f7cfc5 commit c99344f
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 258 deletions.
23 changes: 12 additions & 11 deletions src/aerospike_vector_search/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .internal import channel_provider
from .shared.admin_helpers import BaseClient
from .shared.conversions import fromIndexStatusResponse
from .types import IndexDefinition, Role

logger = logging.getLogger(__name__)

Expand All @@ -19,7 +20,7 @@ class Client(BaseClient):
This client is designed to conduct Aerospike Vector Search administrative operation such as creating indexes, querying index information, and dropping indexes.
:param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all of the nodes in the cluster.
:param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all the nodes in the cluster.
:type seeds: Union[types.HostPort, tuple[types.HostPort, ...]]
:param listener_name: An external (NATed) address and port combination that differs from the actual address and port where AVS is listening. Clients can access AVS on a node using the advertised listener address and port. Defaults to None.
Expand Down Expand Up @@ -86,7 +87,7 @@ def index_create(
name: str,
vector_field: str,
dimensions: int,
vector_distance_metric: Optional[types.VectorDistanceMetric] = (
vector_distance_metric: types.VectorDistanceMetric = (
types.VectorDistanceMetric.SQUARED_EUCLIDEAN
),
sets: Optional[str] = None,
Expand All @@ -113,7 +114,7 @@ def index_create(
:param vector_distance_metric:
The distance metric used to compare when performing a vector search.
Defaults to :class:`VectorDistanceMetric.SQUARED_EUCLIDEAN`.
:type vector_distance_metric: Optional[types.VectorDistanceMetric]
:type vector_distance_metric: types.VectorDistanceMetric
:param sets: The set used for the index. Defaults to None.
:type sets: Optional[str]
Expand All @@ -123,7 +124,7 @@ def index_create(
specified for :class:`types.HnswParams` will be used.
:type index_params: Optional[types.HnswParams]
:param index_labels: Meta data associated with the index. Defaults to None.
:param index_labels: Metadata associated with the index. Defaults to None.
:type index_labels: Optional[dict[str, str]]
:param index_storage: Namespace and set where index overhead (non-vector data) is stored.
Expand Down Expand Up @@ -269,7 +270,7 @@ def index_drop(

def index_list(
self, timeout: Optional[int] = None, apply_defaults: Optional[bool] = True
) -> list[dict]:
) -> list[IndexDefinition]:
"""
List all indices.
Expand Down Expand Up @@ -308,7 +309,7 @@ def index_get(
name: str,
timeout: Optional[int] = None,
apply_defaults: Optional[bool] = True,
) -> dict[str, Union[int, str]]:
) -> IndexDefinition:
"""
Retrieve the information related with an index.
Expand Down Expand Up @@ -630,7 +631,7 @@ def revoke_roles(
logger.error("Failed to revoke roles with error: %s", e)
raise types.AVSServerError(rpc_error=e)

def list_roles(self, timeout: Optional[int] = None) -> list[dict]:
def list_roles(self, timeout: Optional[int] = None) -> list[Role]:
"""
List roles available on the AVS server.
Expand Down Expand Up @@ -663,8 +664,8 @@ def _wait_for_index_creation(
*,
namespace: str,
name: str,
timeout: Optional[int] = sys.maxsize,
wait_interval: Optional[int] = 0.1,
timeout: int = sys.maxsize,
wait_interval: float = 0.1,
) -> None:
"""
Wait for the index to be created.
Expand Down Expand Up @@ -697,8 +698,8 @@ def _wait_for_index_deletion(
*,
namespace: str,
name: str,
timeout: Optional[int] = sys.maxsize,
wait_interval: Optional[int] = 0.1,
timeout: int = sys.maxsize,
wait_interval: float = 0.1,
) -> None:
"""
Wait for the index to be deleted.
Expand Down
37 changes: 18 additions & 19 deletions src/aerospike_vector_search/aio/admin.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import asyncio
import logging
import sys

from typing import Optional, Union

import grpc

from .internal import channel_provider
from .. import types
from ..shared.conversions import fromIndexStatusResponse
from ..shared.admin_helpers import BaseClient

from ..shared.conversions import fromIndexStatusResponse
from ..types import Role, IndexDefinition

logger = logging.getLogger(__name__)

Expand All @@ -21,7 +20,7 @@ class Client(BaseClient):
This client is designed to conduct Aerospike Vector Search administrative operation such as creating indexes, querying index information, and dropping indexes.
:param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all of the nodes in the cluster.
:param seeds: Defines the AVS nodes to which you want AVS to connect. AVS iterates through the seed nodes. After connecting to a node, AVS discovers all the nodes in the cluster.
:type seeds: Union[types.HostPort, tuple[types.HostPort, ...]]
:param listener_name: An external (NATed) address and port combination that differs from the actual address and port where AVS is listening. Clients can access AVS on a node using the advertised listener address and port. Defaults to None.
Expand Down Expand Up @@ -88,7 +87,7 @@ async def index_create(
name: str,
vector_field: str,
dimensions: int,
vector_distance_metric: Optional[types.VectorDistanceMetric] = (
vector_distance_metric: types.VectorDistanceMetric = (
types.VectorDistanceMetric.SQUARED_EUCLIDEAN
),
sets: Optional[str] = None,
Expand All @@ -115,7 +114,7 @@ async def index_create(
:param vector_distance_metric:
The distance metric used to compare when performing a vector search.
Defaults to :class:`VectorDistanceMetric.SQUARED_EUCLIDEAN`.
:type vector_distance_metric: Optional[types.VectorDistanceMetric]
:type vector_distance_metric: types.VectorDistanceMetric
:param sets: The set used for the index. Defaults to None.
:type sets: Optional[str]
Expand All @@ -125,7 +124,7 @@ async def index_create(
specified for :class:`types.HnswParams` will be used.
:type index_params: Optional[types.HnswParams]
:param index_labels: Meta data associated with the index. Defaults to None.
:param index_labels: Metadata associated with the index. Defaults to None.
:type index_labels: Optional[dict[str, str]]
:param index_storage: Namespace and set where index overhead (non-vector data) is stored.
Expand Down Expand Up @@ -278,7 +277,7 @@ async def index_drop(

async def index_list(
self, timeout: Optional[int] = None, apply_defaults: Optional[bool] = True
) -> list[dict]:
) -> list[IndexDefinition]:
"""
List all indices.
Expand Down Expand Up @@ -320,7 +319,7 @@ async def index_get(
name: str,
timeout: Optional[int] = None,
apply_defaults: Optional[bool] = True,
) -> dict[str, Union[int, str]]:
) -> IndexDefinition:
"""
Retrieve the information related with an index.
Expand Down Expand Up @@ -594,14 +593,14 @@ async def list_users(self, timeout: Optional[int] = None) -> list[types.User]:

async def grant_roles(
self, *, username: str, roles: list[str], timeout: Optional[int] = None
) -> int:
) :
"""
Grant roles to existing AVS Users.
:param username: Username of the user which will receive the roles.
:type username: str
:param roles: Roles the specified user will recieved.
:param roles: Roles the specified user will receive.
:type roles: list[str]
:param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError <aerospike_vector_search.types.AVSServerError>`. Defaults to None.
Expand Down Expand Up @@ -630,7 +629,7 @@ async def grant_roles(

async def revoke_roles(
self, *, username: str, roles: list[str], timeout: Optional[int] = None
) -> int:
) :
"""
Revoke roles from existing AVS Users.
Expand Down Expand Up @@ -664,7 +663,7 @@ async def revoke_roles(
logger.error("Failed to revoke roles with error: %s", e)
raise types.AVSServerError(rpc_error=e)

async def list_roles(self, timeout: Optional[int] = None) -> None:
async def list_roles(self, timeout: Optional[int] = None) -> list[Role]:
"""
list roles of existing AVS Users.
Expand Down Expand Up @@ -699,8 +698,8 @@ async def _wait_for_index_creation(
*,
namespace: str,
name: str,
timeout: Optional[int] = sys.maxsize,
wait_interval: Optional[int] = 0.1,
timeout: int = sys.maxsize,
wait_interval: float = 0.1,
) -> None:
"""
Wait for the index to be created.
Expand All @@ -717,7 +716,7 @@ async def _wait_for_index_creation(
index_creation_request,
credentials=self._channel_provider.get_token(),
)
logger.debug("Index created succesfully")
logger.debug("Index created successfully")
# Index has been created
return
except grpc.RpcError as e:
Expand All @@ -734,8 +733,8 @@ async def _wait_for_index_deletion(
*,
namespace: str,
name: str,
timeout: Optional[int] = sys.maxsize,
wait_interval: Optional[int] = 0.1,
timeout: int = sys.maxsize,
wait_interval: float = 0.1,
) -> None:
"""
Wait for the index to be deleted.
Expand All @@ -759,7 +758,7 @@ async def _wait_for_index_deletion(
await asyncio.sleep(wait_interval)
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.NOT_FOUND:
logger.debug("Index deleted succesfully")
logger.debug("Index deleted successfully")
# Index has been created
return
else:
Expand Down
12 changes: 6 additions & 6 deletions src/aerospike_vector_search/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ async def exists(
exists_request, credentials=self._channel_provider.get_token(), **kwargs
)
except grpc.RpcError as e:
logger.error("Failed to verfiy vector existence with error: %s", e)
logger.error("Failed to verify vector existence with error: %s", e)
raise types.AVSServerError(rpc_error=e)

return self._respond_exists(response)
Expand Down Expand Up @@ -679,9 +679,9 @@ async def wait_for_index_completion(
*,
namespace: str,
name: str,
timeout: Optional[int] = sys.maxsize,
wait_interval: Optional[int] = 12,
validation_threshold: Optional[int] = 2,
timeout: int = sys.maxsize,
wait_interval: int = 12,
validation_threshold: int = 2,
) -> None:
"""
Wait for the index to have no pending index update operations.
Expand Down Expand Up @@ -714,7 +714,7 @@ async def wait_for_index_completion(
# Wait interval between polling
(
index_stub,
wait_interval,
wait_interval_float,
start_time,
unmerged_record_initialized,
validation_count,
Expand All @@ -740,7 +740,7 @@ async def wait_for_index_completion(
validation_count += 1
else:
validation_count = 0
await asyncio.sleep(wait_interval)
await asyncio.sleep(wait_interval_float)

async def close(self):
"""
Expand Down
11 changes: 4 additions & 7 deletions src/aerospike_vector_search/aio/internal/channel_provider.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import re
import asyncio
import logging
import jwt
from jwt.exceptions import InvalidTokenError
from typing import Optional, Union

import google.protobuf.empty_pb2
import grpc
import random

from ... import types
from ...shared.proto_generated import vector_db_pb2
Expand All @@ -18,7 +15,7 @@

logger = logging.getLogger(__name__)

TEND_INTERVAL = 1
TEND_INTERVAL : int = 1


class ChannelProvider(base_channel_provider.BaseChannelProvider):
Expand Down Expand Up @@ -72,7 +69,7 @@ def __init__(
self._tend_exception: Exception = None

async def _is_ready(self):
# Wait 1 round of cluster tending, auth token initialization, and server client compatiblity verfication
# Wait 1 round of cluster tending, auth token initialization, and server client compatibility verification
await self._ready.wait()

# This propogates any fatal/unexpected errors from client initialization/tending to the client.
Expand All @@ -92,7 +89,7 @@ async def _tend(self):

self._ready.set()

except Exception as e:
except Exception as e:
# Set all event to prevent hanging if initial tend fails with error
self._tend_ended.set()
self._ready.set()
Expand Down Expand Up @@ -185,7 +182,7 @@ async def _tend_token(self):
try:
if not self._token:
return
elif self._token != True:
elif not self._token:
await asyncio.sleep((self._ttl * self._ttl_threshold))

await self._update_token_and_ttl()
Expand Down
15 changes: 8 additions & 7 deletions src/aerospike_vector_search/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings

import grpc
import numpy as np

from . import types
from .internal import channel_provider
Expand Down Expand Up @@ -199,7 +200,7 @@ def upsert(
self,
*,
namespace: str,
key: Union[int, str, bytes, bytearray],
key: Union[int, str, bytes, bytearray, np.generic, np.ndarray],
record_data: dict[str, Any],
set_name: Optional[str] = None,
ignore_mem_queue_full: Optional[bool] = False,
Expand Down Expand Up @@ -369,7 +370,7 @@ def exists(
exists_request, credentials=self._channel_provider.get_token(), **kwargs
)
except grpc.RpcError as e:
logger.error("Failed to verfiy vector existence with error: %s", e)
logger.error("Failed to verify vector existence with error: %s", e)
raise types.AVSServerError(rpc_error=e)

return self._respond_exists(response)
Expand Down Expand Up @@ -661,9 +662,9 @@ def wait_for_index_completion(
*,
namespace: str,
name: str,
timeout: Optional[int] = sys.maxsize,
wait_interval: Optional[int] = 12,
validation_threshold: Optional[int] = 2,
timeout: int = sys.maxsize,
wait_interval: int = 12,
validation_threshold: int = 2,
) -> None:
"""
Wait for the index to have no pending index update operations.
Expand Down Expand Up @@ -693,7 +694,7 @@ def wait_for_index_completion(
"""
(
index_stub,
wait_interval,
wait_interval_float,
start_time,
unmerged_record_initialized,
validation_count,
Expand All @@ -719,7 +720,7 @@ def wait_for_index_completion(
validation_count += 1
else:
validation_count = 0
time.sleep(wait_interval)
time.sleep(wait_interval_float)

def close(self):
"""
Expand Down
Loading

0 comments on commit c99344f

Please sign in to comment.