-
Notifications
You must be signed in to change notification settings - Fork 339
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reuse GRPC channel for same connections
Signed-off-by: Filip Haltmayer <[email protected]>
- Loading branch information
Filip Haltmayer
committed
May 24, 2023
1 parent
256a523
commit 18fd399
Showing
3 changed files
with
212 additions
and
155 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,54 +78,19 @@ def __init__(self): | |
""" | ||
self._alias = {} | ||
self._connected_alias = {} | ||
self._env_uri = None | ||
self._connection_references = {} | ||
self._con_lock = threading.RLock() | ||
|
||
if Config.MILVUS_URI != "": | ||
address, parsed_uri = self.__parse_address_from_uri(Config.MILVUS_URI) | ||
self._env_uri = (address, parsed_uri) | ||
address, user, _, db_name = self.__parse_info(Config.MILVUS_URI) | ||
|
||
default_conn_config = { | ||
"user": parsed_uri.username if parsed_uri.username is not None else "", | ||
"address": address, | ||
} | ||
else: | ||
default_conn_config = { | ||
"user": "", | ||
"address": f"{Config.DEFAULT_HOST}:{Config.DEFAULT_PORT}", | ||
} | ||
default_conn_config = { | ||
"user": user, | ||
"address": address, | ||
"db_name": db_name, | ||
} | ||
|
||
self.add_connection(**{Config.MILVUS_CONN_ALIAS: default_conn_config}) | ||
|
||
def __verify_host_port(self, host, port): | ||
if not is_legal_host(host): | ||
raise ConnectionConfigException(message=ExceptionsMessage.HostType) | ||
if not is_legal_port(port): | ||
raise ConnectionConfigException(message=ExceptionsMessage.PortType) | ||
if not 0 <= int(port) < 65535: | ||
raise ConnectionConfigException(message=f"port number {port} out of range, valid range [0, 65535)") | ||
|
||
def __parse_address_from_uri(self, uri: str) -> (str, parse.ParseResult): | ||
illegal_uri_msg = "Illegal uri: [{}], expected form 'https://user:[email protected]:12345'" | ||
try: | ||
parsed_uri = parse.urlparse(uri) | ||
except (Exception) as e: | ||
raise ConnectionConfigException( | ||
message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>") from None | ||
|
||
if len(parsed_uri.netloc) == 0: | ||
raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}") from None | ||
|
||
host = parsed_uri.hostname if parsed_uri.hostname is not None else Config.DEFAULT_HOST | ||
port = parsed_uri.port if parsed_uri.port is not None else Config.DEFAULT_PORT | ||
addr = f"{host}:{port}" | ||
|
||
self.__verify_host_port(host, port) | ||
|
||
if not is_legal_address(addr): | ||
raise ConnectionConfigException(message=illegal_uri_msg.format(uri)) | ||
|
||
return addr, parsed_uri | ||
|
||
def add_connection(self, **kwargs): | ||
""" Configures a milvus connection. | ||
|
@@ -157,41 +122,24 @@ def add_connection(self, **kwargs): | |
) | ||
""" | ||
for alias, config in kwargs.items(): | ||
addr, _ = self.__get_full_address( | ||
config.get("address", ""), | ||
config.get("uri", ""), | ||
config.get("host", ""), | ||
config.get("port", "")) | ||
address, user, _, db_name = self.__parse_info(**config) | ||
|
||
if alias in self._connected_alias: | ||
if self._alias[alias].get("address") != addr: | ||
if ( | ||
self._alias[alias].get("address") != address | ||
or self._alias[alias].get("user") != user | ||
or self._alias[alias].get("db_name") != db_name | ||
): | ||
raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias) | ||
|
||
alias_config = { | ||
"address": addr, | ||
"user": config.get("user", ""), | ||
"address": address, | ||
"user": user, | ||
"db_name": db_name, | ||
} | ||
|
||
self._alias[alias] = alias_config | ||
|
||
def __get_full_address(self, address: str = "", uri: str = "", host: str = "", port: str = "") -> ( | ||
str, parse.ParseResult): | ||
if address != "": | ||
if not is_legal_address(address): | ||
raise ConnectionConfigException( | ||
message=f"Illegal address: {address}, should be in form 'localhost:19530'") | ||
return address, None | ||
|
||
if uri != "": | ||
address, parsed = self.__parse_address_from_uri(uri) | ||
return address, parsed | ||
|
||
host = host if host != "" else Config.DEFAULT_HOST | ||
port = port if port != "" else Config.DEFAULT_PORT | ||
self.__verify_host_port(host, port) | ||
|
||
return f"{host}:{port}", None | ||
|
||
def disconnect(self, alias: str): | ||
""" Disconnects connection from the registry. | ||
|
@@ -201,8 +149,13 @@ def disconnect(self, alias: str): | |
if not isinstance(alias, str): | ||
raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) | ||
|
||
if alias in self._connected_alias: | ||
self._connected_alias.pop(alias).close() | ||
with self._con_lock: | ||
if alias in self._connected_alias: | ||
gh = self._connected_alias.pop(alias) | ||
self._connection_references[id(gh)] -= 1 | ||
if self._connection_references[id(gh)] <= 0: | ||
gh.close() | ||
del self._connection_references[id(gh)] | ||
|
||
def remove_connection(self, alias: str): | ||
""" Removes connection from the registry. | ||
|
@@ -266,96 +219,74 @@ def connect(self, alias=Config.MILVUS_CONN_ALIAS, user="", password="", db_name= | |
>>> from pymilvus import connections | ||
>>> connections.connect("test", host="localhost", port="19530") | ||
""" | ||
# pylint: disable=too-many-statements | ||
|
||
def connect_milvus(**kwargs): | ||
gh = GrpcHandler(**kwargs) | ||
with self._con_lock: | ||
gh = None | ||
for key, connection_details in self._alias.items(): | ||
|
||
if ( | ||
key in self._connected_alias | ||
and connection_details["address"] == kwargs["address"] | ||
and connection_details["user"] == kwargs["user"] | ||
and connection_details["db_name"] == kwargs["db_name"] | ||
): | ||
gh = self._connected_alias[key] | ||
break | ||
|
||
t = kwargs.get("timeout") | ||
timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT | ||
if gh is None: | ||
gh = GrpcHandler(**kwargs) | ||
t = kwargs.get("timeout") | ||
timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT | ||
gh._wait_for_channel_ready(timeout=timeout) | ||
|
||
gh._wait_for_channel_ready(timeout=timeout) | ||
kwargs.pop('password') | ||
kwargs.pop('db_name', None) | ||
kwargs.pop('secure', None) | ||
kwargs.pop("db_name", "") | ||
kwargs.pop('password', None) | ||
kwargs.pop('secure', None) | ||
|
||
self._connected_alias[alias] = gh | ||
self._alias[alias] = copy.deepcopy(kwargs) | ||
self._connected_alias[alias] = gh | ||
|
||
def with_config(config: Tuple) -> bool: | ||
for c in config: | ||
if c != "": | ||
return True | ||
self._alias[alias] = copy.deepcopy(kwargs) | ||
|
||
return False | ||
if id(gh) not in self._connection_references: | ||
self._connection_references[id(gh)] = 1 | ||
else: | ||
self._connection_references[id(gh)] += 1 | ||
|
||
if not isinstance(alias, str): | ||
raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) | ||
|
||
config = ( | ||
kwargs.pop("address", ""), | ||
kwargs.pop("uri", ""), | ||
kwargs.pop("host", ""), | ||
kwargs.pop("port", "") | ||
) | ||
|
||
# Make sure passed in None doesnt break | ||
user = user or "" | ||
password = password or "" | ||
# Make sure passed in are Strings | ||
user = str(user) | ||
password = str(password) | ||
|
||
# 1st Priority: connection from params | ||
if with_config(config): | ||
in_addr, parsed_uri = self.__get_full_address(*config) | ||
kwargs["address"] = in_addr | ||
|
||
if self.has_connection(alias): | ||
if self._alias[alias].get("address") != in_addr: | ||
raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias) | ||
|
||
# uri might take extra info | ||
if parsed_uri is not None: | ||
user = parsed_uri.username if parsed_uri.username is not None else user | ||
password = parsed_uri.password if parsed_uri.password is not None else password | ||
|
||
group = parsed_uri.path.split("/") | ||
db_name = "default" | ||
if len(group) > 1: | ||
db_name = group[1] | ||
|
||
# Set secure=True if username and password are provided | ||
if len(user) > 0 and len(password) > 0: | ||
kwargs["secure"] = True | ||
|
||
address = kwargs.pop("address", "") | ||
uri = kwargs.pop("uri", "") | ||
host = kwargs.pop("host", "") | ||
port = kwargs.pop("port", "") | ||
user = '' if user is None else str(user) | ||
password = '' if password is None else str(password) | ||
db_name = '' if db_name is None else str(db_name) | ||
|
||
if set([address, uri, host, port]) != {''}: | ||
address, user, password, db_name = self.__parse_info(address, uri, host, port, db_name, user, password) | ||
kwargs["address"] = address | ||
|
||
elif alias in self._alias: | ||
kwargs = dict(self._alias[alias].items()) | ||
# If user is passed in, use it, if not, use previous connections user. | ||
prev_user = kwargs.pop("user") | ||
user = user if user != "" else prev_user | ||
# If db_name is passed in, use it, if not, use previous db_name. | ||
prev_db_name = kwargs.pop("db_name") | ||
db_name = db_name if db_name != "" else prev_db_name | ||
|
||
connect_milvus(**kwargs, user=user, password=password, db_name=db_name) | ||
return | ||
|
||
# 2nd Priority, connection configs from env | ||
if self._env_uri is not None: | ||
addr, parsed_uri = self._env_uri | ||
kwargs["address"] = addr | ||
|
||
user = parsed_uri.username if parsed_uri.username is not None else "" | ||
password = parsed_uri.password if parsed_uri.password is not None else "" | ||
# Set secure=True if uri provided user and password | ||
if len(user) > 0 and len(password) > 0: | ||
kwargs["secure"] = True | ||
# No params, env, and cached configs for the alias | ||
else: | ||
raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias) | ||
|
||
connect_milvus(**kwargs, user=user, password=password, db_name=db_name) | ||
return | ||
# Set secure=True if username and password are provided | ||
if len(user) > 0 and len(password) > 0: | ||
kwargs["secure"] = True | ||
|
||
# 3rd Priority, connect to cached configs with provided user and password | ||
if alias in self._alias: | ||
connect_alias = dict(self._alias[alias].items()) | ||
connect_alias["user"] = user | ||
connect_milvus(**connect_alias, password=password, db_name=db_name, **kwargs) | ||
return | ||
connect_milvus(**kwargs, user=user, password=password, db_name=db_name) | ||
|
||
# No params, env, and cached configs for the alias | ||
raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias) | ||
|
||
def list_connections(self) -> list: | ||
""" List names of all connections. | ||
|
@@ -369,7 +300,8 @@ def list_connections(self) -> list: | |
>>> connections.list_connections() | ||
// TODO [('default', None), ('test', <pymilvus.client.grpc_handler.GrpcHandler object at 0x7f05003f3e80>)] | ||
""" | ||
return [(k, self._connected_alias.get(k, None)) for k in self._alias] | ||
with self._con_lock: | ||
return [(k, self._connected_alias.get(k, None)) for k in self._alias] | ||
|
||
def get_connection_addr(self, alias: str): | ||
""" | ||
|
@@ -414,7 +346,99 @@ def has_connection(self, alias: str) -> bool: | |
""" | ||
if not isinstance(alias, str): | ||
raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) | ||
return alias in self._connected_alias | ||
with self._con_lock: | ||
return alias in self._connected_alias | ||
|
||
def __parse_info( | ||
self, | ||
address: str = "", | ||
uri: str = "", | ||
host: str = "", | ||
port: str = "", | ||
db_name: str = "", | ||
user: str = "", | ||
password: str = "", | ||
**kwargs) -> dict: | ||
|
||
passed_in_address = "" | ||
passed_in_user = "" | ||
passed_in_password = "" | ||
passed_in_db_name = "" | ||
|
||
# If uri | ||
if uri != "": | ||
passed_in_address, passed_in_user, passed_in_password, passed_in_db_name = ( | ||
self.__parse_address_from_uri(uri) | ||
) | ||
|
||
elif address != "": | ||
if not is_legal_address(address): | ||
raise ConnectionConfigException( | ||
message=f"Illegal address: {address}, should be in form 'localhost:19530'") | ||
passed_in_address = address | ||
|
||
else: | ||
if host == "": | ||
host = Config.DEFAULT_HOST | ||
if port == "": | ||
port = Config.DEFAULT_PORT | ||
self.__verify_host_port(host, port) | ||
passed_in_address = f"{host}:{port}" | ||
|
||
passed_in_user = user if passed_in_user == "" else str(passed_in_user) | ||
passed_in_user = Config.MILVUS_USER if passed_in_user == "" else str(passed_in_user) | ||
|
||
passed_in_password = password if passed_in_password == "" else str(passed_in_password) | ||
passed_in_password = Config.MILVUS_PASSWORD if passed_in_password == "" else str(passed_in_password) | ||
|
||
passed_in_db_name = db_name if passed_in_db_name == "" else str(passed_in_db_name) | ||
passed_in_db_name = Config.MILVUS_DB_NAME if passed_in_db_name == "" else str(passed_in_db_name) | ||
|
||
return passed_in_address, passed_in_user, passed_in_password, passed_in_db_name | ||
|
||
def __verify_host_port(self, host, port): | ||
if not is_legal_host(host): | ||
raise ConnectionConfigException(message=ExceptionsMessage.HostType) | ||
if not is_legal_port(port): | ||
raise ConnectionConfigException(message=ExceptionsMessage.PortType) | ||
if not 0 <= int(port) < 65535: | ||
raise ConnectionConfigException(message=f"port number {port} out of range, valid range [0, 65535)") | ||
|
||
def __parse_address_from_uri(self, uri: str) -> Tuple[str, str, str, str]: | ||
illegal_uri_msg = "Illegal uri: [{}], expected form 'https://user:[email protected]:12345'" | ||
try: | ||
parsed_uri = parse.urlparse(uri) | ||
except (Exception) as e: | ||
raise ConnectionConfigException( | ||
message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>") from None | ||
|
||
if len(parsed_uri.netloc) == 0: | ||
raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}") from None | ||
|
||
group = parsed_uri.path.split("/") | ||
if len(group) > 1: | ||
db_name = group[1] | ||
else: | ||
db_name = "" | ||
|
||
host = parsed_uri.hostname if parsed_uri.hostname is not None else "" | ||
port = parsed_uri.port if parsed_uri.port is not None else "" | ||
user = parsed_uri.username if parsed_uri.username is not None else "" | ||
password = parsed_uri.password if parsed_uri.password is not None else "" | ||
|
||
if host == "": | ||
raise ConnectionConfigException(message=f"Illegal uri: URI is missing host address: {uri}") | ||
if port == "": | ||
raise ConnectionConfigException(message=f"Illegal uri: URI is missing port: {uri}") | ||
|
||
self.__verify_host_port(host, port) | ||
addr = f"{host}:{port}" | ||
|
||
if not is_legal_address(addr): | ||
raise ConnectionConfigException(message=illegal_uri_msg.format(uri)) | ||
|
||
return addr, user, password, db_name | ||
|
||
|
||
def _fetch_handler(self, alias=Config.MILVUS_CONN_ALIAS) -> GrpcHandler: | ||
""" Retrieves a GrpcHandler by alias. """ | ||
|
Oops, something went wrong.