Skip to content

Commit

Permalink
Reuse GRPC channel for same connections
Browse files Browse the repository at this point in the history
Signed-off-by: Filip Haltmayer <[email protected]>
  • Loading branch information
Filip Haltmayer committed May 24, 2023
1 parent 256a523 commit 18fd399
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 155 deletions.
320 changes: 172 additions & 148 deletions pymilvus/orm/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,54 +78,19 @@ def __init__(self):
"""
self._alias = {}
self._connected_alias = {}
self._env_uri = None
self._connection_references = {}
self._con_lock = threading.RLock()

if Config.MILVUS_URI != "":
address, parsed_uri = self.__parse_address_from_uri(Config.MILVUS_URI)
self._env_uri = (address, parsed_uri)
address, user, _, db_name = self.__parse_info(Config.MILVUS_URI)

default_conn_config = {
"user": parsed_uri.username if parsed_uri.username is not None else "",
"address": address,
}
else:
default_conn_config = {
"user": "",
"address": f"{Config.DEFAULT_HOST}:{Config.DEFAULT_PORT}",
}
default_conn_config = {
"user": user,
"address": address,
"db_name": db_name,
}

self.add_connection(**{Config.MILVUS_CONN_ALIAS: default_conn_config})

def __verify_host_port(self, host, port):
if not is_legal_host(host):
raise ConnectionConfigException(message=ExceptionsMessage.HostType)
if not is_legal_port(port):
raise ConnectionConfigException(message=ExceptionsMessage.PortType)
if not 0 <= int(port) < 65535:
raise ConnectionConfigException(message=f"port number {port} out of range, valid range [0, 65535)")

def __parse_address_from_uri(self, uri: str) -> (str, parse.ParseResult):
illegal_uri_msg = "Illegal uri: [{}], expected form 'https://user:[email protected]:12345'"
try:
parsed_uri = parse.urlparse(uri)
except (Exception) as e:
raise ConnectionConfigException(
message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>") from None

if len(parsed_uri.netloc) == 0:
raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}") from None

host = parsed_uri.hostname if parsed_uri.hostname is not None else Config.DEFAULT_HOST
port = parsed_uri.port if parsed_uri.port is not None else Config.DEFAULT_PORT
addr = f"{host}:{port}"

self.__verify_host_port(host, port)

if not is_legal_address(addr):
raise ConnectionConfigException(message=illegal_uri_msg.format(uri))

return addr, parsed_uri

def add_connection(self, **kwargs):
""" Configures a milvus connection.
Expand Down Expand Up @@ -157,41 +122,24 @@ def add_connection(self, **kwargs):
)
"""
for alias, config in kwargs.items():
addr, _ = self.__get_full_address(
config.get("address", ""),
config.get("uri", ""),
config.get("host", ""),
config.get("port", ""))
address, user, _, db_name = self.__parse_info(**config)

if alias in self._connected_alias:
if self._alias[alias].get("address") != addr:
if (
self._alias[alias].get("address") != address
or self._alias[alias].get("user") != user
or self._alias[alias].get("db_name") != db_name
):
raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias)

alias_config = {
"address": addr,
"user": config.get("user", ""),
"address": address,
"user": user,
"db_name": db_name,
}

self._alias[alias] = alias_config

def __get_full_address(self, address: str = "", uri: str = "", host: str = "", port: str = "") -> (
str, parse.ParseResult):
if address != "":
if not is_legal_address(address):
raise ConnectionConfigException(
message=f"Illegal address: {address}, should be in form 'localhost:19530'")
return address, None

if uri != "":
address, parsed = self.__parse_address_from_uri(uri)
return address, parsed

host = host if host != "" else Config.DEFAULT_HOST
port = port if port != "" else Config.DEFAULT_PORT
self.__verify_host_port(host, port)

return f"{host}:{port}", None

def disconnect(self, alias: str):
""" Disconnects connection from the registry.
Expand All @@ -201,8 +149,13 @@ def disconnect(self, alias: str):
if not isinstance(alias, str):
raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias))

if alias in self._connected_alias:
self._connected_alias.pop(alias).close()
with self._con_lock:
if alias in self._connected_alias:
gh = self._connected_alias.pop(alias)
self._connection_references[id(gh)] -= 1
if self._connection_references[id(gh)] <= 0:
gh.close()
del self._connection_references[id(gh)]

def remove_connection(self, alias: str):
""" Removes connection from the registry.
Expand Down Expand Up @@ -266,96 +219,74 @@ def connect(self, alias=Config.MILVUS_CONN_ALIAS, user="", password="", db_name=
>>> from pymilvus import connections
>>> connections.connect("test", host="localhost", port="19530")
"""
# pylint: disable=too-many-statements

def connect_milvus(**kwargs):
gh = GrpcHandler(**kwargs)
with self._con_lock:
gh = None
for key, connection_details in self._alias.items():

if (
key in self._connected_alias
and connection_details["address"] == kwargs["address"]
and connection_details["user"] == kwargs["user"]
and connection_details["db_name"] == kwargs["db_name"]
):
gh = self._connected_alias[key]
break

t = kwargs.get("timeout")
timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT
if gh is None:
gh = GrpcHandler(**kwargs)
t = kwargs.get("timeout")
timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT
gh._wait_for_channel_ready(timeout=timeout)

gh._wait_for_channel_ready(timeout=timeout)
kwargs.pop('password')
kwargs.pop('db_name', None)
kwargs.pop('secure', None)
kwargs.pop("db_name", "")
kwargs.pop('password', None)
kwargs.pop('secure', None)

self._connected_alias[alias] = gh
self._alias[alias] = copy.deepcopy(kwargs)
self._connected_alias[alias] = gh

def with_config(config: Tuple) -> bool:
for c in config:
if c != "":
return True
self._alias[alias] = copy.deepcopy(kwargs)

return False
if id(gh) not in self._connection_references:
self._connection_references[id(gh)] = 1
else:
self._connection_references[id(gh)] += 1

if not isinstance(alias, str):
raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias))

config = (
kwargs.pop("address", ""),
kwargs.pop("uri", ""),
kwargs.pop("host", ""),
kwargs.pop("port", "")
)

# Make sure passed in None doesnt break
user = user or ""
password = password or ""
# Make sure passed in are Strings
user = str(user)
password = str(password)

# 1st Priority: connection from params
if with_config(config):
in_addr, parsed_uri = self.__get_full_address(*config)
kwargs["address"] = in_addr

if self.has_connection(alias):
if self._alias[alias].get("address") != in_addr:
raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias)

# uri might take extra info
if parsed_uri is not None:
user = parsed_uri.username if parsed_uri.username is not None else user
password = parsed_uri.password if parsed_uri.password is not None else password

group = parsed_uri.path.split("/")
db_name = "default"
if len(group) > 1:
db_name = group[1]

# Set secure=True if username and password are provided
if len(user) > 0 and len(password) > 0:
kwargs["secure"] = True

address = kwargs.pop("address", "")
uri = kwargs.pop("uri", "")
host = kwargs.pop("host", "")
port = kwargs.pop("port", "")
user = '' if user is None else str(user)
password = '' if password is None else str(password)
db_name = '' if db_name is None else str(db_name)

if set([address, uri, host, port]) != {''}:
address, user, password, db_name = self.__parse_info(address, uri, host, port, db_name, user, password)
kwargs["address"] = address

elif alias in self._alias:
kwargs = dict(self._alias[alias].items())
# If user is passed in, use it, if not, use previous connections user.
prev_user = kwargs.pop("user")
user = user if user != "" else prev_user
# If db_name is passed in, use it, if not, use previous db_name.
prev_db_name = kwargs.pop("db_name")
db_name = db_name if db_name != "" else prev_db_name

connect_milvus(**kwargs, user=user, password=password, db_name=db_name)
return

# 2nd Priority, connection configs from env
if self._env_uri is not None:
addr, parsed_uri = self._env_uri
kwargs["address"] = addr

user = parsed_uri.username if parsed_uri.username is not None else ""
password = parsed_uri.password if parsed_uri.password is not None else ""
# Set secure=True if uri provided user and password
if len(user) > 0 and len(password) > 0:
kwargs["secure"] = True
# No params, env, and cached configs for the alias
else:
raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias)

connect_milvus(**kwargs, user=user, password=password, db_name=db_name)
return
# Set secure=True if username and password are provided
if len(user) > 0 and len(password) > 0:
kwargs["secure"] = True

# 3rd Priority, connect to cached configs with provided user and password
if alias in self._alias:
connect_alias = dict(self._alias[alias].items())
connect_alias["user"] = user
connect_milvus(**connect_alias, password=password, db_name=db_name, **kwargs)
return
connect_milvus(**kwargs, user=user, password=password, db_name=db_name)

# No params, env, and cached configs for the alias
raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias)

def list_connections(self) -> list:
""" List names of all connections.
Expand All @@ -369,7 +300,8 @@ def list_connections(self) -> list:
>>> connections.list_connections()
// TODO [('default', None), ('test', <pymilvus.client.grpc_handler.GrpcHandler object at 0x7f05003f3e80>)]
"""
return [(k, self._connected_alias.get(k, None)) for k in self._alias]
with self._con_lock:
return [(k, self._connected_alias.get(k, None)) for k in self._alias]

def get_connection_addr(self, alias: str):
"""
Expand Down Expand Up @@ -414,7 +346,99 @@ def has_connection(self, alias: str) -> bool:
"""
if not isinstance(alias, str):
raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias))
return alias in self._connected_alias
with self._con_lock:
return alias in self._connected_alias

def __parse_info(
self,
address: str = "",
uri: str = "",
host: str = "",
port: str = "",
db_name: str = "",
user: str = "",
password: str = "",
**kwargs) -> dict:

passed_in_address = ""
passed_in_user = ""
passed_in_password = ""
passed_in_db_name = ""

# If uri
if uri != "":
passed_in_address, passed_in_user, passed_in_password, passed_in_db_name = (
self.__parse_address_from_uri(uri)
)

elif address != "":
if not is_legal_address(address):
raise ConnectionConfigException(
message=f"Illegal address: {address}, should be in form 'localhost:19530'")
passed_in_address = address

else:
if host == "":
host = Config.DEFAULT_HOST
if port == "":
port = Config.DEFAULT_PORT
self.__verify_host_port(host, port)
passed_in_address = f"{host}:{port}"

passed_in_user = user if passed_in_user == "" else str(passed_in_user)
passed_in_user = Config.MILVUS_USER if passed_in_user == "" else str(passed_in_user)

passed_in_password = password if passed_in_password == "" else str(passed_in_password)
passed_in_password = Config.MILVUS_PASSWORD if passed_in_password == "" else str(passed_in_password)

passed_in_db_name = db_name if passed_in_db_name == "" else str(passed_in_db_name)
passed_in_db_name = Config.MILVUS_DB_NAME if passed_in_db_name == "" else str(passed_in_db_name)

return passed_in_address, passed_in_user, passed_in_password, passed_in_db_name

def __verify_host_port(self, host, port):
if not is_legal_host(host):
raise ConnectionConfigException(message=ExceptionsMessage.HostType)
if not is_legal_port(port):
raise ConnectionConfigException(message=ExceptionsMessage.PortType)
if not 0 <= int(port) < 65535:
raise ConnectionConfigException(message=f"port number {port} out of range, valid range [0, 65535)")

def __parse_address_from_uri(self, uri: str) -> Tuple[str, str, str, str]:
illegal_uri_msg = "Illegal uri: [{}], expected form 'https://user:[email protected]:12345'"
try:
parsed_uri = parse.urlparse(uri)
except (Exception) as e:
raise ConnectionConfigException(
message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>") from None

if len(parsed_uri.netloc) == 0:
raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}") from None

group = parsed_uri.path.split("/")
if len(group) > 1:
db_name = group[1]
else:
db_name = ""

host = parsed_uri.hostname if parsed_uri.hostname is not None else ""
port = parsed_uri.port if parsed_uri.port is not None else ""
user = parsed_uri.username if parsed_uri.username is not None else ""
password = parsed_uri.password if parsed_uri.password is not None else ""

if host == "":
raise ConnectionConfigException(message=f"Illegal uri: URI is missing host address: {uri}")
if port == "":
raise ConnectionConfigException(message=f"Illegal uri: URI is missing port: {uri}")

self.__verify_host_port(host, port)
addr = f"{host}:{port}"

if not is_legal_address(addr):
raise ConnectionConfigException(message=illegal_uri_msg.format(uri))

return addr, user, password, db_name


def _fetch_handler(self, alias=Config.MILVUS_CONN_ALIAS) -> GrpcHandler:
""" Retrieves a GrpcHandler by alias. """
Expand Down
Loading

0 comments on commit 18fd399

Please sign in to comment.