Skip to content

Commit

Permalink
fix: add authorization_interceptor and db_interceptor to async ch…
Browse files Browse the repository at this point in the history
…annel (milvus-io#2467)

Signed-off-by: Ruichen Bao <[email protected]>
  • Loading branch information
brcarry authored Dec 18, 2024
1 parent cd4aab7 commit 52c366c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 20 deletions.
1 change: 1 addition & 0 deletions examples/simple_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
index_params.add_index(field_name = "embeddings", index_type = "HNSW", metric_type="L2", nlist=128)
index_params.add_index(field_name = "embeddings2",index_type = "HNSW", metric_type="L2", nlist=128)

# Always use `await` when you want to guarantee the execution order of tasks.
async def recreate_collection():
print(fmt.format("Start dropping collection"))
await async_milvus_client.drop_collection(collection_name)
Expand Down
46 changes: 27 additions & 19 deletions pymilvus/client/async_grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pymilvus.grpc_gen import milvus_pb2_grpc
from pymilvus.settings import Config

from . import entity_helper, interceptor, ts_utils, utils
from . import entity_helper, ts_utils, utils
from .abstract import AnnSearchRequest, BaseRanker, CollectionSchema, MutationResult, SearchResult
from .async_interceptor import async_header_adder_interceptor
from .check import (
Expand Down Expand Up @@ -62,8 +62,8 @@ def __init__(
self._request_id = None
self._user = kwargs.get("user")
self._set_authorization(**kwargs)
self._setup_db_interceptor(kwargs.get("db_name"))
self._setup_grpc_channel() # init channel and stub
self._setup_db_name(kwargs.get("db_name"))
self._setup_grpc_channel(**kwargs)
self.callbacks = []

def register_state_change_callback(self, callback: Callable):
Expand Down Expand Up @@ -96,12 +96,7 @@ def _set_authorization(self, **kwargs):
self._server_pem_path = kwargs.get("server_pem_path", "")
self._server_name = kwargs.get("server_name", "")

self._authorization_interceptor = None
self._setup_authorization_interceptor(
kwargs.get("user"),
kwargs.get("password"),
kwargs.get("token"),
)
self._async_authorization_interceptor = None

def __enter__(self):
return self
Expand Down Expand Up @@ -132,7 +127,7 @@ def close(self):
self._async_channel.close()

def reset_db_name(self, db_name: str):
self._setup_db_interceptor(db_name)
self._setup_db_name(db_name)
self._setup_grpc_channel()
self._setup_identifier_interceptor(self._user)

Expand All @@ -148,16 +143,19 @@ def _setup_authorization_interceptor(self, user: str, password: str, token: str)
keys.append("authorization")
values.append(authorization)
if len(keys) > 0 and len(values) > 0:
self._authorization_interceptor = interceptor.header_adder_interceptor(keys, values)
self._async_authorization_interceptor = async_header_adder_interceptor(keys, values)
self._final_channel._unary_unary_interceptors.append(
self._async_authorization_interceptor
)

def _setup_db_interceptor(self, db_name: str):
def _setup_db_name(self, db_name: str):
if db_name is None:
self._db_interceptor = None
self._db_name = None
else:
check_pass_param(db_name=db_name)
self._db_interceptor = interceptor.header_adder_interceptor(["dbname"], [db_name])
self._db_name = db_name

def _setup_grpc_channel(self):
def _setup_grpc_channel(self, **kwargs):
if self._async_channel is None:
opts = [
(cygrpc.ChannelArgKey.max_send_message_length, -1),
Expand Down Expand Up @@ -203,21 +201,31 @@ def _setup_grpc_channel(self):

# avoid to add duplicate headers.
self._final_channel = self._async_channel
if self._log_level:

if self._async_authorization_interceptor:
self._final_channel._unary_unary_interceptors.append(
self._async_authorization_interceptor
)
else:
self._setup_authorization_interceptor(
kwargs.get("user"),
kwargs.get("password"),
kwargs.get("token"),
)
if self._db_name:
async_db_interceptor = async_header_adder_interceptor(["dbname"], [self._db_name])
self._final_channel._unary_unary_interceptors.append(async_db_interceptor)
if self._log_level:
async_log_level_interceptor = async_header_adder_interceptor(
["log_level"], [self._log_level]
)
self._final_channel._unary_unary_interceptors.append(async_log_level_interceptor)

self._log_level = None
if self._request_id:

async_request_id_interceptor = async_header_adder_interceptor(
["client_request_id"], [self._request_id]
)
self._final_channel._unary_unary_interceptors.append(async_request_id_interceptor)

self._request_id = None
self._async_stub = milvus_pb2_grpc.MilvusServiceStub(self._final_channel)

Expand Down
2 changes: 1 addition & 1 deletion pymilvus/client/async_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def intercept_stream_stream(
return await continuation(new_details, new_request_iterator)


def async_header_adder_interceptor(headers: List[str], values: List[str]):
def async_header_adder_interceptor(headers: List[str], values: Union[List[str], List[bytes]]):
def intercept_call(client_call_details: ClientCallDetails, request: Any):
metadata = []
if client_call_details.metadata:
Expand Down

0 comments on commit 52c366c

Please sign in to comment.