Skip to content

Commit

Permalink
Address review
Browse files Browse the repository at this point in the history
  • Loading branch information
tellet-q committed Dec 5, 2024
1 parent 1d0c171 commit 0c819b1
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 108 deletions.
4 changes: 2 additions & 2 deletions qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
Union[Callable[[], str], Callable[[], Awaitable[str]]]
] = None,
cloud_inference: bool = False,
check_version: Optional[bool] = None,
check_compatibility: Optional[bool] = True,
**kwargs: Any,
):
self._inference_inspector = Inspector()
Expand Down Expand Up @@ -133,7 +133,7 @@ def __init__(
host=host,
grpc_options=grpc_options,
auth_token_provider=auth_token_provider,
check_version=check_version,
check_compatibility=check_compatibility,
**kwargs,
)
if isinstance(self._client, AsyncQdrantLocal) and cloud_inference:
Expand Down
15 changes: 9 additions & 6 deletions qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from qdrant_client._pydantic_compat import construct
from qdrant_client.auth import BearerAuth
from qdrant_client.async_client_base import AsyncQdrantBase
from qdrant_client.common.version_check import is_server_version_compatible
from qdrant_client.common.version_check import is_versions_compatible, get_server_version
from qdrant_client.connection import get_async_channel as get_channel
from qdrant_client.conversions import common_types as types
from qdrant_client.conversions.common_types import get_args_subscribed
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(
auth_token_provider: Optional[
Union[Callable[[], str], Callable[[], Awaitable[str]]]
] = None,
check_version: Optional[bool] = None,
check_compatibility: Optional[bool] = True,
**kwargs: Any,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -159,10 +159,13 @@ def __init__(
self._grpc_snapshots_client: Optional[grpc.SnapshotsStub] = None
self._grpc_root_client: Optional[grpc.QdrantStub] = None
self._closed: bool = False
if check_version and (not is_server_version_compatible(self.rest_uri, **self._rest_args)):
warnings.warn(
"Qdrant client version may be incompatible with server version. Set check_version=False to skip version check."
)
if check_compatibility:
client_version = importlib.metadata.version("qdrant-client")
server_version = get_server_version(self.rest_uri, **self._rest_args)
if not is_versions_compatible(client_version, server_version):
warnings.warn(
f"Qdrant client version {client_version} is incompatible with server version {server_version}. Major versions should mathc and minor version difference must not exceed 1. Set check_version=False to skip version check."
)

@property
def closed(self) -> bool:
Expand Down
52 changes: 24 additions & 28 deletions qdrant_client/common/version_check.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,32 @@
import importlib.metadata
import logging
import warnings
from typing import Union, Any
from collections import namedtuple

from qdrant_client.http import SyncApis, ApiClient
from qdrant_client.http.models import models

Version = namedtuple("Version", ["major", "minor", "rest"])


def is_server_version_compatible(rest_uri: str, **kwargs: Any) -> bool:
def get_server_info() -> Any:
def get_server_version(rest_uri: str, **kwargs: Any) -> Union[str, None]:
try:
openapi_client: SyncApis[ApiClient] = SyncApis(
host=rest_uri,
**kwargs,
)
return openapi_client.client.request(
type_=models.VersionInfo,
method="GET",
url="/",
headers=None,
)
version_info = openapi_client.service_api.root()

def get_server_version() -> Union[str, None]:
try:
version_info = get_server_info()
except Exception as er:
logging.warning(f"Unable to get server version: {er}, default to None")
return None
openapi_client.close()
except Exception:
logging.warning(
"Unable to close http connection. Connection was interrupted on the server side"
)

if not version_info:
return None
return version_info.version

client_version = importlib.metadata.version("qdrant-client")
server_version = get_server_version()
return compare_versions(client_version, server_version)
except Exception as er:
warnings.warn(f"Unable to get server version: {er}, default to None")
return None


def parse_version(version: str) -> Version:
Expand All @@ -50,9 +41,15 @@ def parse_version(version: str) -> Version:
) from er


def compare_versions(client_version: Union[str, None], server_version: Union[str, None]) -> bool:
if not client_version or not server_version:
logging.warning(f"Unable to compare: {client_version} vs {server_version}")
def is_versions_compatible(
client_version: Union[str, None], server_version: Union[str, None]
) -> bool:
if not client_version:
warnings.warn(f"Unable to compare with client version {client_version}")
return False

if not server_version:
warnings.warn(f"Unable to compare with server version {server_version}")
return False

if client_version == server_version:
Expand All @@ -62,11 +59,10 @@ def compare_versions(client_version: Union[str, None], server_version: Union[str
parsed_server_version = parse_version(server_version)
parsed_client_version = parse_version(client_version)
except ValueError as er:
logging.warning(f"Unable to parse version: {er}")
warnings.warn(f"Unable to compare versions: {er}")
return False

major_dif = abs(parsed_server_version.major - parsed_client_version.major)
if major_dif >= 1:
return False
elif major_dif == 0:
return abs(parsed_server_version.minor - parsed_client_version.minor) <= 1
return False
return abs(parsed_server_version.minor - parsed_client_version.minor) <= 1
4 changes: 2 additions & 2 deletions qdrant_client/qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
Union[Callable[[], str], Callable[[], Awaitable[str]]]
] = None,
cloud_inference: bool = False,
check_version: Optional[bool] = None,
check_compatibility: Optional[bool] = True,
**kwargs: Any,
):
self._inference_inspector = Inspector()
Expand Down Expand Up @@ -145,7 +145,7 @@ def __init__(
host=host,
grpc_options=grpc_options,
auth_token_provider=auth_token_provider,
check_version=check_version,
check_compatibility=check_compatibility,
**kwargs,
)

Expand Down
15 changes: 9 additions & 6 deletions qdrant_client/qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from qdrant_client._pydantic_compat import construct
from qdrant_client.auth import BearerAuth
from qdrant_client.client_base import QdrantBase
from qdrant_client.common.version_check import is_server_version_compatible
from qdrant_client.common.version_check import is_versions_compatible, get_server_version
from qdrant_client.connection import get_async_channel, get_channel
from qdrant_client.conversions import common_types as types
from qdrant_client.conversions.common_types import get_args_subscribed
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(
auth_token_provider: Optional[
Union[Callable[[], str], Callable[[], Awaitable[str]]]
] = None,
check_version: Optional[bool] = None,
check_compatibility: Optional[bool] = True,
**kwargs: Any,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -196,10 +196,13 @@ def __init__(

self._closed: bool = False

if check_version and not is_server_version_compatible(self.rest_uri, **self._rest_args):
warnings.warn(
"Qdrant client version may be incompatible with server version. Set check_version=False to skip version check."
)
if check_compatibility:
client_version = importlib.metadata.version("qdrant-client")
server_version = get_server_version(self.rest_uri, **self._rest_args)
if not is_versions_compatible(client_version, server_version):
warnings.warn(
f"Qdrant client version {client_version} is incompatible with server version {server_version}. Major versions should mathc and minor version difference must not exceed 1. Set check_version=False to skip version check."
)

@property
def closed(self) -> bool:
Expand Down
111 changes: 50 additions & 61 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,66 @@
import pytest

from qdrant_client.common.version_check import compare_versions, parse_version
from qdrant_client.common.version_check import is_versions_compatible, parse_version


@pytest.mark.parametrize(
"test_data",
"client_version, server_version, expected_result",
[
("1.9.3.dev0", "2.0.1", False, "Diff between major versions = 1, minor versions differ"),
(
"1.9",
"2.0",
False,
"Diff between major versions = 1, minor versions differ, only major and patch",
),
("1", "2", False, "Diff between major versions = 1, minor versions differ, only major"),
("1.9.0", "2.9.0", False, "Diff between major versions = 1, minor versions are the same"),
(
"1.1.0",
"1.2.9",
True,
"Diff between major versions == 0, diff between minor versions == 1 (server > client)",
),
(
"1.2.7",
"1.1.8.dev0",
True,
"Diff between major versions == 0, diff between minor versions == 1 (client > server)",
),
(
"1.2.1",
"1.2.29",
True,
"Diff between major versions == 0, diff between minor versions == 0",
),
("1.2.0", "1.2.0", True, "Same versions"),
(
"1.2.0",
"1.4.0",
False,
"Diff between major versions == 0, diff between minor versions > 1 (server > client)",
),
(
"1.4.0",
"1.2.0",
False,
"Diff between major versions == 0, diff between minor versions > 1 (client > server)",
),
("1.9.0", "3.7.0", False, "Diff between major versions > 1 (server > client)"),
("3.0.0", "1.0.0", False, "Diff between major versions > 1 (client > server)"),
(None, "1.0.0", False, "Client version is None"),
("1.0.0", None, False, "Server version is None"),
(None, None, False, "Both versions are None"),
("1.9.3.dev0", "2.8.1.dev12-something", False),
("1.9", "2.8", False),
("1", "2", False),
("1.9.0", "2.9.0", False),
("1.1.0", "1.2.9", True),
("1.2.7", "1.1.8.dev0", True),
("1.2.1", "1.2.29", True),
("1.2.0", "1.2.0", True),
("1.2.0", "1.4.0", False),
("1.4.0", "1.2.0", False),
("1.9.0", "3.7.0", False),
("3.0.0", "1.0.0", False),
(None, "1.0.0", False),
("1.0.0", None, False),
(None, None, False),
],
ids=[
"Diff between major versions = 1, negative",
"Diff between major versions = 1, only major and minor, negative",
"Diff between major versions = 1, only major, negative",
"Diff between major versions = 1, minor versions are the same, negative",
"Diff between major versions == 0, diff between minor versions == 1 (server > client), positive",
"Diff between major versions == 0, diff between minor versions == 1 (client > server), positive",
"Diff between major versions == 0, diff between minor versions == 0, positive",
"Same versions, positive",
"Diff between major versions == 0, diff between minor versions > 1 (server > client), negative",
"Diff between major versions == 0, diff between minor versions > 1 (client > server), negative",
"Diff between major versions > 1 (server > client), negative",
"Diff between major versions > 1 (client > server), negative",
"Client version is None, negative",
"Server version is None, negative",
"Both versions are None, negative",
],
)
def test_check_versions(test_data):
def test_check_versions(client_version, server_version, expected_result):
assert (
compare_versions(client_version=test_data[0], server_version=test_data[1]) is test_data[2]
is_versions_compatible(client_version=client_version, server_version=server_version)
is expected_result
)


@pytest.mark.parametrize(
"test_data",
[
("1", "Only major version"),
("1.", "Only major version"),
(".1", "Only minor version"),
(".1.", "Only minor version"),
("1.None.1", "Minor version is not a number"),
("None.0.1", "Major version is not a number"),
(None, "Version is None"),
("", "Version is empty"),
"input_version",
["1", "1.", ".1", ".1.", "1.None.1", "None.0.1", None, ""],
ids=[
"Only major part",
"Only major part with dot",
"Only minor part",
"Only minor part with dot",
"Minor part is not a number",
"Major part is not a number",
"Version is None",
"Version is empty",
],
)
def test_parse_versions_value_error(test_data):
def test_parse_versions_value_error(input_version):
with pytest.raises(ValueError):
parse_version(test_data[0])
parse_version(input_version)
6 changes: 3 additions & 3 deletions tests/test_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ def test_client_init():
assert isinstance(client._client, QdrantRemote)
assert client._client.rest_uri == "http://localhost:6333"

client = QdrantClient(":memory:", check_version=True)
client = QdrantClient(":memory:", check_compatibility=True)
assert isinstance(client._client, QdrantLocal)

client = QdrantClient(check_version=True)
client = QdrantClient(check_compatibility=True)
assert isinstance(client._client, QdrantRemote)

client = QdrantClient(check_version=True, prefer_grpc=True)
client = QdrantClient(check_compatibility=True, prefer_grpc=True)
assert isinstance(client._client, QdrantRemote)

client = QdrantClient(https=True)
Expand Down

0 comments on commit 0c819b1

Please sign in to comment.