From ef8e772542093609a343c460c2f0db1ef2d3ae49 Mon Sep 17 00:00:00 2001 From: Olli-Pekka Heinisuo Date: Tue, 16 Apr 2024 16:42:21 +0300 Subject: [PATCH] Feat: bearer token authentication support (#591) * bearer token authentication provider support * add tests and checks, move auth file to separate dir * fix error message * remove locks * rename var * refactoring: refactor exceptions, fix mypy * fix: regen async * tests: extend token tests to check token updates * new: add warning when auth token provider is used with an insecure connection * fix: propagate auth token to rest client even with prefer_grpc set --------- Co-authored-by: George Panchuk --- qdrant_client/async_qdrant_client.py | 19 ++++- qdrant_client/async_qdrant_remote.py | 15 +++- qdrant_client/auth/__init__.py | 1 + qdrant_client/auth/bearer_auth.py | 42 ++++++++++ qdrant_client/connection.py | 120 +++++++++++++-------------- qdrant_client/qdrant_client.py | 24 +++++- qdrant_client/qdrant_remote.py | 39 +++++++-- tests/test_async_qdrant_client.py | 69 +++++++++++++++ tests/test_qdrant_client.py | 120 ++++++++++++++++++++++++--- 9 files changed, 363 insertions(+), 86 deletions(-) create mode 100644 qdrant_client/auth/__init__.py create mode 100644 qdrant_client/auth/bearer_auth.py diff --git a/qdrant_client/async_qdrant_client.py b/qdrant_client/async_qdrant_client.py index a54c5ad4..4ac2b1b8 100644 --- a/qdrant_client/async_qdrant_client.py +++ b/qdrant_client/async_qdrant_client.py @@ -10,7 +10,19 @@ # ****** WARNING: THIS FILE IS AUTOGENERATED ****** import warnings -from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) from qdrant_client import grpc as grpc from qdrant_client.async_client_base import AsyncQdrantBase @@ -68,6 +80,7 @@ class AsyncQdrantClient(AsyncQdrantFastembedMixin): force_disable_check_same_thread: For QdrantLocal, force disable check_same_thread. Default: `False` Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient. + auth_token_provider: Callback function to get Bearer access token. If given, the function will be called before each request to get the token. **kwargs: Additional arguments passed directly into REST client initialization """ @@ -87,6 +100,9 @@ def __init__( path: Optional[str] = None, force_disable_check_same_thread: bool = False, grpc_options: Optional[Dict[str, Any]] = None, + auth_token_provider: Optional[ + Union[Callable[[], str], Callable[[], Awaitable[str]]] + ] = None, **kwargs: Any, ): super().__init__(**kwargs) @@ -117,6 +133,7 @@ def __init__( timeout=timeout, host=host, grpc_options=grpc_options, + auth_token_provider=auth_token_provider, **kwargs, ) diff --git a/qdrant_client/async_qdrant_remote.py b/qdrant_client/async_qdrant_remote.py index 81b378d1..d0133301 100644 --- a/qdrant_client/async_qdrant_remote.py +++ b/qdrant_client/async_qdrant_remote.py @@ -15,6 +15,8 @@ from multiprocessing import get_all_start_methods from typing import ( Any, + Awaitable, + Callable, Dict, Iterable, List, @@ -35,6 +37,7 @@ from qdrant_client import grpc as grpc from qdrant_client._pydantic_compat import construct from qdrant_client.async_client_base import AsyncQdrantBase +from qdrant_client.auth import BearerAuth from qdrant_client.connection import get_async_channel as get_channel from qdrant_client.conversions import common_types as types from qdrant_client.conversions.common_types import get_args_subscribed @@ -63,6 +66,9 @@ def __init__( timeout: Optional[int] = None, host: Optional[str] = None, grpc_options: Optional[Dict[str, Any]] = None, + auth_token_provider: Optional[ + Union[Callable[[], str], Callable[[], Awaitable[str]]] + ] = None, **kwargs: Any, ): super().__init__(**kwargs) @@ -100,6 +106,7 @@ def __init__( self._port = port self._timeout = math.ceil(timeout) if timeout is not None else None self._api_key = api_key + self._auth_token_provider = auth_token_provider limits = kwargs.pop("limits", None) if limits is None: if self._host in ["localhost", "127.0.0.1"]: @@ -109,7 +116,7 @@ def __init__( self._rest_headers = kwargs.pop("metadata", {}) if api_key is not None: if self._scheme == "http": - warnings.warn("Api key is used with unsecure connection.") + warnings.warn("Api key is used with an insecure connection.") self._rest_headers["api-key"] = api_key self._grpc_headers.append(("api-key", api_key)) grpc_compression: Optional[Compression] = kwargs.pop("grpc_compression", None) @@ -129,6 +136,11 @@ def __init__( self._rest_args["limits"] = limits if self._timeout is not None: self._rest_args["timeout"] = self._timeout + if self._auth_token_provider is not None: + if self._scheme == "http": + warnings.warn("Auth token provider is used with an insecure connection.") + bearer_auth = BearerAuth(self._auth_token_provider) + self._rest_args["auth"] = bearer_auth self.openapi_client: AsyncApis[AsyncApiClient] = AsyncApis( host=self.rest_uri, **self._rest_args ) @@ -182,6 +194,7 @@ def _init_grpc_channel(self) -> None: metadata=self._grpc_headers, options=self._grpc_options, compression=self._grpc_compression, + auth_token_provider=self._auth_token_provider, ) def _init_grpc_points_client(self) -> None: diff --git a/qdrant_client/auth/__init__.py b/qdrant_client/auth/__init__.py new file mode 100644 index 00000000..f5e1c98d --- /dev/null +++ b/qdrant_client/auth/__init__.py @@ -0,0 +1 @@ +from qdrant_client.auth.bearer_auth import BearerAuth diff --git a/qdrant_client/auth/bearer_auth.py b/qdrant_client/auth/bearer_auth.py new file mode 100644 index 00000000..effaaccc --- /dev/null +++ b/qdrant_client/auth/bearer_auth.py @@ -0,0 +1,42 @@ +import asyncio +from typing import Awaitable, Callable, Optional, Union + +import httpx + + +class BearerAuth(httpx.Auth): + def __init__( + self, + auth_token_provider: Union[Callable[[], str], Callable[[], Awaitable[str]]], + ): + self.async_token: Optional[Callable[[], Awaitable[str]]] = None + self.sync_token: Optional[Callable[[], str]] = None + + if asyncio.iscoroutinefunction(auth_token_provider): + self.async_token = auth_token_provider + else: + if callable(auth_token_provider): + self.sync_token = auth_token_provider # type: ignore + else: + raise ValueError("auth_token_provider must be a callable or awaitable") + + def _sync_get_token(self) -> str: + if self.sync_token is None: + raise ValueError("Synchronous token provider is not set.") + return self.sync_token() + + def sync_auth_flow(self, request: httpx.Request) -> httpx.Request: + token = self._sync_get_token() + request.headers["Authorization"] = f"Bearer {token}" + yield request + + async def _async_get_token(self) -> str: + if self.async_token is not None: + return await self.async_token() # type: ignore + # Fallback to synchronous token if asynchronous token is not available + return self._sync_get_token() + + async def async_auth_flow(self, request: httpx.Request) -> httpx.Request: + token = await self._async_get_token() + request.headers["Authorization"] = f"Bearer {token}" + yield request diff --git a/qdrant_client/connection.py b/qdrant_client/connection.py index 4ea9afe3..1091bbbe 100644 --- a/qdrant_client/connection.py +++ b/qdrant_client/connection.py @@ -1,5 +1,6 @@ +import asyncio import collections -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union import grpc @@ -64,7 +65,7 @@ def __init__(self, interceptor_function: Callable): async def intercept_unary_unary( self, continuation: Any, client_call_details: Any, request: Any ) -> Any: - new_details, new_request_iterator, postprocess = self._fn( + new_details, new_request_iterator, postprocess = await self._fn( client_call_details, iter((request,)), False, False ) next_request = next(new_request_iterator) @@ -74,7 +75,7 @@ async def intercept_unary_unary( async def intercept_unary_stream( self, continuation: Any, client_call_details: Any, request: Any ) -> Any: - new_details, new_request_iterator, postprocess = self._fn( + new_details, new_request_iterator, postprocess = await self._fn( client_call_details, iter((request,)), False, True ) response_it = await continuation(new_details, next(new_request_iterator)) @@ -83,7 +84,7 @@ async def intercept_unary_stream( async def intercept_stream_unary( self, continuation: Any, client_call_details: Any, request_iterator: Any ) -> Any: - new_details, new_request_iterator, postprocess = self._fn( + new_details, new_request_iterator, postprocess = await self._fn( client_call_details, request_iterator, True, False ) response = await continuation(new_details, new_request_iterator) @@ -92,7 +93,7 @@ async def intercept_stream_unary( async def intercept_stream_stream( self, continuation: Any, client_call_details: Any, request_iterator: Any ) -> Any: - new_details, new_request_iterator, postprocess = self._fn( + new_details, new_request_iterator, postprocess = await self._fn( client_call_details, request_iterator, True, True ) response_it = await continuation(new_details, new_request_iterator) @@ -125,7 +126,10 @@ class _ClientAsyncCallDetails( pass -def header_adder_interceptor(new_metadata: List[Tuple[str, str]]) -> _GenericClientInterceptor: +def header_adder_interceptor( + new_metadata: List[Tuple[str, str]], + auth_token_provider: Optional[Callable[[], str]] = None, +) -> _GenericClientInterceptor: def intercept_call( client_call_details: _ClientCallDetails, request_iterator: Any, @@ -133,6 +137,7 @@ def intercept_call( _response_streaming: Any, ) -> Tuple[_ClientCallDetails, Any, Any]: metadata = [] + if client_call_details.metadata is not None: metadata = list(client_call_details.metadata) for header, value in new_metadata: @@ -142,6 +147,13 @@ def intercept_call( value, ) ) + + if auth_token_provider: + if not asyncio.iscoroutinefunction(auth_token_provider): + metadata.append(("authorization", f"Bearer {auth_token_provider()}")) + else: + raise ValueError("Synchronous channel requires synchronous auth token provider.") + client_call_details = _ClientCallDetails( client_call_details.method, client_call_details.timeout, @@ -154,9 +166,10 @@ def intercept_call( def header_adder_async_interceptor( - new_metadata: List[Tuple[str, str]] + new_metadata: List[Tuple[str, str]], + auth_token_provider: Optional[Union[Callable[[], str], Callable[[], Awaitable[str]]]] = None, ) -> _GenericAsyncClientInterceptor: - def intercept_call( + async def intercept_call( client_call_details: grpc.aio.ClientCallDetails, request_iterator: Any, _request_streaming: Any, @@ -172,6 +185,14 @@ def intercept_call( value, ) ) + + if auth_token_provider: + if asyncio.iscoroutinefunction(auth_token_provider): + token = await auth_token_provider() + else: + token = auth_token_provider() + metadata.append(("authorization", f"Bearer {token}")) + client_call_details = client_call_details._replace(metadata=metadata) return client_call_details, request_iterator, None @@ -200,38 +221,21 @@ def get_channel( metadata: Optional[List[Tuple[str, str]]] = None, options: Optional[Dict[str, Any]] = None, compression: Optional[grpc.Compression] = None, + auth_token_provider: Optional[Callable[[], str]] = None, ) -> grpc.Channel: - # gRPC client options + # Parse gRPC client options _options = parse_channel_options(options) + metadata_interceptor = header_adder_interceptor( + new_metadata=metadata or [], auth_token_provider=auth_token_provider + ) if ssl: - if metadata: - - def metadata_callback(context: Any, callback: Any) -> None: - # for more info see grpc docs - callback(metadata, None) - - # build ssl credentials using the cert the same as before - cert_creds = grpc.ssl_channel_credentials() - - # now build meta data credentials - auth_creds = grpc.metadata_call_credentials(metadata_callback) - - # combine the cert credentials and the macaroon auth credentials - # such that every call is properly encrypted and authenticated - creds = grpc.composite_channel_credentials(cert_creds, auth_creds) - else: - creds = grpc.ssl_channel_credentials() - - # finally pass in the combined credentials when creating a channel - return grpc.secure_channel(f"{host}:{port}", creds, _options, compression) + ssl_creds = grpc.ssl_channel_credentials() + channel = grpc.secure_channel(f"{host}:{port}", ssl_creds, _options, compression) + return grpc.intercept_channel(channel, metadata_interceptor) else: - if metadata: - metadata_interceptor = header_adder_interceptor(metadata) - channel = grpc.insecure_channel(f"{host}:{port}", _options, compression) - return grpc.intercept_channel(channel, metadata_interceptor) - else: - return grpc.insecure_channel(f"{host}:{port}", _options, compression) + channel = grpc.insecure_channel(f"{host}:{port}", _options, compression) + return grpc.intercept_channel(channel, metadata_interceptor) def get_async_channel( @@ -241,36 +245,26 @@ def get_async_channel( metadata: Optional[List[Tuple[str, str]]] = None, options: Optional[Dict[str, Any]] = None, compression: Optional[grpc.Compression] = None, + auth_token_provider: Optional[Union[Callable[[], str], Callable[[], Awaitable[str]]]] = None, ) -> grpc.aio.Channel: - # gRPC client options + # Parse gRPC client options _options = parse_channel_options(options) - if ssl: - if metadata: - - def metadata_callback(context: Any, callback: Any) -> None: - # for more info see grpc docs - callback(metadata, None) - - # build ssl credentials using the cert the same as before - cert_creds = grpc.ssl_channel_credentials() + # Create metadata interceptor + metadata_interceptor = header_adder_async_interceptor( + new_metadata=metadata or [], auth_token_provider=auth_token_provider + ) - # now build meta data credentials - auth_creds = grpc.metadata_call_credentials(metadata_callback) - - # combine the cert credentials and the macaroon auth credentials - # such that every call is properly encrypted and authenticated - creds = grpc.composite_channel_credentials(cert_creds, auth_creds) - else: - creds = grpc.ssl_channel_credentials() - - # finally pass in the combined credentials when creating a channel - return grpc.aio.secure_channel(f"{host}:{port}", creds, _options, compression) + if ssl: + ssl_creds = grpc.ssl_channel_credentials() + return grpc.aio.secure_channel( + f"{host}:{port}", + ssl_creds, + _options, + compression, + interceptors=[metadata_interceptor], + ) else: - if metadata: - metadata_interceptor = header_adder_async_interceptor(metadata) - return grpc.aio.insecure_channel( - f"{host}:{port}", _options, compression, interceptors=[metadata_interceptor] - ) - else: - return grpc.aio.insecure_channel(f"{host}:{port}", _options, compression) + return grpc.aio.insecure_channel( + f"{host}:{port}", _options, compression, interceptors=[metadata_interceptor] + ) diff --git a/qdrant_client/qdrant_client.py b/qdrant_client/qdrant_client.py index 07836bc7..e2a3d2c3 100644 --- a/qdrant_client/qdrant_client.py +++ b/qdrant_client/qdrant_client.py @@ -1,5 +1,17 @@ import warnings -from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) from qdrant_client import grpc as grpc from qdrant_client.client_base import QdrantBase @@ -58,6 +70,7 @@ class QdrantClient(QdrantFastembedMixin): force_disable_check_same_thread: For QdrantLocal, force disable check_same_thread. Default: `False` Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient. + auth_token_provider: Callback function to get Bearer access token. If given, the function will be called before each request to get the token. **kwargs: Additional arguments passed directly into REST client initialization """ @@ -77,6 +90,9 @@ def __init__( path: Optional[str] = None, force_disable_check_same_thread: bool = False, grpc_options: Optional[Dict[str, Any]] = None, + auth_token_provider: Optional[ + Union[Callable[[], str], Callable[[], Awaitable[str]]] + ] = None, **kwargs: Any, ): super().__init__( @@ -116,6 +132,7 @@ def __init__( timeout=timeout, host=host, grpc_options=grpc_options, + auth_token_provider=auth_token_provider, **kwargs, ) @@ -2102,7 +2119,10 @@ def delete_snapshot( assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}" return self._client.delete_snapshot( - collection_name=collection_name, snapshot_name=snapshot_name, wait=wait, **kwargs + collection_name=collection_name, + snapshot_name=snapshot_name, + wait=wait, + **kwargs, ) def list_full_snapshots(self, **kwargs: Any) -> List[types.SnapshotDescription]: diff --git a/qdrant_client/qdrant_remote.py b/qdrant_client/qdrant_remote.py index b8d184b4..44596e68 100644 --- a/qdrant_client/qdrant_remote.py +++ b/qdrant_client/qdrant_remote.py @@ -5,6 +5,8 @@ from multiprocessing import get_all_start_methods from typing import ( Any, + Awaitable, + Callable, Dict, Iterable, List, @@ -24,6 +26,7 @@ from qdrant_client import grpc as grpc from qdrant_client._pydantic_compat import construct +from qdrant_client.auth import BearerAuth from qdrant_client.client_base import QdrantBase from qdrant_client.connection import get_async_channel, get_channel from qdrant_client.conversions import common_types as types @@ -53,6 +56,9 @@ def __init__( timeout: Optional[int] = None, host: Optional[str] = None, grpc_options: Optional[Dict[str, Any]] = None, + auth_token_provider: Optional[ + Union[Callable[[], str], Callable[[], Awaitable[str]]] + ] = None, **kwargs: Any, ): super().__init__(**kwargs) @@ -107,6 +113,7 @@ def __init__( ) # it has been changed from float to int. # convert it to the closest greater or equal int value (e.g. 0.5 -> 1) self._api_key = api_key + self._auth_token_provider = auth_token_provider limits = kwargs.pop("limits", None) if limits is None: @@ -120,7 +127,7 @@ def __init__( self._rest_headers = kwargs.pop("metadata", {}) if api_key is not None: if self._scheme == "http": - warnings.warn("Api key is used with unsecure connection.") + warnings.warn("Api key is used with an insecure connection.") # http2 = True @@ -151,7 +158,17 @@ def __init__( if self._timeout is not None: self._rest_args["timeout"] = self._timeout - self.openapi_client: SyncApis[ApiClient] = SyncApis(host=self.rest_uri, **self._rest_args) + if self._auth_token_provider is not None: + if self._scheme == "http": + warnings.warn("Auth token provider is used with an insecure connection.") + + bearer_auth = BearerAuth(self._auth_token_provider) + self._rest_args["auth"] = bearer_auth + + self.openapi_client: SyncApis[ApiClient] = SyncApis( + host=self.rest_uri, + **self._rest_args, + ) self._grpc_channel = None self._grpc_points_client: Optional[grpc.PointsStub] = None @@ -221,6 +238,9 @@ def _init_grpc_channel(self) -> None: metadata=self._grpc_headers, options=self._grpc_options, compression=self._grpc_compression, + # sync get_channel does not accept coroutine functions, + # but we can't check type here, since it'll get into async client as well + auth_token_provider=self._auth_token_provider, # type: ignore ) def _init_async_grpc_channel(self) -> None: @@ -235,6 +255,7 @@ def _init_async_grpc_channel(self) -> None: metadata=self._grpc_headers, options=self._grpc_options, compression=self._grpc_compression, + auth_token_provider=self._auth_token_provider, ) def _init_grpc_points_client(self) -> None: @@ -381,7 +402,7 @@ def search_batch( ] else: requests = [ - GrpcToRest.convert_search_points(r) if isinstance(r, grpc.SearchPoints) else r + (GrpcToRest.convert_search_points(r) if isinstance(r, grpc.SearchPoints) else r) for r in requests ] http_res: List[List[models.ScoredPoint]] = self.http.points_api.search_batch_points( @@ -1146,7 +1167,11 @@ def discover_batch( ] else: requests = [ - GrpcToRest.convert_discover_points(r) if isinstance(r, grpc.DiscoverPoints) else r + ( + GrpcToRest.convert_discover_points(r) + if isinstance(r, grpc.DiscoverPoints) + else r + ) for r in requests ] http_res: List[List[models.ScoredPoint]] = self.http.points_api.discover_batch_points( @@ -1507,7 +1532,7 @@ def retrieve( with_payload = GrpcToRest.convert_with_payload_selector(with_payload) ids = [ - GrpcToRest.convert_point_id(idx) if isinstance(idx, grpc.PointId) else idx + (GrpcToRest.convert_point_id(idx) if isinstance(idx, grpc.PointId) else idx) for idx in ids ] @@ -1566,7 +1591,7 @@ def _try_argument_to_rest_selector( ) -> models.PointsSelector: if isinstance(points, list): _points = [ - GrpcToRest.convert_point_id(idx) if isinstance(idx, grpc.PointId) else idx + (GrpcToRest.convert_point_id(idx) if isinstance(idx, grpc.PointId) else idx) for idx in points ] points_selector = construct( @@ -1613,7 +1638,7 @@ def _try_argument_to_rest_points_and_filter( _filter = None if isinstance(points, list): _points = [ - GrpcToRest.convert_point_id(idx) if isinstance(idx, grpc.PointId) else idx + (GrpcToRest.convert_point_id(idx) if isinstance(idx, grpc.PointId) else idx) for idx in points ] elif isinstance(points, grpc.PointsSelector): diff --git a/tests/test_async_qdrant_client.py b/tests/test_async_qdrant_client.py index 4e7befac..e45028e7 100644 --- a/tests/test_async_qdrant_client.py +++ b/tests/test_async_qdrant_client.py @@ -518,3 +518,72 @@ async def test_async_qdrant_client_local(): assert all(collection.name != COLLECTION_NAME for collection in collections.collections) await client.close() # endregion + + +@pytest.mark.asyncio +async def test_async_auth(): + """Test that the auth token provider is called and the token in all modes.""" + token = "" + call_num = 0 + + async def async_auth_token_provider(): + nonlocal token + nonlocal call_num + await asyncio.sleep(0.1) + token = f"token_{call_num}" + call_num += 1 + return token + + client = AsyncQdrantClient(timeout=3, auth_token_provider=async_auth_token_provider) + await client.get_collections() + assert token == "token_0" + + await client.get_collections() + assert token == "token_1" + + token = "" + call_num = 0 + + client = AsyncQdrantClient( + prefer_grpc=True, timeout=3, auth_token_provider=async_auth_token_provider + ) + await client.get_collections() + assert token == "token_0" + + await client.get_collections() + assert token == "token_1" + + await client.unlock_storage() + assert token == "token_2" + + sync_token = "" + call_num = 0 + + def auth_token_provider(): + nonlocal sync_token + nonlocal call_num + sync_token = f"token_{call_num}" + call_num += 1 + return sync_token + + client = AsyncQdrantClient(timeout=3, auth_token_provider=auth_token_provider) + await client.get_collections() + assert sync_token == "token_0" + + await client.get_collections() + assert sync_token == "token_1" + + sync_token = "" + call_num = 0 + + client = AsyncQdrantClient( + prefer_grpc=True, timeout=3, auth_token_provider=auth_token_provider + ) + await client.get_collections() + assert sync_token == "token_0" + + await client.get_collections() + assert sync_token == "token_1" + + await client.unlock_storage() + assert sync_token == "token_2" diff --git a/tests/test_qdrant_client.py b/tests/test_qdrant_client.py index 123cb835..9b3d5324 100644 --- a/tests/test_qdrant_client.py +++ b/tests/test_qdrant_client.py @@ -1,3 +1,4 @@ +import asyncio import os import uuid from pprint import pprint @@ -181,7 +182,11 @@ def test_records_upload(prefer_grpc, parallel): warnings.simplefilter("ignore", category=DeprecationWarning) records = ( - Record(id=idx, vector=np.random.rand(DIM).tolist(), payload=one_random_payload_please(idx)) + Record( + id=idx, + vector=np.random.rand(DIM).tolist(), + payload=one_random_payload_please(idx), + ) for idx in range(NUM_VECTORS) ) @@ -241,7 +246,9 @@ def test_records_upload(prefer_grpc, parallel): def test_point_upload(prefer_grpc, parallel): points = ( PointStruct( - id=idx, vector=np.random.rand(DIM).tolist(), payload=one_random_payload_please(idx) + id=idx, + vector=np.random.rand(DIM).tolist(), + payload=one_random_payload_please(idx), ) for idx in range(NUM_VECTORS) ) @@ -626,7 +633,10 @@ def test_qdrant_client_integration(prefer_grpc, numpy_upload, local_mode): print(hit) got_points = client.retrieve( - collection_name=COLLECTION_NAME, ids=[1, 2, 3], with_payload=True, with_vectors=True + collection_name=COLLECTION_NAME, + ids=[1, 2, 3], + with_payload=True, + with_vectors=True, ) # ------------------ Test for full-text filtering ------------------ @@ -760,11 +770,16 @@ def test_qdrant_client_integration(prefer_grpc, numpy_upload, local_mode): assert len(got_points) == 3 client.delete( - collection_name=COLLECTION_NAME, wait=True, points_selector=PointIdsList(points=[2, 3]) + collection_name=COLLECTION_NAME, + wait=True, + points_selector=PointIdsList(points=[2, 3]), ) got_points = client.retrieve( - collection_name=COLLECTION_NAME, ids=[1, 2, 3], with_payload=True, with_vectors=True + collection_name=COLLECTION_NAME, + ids=[1, 2, 3], + with_payload=True, + with_vectors=True, ) assert len(got_points) == 1 @@ -776,17 +791,26 @@ def test_qdrant_client_integration(prefer_grpc, numpy_upload, local_mode): ) got_points = client.retrieve( - collection_name=COLLECTION_NAME, ids=[1, 2, 3], with_payload=True, with_vectors=True + collection_name=COLLECTION_NAME, + ids=[1, 2, 3], + with_payload=True, + with_vectors=True, ) assert len(got_points) == 2 client.set_payload( - collection_name=COLLECTION_NAME, payload={"new_key": 123}, points=[1, 2], wait=True + collection_name=COLLECTION_NAME, + payload={"new_key": 123}, + points=[1, 2], + wait=True, ) got_points = client.retrieve( - collection_name=COLLECTION_NAME, ids=[1, 2], with_payload=True, with_vectors=True + collection_name=COLLECTION_NAME, + ids=[1, 2], + with_payload=True, + with_vectors=True, ) for point in got_points: @@ -811,7 +835,10 @@ def test_qdrant_client_integration(prefer_grpc, numpy_upload, local_mode): ) got_points = client.retrieve( - collection_name=COLLECTION_NAME, ids=[1, 2], with_payload=True, with_vectors=True + collection_name=COLLECTION_NAME, + ids=[1, 2], + with_payload=True, + with_vectors=True, ) for point in got_points: @@ -1039,7 +1066,9 @@ def test_points_crud(prefer_grpc): # Update a single point client.set_payload( - collection_name=COLLECTION_NAME, payload={"test2": ["value2", "value3"]}, points=[123] + collection_name=COLLECTION_NAME, + payload={"test2": ["value2", "value3"]}, + points=[123], ) # Delete a single point @@ -1204,7 +1233,9 @@ def init_collection(): ] client.upload_points( - collection_name=COLLECTION_NAME, points=cat_points, shard_key_selector=cats_shard_key + collection_name=COLLECTION_NAME, + points=cat_points, + shard_key_selector=cats_shard_key, ) res = client.search( @@ -1588,7 +1619,11 @@ def test_locks(): client.upsert( collection_name=COLLECTION_NAME, points=[ - PointStruct(id=123, payload={"test": "value"}, vector=np.random.rand(DIM).tolist()) + PointStruct( + id=123, + payload={"test": "value"}, + vector=np.random.rand(DIM).tolist(), + ) ], wait=True, ) @@ -1823,6 +1858,67 @@ def test_grpc_compression(): QdrantClient(prefer_grpc=True, grpc_compression="gzip") +def test_auth_token_provider(): + """Check that the token provided is called for both http and grpc clients.""" + token = "" + call_num = 0 + + def auth_token_provider(): + nonlocal token + nonlocal call_num + + token = f"token_{call_num}" + call_num += 1 + return token + + client = QdrantClient(auth_token_provider=auth_token_provider) + client.get_collections() + assert token == "token_0" + client.get_collections() + assert token == "token_1" + + token = "" + call_num = 0 + + client = QdrantClient(prefer_grpc=True, auth_token_provider=auth_token_provider) + client.get_collections() + assert token == "token_0" + client.get_collections() + assert token == "token_1" + + client.unlock_storage() + assert token == "token_2" + + +def test_async_auth_token_provider(): + """Check that initialization fails if async auth_token_provider is provided to sync client.""" + token = "" + + async def auth_token_provider(): + nonlocal token + await asyncio.sleep(0.1) + token = "test_token" + return token + + client = QdrantClient(auth_token_provider=auth_token_provider) + + with pytest.raises( + qdrant_client.http.exceptions.ResponseHandlingException, + match="Synchronous token provider is not set.", + ): + client.get_collections() + + assert token == "" + + client = QdrantClient(auth_token_provider=auth_token_provider, prefer_grpc=True) + with pytest.raises( + ValueError, match="Synchronous channel requires synchronous auth token provider." + ): + client.get_collections() + + assert token == "" + + @pytest.mark.parametrize("prefer_grpc", [True, False]) def test_read_consistency(prefer_grpc): fixture_points = generate_fixtures(vectors_sizes=DIM, num=NUM_VECTORS)