Skip to content

Commit

Permalink
- qdrant remote
Browse files Browse the repository at this point in the history
- qdrant base
- qdrant client
- conversions
  • Loading branch information
coszio committed Sep 16, 2024
1 parent 3a736a1 commit da8ce9f
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 9 deletions.
11 changes: 11 additions & 0 deletions qdrant_client/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,17 @@ def count(
) -> types.CountResult:
raise NotImplementedError()

def facet(
self,
collection_name: str,
key: str,
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
exact: bool = False,
**kwargs: Any,
) -> types.FacetResponse:
raise NotImplementedError()

def upsert(
self,
collection_name: str,
Expand Down
2 changes: 2 additions & 0 deletions qdrant_client/conversions/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def get_args_subscribed(tp): # type: ignore
GroupsResult: TypeAlias = rest.GroupsResult
QueryResponse: TypeAlias = rest.QueryResponse

FacetResponse: TypeAlias = rest.FacetResponse

VersionInfo: TypeAlias = rest.VersionInfo

# we can't use `nptyping` package due to numpy/python-version incompatibilities
Expand Down
62 changes: 60 additions & 2 deletions qdrant_client/conversions/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ class GrpcToRest:
@classmethod
def convert_condition(cls, model: grpc.Condition) -> rest.Condition:
name = model.WhichOneof("condition_one_of")
if name is None:
raise ValueError(f"invalid Condition model: {model}")
val = getattr(model, name)

if name == "field":
Expand Down Expand Up @@ -487,6 +489,8 @@ def convert_create_alias(cls, model: grpc.CreateAlias) -> rest.CreateAlias:
@classmethod
def convert_order_value(cls, model: grpc.OrderValue) -> rest.OrderValue:
name = model.WhichOneof("variant")
if name is None:
raise ValueError(f"invalid OrderValue model: {model}")
val = getattr(model, name)

if name == "int":
Expand Down Expand Up @@ -582,6 +586,8 @@ def convert_field_condition(cls, model: grpc.FieldCondition) -> rest.FieldCondit
@classmethod
def convert_match(cls, model: grpc.Match) -> rest.Match:
name = model.WhichOneof("match_value")
if name is None:
raise ValueError(f"invalid Match model: {model}")
val = getattr(model, name)

if name == "integer":
Expand Down Expand Up @@ -712,6 +718,8 @@ def convert_geo_point(cls, model: grpc.GeoPoint) -> rest.GeoPoint:
@classmethod
def convert_alias_operations(cls, model: grpc.AliasOperations) -> rest.AliasOperations:
name = model.WhichOneof("action")
if name is None:
raise ValueError(f"invalid AliasOperations model: {model}")
val = getattr(model, name)

if name == "rename_alias":
Expand All @@ -735,6 +743,8 @@ def convert_points_selector(
cls, model: grpc.PointsSelector, shard_key_selector: Optional[grpc.ShardKeySelector] = None
) -> rest.PointsSelector:
name = model.WhichOneof("points_selector_one_of")
if name is None:
raise ValueError(f"invalid PointsSelector model: {model}")
val = getattr(model, name)

if name == "points":
Expand All @@ -754,6 +764,8 @@ def convert_with_payload_selector(
cls, model: grpc.WithPayloadSelector
) -> rest.WithPayloadInterface:
name = model.WhichOneof("selector_options")
if name is None:
raise ValueError(f"invalid WithPayloadSelector model: {model}")
val = getattr(model, name)

if name == "enable":
Expand Down Expand Up @@ -862,6 +874,8 @@ def convert_multivector_comparator(
@classmethod
def convert_vectors_config(cls, model: grpc.VectorsConfig) -> rest.VectorsConfig:
name = model.WhichOneof("config")
if name is None:
raise ValueError(f"invalid VectorsConfig model: {model}")
val = getattr(model, name)

if name == "params":
Expand Down Expand Up @@ -896,7 +910,10 @@ def convert_named_vectors(cls, model: grpc.NamedVectors) -> Dict[str, rest.Vecto
@classmethod
def convert_vectors(cls, model: grpc.Vectors) -> rest.VectorStruct:
name = model.WhichOneof("vectors_options")
if name is None:
raise ValueError(f"invalid Vectors model: {model}")
val = getattr(model, name)

if name == "vector":
return cls.convert_vector(val)
if name == "vectors":
Expand All @@ -918,6 +935,8 @@ def convert_multi_dense_vector(cls, model: grpc.MultiDenseVector) -> List[List[f
@classmethod
def convert_vector_input(cls, model: grpc.VectorInput) -> rest.VectorInput:
name = model.WhichOneof("variant")
if name is None:
raise ValueError(f"invalid VectorInput model: {model}")
val = getattr(model, name)

if name == "id":
Expand Down Expand Up @@ -978,6 +997,8 @@ def convert_sample(cls, model: grpc.Sample) -> rest.Sample:
@classmethod
def convert_query(cls, model: grpc.Query) -> rest.Query:
name = model.WhichOneof("variant")
if name is None:
raise ValueError(f"invalid Query model: {model}")
val = getattr(model, name)

if name == "nearest":
Expand Down Expand Up @@ -1027,7 +1048,10 @@ def convert_vectors_selector(cls, model: grpc.VectorsSelector) -> List[str]:
@classmethod
def convert_with_vectors_selector(cls, model: grpc.WithVectorsSelector) -> rest.WithVector:
name = model.WhichOneof("selector_options")
if name is None:
raise ValueError(f"invalid WithVectorsSelector model: {model}")
val = getattr(model, name)

if name == "enable":
return val
if name == "include":
Expand Down Expand Up @@ -1319,6 +1343,8 @@ def convert_write_ordering(cls, model: grpc.WriteOrdering) -> rest.WriteOrdering
@classmethod
def convert_read_consistency(cls, model: grpc.ReadConsistency) -> rest.ReadConsistency:
name = model.WhichOneof("value")
if name is None:
raise ValueError(f"invalid ReadConsistency model: {model}")
val = getattr(model, name)
if name == "factor":
return val
Expand Down Expand Up @@ -1384,6 +1410,8 @@ def convert_quantization_config(
cls, model: grpc.QuantizationConfig
) -> rest.QuantizationConfig:
name = model.WhichOneof("quantization")
if name is None:
raise ValueError(f"invalid QuantizationConfig model: {model}")
val = getattr(model, name)
if name == "scalar":
return rest.ScalarQuantization(scalar=cls.convert_scalar_quantization_config(val))
Expand Down Expand Up @@ -1427,6 +1455,8 @@ def convert_point_group(cls, model: grpc.PointGroup) -> rest.PointGroup:
@classmethod
def convert_group_id(cls, model: grpc.GroupId) -> rest.GroupId:
name = model.WhichOneof("kind")
if name is None:
raise ValueError(f"invalid GroupId model: {model}")
val = getattr(model, name)
return val

Expand All @@ -1451,6 +1481,8 @@ def convert_quantization_config_diff(
cls, model: grpc.QuantizationConfigDiff
) -> rest.QuantizationConfigDiff:
name = model.WhichOneof("quantization")
if name is None:
raise ValueError(f"invalid QuantizationConfigDiff model: {model}")
val = getattr(model, name)
if name == "scalar":
return rest.ScalarQuantization(scalar=cls.convert_scalar_quantization_config(val))
Expand Down Expand Up @@ -1481,6 +1513,8 @@ def convert_vector_params_diff(cls, model: grpc.VectorParamsDiff) -> rest.Vector
@classmethod
def convert_vectors_config_diff(cls, model: grpc.VectorsConfigDiff) -> rest.VectorsConfigDiff:
name = model.WhichOneof("config")
if name is None:
raise ValueError(f"invalid VectorsConfigDiff model: {model}")
val = getattr(model, name)

if name == "params":
Expand All @@ -1497,6 +1531,8 @@ def convert_points_update_operation(
cls, model: grpc.PointsUpdateOperation
) -> rest.UpdateOperation:
name = model.WhichOneof("operation")
if name is None:
raise ValueError(f"invalid PointsUpdateOperation model: {model}")
val = getattr(model, name)

if name == "upsert":
Expand Down Expand Up @@ -1698,6 +1734,8 @@ def convert_sparse_vector_config(
@classmethod
def convert_shard_key(cls, model: grpc.ShardKey) -> rest.ShardKey:
name = model.WhichOneof("key")
if name is None:
raise ValueError(f"invalid ShardKey model: {model}")
val = getattr(model, name)
return val

Expand Down Expand Up @@ -1747,6 +1785,28 @@ def convert_order_by(cls, model: grpc.OrderBy) -> rest.OrderBy:
),
)

@classmethod
def convert_facet_value(cls, model: grpc.FacetValue) -> rest.FacetValue:
name = model.WhichOneof("variant")
if name is None:
raise ValueError(f"invalid FacetValue model: {model}")

val = getattr(model, name)
return val

@classmethod
def convert_facet_value_hit(cls, model: grpc.FacetHit) -> rest.FacetValueHit:
return rest.FacetValueHit(
value=cls.convert_facet_value(model.value),
count=model.count,
)

@classmethod
def convert_facet_response(cls, model: grpc.FacetResponse) -> rest.FacetResponse:
return rest.FacetResponse(
hits=[cls.convert_facet_value_hit(hit) for hit in model.hits],
)

@classmethod
def convert_health_check_reply(cls, model: grpc.HealthCheckReply) -> rest.VersionInfo:
return rest.VersionInfo(
Expand Down Expand Up @@ -3272,8 +3332,6 @@ def convert_update_operation(cls, model: rest.UpdateOperation) -> grpc.PointsUpd
def convert_points_update_operation(
cls, model: rest.UpdateOperation
) -> grpc.PointsUpdateOperation:
points_selector: rest.PointsSelector

if isinstance(model, rest.UpsertOperation):
shard_key_selector = (
cls.convert_shard_key_selector(model.upsert.shard_key)
Expand Down
25 changes: 19 additions & 6 deletions qdrant_client/local/local_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,12 @@ def search_groups(

return models.GroupsResult(groups=groups_result)

def facet(self, key: str, facet_filter: Optional[types.Filter] = None) -> models.FacetResult:
def facet(
self,
key: str,
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
) -> models.FacetResponse:
facet_hits = defaultdict(int)

mask = self._payload_and_non_deleted_mask(facet_filter)
Expand All @@ -1098,16 +1103,24 @@ def facet(self, key: str, facet_filter: Optional[types.Filter] = None) -> models
continue

value = value_by_key(payload, key)

if value is None:
continue

if isinstance(value, list):
for v in value:
for v in value:
if isinstance(v, get_args(models.FacetValue)):
facet_hits[v] += 1
else:
facet_hits[value] += 1

return models.FacetResult(hits=facet_hits)
hits = [
models.FacetValueHit(value=value, count=count)
for value, count in sorted(
facet_hits.items(),
# order by count descending, then by value ascending
key=lambda x: (-x[1], x[0]),
)[:limit]
]

return models.FacetResponse(hits=hits)

def retrieve(
self,
Expand Down
12 changes: 12 additions & 0 deletions qdrant_client/local/qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,18 @@ def count(
collection = self._get_collection(collection_name)
return collection.count(count_filter=count_filter)

def facet(
self,
collection_name: str,
key: str,
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
exact: bool = False,
**kwargs: Any,
):
collection = self._get_collection(collection_name)
return collection.facet(key=key, facet_filter=facet_filter, limit=limit)

def upsert(
self, collection_name: str, points: types.Points, **kwargs: Any
) -> types.UpdateResult:
Expand Down
51 changes: 51 additions & 0 deletions qdrant_client/qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,57 @@ def count(
**kwargs,
)

def facet(
self,
collection_name: str,
key: str,
facet_filter: Optional[types.Filter] = None,
limit: int = 10,
exact: bool = False,
read_consistency: Optional[types.ReadConsistency] = None,
timeout: Optional[int] = None,
shard_key_selector: Optional[types.ShardKeySelector] = None,
**kwargs: Any,
) -> types.FacetResponse:
"""Facet counts for the collection. For a specific payload key, returns unique values along with their counts.
Higher counts come first in the results.
Args:
collection_name: Name of the collection
key: Payload field to facet
facet_filter: Filter to apply
limit: Maximum number of hits to return
exact: If `True` - provide the exact count of points matching the filter. If `False` - provide the approximate count of points matching the filter. Works faster.
read_consistency:
Read consistency of the search. Defines how many replicas should be queried before returning the result. Values:
- int - number of replicas to query, values should present in all queried replicas
- 'majority' - query all replicas, but return values present in the majority of replicas
- 'quorum' - query the majority of replicas, return values present in all of them
- 'all' - query all replicas, and return values present in all replicas
timeout: Overrides global timeout for this search. Unit is seconds.
shard_key_selector:
This parameter allows to specify which shards should be queried.
If `None` - query all shards. Only works for collections with `custom` sharding method.
Returns:
Unique values in the facet and the amount of points that they cover.
"""
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"

return self._client.facet(
collection_name=collection_name,
key=key,
facet_filter=facet_filter,
limit=limit,
exact=exact,
read_consistency=read_consistency,
timeout=timeout,
shard_key_selector=shard_key_selector,
**kwargs,
)

def upsert(
self,
collection_name: str,
Expand Down
Loading

0 comments on commit da8ce9f

Please sign in to comment.