Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Beta/return empty dict on no return refs #719

Merged
merged 7 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ on:

env:
OLD_WEAVIATE_VERSION: 1.22.6
NEW_WEAVIATE_VERSION: preview-fix-uuid-casing-for-references-added-with-object-single-refs-and-bat-3e69acc
NEW_WEAVIATE_VERSION: preview-allow-discerning-between-nil-and-in-ref-props-20ad8e5

jobs:
lint-and-format:
Expand Down
8 changes: 7 additions & 1 deletion integration/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,13 @@ def test_add_reference(collection_factory: CollectionFactory) -> None:
uuid2, return_properties=["name"], return_references=FromReference(link_on="self")
)
assert "name" in obj1.properties
assert obj1.references is None
assert (
obj1.references == {}
if collection._connection._weaviate_version.is_at_least(
1, 23, 2
) # TODO: change to 1.23.3 when released
else obj1.references is None
)
assert "name" in obj2.properties
assert "self" in obj2.references

Expand Down
32 changes: 31 additions & 1 deletion integration/test_collection_references.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,11 @@ def test_insert_many_with_refs(collection_factory: CollectionFactory) -> None:
return_properties=["name"], return_references=FromReference(link_on="self")
).objects:
if obj.properties["name"] in ["A", "B"]:
assert obj.references is None
assert (
obj.references == {}
if collection._connection._weaviate_version.is_at_least(1, 23, 2)
else obj.references is None
) # TODO: change to 1.23.3 when released
else:
assert obj.references is not None

Expand Down Expand Up @@ -753,3 +757,29 @@ def test_ref_case_sensitivity(collection_factory: CollectionFactory) -> None:
uid, return_references=[QueryReference(link_on="ref")]
)
assert "ref" in obj.references


def test_empty_return_reference(collection_factory: CollectionFactory) -> None:
to = collection_factory(name="To", vectorizer_config=Configure.Vectorizer.none())
source = collection_factory(
name="From",
references=[
ReferenceProperty(name="ref", target_collection=to.name),
],
vectorizer_config=Configure.Vectorizer.none(),
)
if not source._connection._weaviate_version.is_at_least(
1, 23, 2
): # TODO: change this to 1.23.3 when it is released
pytest.skip("references return empty object only supported in 1.23.3+")
uuid_source = source.data.insert(properties={})
obj = source.query.fetch_object_by_id(
uuid_source, return_references=[QueryReference(link_on="ref")]
)
assert (
obj.references == {}
if source._connection._weaviate_version.is_at_least(
1, 23, 2
) # TODO: change to 1.23.3 when released
else obj.references is None
)
30 changes: 15 additions & 15 deletions weaviate/collections/classes/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from weaviate.collections.classes.types import _WeaviateInput
from weaviate.types import UUID
from weaviate.proto.v1 import search_get_pb2
from weaviate.proto.v1 import base_pb2
from weaviate.util import get_valid_uuid


Expand All @@ -27,33 +27,33 @@ class _Operator(str, Enum):
AND = "And"
OR = "Or"

def _to_grpc(self) -> search_get_pb2.Filters.Operator:
def _to_grpc(self) -> base_pb2.Filters.Operator:
if self == _Operator.EQUAL:
return search_get_pb2.Filters.OPERATOR_EQUAL
return base_pb2.Filters.OPERATOR_EQUAL
elif self == _Operator.NOT_EQUAL:
return search_get_pb2.Filters.OPERATOR_NOT_EQUAL
return base_pb2.Filters.OPERATOR_NOT_EQUAL
elif self == _Operator.LESS_THAN:
return search_get_pb2.Filters.OPERATOR_LESS_THAN
return base_pb2.Filters.OPERATOR_LESS_THAN
elif self == _Operator.LESS_THAN_EQUAL:
return search_get_pb2.Filters.OPERATOR_LESS_THAN_EQUAL
return base_pb2.Filters.OPERATOR_LESS_THAN_EQUAL
elif self == _Operator.GREATER_THAN:
return search_get_pb2.Filters.OPERATOR_GREATER_THAN
return base_pb2.Filters.OPERATOR_GREATER_THAN
elif self == _Operator.GREATER_THAN_EQUAL:
return search_get_pb2.Filters.OPERATOR_GREATER_THAN_EQUAL
return base_pb2.Filters.OPERATOR_GREATER_THAN_EQUAL
elif self == _Operator.LIKE:
return search_get_pb2.Filters.OPERATOR_LIKE
return base_pb2.Filters.OPERATOR_LIKE
elif self == _Operator.IS_NULL:
return search_get_pb2.Filters.OPERATOR_IS_NULL
return base_pb2.Filters.OPERATOR_IS_NULL
elif self == _Operator.CONTAINS_ANY:
return search_get_pb2.Filters.OPERATOR_CONTAINS_ANY
return base_pb2.Filters.OPERATOR_CONTAINS_ANY
elif self == _Operator.CONTAINS_ALL:
return search_get_pb2.Filters.OPERATOR_CONTAINS_ALL
return base_pb2.Filters.OPERATOR_CONTAINS_ALL
elif self == _Operator.WITHIN_GEO_RANGE:
return search_get_pb2.Filters.OPERATOR_WITHIN_GEO_RANGE
return base_pb2.Filters.OPERATOR_WITHIN_GEO_RANGE
elif self == _Operator.AND:
return search_get_pb2.Filters.OPERATOR_AND
return base_pb2.Filters.OPERATOR_AND
elif self == _Operator.OR:
return search_get_pb2.Filters.OPERATOR_OR
return base_pb2.Filters.OPERATOR_OR
else:
raise ValueError(f"Unknown operator {self}")

Expand Down
34 changes: 17 additions & 17 deletions weaviate/collections/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from weaviate.util import _datetime_to_string
from weaviate.types import TIME
from weaviate.proto.v1 import search_get_pb2
from weaviate.proto.v1 import base_pb2


class _FilterToGRPC:
Expand All @@ -22,11 +22,11 @@ def convert(weav_filter: Literal[None]) -> None:

@overload
@staticmethod
def convert(weav_filter: _Filters) -> search_get_pb2.Filters:
def convert(weav_filter: _Filters) -> base_pb2.Filters:
...

@staticmethod
def convert(weav_filter: Optional[_Filters]) -> Optional[search_get_pb2.Filters]:
def convert(weav_filter: Optional[_Filters]) -> Optional[base_pb2.Filters]:
if weav_filter is None:
return None
if isinstance(weav_filter, _FilterValue):
Expand All @@ -35,8 +35,8 @@ def convert(weav_filter: Optional[_Filters]) -> Optional[search_get_pb2.Filters]
return _FilterToGRPC.__and_or_filter(weav_filter)

@staticmethod
def __value_filter(weav_filter: _FilterValue) -> search_get_pb2.Filters:
return search_get_pb2.Filters(
def __value_filter(weav_filter: _FilterValue) -> base_pb2.Filters:
return base_pb2.Filters(
operator=weav_filter.operator._to_grpc(),
value_text=_FilterToGRPC.__filter_to_text(weav_filter.value),
value_int=weav_filter.value if isinstance(weav_filter.value, int) else None,
Expand All @@ -51,11 +51,11 @@ def __value_filter(weav_filter: _FilterValue) -> search_get_pb2.Filters:
)

@staticmethod
def __filter_to_geo(value: FilterValues) -> Optional[search_get_pb2.GeoCoordinatesFilter]:
def __filter_to_geo(value: FilterValues) -> Optional[base_pb2.GeoCoordinatesFilter]:
if not (isinstance(value, _GeoCoordinateFilter)):
return None

return search_get_pb2.GeoCoordinatesFilter(
return base_pb2.GeoCoordinatesFilter(
latitude=value.latitude, longitude=value.longitude, distance=value.distance
)

Expand All @@ -75,7 +75,7 @@ def __filter_to_text(value: FilterValues) -> Optional[str]:
return _datetime_to_string(value)

@staticmethod
def __filter_to_text_list(value: FilterValues) -> Optional[search_get_pb2.TextArray]:
def __filter_to_text_list(value: FilterValues) -> Optional[base_pb2.TextArray]:
if not isinstance(value, list) or not (
isinstance(value[0], TIME)
or isinstance(value[0], str)
Expand All @@ -91,33 +91,33 @@ def __filter_to_text_list(value: FilterValues) -> Optional[search_get_pb2.TextAr
dates = cast(List[TIME], value)
value_list = [_datetime_to_string(date) for date in dates]

return search_get_pb2.TextArray(values=cast(List[str], value_list))
return base_pb2.TextArray(values=cast(List[str], value_list))

@staticmethod
def __filter_to_bool_list(value: FilterValues) -> Optional[search_get_pb2.BooleanArray]:
def __filter_to_bool_list(value: FilterValues) -> Optional[base_pb2.BooleanArray]:
if not isinstance(value, list) or not isinstance(value[0], bool):
return None

return search_get_pb2.BooleanArray(values=cast(List[bool], value))
return base_pb2.BooleanArray(values=cast(List[bool], value))

@staticmethod
def __filter_to_float_list(value: FilterValues) -> Optional[search_get_pb2.NumberArray]:
def __filter_to_float_list(value: FilterValues) -> Optional[base_pb2.NumberArray]:
if not isinstance(value, list) or not isinstance(value[0], float):
return None

return search_get_pb2.NumberArray(values=cast(List[float], value))
return base_pb2.NumberArray(values=cast(List[float], value))

@staticmethod
def __filter_to_int_list(value: FilterValues) -> Optional[search_get_pb2.IntArray]:
def __filter_to_int_list(value: FilterValues) -> Optional[base_pb2.IntArray]:
if not isinstance(value, list) or not isinstance(value[0], int):
return None

return search_get_pb2.IntArray(values=cast(List[int], value))
return base_pb2.IntArray(values=cast(List[int], value))

@staticmethod
def __and_or_filter(weav_filter: _Filters) -> Optional[search_get_pb2.Filters]:
def __and_or_filter(weav_filter: _Filters) -> Optional[base_pb2.Filters]:
assert isinstance(weav_filter, _FilterAnd) or isinstance(weav_filter, _FilterOr)
return search_get_pb2.Filters(
return base_pb2.Filters(
operator=weav_filter.operator._to_grpc(),
filters=[
filter_
Expand Down
17 changes: 9 additions & 8 deletions weaviate/collections/queries/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any, Dict, Generic, List, Optional, Type, Union, cast

from google.protobuf import struct_pb2
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
from requests.exceptions import ConnectionError as RequestsConnectionError
from typing_extensions import is_typeddict

Expand Down Expand Up @@ -213,10 +212,12 @@ def __parse_nonref_properties_result(
}

def __parse_ref_properties_result(
self, properties: RepeatedCompositeFieldContainer[search_get_pb2.RefPropertiesResult]
self,
properties: search_get_pb2.PropertiesResult,
) -> Optional[dict]:
if len(properties) == 0:
return None
if len(properties.ref_props) == 0:
return {} if properties.ref_props_requested else None

return {
ref_prop.prop_name: _CrossReference._from(
[
Expand All @@ -226,7 +227,7 @@ def __parse_ref_properties_result(
for prop in ref_prop.properties
]
)
for ref_prop in properties
for ref_prop in properties.ref_props
}

def __deserialize_primitive_122(self, value: Any) -> Any:
Expand Down Expand Up @@ -325,7 +326,7 @@ def __result_to_query_object(
if options.include_metadata
else _MetadataReturn(),
references=(
self.__parse_ref_properties_result(props.ref_props)
self.__parse_ref_properties_result(props)
if self._is_weaviate_version_123
else self.__parse_ref_properties_result_122(props)
)
Expand Down Expand Up @@ -354,7 +355,7 @@ def __result_to_generative_object(
if options.include_metadata
else _MetadataReturn(),
references=(
self.__parse_ref_properties_result(props.ref_props)
self.__parse_ref_properties_result(props)
if self._is_weaviate_version_123
else self.__parse_ref_properties_result_122(props)
)
Expand Down Expand Up @@ -419,7 +420,7 @@ def __result_to_group_by_object(
if options.include_metadata
else _GroupByMetadataReturn(),
references=(
self.__parse_ref_properties_result(props.ref_props)
self.__parse_ref_properties_result(props)
if self._is_weaviate_version_123
else self.__parse_ref_properties_result_122(props)
)
Expand Down
Loading