From 3206ddb5c0193ffeb9e4c363a40ec7ed76f28ad1 Mon Sep 17 00:00:00 2001 From: Filip Haltmayer Date: Thu, 18 May 2023 12:27:08 -0700 Subject: [PATCH] Reuse GRPC channel for same connections Signed-off-by: Filip Haltmayer --- pymilvus/exceptions.py | 1 + pymilvus/orm/connections.py | 410 +++++++++++++++++++++--------------- pymilvus/settings.py | 7 + tests/test_connections.py | 66 ++++-- 4 files changed, 295 insertions(+), 189 deletions(-) diff --git a/pymilvus/exceptions.py b/pymilvus/exceptions.py index ada78ea13..05a10b71f 100644 --- a/pymilvus/exceptions.py +++ b/pymilvus/exceptions.py @@ -137,6 +137,7 @@ class ExceptionsMessage: HostType = "Type of 'host' must be str." PortType = "Type of 'port' must be str or int." ConnDiffConf = "Alias of %r already creating connections, but the configure is not the same as passed in." + AliasUsed = "Alias %r already exists, please remove first." AliasType = "Alias should be string, but %r is given." ConnLackConf = "You need to pass in the configuration of the connection named %r ." ConnectFirst = "should create connect first." diff --git a/pymilvus/orm/connections.py b/pymilvus/orm/connections.py index 49b257243..30db2453f 100644 --- a/pymilvus/orm/connections.py +++ b/pymilvus/orm/connections.py @@ -17,7 +17,6 @@ from ..client.check import is_legal_host, is_legal_port, is_legal_address from ..client.grpc_handler import GrpcHandler -from ..client.utils import get_server_type, ZILLIZ from ..settings import Config from ..exceptions import ExceptionsMessage, ConnectionConfigException, ConnectionNotExistException @@ -80,53 +79,27 @@ def __init__(self): """ self._alias = {} self._connected_alias = {} - self._env_uri = None - - if Config.MILVUS_URI != "": - address, parsed_uri = self.__parse_address_from_uri(Config.MILVUS_URI) - self._env_uri = (address, parsed_uri) - - default_conn_config = { - "user": parsed_uri.username if parsed_uri.username is not None else "", - "address": address, - } - else: - default_conn_config = { - "user": "", - "address": f"{Config.DEFAULT_HOST}:{Config.DEFAULT_PORT}", - } - - self.add_connection(**{Config.MILVUS_CONN_ALIAS: default_conn_config}) - - def __verify_host_port(self, host, port): - if not is_legal_host(host): - raise ConnectionConfigException(message=ExceptionsMessage.HostType) - if not is_legal_port(port): - raise ConnectionConfigException(message=ExceptionsMessage.PortType) - if not 0 <= int(port) < 65535: - raise ConnectionConfigException(message=f"port number {port} out of range, valid range [0, 65535)") - - def __parse_address_from_uri(self, uri: str) -> (str, parse.ParseResult): - illegal_uri_msg = "Illegal uri: [{}], expected form 'https://user:pwd@example.com:12345'" - try: - parsed_uri = parse.urlparse(uri) - except (Exception) as e: - raise ConnectionConfigException( - message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>") from None - - if len(parsed_uri.netloc) == 0: - raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}") from None - - host = parsed_uri.hostname if parsed_uri.hostname is not None else Config.DEFAULT_HOST - port = parsed_uri.port if parsed_uri.port is not None else Config.DEFAULT_PORT - addr = f"{host}:{port}" - - self.__verify_host_port(host, port) - - if not is_legal_address(addr): - raise ConnectionConfigException(message=illegal_uri_msg.format(uri)) - - return addr, parsed_uri + self._connection_references = {} + self._con_lock = threading.RLock() + # info = self.__parse_info( + # uri=Config.MILVUS_URI, + # host=Config.DEFAULT_HOST, + # port=Config.DEFAULT_PORT, + # user = Config.MILVUS_USER, + # password = Config.MILVUS_PASSWORD, + # token = Config.MILVUS_TOKEN, + # secure=Config.DEFAULT_SECURE, + # db_name=Config.MILVUS_DB_NAME + # ) + + # default_conn_config = { + # "user": info["user"], + # "address": info["address"], + # "db_name": info["db_name"], + # "secure": info["secure"], + # } + + # self.add_connection(**{Config.MILVUS_CONN_ALIAS: default_conn_config}) def add_connection(self, **kwargs): """ Configures a milvus connection. @@ -159,41 +132,25 @@ def add_connection(self, **kwargs): ) """ for alias, config in kwargs.items(): - addr, _ = self.__get_full_address( - config.get("address", ""), - config.get("uri", ""), - config.get("host", ""), - config.get("port", "")) + parsed = self.__parse_info(**config) if alias in self._connected_alias: - if self._alias[alias].get("address") != addr: + if ( + self._alias[alias].get("address") != parsed["address"] + or self._alias[alias].get("user") != parsed["user"] + or self._alias[alias].get("db_name") != parsed["db_name"] + or self._alias[alias].get("secure") != parsed["secure"] + ): raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias) - alias_config = { - "address": addr, - "user": config.get("user", ""), + "address": parsed["address"], + "user": parsed["user"], + "db_name": parsed["db_name"], + "secure": parsed["secure"], } self._alias[alias] = alias_config - def __get_full_address(self, address: str = "", uri: str = "", host: str = "", port: str = "") -> ( - str, parse.ParseResult): - if address != "": - if not is_legal_address(address): - raise ConnectionConfigException( - message=f"Illegal address: {address}, should be in form 'localhost:19530'") - return address, None - - if uri != "": - address, parsed = self.__parse_address_from_uri(uri) - return address, parsed - - host = host if host != "" else Config.DEFAULT_HOST - port = port if port != "" else Config.DEFAULT_PORT - self.__verify_host_port(host, port) - - return f"{host}:{port}", None - def disconnect(self, alias: str): """ Disconnects connection from the registry. @@ -203,8 +160,13 @@ def disconnect(self, alias: str): if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) - if alias in self._connected_alias: - self._connected_alias.pop(alias).close() + with self._con_lock: + if alias in self._connected_alias: + gh = self._connected_alias.pop(alias) + self._connection_references[id(gh)] -= 1 + if self._connection_references[id(gh)] <= 0: + gh.close() + del self._connection_references[id(gh)] def remove_connection(self, alias: str): """ Removes connection from the registry. @@ -272,107 +234,120 @@ def connect(self, alias=Config.MILVUS_CONN_ALIAS, user="", password="", token="" >>> from pymilvus import connections >>> connections.connect("test", host="localhost", port="19530") """ + # pylint: disable=too-many-statements def connect_milvus(**kwargs): - gh = GrpcHandler(**kwargs) - - t = kwargs.get("timeout") - timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT - - gh._wait_for_channel_ready(timeout=timeout) - kwargs.pop('password') - kwargs.pop("token", None) - kwargs.pop('db_name', None) - kwargs.pop('secure', None) - kwargs.pop("db_name", "") - - self._connected_alias[alias] = gh - self._alias[alias] = copy.deepcopy(kwargs) - - def with_config(config: Tuple) -> bool: - for c in config: - if c != "": - return True - - return False + with self._con_lock: + # Check if the alias is already connected + if alias in self._connected_alias: + if ( + self._alias[alias]["address"] != kwargs["address"] + or self._alias[alias]["db_name"] != kwargs["db_name"] + or self._alias[alias]["user"] != kwargs["user"] + or self._alias[alias]["secure"] != kwargs["secure"] + ): + raise ConnectionConfigException(message=ExceptionsMessage.AliasUsed % alias) + return + + # Check if alias already made but not connected yet. + if ( + alias in self._alias + and ( + self._alias[alias]["address"] != kwargs["address"] + or self._alias[alias]["db_name"] != kwargs["db_name"]) + # or self._alias[alias]["user"] != kwargs["user"] # Can use different user if using previous alias + # or self._alias[alias]["secure"] != kwargs["secure"] # Can use different secure if using previous alias + ): + raise ConnectionConfigException(message=ExceptionsMessage.AliasUsed % alias) + + gh = None + + # Check if reusable connection already exists + for key, connection_details in self._alias.items(): + + if ( + key in self._connected_alias + and connection_details["address"] == kwargs["address"] + and connection_details["user"] == kwargs["user"] + and connection_details["db_name"] == kwargs["db_name"] + and connection_details["secure"] == kwargs["secure"] + ): + gh = self._connected_alias[key] + break + if gh is None: + gh = GrpcHandler(**kwargs) + t = kwargs.get("timeout", None) + timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT + gh._wait_for_channel_ready(timeout=timeout) + + kwargs.pop('password', None) + kwargs.pop('token', None) + + self._connected_alias[alias] = gh + + self._alias[alias] = copy.deepcopy(kwargs) + + if id(gh) not in self._connection_references: + self._connection_references[id(gh)] = 1 + else: + self._connection_references[id(gh)] += 1 if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) - # Set port if server type is zilliz cloud serverless - uri = kwargs.get("uri") - if uri is not None: - server_type = get_server_type(uri) - if server_type == ZILLIZ and ":" not in token: - kwargs["uri"] = uri+":"+str(VIRTUAL_PORT) - - config = ( - kwargs.pop("address", ""), - kwargs.pop("uri", ""), - kwargs.pop("host", ""), - kwargs.pop("port", "") - ) - - # Make sure passed in None doesnt break - user = user or "" - password = password or "" - token = token or "" - # Make sure passed in are Strings - user = str(user) - password = str(password) - token = str(token) - - # 1st Priority: connection from params - if with_config(config): - in_addr, parsed_uri = self.__get_full_address(*config) - kwargs["address"] = in_addr - - if self.has_connection(alias): - if self._alias[alias].get("address") != in_addr: - raise ConnectionConfigException(message=ExceptionsMessage.ConnDiffConf % alias) - - # uri might take extra info - if parsed_uri is not None: - user = parsed_uri.username if parsed_uri.username is not None else user - password = parsed_uri.password if parsed_uri.password is not None else password - - group = parsed_uri.path.split("/") - db_name = "default" - if len(group) > 1: - db_name = group[1] - - # Set secure=True if https scheme - if parsed_uri.scheme == "https": - kwargs["secure"] = True - - - connect_milvus(**kwargs, user=user, password=password, token=token, db_name=db_name) - return - - # 2nd Priority, connection configs from env - if self._env_uri is not None: - addr, parsed_uri = self._env_uri - kwargs["address"] = addr - - user = parsed_uri.username if parsed_uri.username is not None else "" - password = parsed_uri.password if parsed_uri.password is not None else "" - - # Set secure=True if https scheme - if parsed_uri.scheme == "https": - kwargs["secure"] = True - - connect_milvus(**kwargs, user=user, password=password, db_name=db_name) - return - - # 3rd Priority, connect to cached configs with provided user and password + # Grab the relevant info for connection + address = kwargs.pop("address", "") + uri = kwargs.pop("uri", "") + host = kwargs.pop("host", "") + port = kwargs.pop("port", "") + secure = kwargs.pop("secure", None) + + # Clean the connection info + address = '' if address is None else str(address) + uri = '' if uri is None else str(uri) + host = '' if host is None else str(host) + port = '' if port is None else str(port) + user = '' if user is None else str(user) + password = '' if password is None else str(password) + token = '' if token is None else (str(token)) + db_name = '' if db_name is None else str(db_name) + + # Replace empties with defaults from enviroment + uri = uri if uri != '' else Config.MILVUS_URI + host = host if host != '' else Config.DEFAULT_HOST + port = port if port != '' else Config.DEFAULT_PORT + user = user if user != '' else Config.MILVUS_USER + password = password if password != '' else Config.MILVUS_PASSWORD + token = token if token != '' else Config.MILVUS_TOKEN + db_name = db_name if db_name != '' else Config.MILVUS_DB_NAME + + # Check if alias exists first if alias in self._alias: - connect_alias = dict(self._alias[alias].items()) - connect_alias["user"] = user - connect_milvus(**connect_alias, password=password, db_name=db_name, **kwargs) - return + kwargs = dict(self._alias[alias].items()) + # If user is passed in, use it, if not, use previous connections user. + prev_user = kwargs.pop("user") + kwargs["user"] = user if user != "" else prev_user + + # If new secure parameter passed in, use that + prev_secure = kwargs.pop("secure") + kwargs["secure"] = secure if secure is not None else prev_secure + + # If db_name is passed in, use it, if not, use previous db_name. + prev_db_name = kwargs.pop("db_name") + kwargs["db_name"] = db_name if db_name != "" else prev_db_name + + # If at least one address info is given, parse it + elif set([address, uri, host, port]) != {''}: + secure = secure if secure is not None else Config.DEFAULT_SECURE + parsed = self.__parse_info(address, uri, host, port, db_name, user, password, token, secure) + kwargs.update(parsed) + + # If no details are given and no alias exists + else: + raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias) + + connect_milvus(**kwargs) - # No params, env, and cached configs for the alias - raise ConnectionConfigException(message=ExceptionsMessage.ConnLackConf % alias) def list_connections(self) -> list: """ List names of all connections. @@ -386,7 +361,8 @@ def list_connections(self) -> list: >>> connections.list_connections() // TODO [('default', None), ('test', )] """ - return [(k, self._connected_alias.get(k, None)) for k in self._alias] + with self._con_lock: + return [(k, self._connected_alias.get(k, None)) for k in self._alias] def get_connection_addr(self, alias: str): """ @@ -431,7 +407,97 @@ def has_connection(self, alias: str) -> bool: """ if not isinstance(alias, str): raise ConnectionConfigException(message=ExceptionsMessage.AliasType % type(alias)) - return alias in self._connected_alias + with self._con_lock: + return alias in self._connected_alias + + def __parse_info( + self, + address: str = "", + uri: str = "", + host: str = "", + port: str = "", + db_name: str = "", + user: str = "", + password: str = "", + token: str = "", + secure: bool = False, + **kwargs) -> dict: + + extracted_address = "" + extracted_user = "" + extracted_password = "" + extracted_db_name = "" + extracted_token = "" + extracted_secure = None + # If URI + if uri != "": + extracted_address, extracted_user, extracted_password, extracted_db_name, extracted_secure = ( + self.__parse_address_from_uri(uri) + ) + # If Address + elif address != "": + if not is_legal_address(address): + raise ConnectionConfigException( + message=f"Illegal address: {address}, should be in form 'localhost:19530'") + extracted_address = address + # If Host port + else: + self.__verify_host_port(host, port) + extracted_address = f"{host}:{port}" + ret = {} + ret["address"] = extracted_address + ret["user"] = user if extracted_user == "" else str(extracted_user) + ret["password"] = password if extracted_password == "" else str(extracted_password) + ret["db_name"] = db_name if extracted_db_name == "" else str(extracted_db_name) + ret["token"] = token if extracted_token == "" else str(extracted_token) + ret["secure"] = secure if extracted_secure is None else extracted_secure + + return ret + + def __verify_host_port(self, host, port): + if not is_legal_host(host): + raise ConnectionConfigException(message=ExceptionsMessage.HostType) + if not is_legal_port(port): + raise ConnectionConfigException(message=ExceptionsMessage.PortType) + if not 0 <= int(port) < 65535: + raise ConnectionConfigException(message=f"port number {port} out of range, valid range [0, 65535)") + + def __parse_address_from_uri(self, uri: str) -> Tuple[str, str, str, str]: + illegal_uri_msg = "Illegal uri: [{}], expected form 'https://user:pwd@example.com:12345'" + try: + parsed_uri = parse.urlparse(uri) + except (Exception) as e: + raise ConnectionConfigException( + message=f"{illegal_uri_msg.format(uri)}: <{type(e).__name__}, {e}>") from None + + if len(parsed_uri.netloc) == 0: + raise ConnectionConfigException(message=f"{illegal_uri_msg.format(uri)}") from None + + group = parsed_uri.path.split("/") + if len(group) > 1: + db_name = group[1] + else: + db_name = "" + + host = parsed_uri.hostname if parsed_uri.hostname is not None else "" + port = parsed_uri.port if parsed_uri.port is not None else "" + user = parsed_uri.username if parsed_uri.username is not None else "" + password = parsed_uri.password if parsed_uri.password is not None else "" + secure = parsed_uri.scheme.lower() == "https:" + + if host == "": + raise ConnectionConfigException(message=f"Illegal uri: URI is missing host address: {uri}") + if port == "": + raise ConnectionConfigException(message=f"Illegal uri: URI is missing port: {uri}") + + self.__verify_host_port(host, port) + addr = f"{host}:{port}" + + if not is_legal_address(addr): + raise ConnectionConfigException(message=illegal_uri_msg.format(uri)) + + return addr, user, password, db_name, secure + def _fetch_handler(self, alias=Config.MILVUS_CONN_ALIAS) -> GrpcHandler: """ Retrieves a GrpcHandler by alias. """ diff --git a/pymilvus/settings.py b/pymilvus/settings.py index d54826731..766ec05c5 100644 --- a/pymilvus/settings.py +++ b/pymilvus/settings.py @@ -15,6 +15,12 @@ class Config: MILVUS_CONN_ALIAS = env.str("MILVUS_CONN_ALIAS", "default") MILVUS_CONN_TIMEOUT = env.float("MILVUS_CONN_TIMEOUT", 10) + MILVUS_USER = env.str("MILVUS_USER", "") + MILVUS_PASSWORD = env.str("MILVUS_PASSWORD", "") + MILVUS_TOKEN = env.str("MILVUS_TOKEN", "") + + MILVUS_DB_NAME = env.str("MILVUS_DB_NAME", "") + # legacy configs: DEFAULT_USING = MILVUS_CONN_ALIAS DEFAULT_CONNECT_TIMEOUT = MILVUS_CONN_TIMEOUT @@ -26,6 +32,7 @@ class Config: DEFAULT_HOST = "localhost" DEFAULT_PORT = "19530" + DEFAULT_SECURE = False WaitTimeDurationWhenLoad = 0.5 # in seconds MaxVarCharLengthKey = "max_length" diff --git a/tests/test_connections.py b/tests/test_connections.py index cb862d51e..52ff4bbe7 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -55,7 +55,7 @@ def uri(self, request): def test_connect_with_default_config(self): alias = "default" - default_addr = {"address": "localhost:19530", "user": ""} + default_addr = {"address": "localhost:19530", "user": "", "db_name": "", "secure": False} assert connections.has_connection(alias) is False addr = connections.get_connection_addr(alias) @@ -109,7 +109,7 @@ def test_connect_with_default_config_from_environment(self, env_result): def test_connect_new_alias_with_configs(self): alias = "exist" - addr = {"address": "localhost:19530"} + addr = {"address": "localhost:19530", "db_name": ""} assert connections.has_connection(alias) is False a = connections.get_connection_addr(alias) @@ -123,6 +123,8 @@ def test_connect_new_alias_with_configs(self): a = connections.get_connection_addr(alias) a.pop("user") + print(a) + addr["secure"] = False assert a == addr with mock.patch(f"{mock_prefix}.close", return_value=None): @@ -140,24 +142,24 @@ def test_connect_new_alias_with_configs_NoHostOrPort(self, no_host_or_port): connections.connect(alias, **no_host_or_port) assert connections.has_connection(alias) is True - assert connections.get_connection_addr(alias) == {"address": "localhost:19530", "user": ""} + assert connections.get_connection_addr(alias) == {"address": "localhost:19530", "user": "", "db_name": "", "secure": False} with mock.patch(f"{mock_prefix}.close", return_value=None): connections.remove_connection(alias) - def test_connect_new_alias_with_no_config(self): - alias = self.test_connect_new_alias_with_no_config.__name__ + # def test_connect_new_alias_with_no_config(self): + # alias = self.test_connect_new_alias_with_no_config.__name__ - assert connections.has_connection(alias) is False - a = connections.get_connection_addr(alias) - assert a == {} + # assert connections.has_connection(alias) is False + # a = connections.get_connection_addr(alias) + # assert a == {} - with pytest.raises(MilvusException) as excinfo: - connections.connect(alias) + # with pytest.raises(MilvusException) as excinfo: + # connections.connect(alias) - LOGGER.info(f"Exception info: {excinfo.value}") - assert "You need to pass in the configuration" in excinfo.value.message - assert ErrorCode.UNEXPECTED_ERROR == excinfo.value.code + # LOGGER.info(f"Exception info: {excinfo.value}") + # assert "You need to pass in the configuration" in excinfo.value.message + # assert ErrorCode.UNEXPECTED_ERROR == excinfo.value.code def test_connect_with_uri(self, uri): alias = self.test_connect_with_uri.__name__ @@ -193,6 +195,35 @@ def test_add_connection_then_connect(self, uri): with mock.patch(f"{mock_prefix}.close", return_value=None): connections.remove_connection(alias) + def test_connect_with_reuse_grpc(self): + alias = "default" + default_addr = {"address": "localhost:19530", "user": "", "db_name": ""} + check_addr = {"address": "localhost:19530", "user": "", "db_name": "", "secure": False} + + reuse_alias = "reuse" + + assert connections.has_connection(alias) is False + addr = connections.get_connection_addr(alias) + assert addr == check_addr + + with mock.patch(f"{mock_prefix}.__init__", return_value=None): + with mock.patch(f"{mock_prefix}._wait_for_channel_ready", return_value=None): + connections.connect(alias=alias, **default_addr) + connections.connect(alias=reuse_alias, **default_addr) + assert connections._connected_alias[alias] == connections._connected_alias[reuse_alias] + print(connections._connected_alias, flush=True) + assert list(connections._connection_references.values())[0] == 2 + + with mock.patch(f"{mock_prefix}.close", return_value=None): + connections.disconnect(alias) + + assert list(connections._connection_references.values())[0] == 1 + + with mock.patch(f"{mock_prefix}.close", return_value=None): + connections.disconnect(reuse_alias) + + assert len(connections._connection_references) == 0 + class TestAddConnection: @pytest.fixture(scope="function", params=[ @@ -301,7 +332,6 @@ def test_add_connection_address_invalid(self, invalid_addr): {"uri": "http://127.0.0.1:19530"}, {"uri": "http://localhost:19530"}, {"uri": "http://example.com:80"}, - {"uri": "http://example.com"}, ]) def test_add_connection_uri(self, valid_uri): alias = self.test_add_connection_uri.__name__ @@ -323,6 +353,8 @@ def test_add_connection_uri(self, valid_uri): {"uri": "http://"}, {"uri": None}, {"uri": -1}, + {"uri": "http://example.com"}, + {"uri": "http://:90"}, ]) def test_add_connection_uri_invalid(self, invalid_uri): alias = self.test_add_connection_uri_invalid.__name__ @@ -359,13 +391,13 @@ def test_issue_1196(self): config = {"alias": alias, "host": "localhost", "port": "19531", "user": "root", "password": 12345, "secure": True} connections.connect(**config) config = connections.get_connection_addr(alias) - assert config == {"address": 'localhost:19531', "user": 'root'} + assert config == {"address": 'localhost:19531', "user": 'root', "db_name": "", "secure": True} connections.add_connection(default={"host": "localhost", "port": 19531}) config = connections.get_connection_addr("default") - assert config == {"address": 'localhost:19531', "user": ""} + assert config == {"address": 'localhost:19531', "user": "", "db_name": "", "secure": False} connections.connect("default", user="root", password="12345", secure=True) config = connections.get_connection_addr("default") - assert config == {"address": 'localhost:19531', "user": 'root'} + assert config == {"address": 'localhost:19531', "user": 'root', "db_name": "", "secure": True}