diff --git a/.github/workflows/extensive_vector_search_tests.yml b/.github/workflows/extensive_vector_search_tests.yml index 2647ae46..22ad4ed6 100644 --- a/.github/workflows/extensive_vector_search_tests.yml +++ b/.github/workflows/extensive_vector_search_tests.yml @@ -3,7 +3,7 @@ name: Run long running vector search tests on: push: branches: - - dev + - main jobs: test-exhaustive-vector-search: @@ -18,10 +18,10 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/integration_test.yml b/.github/workflows/integration_test.yml index f9c1bded..fbfb1517 100644 --- a/.github/workflows/integration_test.yml +++ b/.github/workflows/integration_test.yml @@ -21,10 +21,10 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/src/aerospike_vector_search/admin.py b/src/aerospike_vector_search/admin.py index 4782c858..89ca5b0b 100644 --- a/src/aerospike_vector_search/admin.py +++ b/src/aerospike_vector_search/admin.py @@ -93,7 +93,7 @@ def index_create( index_params: Optional[types.HnswParams] = None, index_labels: Optional[dict[str, str]] = None, index_storage: Optional[types.IndexStorage] = None, - timeout: Optional[int] = None, + timeout: Optional[int] = 100_000, ) -> None: """ Create an index. @@ -172,6 +172,56 @@ def index_create( logger.error("Failed waiting for creation with error: %s", e) raise types.AVSServerError(rpc_error=e) + def index_update( + self, + *, + namespace: str, + name: str, + index_labels: Optional[dict[str, str]] = None, + hnsw_update_params: Optional[types.HnswIndexUpdate] = None, + timeout: Optional[int] = 100_000, + ) -> None: + """ + Update an existing index. + + :param namespace: The namespace for the index. + :type namespace: str + + :param name: The name of the index. + :type name: str + + :param index_labels: Optional labels associated with the index. Defaults to None. + :type index_labels: Optional[dict[str, str]] + + :param hnsw_update_params: Parameters for updating HNSW index settings. + :type hnsw_update_params: Optional[types.HnswIndexUpdate] + + :param timeout: Time in seconds (default 100_000) this operation will wait before raising an error. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update the index. + """ + (index_stub, index_update_request, kwargs) = self._prepare_index_update( + namespace = namespace, + name = name, + index_labels = index_labels, + hnsw_update_params = hnsw_update_params, + timeout = timeout, + logger = logger, + ) + + try: + index_stub.Update( + index_update_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to update index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + def index_drop( self, *, namespace: str, name: str, timeout: Optional[int] = None ) -> None: diff --git a/src/aerospike_vector_search/aio/admin.py b/src/aerospike_vector_search/aio/admin.py index d8ae628e..e4e369cc 100644 --- a/src/aerospike_vector_search/aio/admin.py +++ b/src/aerospike_vector_search/aio/admin.py @@ -95,7 +95,7 @@ async def index_create( index_params: Optional[types.HnswParams] = None, index_labels: Optional[dict[str, str]] = None, index_storage: Optional[types.IndexStorage] = None, - timeout: Optional[int] = None, + timeout: Optional[int] = 100_000, ) -> None: """ Create an index. @@ -140,7 +140,7 @@ async def index_create( Note: This method creates an index with the specified parameters and waits for the index creation to complete. - It waits for up to 100,000 seconds for the index creation to complete. + It waits for up to 100,000 seconds or the specified timeout for the index creation to complete. """ await self._channel_provider._is_ready() @@ -176,6 +176,60 @@ async def index_create( logger.error("Failed waiting for creation with error: %s", e) raise types.AVSServerError(rpc_error=e) + + async def index_update( + self, + *, + namespace: str, + name: str, + index_labels: Optional[dict[str, str]] = None, + hnsw_update_params: Optional[types.HnswIndexUpdate] = None, + timeout: Optional[int] = 100_000, + ) -> None: + """ + Update an existing index. + + :param namespace: The namespace for the index. + :type namespace: str + + :param name: The name of the index. + :type name: str + + :param index_labels: Optional labels associated with the index. Defaults to None. + :type index_labels: Optional[dict[str, str]] + + :param hnsw_update_params: Parameters for updating HNSW index settings. + :type hnsw_update_params: Optional[types.HnswIndexUpdate] + + :param timeout: Timeout in seconds for internal index update tasks. Defaults to 100_000. + :type timeout: int + + Raises: + AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to update the index. + """ + + await self._channel_provider._is_ready() + + (index_stub, index_update_request, kwargs) = self._prepare_index_update( + namespace = namespace, + name = name, + index_labels = index_labels, + hnsw_update_params = hnsw_update_params, + logger = logger, + timeout = timeout + ) + + try: + await index_stub.Update( + index_update_request, + credentials=self._channel_provider.get_token(), + **kwargs, + ) + except grpc.RpcError as e: + logger.error("Failed to update index with error: %s", e) + raise types.AVSServerError(rpc_error=e) + + async def index_drop( self, *, namespace: str, name: str, timeout: Optional[int] = None ) -> None: diff --git a/src/aerospike_vector_search/shared/admin_helpers.py b/src/aerospike_vector_search/shared/admin_helpers.py index 4ac71ac9..ca49c750 100644 --- a/src/aerospike_vector_search/shared/admin_helpers.py +++ b/src/aerospike_vector_search/shared/admin_helpers.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Any, Optional, Union +from typing import Any, Optional, Tuple, Dict import time import google.protobuf.empty_pb2 @@ -83,6 +83,50 @@ def _prepare_index_create( index_create_request = index_pb2.IndexCreateRequest(definition=index_definition) return (index_stub, index_create_request, kwargs) + def _prepare_index_update( + self, + namespace: str, + name: str, + index_labels: Optional[Dict[str, str]], + hnsw_update_params: Optional[types.HnswIndexUpdate], + timeout: Optional[int], + logger: logging.Logger + ) -> tuple[index_pb2_grpc.IndexServiceStub, index_pb2.IndexUpdateRequest, dict[str, Any]]: + """ + Prepares the index update request for updating an existing index. + """ + + logger.debug( + "Updating index: namespace=%s, name=%s, labels=%s, hnsw_update_params=%s, timeout=%s", + namespace, + name, + index_labels, + hnsw_update_params, + timeout, + ) + + kwargs = {} + if timeout is not None: + kwargs["timeout"] = timeout + + index_stub = self._get_index_stub() + index_id = self._get_index_id(namespace, name) + + # Prepare HNSW update parameters if provided + hnsw_update = None + if hnsw_update_params is not None: + hnsw_update = hnsw_update_params._to_pb2() + + # Create the IndexUpdateRequest with optional fields + index_update_request = index_pb2.IndexUpdateRequest( + indexId=index_id, + labels=index_labels, + hnswIndexUpdate=hnsw_update, + ) + + return (index_stub, index_update_request, kwargs) + + def _prepare_index_drop(self, namespace, name, timeout, logger) -> None: logger.debug( diff --git a/src/aerospike_vector_search/types.py b/src/aerospike_vector_search/types.py index 67510250..386ec529 100644 --- a/src/aerospike_vector_search/types.py +++ b/src/aerospike_vector_search/types.py @@ -666,6 +666,126 @@ def _to_pb2(self): return params +class HnswIndexUpdate: + """ + Represents parameters for updating HNSW index settings. + + :param batching_params: Configures batching behavior for batch-based index update. + :type batching_params: Optional[HnswBatchingParams] + + :param max_mem_queue_size: Maximum size of in-memory queue for inserted/updated vector records. + :type max_mem_queue_size: Optional[int] + + :param index_caching_params: Configures caching for HNSW index. + :type index_caching_params: Optional[HnswCachingParams] + + :param healer_params: Configures index healer parameters. + :type healer_params: Optional[HnswHealerParams] + + :param merge_params: Configures merging of batch indices to the main index. + :type merge_params: Optional[HnswIndexMergeParams] + + :param enable_vector_integrity_check: Verifies if the underlying vector has changed before returning the kANN result. + :type enable_vector_integrity_check: Optional[bool] + + :param record_caching_params: Configures caching for vector records. + :type record_caching_params: Optional[HnswCachingParams] + """ + + def __init__( + self, + *, + batching_params: Optional[HnswBatchingParams] = None, + max_mem_queue_size: Optional[int] = None, + index_caching_params: Optional[HnswCachingParams] = None, + healer_params: Optional[HnswHealerParams] = None, + merge_params: Optional[HnswIndexMergeParams] = None, + enable_vector_integrity_check: Optional[bool] = True, + record_caching_params: Optional[HnswCachingParams] = None, + ) -> None: + self.batching_params = batching_params + self.max_mem_queue_size = max_mem_queue_size + self.index_caching_params = index_caching_params + self.healer_params = healer_params + self.merge_params = merge_params + self.enable_vector_integrity_check = enable_vector_integrity_check + self.record_caching_params = record_caching_params + + def _to_pb2(self) -> types_pb2.HnswIndexUpdate: + """ + Converts the HnswIndexUpdate instance to its protobuf representation. + """ + params: types_pb2.HnswIndexUpdate = types_pb2.HnswIndexUpdate() + + if self.batching_params: + params.batchingParams.CopyFrom(self.batching_params._to_pb2()) + + if self.max_mem_queue_size is not None: + params.maxMemQueueSize = self.max_mem_queue_size + + if self.index_caching_params: + params.indexCachingParams.CopyFrom(self.index_caching_params._to_pb2()) + + if self.healer_params: + params.healerParams.CopyFrom(self.healer_params._to_pb2()) + + if self.merge_params: + params.mergeParams.CopyFrom(self.merge_params._to_pb2()) + + if self.enable_vector_integrity_check is not None: + params.enableVectorIntegrityCheck = self.enable_vector_integrity_check + + if self.record_caching_params: + params.recordCachingParams.CopyFrom(self.record_caching_params._to_pb2()) + + return params + + def __repr__(self) -> str: + return ( + f"HnswIndexUpdate(batching_params={self.batching_params}, " + f"max_mem_queue_size={self.max_mem_queue_size}, " + f"index_caching_params={self.index_caching_params}, " + f"healer_params={self.healer_params}, " + f"merge_params={self.merge_params}, " + f"enable_vector_integrity_check={self.enable_vector_integrity_check}, " + f"record_caching_params={self.record_caching_params})" + ) + + def __str__(self) -> str: + return ( + f"HnswIndexUpdate {{\n" + f" batching_params: {self.batching_params},\n" + f" max_mem_queue_size: {self.max_mem_queue_size},\n" + f" index_caching_params: {self.index_caching_params},\n" + f" healer_params: {self.healer_params},\n" + f" merge_params: {self.merge_params},\n" + f" enable_vector_integrity_check: {self.enable_vector_integrity_check},\n" + f" record_caching_params: {self.record_caching_params}\n" + f"}}" + ) + + def __eq__(self, other) -> bool: + if not isinstance(other, HnswIndexUpdate): + return NotImplemented + return ( + self.batching_params == other.batching_params + and self.max_mem_queue_size == other.max_mem_queue_size + and self.index_caching_params == other.index_caching_params + and self.healer_params == other.healer_params + and self.merge_params == other.merge_params + and self.enable_vector_integrity_check == other.enable_vector_integrity_check + and self.record_caching_params == other.record_caching_params + ) + + def __getitem__(self, key): + if not hasattr(self, key): + raise AttributeError(f"'HnswIndexUpdate' object has no attribute '{key}'") + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + class IndexStorage(object): """ Helper class primarily used to specify which namespace and set to build the index on. diff --git a/tests/standard/aio/test_admin_client_index_update.py b/tests/standard/aio/test_admin_client_index_update.py new file mode 100644 index 00000000..9f1bdc39 --- /dev/null +++ b/tests/standard/aio/test_admin_client_index_update.py @@ -0,0 +1,133 @@ +import time +import pytest +from aerospike_vector_search import types, AVSServerError +import grpc + +from .aio_utils import drop_specified_index + +server_defaults = { + "m": 16, + "ef_construction": 100, + "ef": 100, + "batching_params": { + "max_index_records": 10000, + "index_interval": 10000, + } +} + + +class index_update_test_case: + def __init__( + self, + *, + namespace, + vector_field, + dimensions, + initial_labels, + update_labels, + hnsw_index_update, + timeout + ): + self.namespace = namespace + self.vector_field = vector_field + self.dimensions = dimensions + self.initial_labels = initial_labels + self.update_labels = update_labels + self.hnsw_index_update = hnsw_index_update + self.timeout = timeout + + +@pytest.mark.parametrize( + "test_case", + [ + index_update_test_case( + namespace="test", + vector_field="update_2", + dimensions=256, + initial_labels={"status": "active"}, + update_labels={"status": "inactive", "region": "us-west"}, + hnsw_index_update=types.HnswIndexUpdate( + batching_params=types.HnswBatchingParams( + max_index_records=2000, + index_interval=20000, + max_reindex_records=1500, + reindex_interval=70000 + ), + max_mem_queue_size=1000030, + index_caching_params=types.HnswCachingParams(max_entries=10, expiry=3000), + merge_params=types.HnswIndexMergeParams(index_parallelism=10, reindex_parallelism=3), + healer_params=types.HnswHealerParams(max_scan_rate_per_node=80), + ), + timeout=None, + ), + ], +) +async def test_index_update_async(session_admin_client, test_case): + # Create a unique index name for each test run + trimmed_random = "aBEd-1" + + # Drop any pre-existing index with the same name + try: + session_admin_client.index_drop(namespace="test", name=trimmed_random) + except AVSServerError as se: + if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: + pass + + # Create the index + await session_admin_client.index_create( + namespace=test_case.namespace, + name=trimmed_random, + vector_field=test_case.vector_field, + dimensions=test_case.dimensions, + index_labels=test_case.initial_labels, + timeout=test_case.timeout, + ) + + # Update the index with new labels and parameters + await session_admin_client.index_update( + namespace=test_case.namespace, + name=trimmed_random, + index_labels=test_case.update_labels, + hnsw_update_params=test_case.hnsw_index_update + ) + + # Allow time for update to be applied + time.sleep(10) + + # Verify the update + result = await session_admin_client.index_get(namespace=test_case.namespace, name=trimmed_random, apply_defaults=True) + assert result, "Expected result to be non-empty but got an empty dictionary." + + assert result["id"]["namespace"] == test_case.namespace + + # Assertions based on provided parameters + if test_case.hnsw_index_update.batching_params: + assert result["hnsw_params"]["batching_params"][ + "max_index_records"] == test_case.hnsw_index_update.batching_params.max_index_records + assert result["hnsw_params"]["batching_params"][ + "index_interval"] == test_case.hnsw_index_update.batching_params.index_interval + assert result["hnsw_params"]["batching_params"][ + "max_reindex_records"] == test_case.hnsw_index_update.batching_params.max_reindex_records + assert result["hnsw_params"]["batching_params"][ + "reindex_interval"] == test_case.hnsw_index_update.batching_params.reindex_interval + + assert result["hnsw_params"]["max_mem_queue_size"] == test_case.hnsw_index_update.max_mem_queue_size + + if test_case.hnsw_index_update.index_caching_params: + assert result["hnsw_params"]["index_caching_params"][ + "max_entries"] == test_case.hnsw_index_update.index_caching_params.max_entries + assert result["hnsw_params"]["index_caching_params"][ + "expiry"] == test_case.hnsw_index_update.index_caching_params.expiry + + if test_case.hnsw_index_update.merge_params: + assert result["hnsw_params"]["merge_params"][ + "index_parallelism"] == test_case.hnsw_index_update.merge_params.index_parallelism + assert result["hnsw_params"]["merge_params"][ + "reindex_parallelism"] == test_case.hnsw_index_update.merge_params.reindex_parallelism + + if test_case.hnsw_index_update.healer_params: + assert result["hnsw_params"]["healer_params"][ + "max_scan_rate_per_node"] == test_case.hnsw_index_update.healer_params.max_scan_rate_per_node + + # Clean up by dropping the index after the test + await drop_specified_index(session_admin_client, test_case.namespace, trimmed_random) diff --git a/tests/standard/sync/test_admin_client_index_update.py b/tests/standard/sync/test_admin_client_index_update.py new file mode 100644 index 00000000..a5b5722e --- /dev/null +++ b/tests/standard/sync/test_admin_client_index_update.py @@ -0,0 +1,106 @@ +import time +import pytest +from aerospike_vector_search import types, AVSServerError +import grpc + +from .sync_utils import drop_specified_index + + +class index_update_test_case: + def __init__( + self, + *, + namespace, + vector_field, + dimensions, + initial_labels, + update_labels, + hnsw_index_update, + timeout + ): + self.namespace = namespace + self.vector_field = vector_field + self.dimensions = dimensions + self.initial_labels = initial_labels + self.update_labels = update_labels + self.hnsw_index_update = hnsw_index_update + self.timeout = timeout + + +@pytest.mark.parametrize( + "test_case", + [ + index_update_test_case( + namespace="test", + vector_field="update_2", + dimensions=256, + initial_labels={"status": "active"}, + update_labels={"status": "inactive", "region": "us-west"}, + hnsw_index_update=types.HnswIndexUpdate( + batching_params=types.HnswBatchingParams(max_index_records=2000, index_interval=20000, max_reindex_records=1500, reindex_interval=70000), + max_mem_queue_size=1000030, + index_caching_params=types.HnswCachingParams(max_entries=10, expiry=3000), + merge_params=types.HnswIndexMergeParams(index_parallelism=10,reindex_parallelism=3), + healer_params=types.HnswHealerParams(max_scan_rate_per_node=80), + ), + timeout=None, + ), + ], +) +def test_index_update(session_admin_client, test_case): + trimmed_random = "saUEN1-" + + # Drop any pre-existing index with the same name + try: + session_admin_client.index_drop(namespace="test", name=trimmed_random) + except AVSServerError as se: + if se.rpc_error.code() != grpc.StatusCode.NOT_FOUND: + pass + + # Create the index + session_admin_client.index_create( + namespace=test_case.namespace, + name=trimmed_random, + vector_field=test_case.vector_field, + dimensions=test_case.dimensions, + index_labels=test_case.initial_labels, + timeout=test_case.timeout, + ) + + # Update the index with parameters based on the test case + session_admin_client.index_update( + namespace=test_case.namespace, + name=trimmed_random, + index_labels=test_case.update_labels, + hnsw_update_params=test_case.hnsw_index_update, + timeout=100_000, + ) + #wait for index to get updated + time.sleep(10) + + # Verify the update + result = session_admin_client.index_get(namespace=test_case.namespace, name=trimmed_random, apply_defaults=True) + assert result, "Expected result to be non-empty but got an empty dictionary." + + # Assertions + if test_case.hnsw_index_update.batching_params: + assert result["hnsw_params"]["batching_params"]["max_index_records"] == test_case.hnsw_index_update.batching_params.max_index_records + assert result["hnsw_params"]["batching_params"]["index_interval"] == test_case.hnsw_index_update.batching_params.index_interval + assert result["hnsw_params"]["batching_params"]["max_reindex_records"] == test_case.hnsw_index_update.batching_params.max_reindex_records + assert result["hnsw_params"]["batching_params"]["reindex_interval"] == test_case.hnsw_index_update.batching_params.reindex_interval + + assert result["hnsw_params"]["max_mem_queue_size"] == test_case.hnsw_index_update.max_mem_queue_size + + if test_case.hnsw_index_update.index_caching_params: + assert result["hnsw_params"]["index_caching_params"]["max_entries"] == test_case.hnsw_index_update.index_caching_params.max_entries + assert result["hnsw_params"]["index_caching_params"]["expiry"] == test_case.hnsw_index_update.index_caching_params.expiry + + if test_case.hnsw_index_update.merge_params: + assert result["hnsw_params"]["merge_params"]["index_parallelism"] == test_case.hnsw_index_update.merge_params.index_parallelism + assert result["hnsw_params"]["merge_params"]["reindex_parallelism"] == test_case.hnsw_index_update.merge_params.reindex_parallelism + + if test_case.hnsw_index_update.healer_params: + assert result["hnsw_params"]["healer_params"]["max_scan_rate_per_node"] == test_case.hnsw_index_update.healer_params.max_scan_rate_per_node + + # Clean up by dropping the index after the test + drop_specified_index(session_admin_client, test_case.namespace, trimmed_random)