Skip to content

Commit

Permalink
Feat: bearer token authentication support (#591)
Browse files Browse the repository at this point in the history
* bearer token authentication provider support

* add tests and checks, move auth file to separate dir

* fix error message

* remove locks

* rename var

* refactoring: refactor exceptions, fix mypy

* fix: regen async

* tests: extend token tests to check token updates

* new: add warning when auth token provider is used with an insecure connection

* fix: propagate auth token to rest client even with prefer_grpc set

---------

Co-authored-by: George Panchuk <[email protected]>
  • Loading branch information
skvark and joein committed Apr 16, 2024
1 parent efb0309 commit ef8e772
Show file tree
Hide file tree
Showing 9 changed files with 363 additions and 86 deletions.
19 changes: 18 additions & 1 deletion qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,19 @@
# ****** WARNING: THIS FILE IS AUTOGENERATED ******

import warnings
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)

from qdrant_client import grpc as grpc
from qdrant_client.async_client_base import AsyncQdrantBase
Expand Down Expand Up @@ -68,6 +80,7 @@ class AsyncQdrantClient(AsyncQdrantFastembedMixin):
force_disable_check_same_thread:
For QdrantLocal, force disable check_same_thread. Default: `False`
Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient.
auth_token_provider: Callback function to get Bearer access token. If given, the function will be called before each request to get the token.
**kwargs: Additional arguments passed directly into REST client initialization
"""
Expand All @@ -87,6 +100,9 @@ def __init__(
path: Optional[str] = None,
force_disable_check_same_thread: bool = False,
grpc_options: Optional[Dict[str, Any]] = None,
auth_token_provider: Optional[
Union[Callable[[], str], Callable[[], Awaitable[str]]]
] = None,
**kwargs: Any,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -117,6 +133,7 @@ def __init__(
timeout=timeout,
host=host,
grpc_options=grpc_options,
auth_token_provider=auth_token_provider,
**kwargs,
)

Expand Down
15 changes: 14 additions & 1 deletion qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from multiprocessing import get_all_start_methods
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Expand All @@ -35,6 +37,7 @@
from qdrant_client import grpc as grpc
from qdrant_client._pydantic_compat import construct
from qdrant_client.async_client_base import AsyncQdrantBase
from qdrant_client.auth import BearerAuth
from qdrant_client.connection import get_async_channel as get_channel
from qdrant_client.conversions import common_types as types
from qdrant_client.conversions.common_types import get_args_subscribed
Expand Down Expand Up @@ -63,6 +66,9 @@ def __init__(
timeout: Optional[int] = None,
host: Optional[str] = None,
grpc_options: Optional[Dict[str, Any]] = None,
auth_token_provider: Optional[
Union[Callable[[], str], Callable[[], Awaitable[str]]]
] = None,
**kwargs: Any,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -100,6 +106,7 @@ def __init__(
self._port = port
self._timeout = math.ceil(timeout) if timeout is not None else None
self._api_key = api_key
self._auth_token_provider = auth_token_provider
limits = kwargs.pop("limits", None)
if limits is None:
if self._host in ["localhost", "127.0.0.1"]:
Expand All @@ -109,7 +116,7 @@ def __init__(
self._rest_headers = kwargs.pop("metadata", {})
if api_key is not None:
if self._scheme == "http":
warnings.warn("Api key is used with unsecure connection.")
warnings.warn("Api key is used with an insecure connection.")
self._rest_headers["api-key"] = api_key
self._grpc_headers.append(("api-key", api_key))
grpc_compression: Optional[Compression] = kwargs.pop("grpc_compression", None)
Expand All @@ -129,6 +136,11 @@ def __init__(
self._rest_args["limits"] = limits
if self._timeout is not None:
self._rest_args["timeout"] = self._timeout
if self._auth_token_provider is not None:
if self._scheme == "http":
warnings.warn("Auth token provider is used with an insecure connection.")
bearer_auth = BearerAuth(self._auth_token_provider)
self._rest_args["auth"] = bearer_auth
self.openapi_client: AsyncApis[AsyncApiClient] = AsyncApis(
host=self.rest_uri, **self._rest_args
)
Expand Down Expand Up @@ -182,6 +194,7 @@ def _init_grpc_channel(self) -> None:
metadata=self._grpc_headers,
options=self._grpc_options,
compression=self._grpc_compression,
auth_token_provider=self._auth_token_provider,
)

def _init_grpc_points_client(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions qdrant_client/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from qdrant_client.auth.bearer_auth import BearerAuth
42 changes: 42 additions & 0 deletions qdrant_client/auth/bearer_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import asyncio
from typing import Awaitable, Callable, Optional, Union

import httpx


class BearerAuth(httpx.Auth):
def __init__(
self,
auth_token_provider: Union[Callable[[], str], Callable[[], Awaitable[str]]],
):
self.async_token: Optional[Callable[[], Awaitable[str]]] = None
self.sync_token: Optional[Callable[[], str]] = None

if asyncio.iscoroutinefunction(auth_token_provider):
self.async_token = auth_token_provider
else:
if callable(auth_token_provider):
self.sync_token = auth_token_provider # type: ignore
else:
raise ValueError("auth_token_provider must be a callable or awaitable")

def _sync_get_token(self) -> str:
if self.sync_token is None:
raise ValueError("Synchronous token provider is not set.")
return self.sync_token()

def sync_auth_flow(self, request: httpx.Request) -> httpx.Request:
token = self._sync_get_token()
request.headers["Authorization"] = f"Bearer {token}"
yield request

async def _async_get_token(self) -> str:
if self.async_token is not None:
return await self.async_token() # type: ignore
# Fallback to synchronous token if asynchronous token is not available
return self._sync_get_token()

async def async_auth_flow(self, request: httpx.Request) -> httpx.Request:
token = await self._async_get_token()
request.headers["Authorization"] = f"Bearer {token}"
yield request
120 changes: 57 additions & 63 deletions qdrant_client/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import collections
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union

import grpc

Expand Down Expand Up @@ -64,7 +65,7 @@ def __init__(self, interceptor_function: Callable):
async def intercept_unary_unary(
self, continuation: Any, client_call_details: Any, request: Any
) -> Any:
new_details, new_request_iterator, postprocess = self._fn(
new_details, new_request_iterator, postprocess = await self._fn(
client_call_details, iter((request,)), False, False
)
next_request = next(new_request_iterator)
Expand All @@ -74,7 +75,7 @@ async def intercept_unary_unary(
async def intercept_unary_stream(
self, continuation: Any, client_call_details: Any, request: Any
) -> Any:
new_details, new_request_iterator, postprocess = self._fn(
new_details, new_request_iterator, postprocess = await self._fn(
client_call_details, iter((request,)), False, True
)
response_it = await continuation(new_details, next(new_request_iterator))
Expand All @@ -83,7 +84,7 @@ async def intercept_unary_stream(
async def intercept_stream_unary(
self, continuation: Any, client_call_details: Any, request_iterator: Any
) -> Any:
new_details, new_request_iterator, postprocess = self._fn(
new_details, new_request_iterator, postprocess = await self._fn(
client_call_details, request_iterator, True, False
)
response = await continuation(new_details, new_request_iterator)
Expand All @@ -92,7 +93,7 @@ async def intercept_stream_unary(
async def intercept_stream_stream(
self, continuation: Any, client_call_details: Any, request_iterator: Any
) -> Any:
new_details, new_request_iterator, postprocess = self._fn(
new_details, new_request_iterator, postprocess = await self._fn(
client_call_details, request_iterator, True, True
)
response_it = await continuation(new_details, new_request_iterator)
Expand Down Expand Up @@ -125,14 +126,18 @@ class _ClientAsyncCallDetails(
pass


def header_adder_interceptor(new_metadata: List[Tuple[str, str]]) -> _GenericClientInterceptor:
def header_adder_interceptor(
new_metadata: List[Tuple[str, str]],
auth_token_provider: Optional[Callable[[], str]] = None,
) -> _GenericClientInterceptor:
def intercept_call(
client_call_details: _ClientCallDetails,
request_iterator: Any,
_request_streaming: Any,
_response_streaming: Any,
) -> Tuple[_ClientCallDetails, Any, Any]:
metadata = []

if client_call_details.metadata is not None:
metadata = list(client_call_details.metadata)
for header, value in new_metadata:
Expand All @@ -142,6 +147,13 @@ def intercept_call(
value,
)
)

if auth_token_provider:
if not asyncio.iscoroutinefunction(auth_token_provider):
metadata.append(("authorization", f"Bearer {auth_token_provider()}"))
else:
raise ValueError("Synchronous channel requires synchronous auth token provider.")

client_call_details = _ClientCallDetails(
client_call_details.method,
client_call_details.timeout,
Expand All @@ -154,9 +166,10 @@ def intercept_call(


def header_adder_async_interceptor(
new_metadata: List[Tuple[str, str]]
new_metadata: List[Tuple[str, str]],
auth_token_provider: Optional[Union[Callable[[], str], Callable[[], Awaitable[str]]]] = None,
) -> _GenericAsyncClientInterceptor:
def intercept_call(
async def intercept_call(
client_call_details: grpc.aio.ClientCallDetails,
request_iterator: Any,
_request_streaming: Any,
Expand All @@ -172,6 +185,14 @@ def intercept_call(
value,
)
)

if auth_token_provider:
if asyncio.iscoroutinefunction(auth_token_provider):
token = await auth_token_provider()
else:
token = auth_token_provider()
metadata.append(("authorization", f"Bearer {token}"))

client_call_details = client_call_details._replace(metadata=metadata)
return client_call_details, request_iterator, None

Expand Down Expand Up @@ -200,38 +221,21 @@ def get_channel(
metadata: Optional[List[Tuple[str, str]]] = None,
options: Optional[Dict[str, Any]] = None,
compression: Optional[grpc.Compression] = None,
auth_token_provider: Optional[Callable[[], str]] = None,
) -> grpc.Channel:
# gRPC client options
# Parse gRPC client options
_options = parse_channel_options(options)
metadata_interceptor = header_adder_interceptor(
new_metadata=metadata or [], auth_token_provider=auth_token_provider
)

if ssl:
if metadata:

def metadata_callback(context: Any, callback: Any) -> None:
# for more info see grpc docs
callback(metadata, None)

# build ssl credentials using the cert the same as before
cert_creds = grpc.ssl_channel_credentials()

# now build meta data credentials
auth_creds = grpc.metadata_call_credentials(metadata_callback)

# combine the cert credentials and the macaroon auth credentials
# such that every call is properly encrypted and authenticated
creds = grpc.composite_channel_credentials(cert_creds, auth_creds)
else:
creds = grpc.ssl_channel_credentials()

# finally pass in the combined credentials when creating a channel
return grpc.secure_channel(f"{host}:{port}", creds, _options, compression)
ssl_creds = grpc.ssl_channel_credentials()
channel = grpc.secure_channel(f"{host}:{port}", ssl_creds, _options, compression)
return grpc.intercept_channel(channel, metadata_interceptor)
else:
if metadata:
metadata_interceptor = header_adder_interceptor(metadata)
channel = grpc.insecure_channel(f"{host}:{port}", _options, compression)
return grpc.intercept_channel(channel, metadata_interceptor)
else:
return grpc.insecure_channel(f"{host}:{port}", _options, compression)
channel = grpc.insecure_channel(f"{host}:{port}", _options, compression)
return grpc.intercept_channel(channel, metadata_interceptor)


def get_async_channel(
Expand All @@ -241,36 +245,26 @@ def get_async_channel(
metadata: Optional[List[Tuple[str, str]]] = None,
options: Optional[Dict[str, Any]] = None,
compression: Optional[grpc.Compression] = None,
auth_token_provider: Optional[Union[Callable[[], str], Callable[[], Awaitable[str]]]] = None,
) -> grpc.aio.Channel:
# gRPC client options
# Parse gRPC client options
_options = parse_channel_options(options)

if ssl:
if metadata:

def metadata_callback(context: Any, callback: Any) -> None:
# for more info see grpc docs
callback(metadata, None)

# build ssl credentials using the cert the same as before
cert_creds = grpc.ssl_channel_credentials()
# Create metadata interceptor
metadata_interceptor = header_adder_async_interceptor(
new_metadata=metadata or [], auth_token_provider=auth_token_provider
)

# now build meta data credentials
auth_creds = grpc.metadata_call_credentials(metadata_callback)

# combine the cert credentials and the macaroon auth credentials
# such that every call is properly encrypted and authenticated
creds = grpc.composite_channel_credentials(cert_creds, auth_creds)
else:
creds = grpc.ssl_channel_credentials()

# finally pass in the combined credentials when creating a channel
return grpc.aio.secure_channel(f"{host}:{port}", creds, _options, compression)
if ssl:
ssl_creds = grpc.ssl_channel_credentials()
return grpc.aio.secure_channel(
f"{host}:{port}",
ssl_creds,
_options,
compression,
interceptors=[metadata_interceptor],
)
else:
if metadata:
metadata_interceptor = header_adder_async_interceptor(metadata)
return grpc.aio.insecure_channel(
f"{host}:{port}", _options, compression, interceptors=[metadata_interceptor]
)
else:
return grpc.aio.insecure_channel(f"{host}:{port}", _options, compression)
return grpc.aio.insecure_channel(
f"{host}:{port}", _options, compression, interceptors=[metadata_interceptor]
)
Loading

0 comments on commit ef8e772

Please sign in to comment.