diff --git a/qdrant_client/async_qdrant_remote.py b/qdrant_client/async_qdrant_remote.py index 3a0523bc..99884d6f 100644 --- a/qdrant_client/async_qdrant_remote.py +++ b/qdrant_client/async_qdrant_remote.py @@ -9,7 +9,6 @@ # # ****** WARNING: THIS FILE IS AUTOGENERATED ****** -import asyncio import importlib.metadata import logging import math @@ -160,8 +159,10 @@ def __init__( self._grpc_snapshots_client: Optional[grpc.SnapshotsStub] = None self._grpc_root_client: Optional[grpc.QdrantStub] = None self._closed: bool = False - if check_version and (not is_server_version_compatible(self)): - warnings.warn("Qdrant client version may be incompatible with server version.") + if check_version and (not is_server_version_compatible(self.rest_uri, self._rest_args)): + warnings.warn( + "Qdrant client version may be incompatible with server version. Set check_version=False to skip version check." + ) @property def closed(self) -> bool: diff --git a/qdrant_client/common/version_check.py b/qdrant_client/common/version_check.py index 1ab6585d..3eaa770f 100644 --- a/qdrant_client/common/version_check.py +++ b/qdrant_client/common/version_check.py @@ -1,36 +1,46 @@ -import asyncio import importlib.metadata -import inspect import logging -from typing import Union, TYPE_CHECKING +from typing import Union from collections import namedtuple +from qdrant_client.http import SyncApis, ApiClient +from qdrant_client.http.models import models Version = namedtuple("Version", ["major", "minor", "rest"]) -if TYPE_CHECKING: - from qdrant_client.qdrant_remote import QdrantRemote - from qdrant_client.async_qdrant_remote import AsyncQdrantRemote +def is_server_version_compatible(rest_uri, rest_args): + def get_server_info(): + openapi_client: SyncApis[ApiClient] = SyncApis( + host=rest_uri, + **rest_args, + ) + return openapi_client.client.request( + type_=models.VersionInfo, + method="GET", + url="/", + headers=None, + ) + def get_server_version() -> Union[str, None]: + try: + version_info = get_server_info() + except Exception as er: + logging.warning(f"Unable to get server version: {er}, default to None") + return None -def is_server_version_compatible(client: Union["QdrantRemote", "AsyncQdrantRemote"]) -> bool: - client_version = importlib.metadata.version("qdrant-client") - - get_info = client.info() - if inspect.iscoroutine(get_info): - loop = asyncio.get_event_loop() - info_version = loop.run_until_complete(get_info).version - elif hasattr(get_info, "version"): - info_version = get_info.version - else: - raise ValueError("Unable to retrieve server version") - server_version = info_version + if not version_info: + return None + return version_info.version - return check_version(client_version, server_version) + client_version = importlib.metadata.version("qdrant-client") + server_version = get_server_version() + return compare_versions(client_version, server_version) def parse_version(version: str) -> Version: + if not version: + raise ValueError("Version is None") try: major, minor, *rest = version.split(".") return Version(int(major), int(minor), rest) @@ -40,7 +50,11 @@ def parse_version(version: str) -> Version: ) from er -def check_version(client_version: str, server_version: str) -> bool: +def compare_versions(client_version: str, server_version: str) -> bool: + if not client_version or not server_version: + logging.warning(f"Unable to compare: {client_version} vs {server_version}") + return False + if client_version == server_version: return True diff --git a/qdrant_client/qdrant_remote.py b/qdrant_client/qdrant_remote.py index 12fa4c75..11879e1f 100644 --- a/qdrant_client/qdrant_remote.py +++ b/qdrant_client/qdrant_remote.py @@ -1,4 +1,5 @@ import asyncio +import importlib.metadata import logging import math import platform @@ -195,8 +196,10 @@ def __init__( self._closed: bool = False - if check_version and not is_server_version_compatible(self): - warnings.warn("Qdrant client version may be incompatible with server version.") + if check_version and not is_server_version_compatible(self.rest_uri, self._rest_args): + warnings.warn( + "Qdrant client version may be incompatible with server version. Set check_version=False to skip version check." + ) @property def closed(self) -> bool: diff --git a/tests/test_async_qdrant_client.py b/tests/test_async_qdrant_client.py index 4174f6fd..38eb4eb6 100644 --- a/tests/test_async_qdrant_client.py +++ b/tests/test_async_qdrant_client.py @@ -98,7 +98,7 @@ async def test_async_grpc(): async def test_async_qdrant_client(prefer_grpc): major, minor, patch, dev = read_version() - client = AsyncQdrantClient(prefer_grpc=prefer_grpc, timeout=15) + client = AsyncQdrantClient(prefer_grpc=prefer_grpc, timeout=15, check_version=True) collection_params = dict( collection_name=COLLECTION_NAME, vectors_config=models.VectorParams(size=10, distance=models.Distance.EUCLID), diff --git a/tests/test_common.py b/tests/test_common.py index e69c179d..d6026eba 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,6 +1,6 @@ import pytest -from qdrant_client.common.version_check import check_version, parse_version +from qdrant_client.common.version_check import compare_versions, parse_version @pytest.mark.parametrize( @@ -48,10 +48,15 @@ ), ("1.9.0", "3.7.0", False, "Diff between major versions > 1 (server > client)"), ("3.0.0", "1.0.0", False, "Diff between major versions > 1 (client > server)"), + (None, "1.0.0", False, "Client version is None"), + ("1.0.0", None, False, "Server version is None"), + (None, None, False, "Both versions are None"), ], ) def test_check_versions(test_data): - assert check_version(client_version=test_data[0], server_version=test_data[1]) is test_data[2] + assert ( + compare_versions(client_version=test_data[0], server_version=test_data[1]) is test_data[2] + ) @pytest.mark.parametrize( @@ -63,6 +68,8 @@ def test_check_versions(test_data): (".1.", "Only minor version"), ("1.None.1", "Minor version is not a number"), ("None.0.1", "Major version is not a number"), + (None, "Version is None"), + ("", "Version is empty"), ], ) def test_parse_versions_value_error(test_data):