Skip to content

Commit

Permalink
fix: allow nan values in payload in local mode
Browse files Browse the repository at this point in the history
  • Loading branch information
joein committed Apr 19, 2024
1 parent 477371d commit c23d52b
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 90 deletions.
23 changes: 2 additions & 21 deletions qdrant_client/local/local_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,31 +76,12 @@ def _to_jsonable_python(x: Any) -> Any:
return ENCODERS_BY_TYPE[type(x)](x)


def convert_nan_inf_to_null(obj: Any) -> Any:
if isinstance(obj, float) and (np.isnan(obj) or np.isinf(obj)):
return None

if isinstance(obj, dict):
return {k: convert_nan_inf_to_null(v) for k, v in obj.items()}

if isinstance(obj, str) or isinstance(obj, bytes) or isinstance(obj, range):
return obj

# pydantic converts iterables to lists
if isinstance(obj, Iterable):
return [convert_nan_inf_to_null(v) for v in obj]

return obj


def to_jsonable_python(x: Any) -> Any:
# breaks congruence with remote if pydantic<2.7, since it does not convert nan/inf to null
x = convert_nan_inf_to_null(x)
try:
json.dumps(x, allow_nan=False)
json.dumps(x, allow_nan=True)
return x
except Exception:
return json.loads(json.dumps(x, allow_nan=False, default=_to_jsonable_python))
return json.loads(json.dumps(x, allow_nan=True, default=_to_jsonable_python))


class LocalCollection:
Expand Down
69 changes: 0 additions & 69 deletions tests/congruence_tests/test_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,75 +258,6 @@ def test_upload_collection_dict_np_arrays(local_client, remote_client):
compare_collections(local_client, remote_client, UPLOAD_NUM_VECTORS)


def test_upload_payload_contain_nan_values():
# usual case when payload is extracted from pandas dataframe
from pydantic.version import VERSION

major, minor, patch = map(int, VERSION.split("."))
if major < 2 or (major == 2 and minor < 7):
pytest.skip("Test requires pydantic>=2.7.0")

def recreate_collections():
local_client.recreate_collection(
collection_name=nans_collection,
vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.DOT),
)
remote_client.recreate_collection(
collection_name=nans_collection,
vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.DOT),
)

local_client = init_local()
remote_client = init_remote()

vector_size = 2
nans_collection = "nans_collection"

points = generate_points(
num_points=UPLOAD_NUM_VECTORS,
vector_sizes=2,
with_payload=False,
)
ids, vectors, payload = [], [], []
for i in range(len(points)):
points[i].payload = {"surprise": math.nan}

for point in points:
ids.append(point.id)
vectors.append(point.vector)
payload.append(point.payload)

points_batch = models.Batch(
ids=ids,
vectors=vectors,
payloads=payload,
)

recreate_collections()
local_client.upload_collection(nans_collection, vectors, payload, ids)
remote_client.upload_collection(nans_collection, vectors, payload, ids)
compare_collections(
local_client, remote_client, UPLOAD_NUM_VECTORS, collection_name=nans_collection
)

recreate_collections()
local_client.upload_points(nans_collection, points)
remote_client.upload_points(nans_collection, points)
compare_collections(
local_client, remote_client, UPLOAD_NUM_VECTORS, collection_name=nans_collection
)

recreate_collections()
local_client.upsert(nans_collection, points=points_batch)
remote_client.upsert(nans_collection, points=points_batch)
compare_collections(
local_client, remote_client, UPLOAD_NUM_VECTORS, collection_name=nans_collection
)

local_client.delete_collection(nans_collection)
remote_client.delete_collection(nans_collection)


def test_upload_wrong_vectors():
local_client = init_local()
remote_client = init_remote()
Expand Down

0 comments on commit c23d52b

Please sign in to comment.