Skip to content

Commit

Permalink
Refactor version parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
tellet-q committed Dec 5, 2024
1 parent d188327 commit bc47410
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 11 deletions.
30 changes: 25 additions & 5 deletions qdrant_client/common/version_check.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import asyncio
import importlib.metadata
import inspect
import logging
from typing import Union, TYPE_CHECKING
from collections import namedtuple


Version = namedtuple("Version", ["major", "minor", "rest"])

from packaging import version
from packaging.version import Version

if TYPE_CHECKING:
from qdrant_client.qdrant_remote import QdrantRemote
from qdrant_client.async_qdrant_remote import AsyncQdrantRemote


def is_server_version_compatible(client: Union["QdrantRemote", "AsyncQdrantRemote"]) -> bool:
client_version = version.parse(importlib.metadata.version("qdrant-client"))
client_version = importlib.metadata.version("qdrant-client")

get_info = client.info()
if inspect.iscoroutine(get_info):
Expand All @@ -22,14 +25,31 @@ def is_server_version_compatible(client: Union["QdrantRemote", "AsyncQdrantRemot
info_version = get_info.version
else:
raise ValueError("Unable to retrieve server version")
server_version = version.parse(info_version)
server_version = info_version

return check_version(client_version, server_version)


def check_version(client_version: Version, server_version: Version) -> bool:
def parse_version(version: str) -> Version:
try:
major, minor, *rest = version.split(".")
return Version(int(major), int(minor), rest)
except ValueError as er:
raise ValueError(
f"Unable to parse version, expected format: x.y.z, found: {version}"
) from er


def check_version(client_version: str, server_version: str) -> bool:
if client_version == server_version:
return True

try:
server_version = parse_version(server_version)
client_version = parse_version(client_version)
except ValueError as er:
logging.warning(f"Unable to parse version: {er}")
return False
major_dif = abs(server_version.major - client_version.major)
if major_dif >= 1:
return False
Expand Down
31 changes: 25 additions & 6 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import pytest
from packaging.version import Version

from qdrant_client.common.version_check import check_version
from qdrant_client.common.version_check import check_version, parse_version


@pytest.mark.parametrize(
"test_data",
[
("1.9.3.dev0", "2.0.1", False, "Diff between major versions = 1, minor versions differ"),
(
"1.9",
"2.0",
False,
"Diff between major versions = 1, minor versions differ, only major and patch",
),
("1", "2", False, "Diff between major versions = 1, minor versions differ, only major"),
("1.9.0", "2.9.0", False, "Diff between major versions = 1, minor versions are the same"),
(
"1.1.0",
Expand Down Expand Up @@ -45,7 +51,20 @@
],
)
def test_check_versions(test_data):
assert (
check_version(client_version=Version(test_data[0]), server_version=Version(test_data[1]))
is test_data[2]
)
assert check_version(client_version=test_data[0], server_version=test_data[1]) is test_data[2]


@pytest.mark.parametrize(
"test_data",
[
("1", "Only major version"),
("1.", "Only major version"),
(".1", "Only minor version"),
(".1.", "Only minor version"),
("1.None.1", "Minor version is not a number"),
("None.0.1", "Major version is not a number"),
],
)
def test_parse_versions_value_error(test_data):
with pytest.raises(ValueError):
parse_version(test_data[0])
3 changes: 3 additions & 0 deletions tests/test_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def test_client_init():
client = QdrantClient(check_version=True)
assert isinstance(client._client, QdrantRemote)

client = QdrantClient(check_version=True, prefer_grpc=True)
assert isinstance(client._client, QdrantRemote)

client = QdrantClient(https=True)
assert isinstance(client._client, QdrantRemote)
assert client._client.rest_uri == "https://localhost:6333"
Expand Down

0 comments on commit bc47410

Please sign in to comment.