diff --git a/qdrant_client/async_qdrant_client.py b/qdrant_client/async_qdrant_client.py index 98fd9404..db81ca42 100644 --- a/qdrant_client/async_qdrant_client.py +++ b/qdrant_client/async_qdrant_client.py @@ -392,7 +392,7 @@ async def query_batch_points( requests = self._resolve_query_batch_request(requests) requires_inference = self._inference_inspector.inspect(requests) if requires_inference and (not self.cloud_inference): - requests = [self._embed_models(request) for request in requests] + requests = self._embed_models(requests) return await self._client.query_batch_points( collection_name=collection_name, requests=requests, @@ -1501,10 +1501,7 @@ async def upsert( ) requires_inference = self._inference_inspector.inspect(points) if requires_inference and (not self.cloud_inference): - if isinstance(points, list): - points = [self._embed_models(point, is_query=False) for point in points] - else: - points = self._embed_models(points, is_query=False) + points = self._embed_models(points, is_query=False) return await self._client.upsert( collection_name=collection_name, points=points, @@ -1555,7 +1552,7 @@ async def update_vectors( assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}" requires_inference = self._inference_inspector.inspect(points) if requires_inference and (not self.cloud_inference): - points = [self._embed_models(point, is_query=False) for point in points] + points = self._embed_models(points, is_query=False) return await self._client.update_vectors( collection_name=collection_name, points=points, @@ -1995,9 +1992,7 @@ async def batch_update_points( assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}" requires_inference = self._inference_inspector.inspect(update_operations) if requires_inference and (not self.cloud_inference): - update_operations = [ - self._embed_models(op, is_query=False) for op in update_operations - ] + update_operations = self._embed_models(update_operations, is_query=False) return await self._client.batch_update_points( collection_name=collection_name, update_operations=update_operations, diff --git a/qdrant_client/async_qdrant_fastembed.py b/qdrant_client/async_qdrant_fastembed.py index 31ee96bd..15fb6036 100644 --- a/qdrant_client/async_qdrant_fastembed.py +++ b/qdrant_client/async_qdrant_fastembed.py @@ -14,7 +14,6 @@ from itertools import tee from typing import Any, Iterable, Optional, Sequence, Union, get_args from copy import deepcopy -from pathlib import Path import numpy as np from pydantic import BaseModel from qdrant_client.async_client_base import AsyncQdrantBase @@ -22,7 +21,7 @@ from qdrant_client.conversions.conversion import GrpcToRest from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES from qdrant_client.embed.embed_inspector import InspectorEmbed -from qdrant_client.embed.models import NumericVector, NumericVectorStruct +from qdrant_client.embed.models import NumericVector from qdrant_client.embed.schema_parser import ModelSchemaParser from qdrant_client.embed.utils import FieldPath from qdrant_client.fastembed_common import QueryResponse @@ -92,6 +91,8 @@ def __init__(self, parser: ModelSchemaParser, **kwargs: Any): self._embedding_model_name: Optional[str] = None self._sparse_embedding_model_name: Optional[str] = None self._embed_inspector = InspectorEmbed(parser=parser) + self._batch_accumulator: dict[str, list[Any]] = {} + self._embed_storage: dict[str, NumericVector] = {} try: from fastembed import SparseTextEmbedding, TextEmbedding @@ -841,7 +842,39 @@ def _resolve_query_batch_request( return [self._resolve_query_request(query) for query in requests] def _embed_models( - self, model: BaseModel, paths: Optional[list[FieldPath]] = None, is_query: bool = False + self, raw_models: Union[BaseModel, Sequence[BaseModel]], is_query: bool = False + ) -> Union[BaseModel, NumericVector, list[BaseModel]]: + """Embed raw data fields in models and return models with vectors + + If any of model fields required inference, a deepcopy of a model with computed embeddings is returned, + otherwise returns original models. + Args: + raw_models: Union[BaseModel, Sequence[BaseModel]] - models which can contain fields with raw data + is_query: bool - flag to determine which embed method to use. Defaults to False. + + Returns: + Union[BaseModel, NumericVector, list[BaseModel]]: models with embedded fields + """ + if isinstance(raw_models, list): + for raw_model in raw_models: + self._embed_model(raw_model, is_query=is_query, accumulating=True) + else: + self._embed_model(raw_models, is_query=is_query, accumulating=True) + if not self._batch_accumulator: + return raw_models + if isinstance(raw_models, list): + return [ + self._embed_model(raw_model, is_query=is_query, accumulating=False) + for raw_model in raw_models + ] + return self._embed_model(raw_models, is_query=is_query, accumulating=False) + + def _embed_model( + self, + model: BaseModel, + paths: Optional[list[FieldPath]] = None, + is_query: bool = False, + accumulating: bool = False, ) -> Union[BaseModel, NumericVector]: """Embed model's fields requiring inference @@ -849,14 +882,18 @@ def _embed_models( model: Qdrant http model containing fields to embed paths: Path to fields to embed. E.g. [FieldPath(current="recommend", tail=[FieldPath(current="negative", tail=None)])] is_query: Flag to determine which embed method to use. Defaults to False. + accumulating: Flag to determine if we are accumulating models for batch embedding. Defaults to False. Returns: A deepcopy of the method with embedded fields """ + if isinstance(model, INFERENCE_OBJECT_TYPES): + if not accumulating: + return self._drain_accumulator(model) + else: + self._accumulate(model) if paths is None: - if isinstance(model, INFERENCE_OBJECT_TYPES): - return self._embed_raw_data(model, is_query=is_query) - model = deepcopy(model) + model = deepcopy(model) if not accumulating else model paths = self._embed_inspector.inspect(model) for path in paths: list_model = [model] if not isinstance(model, list) else model @@ -865,142 +902,195 @@ def _embed_models( if current_model is None: continue if path.tail: - self._embed_models(current_model, path.tail, is_query=is_query) + self._embed_model( + current_model, path.tail, is_query=is_query, accumulating=accumulating + ) else: was_list = isinstance(current_model, list) - current_model = ( - [current_model] if not isinstance(current_model, list) else current_model - ) - embeddings = [ - self._embed_raw_data(data, is_query=is_query) for data in current_model - ] - if was_list: - setattr(item, path.current, embeddings) + current_model = current_model if was_list else [current_model] + if not accumulating: + embeddings = [self._drain_accumulator(data) for data in current_model] + if was_list: + setattr(item, path.current, embeddings) + else: + setattr(item, path.current, embeddings[0]) else: - setattr(item, path.current, embeddings[0]) + for data in current_model: + self._accumulate(data) return model - @staticmethod - def _resolve_inference_object(data: models.VectorStruct) -> models.VectorStruct: - """Resolve inference object into a model + def _accumulate(self, data: models.VectorStruct) -> None: + """Add data to batch accumulator Args: - data: models.VectorStruct - data to resolve, if it's an inference object, convert it to a proper type, - otherwise - keep unchanged + data: models.VectorStruct - any vector struct data, if inference object types instances in `data` - add them + to the accumulator, otherwise - do nothing. `InferenceObject` instances are converted to proper types. Returns: - models.VectorStruct: resolved data + None """ - if not isinstance(data, models.InferenceObject): - return data - model_name = data.model - value = data.object - options = data.options - if model_name in ( - *SUPPORTED_EMBEDDING_MODELS.keys(), - *SUPPORTED_SPARSE_EMBEDDING_MODELS.keys(), - *_LATE_INTERACTION_EMBEDDING_MODELS.keys(), - ): - return models.Document(model=model_name, text=value, options=options) - if model_name in _IMAGE_EMBEDDING_MODELS: - return models.Image(model=model_name, image=value, options=options) - raise ValueError(f"{model_name} is not among supported models") + if isinstance(data, dict): + for value in data.values(): + self._accumulate(value) + return None + if isinstance(data, list): + for value in data: + if not isinstance(value, INFERENCE_OBJECT_TYPES): + return None + self._accumulate(value) + if not isinstance(data, INFERENCE_OBJECT_TYPES): + return None + data = self._resolve_inference_object(data) + if data.model not in self._batch_accumulator: + self._batch_accumulator[data.model] = [] + self._batch_accumulator[data.model].append(data) - def _embed_raw_data( - self, data: models.VectorStruct, is_query: bool = False - ) -> NumericVectorStruct: - """Iterates over the data and calls inference on the fields requiring it + def _drain_accumulator(self, data: models.VectorStruct) -> models.VectorStruct: + """Drain accumulator and replaces inference objects with computed embeddings + It is assumed objects are traversed in the same order as they were added to the accumulator Args: - data: models.VectorStruct - data to embed, if it's not a field which requires inference, leave it as is - is_query: Flag to determine which embed method to use. Defaults to False. + data: models.VectorStruct - any vector struct data, if inference object types instances in `data` - replace + them with computed embeddings. If embeddings haven't yet been computed - compute them and then replace + inference objects. Returns: - NumericVectorStruct: Embedded data + models.VectorStruct: data with replaced inference objects """ - data = self._resolve_inference_object(data) - if isinstance(data, models.Document): - return self._embed_document(data, is_query=is_query) - elif isinstance(data, models.Image): - return self._embed_image(data) - elif isinstance(data, dict): - return { - key: self._embed_raw_data(value, is_query=is_query) - for (key, value) in data.items() - } - elif isinstance(data, list): - if data and isinstance(data[0], float): - return data - return [self._embed_raw_data(value, is_query=is_query) for value in data] - return data - - def _embed_document(self, document: models.Document, is_query: bool = False) -> NumericVector: - """Embed a document using the specified embedding model + if isinstance(data, dict): + for key, value in data.items(): + data[key] = self._drain_accumulator(value) + return data + if isinstance(data, list): + for i, value in enumerate(data): + if not isinstance(value, INFERENCE_OBJECT_TYPES): + return data + data[i] = self._drain_accumulator(value) + return data + if not isinstance(data, INFERENCE_OBJECT_TYPES): + return data + if not self._embed_storage or not self._embed_storage.get(data.model, None): + self._embed_accumulator() + return self._next_embed(data.model) + + def _embed_accumulator(self, is_query: bool = False) -> None: + """Embed all accumulated objects for all models Args: - document: Document to embed - is_query: Flag to determine which embed method to use. Defaults to False. + is_query: bool - flag to determine which embed method to use. Defaults to False. Returns: - NumericVector: Document's embedding - - Raises: - ValueError: If model is not supported + None """ - model_name = document.model - text = document.text - options = document.options or {} - if model_name in SUPPORTED_EMBEDDING_MODELS: - embedding_model_inst = self._get_or_init_model(model_name=model_name, **options) - if not is_query: - embedding = list(embedding_model_inst.embed(documents=[text]))[0].tolist() - else: - embedding = list(embedding_model_inst.query_embed(query=text))[0].tolist() - return embedding - elif model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS: - sparse_embedding_model_inst = self._get_or_init_sparse_model( - model_name=model_name, **options - ) - if not is_query: - sparse_embedding = list(sparse_embedding_model_inst.embed(documents=[text]))[0] - else: - sparse_embedding = list(sparse_embedding_model_inst.query_embed(query=text))[0] - return models.SparseVector( - indices=sparse_embedding.indices.tolist(), values=sparse_embedding.values.tolist() - ) - elif model_name in _LATE_INTERACTION_EMBEDDING_MODELS: - li_embedding_model_inst = self._get_or_init_late_interaction_model( - model_name=model_name, **options - ) - if not is_query: - embedding = list(li_embedding_model_inst.embed(documents=[text]))[0].tolist() + for model_name, objects in self._batch_accumulator.items(): + if model_name not in ( + *SUPPORTED_EMBEDDING_MODELS.keys(), + *SUPPORTED_SPARSE_EMBEDDING_MODELS.keys(), + *_LATE_INTERACTION_EMBEDDING_MODELS.keys(), + *_IMAGE_EMBEDDING_MODELS, + ): + raise ValueError(f"{model_name} is not among supported models") + options = next(iter(objects)).options + for obj in objects: + if options != obj.options: + raise ValueError( + f"Options for {model_name} model should be the same for all objects in one request" + ) + for model_name, objects in self._batch_accumulator.items(): + options = next(iter(objects)).options or {} + if model_name in SUPPORTED_EMBEDDING_MODELS.keys(): + texts = [obj.text for obj in objects] + embedding_model_inst = self._get_or_init_model(model_name=model_name, **options) + if not is_query: + embeddings = [ + embedding.tolist() + for embedding in embedding_model_inst.embed(documents=texts) + ] + else: + embeddings = [ + embedding.tolist() + for embedding in embedding_model_inst.query_embed(query=texts) + ] + elif model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS.keys(): + texts = [obj.text for obj in objects] + embedding_model_inst = self._get_or_init_sparse_model( + model_name=model_name, **options + ) + if not is_query: + embeddings = [ + models.SparseVector( + indices=sparse_embedding.indices.tolist(), + values=sparse_embedding.values.tolist(), + ) + for sparse_embedding in embedding_model_inst.embed(documents=texts) + ] + else: + embeddings = [ + models.SparseVector( + indices=sparse_embedding.indices.tolist(), + values=sparse_embedding.values.tolist(), + ) + for sparse_embedding in embedding_model_inst.query_embed(query=texts) + ] + elif model_name in _LATE_INTERACTION_EMBEDDING_MODELS.keys(): + texts = [obj.text for obj in objects] + embedding_model_inst = self._get_or_init_late_interaction_model( + model_name=model_name, **options + ) + if not is_query: + embeddings = [ + embedding.tolist() + for embedding in embedding_model_inst.embed(documents=texts) + ] + else: + embeddings = [ + embedding.tolist() + for embedding in embedding_model_inst.query_embed(query=texts) + ] else: - embedding = list(li_embedding_model_inst.query_embed(query=text))[0].tolist() - return embedding - else: - raise ValueError(f"{model_name} is not among supported models") + images = [obj.image for obj in objects] + embedding_model_inst = self._get_or_init_image_model( + model_name=model_name, **options + ) + embeddings = [ + embedding.tolist() for embedding in embedding_model_inst.embed(images=images) + ] + self._embed_storage[model_name] = embeddings + self._batch_accumulator.clear() - def _embed_image(self, image: models.Image) -> NumericVector: - """Embed an image using the specified embedding model + def _next_embed(self, model_name: str) -> NumericVector: + """Get next computed embedding from embedded batch Args: - image: Image to embed + model_name: str - retrieve embedding from the storage by this model name Returns: - NumericVector: Image's embedding + NumericVector: computed embedding + """ + return self._embed_storage[model_name].pop(0) - Raises: - ValueError: If model is not supported + @staticmethod + def _resolve_inference_object(data: models.VectorStruct) -> models.VectorStruct: + """Resolve inference object into a model + + Args: + data: models.VectorStruct - data to resolve, if it's an inference object, convert it to a proper type, + otherwise - keep unchanged + + Returns: + models.VectorStruct: resolved data """ - model_name = image.model + if not isinstance(data, models.InferenceObject): + return data + model_name = data.model + value = data.object + options = data.options + if model_name in ( + *SUPPORTED_EMBEDDING_MODELS.keys(), + *SUPPORTED_SPARSE_EMBEDDING_MODELS.keys(), + *_LATE_INTERACTION_EMBEDDING_MODELS.keys(), + ): + return models.Document(model=model_name, text=value, options=options) if model_name in _IMAGE_EMBEDDING_MODELS: - embedding_model_inst = self._get_or_init_image_model( - model_name=model_name, **image.options or {} - ) - if not isinstance(image.image, (str, Path, PilImage.Image)): - raise ValueError( - f"Unsupported image type: {type(image.image)}. Image: {image.image}" - ) - embedding = list(embedding_model_inst.embed(images=[image.image]))[0].tolist() - return embedding + return models.Image(model=model_name, image=value, options=options) raise ValueError(f"{model_name} is not among supported models") diff --git a/qdrant_client/qdrant_client.py b/qdrant_client/qdrant_client.py index 4d3dab4a..798f1be0 100644 --- a/qdrant_client/qdrant_client.py +++ b/qdrant_client/qdrant_client.py @@ -449,7 +449,7 @@ def query_batch_points( requests = self._resolve_query_batch_request(requests) requires_inference = self._inference_inspector.inspect(requests) if requires_inference and not self.cloud_inference: - requests = [self._embed_models(request) for request in requests] + requests = self._embed_models(requests) return self._client.query_batch_points( collection_name=collection_name, @@ -1588,10 +1588,7 @@ def upsert( requires_inference = self._inference_inspector.inspect(points) if requires_inference and not self.cloud_inference: - if isinstance(points, list): - points = [self._embed_models(point, is_query=False) for point in points] - else: - points = self._embed_models(points, is_query=False) + points = self._embed_models(points, is_query=False) return self._client.upsert( collection_name=collection_name, @@ -1644,7 +1641,7 @@ def update_vectors( requires_inference = self._inference_inspector.inspect(points) if requires_inference and not self.cloud_inference: - points = [self._embed_models(point, is_query=False) for point in points] + points = self._embed_models(points, is_query=False) return self._client.update_vectors( collection_name=collection_name, @@ -2094,9 +2091,7 @@ def batch_update_points( assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}" requires_inference = self._inference_inspector.inspect(update_operations) if requires_inference and not self.cloud_inference: - update_operations = [ - self._embed_models(op, is_query=False) for op in update_operations - ] + update_operations = self._embed_models(update_operations, is_query=False) return self._client.batch_update_points( collection_name=collection_name, diff --git a/qdrant_client/qdrant_fastembed.py b/qdrant_client/qdrant_fastembed.py index c29904c8..8fefbc57 100644 --- a/qdrant_client/qdrant_fastembed.py +++ b/qdrant_client/qdrant_fastembed.py @@ -3,7 +3,6 @@ from itertools import tee from typing import Any, Iterable, Optional, Sequence, Union, get_args from copy import deepcopy -from pathlib import Path import numpy as np @@ -14,7 +13,7 @@ from qdrant_client.conversions.conversion import GrpcToRest from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES from qdrant_client.embed.embed_inspector import InspectorEmbed -from qdrant_client.embed.models import NumericVector, NumericVectorStruct +from qdrant_client.embed.models import NumericVector from qdrant_client.embed.schema_parser import ModelSchemaParser from qdrant_client.embed.utils import FieldPath from qdrant_client.fastembed_common import QueryResponse @@ -91,6 +90,8 @@ def __init__(self, parser: ModelSchemaParser, **kwargs: Any): self._embedding_model_name: Optional[str] = None self._sparse_embedding_model_name: Optional[str] = None self._embed_inspector = InspectorEmbed(parser=parser) + self._batch_accumulator: dict[str, list[Any]] = {} # lists of inference object types + self._embed_storage: dict[str, NumericVector] = {} try: from fastembed import SparseTextEmbedding, TextEmbedding @@ -928,10 +929,42 @@ def _resolve_query_batch_request( return [self._resolve_query_request(query) for query in requests] def _embed_models( + self, raw_models: Union[BaseModel, Sequence[BaseModel]], is_query: bool = False + ) -> Union[BaseModel, NumericVector, list[BaseModel]]: + """Embed raw data fields in models and return models with vectors + + If any of model fields required inference, a deepcopy of a model with computed embeddings is returned, + otherwise returns original models. + Args: + raw_models: Union[BaseModel, Sequence[BaseModel]] - models which can contain fields with raw data + is_query: bool - flag to determine which embed method to use. Defaults to False. + + Returns: + Union[BaseModel, NumericVector, list[BaseModel]]: models with embedded fields + """ + if isinstance(raw_models, list): + for raw_model in raw_models: + self._embed_model(raw_model, is_query=is_query, accumulating=True) + else: + self._embed_model(raw_models, is_query=is_query, accumulating=True) + + if not self._batch_accumulator: + return raw_models + + if isinstance(raw_models, list): + return [ + self._embed_model(raw_model, is_query=is_query, accumulating=False) + for raw_model in raw_models + ] + + return self._embed_model(raw_models, is_query=is_query, accumulating=False) + + def _embed_model( self, model: BaseModel, paths: Optional[list[FieldPath]] = None, is_query: bool = False, + accumulating: bool = False, ) -> Union[BaseModel, NumericVector]: """Embed model's fields requiring inference @@ -939,15 +972,21 @@ def _embed_models( model: Qdrant http model containing fields to embed paths: Path to fields to embed. E.g. [FieldPath(current="recommend", tail=[FieldPath(current="negative", tail=None)])] is_query: Flag to determine which embed method to use. Defaults to False. + accumulating: Flag to determine if we are accumulating models for batch embedding. Defaults to False. Returns: A deepcopy of the method with embedded fields """ + if isinstance(model, INFERENCE_OBJECT_TYPES): + if not accumulating: + return self._drain_accumulator(model) + else: + self._accumulate(model) + if paths is None: - if isinstance(model, INFERENCE_OBJECT_TYPES): - return self._embed_raw_data(model, is_query=is_query) - model = deepcopy(model) + model = deepcopy(model) if not accumulating else model paths = self._embed_inspector.inspect(model) + for path in paths: list_model = [model] if not isinstance(model, list) else model for item in list_model: @@ -955,152 +994,213 @@ def _embed_models( if current_model is None: continue if path.tail: - self._embed_models(current_model, path.tail, is_query=is_query) + self._embed_model( + current_model, path.tail, is_query=is_query, accumulating=accumulating + ) else: was_list = isinstance(current_model, list) - current_model = ( - [current_model] if not isinstance(current_model, list) else current_model - ) - embeddings = [ - self._embed_raw_data(data, is_query=is_query) for data in current_model - ] - if was_list: - setattr(item, path.current, embeddings) + current_model = current_model if was_list else [current_model] + + if not accumulating: + embeddings = [self._drain_accumulator(data) for data in current_model] + if was_list: + setattr(item, path.current, embeddings) + else: + setattr(item, path.current, embeddings[0]) else: - setattr(item, path.current, embeddings[0]) + for data in current_model: + self._accumulate(data) return model - @staticmethod - def _resolve_inference_object(data: models.VectorStruct) -> models.VectorStruct: - """Resolve inference object into a model + def _accumulate(self, data: models.VectorStruct) -> None: + """Add data to batch accumulator Args: - data: models.VectorStruct - data to resolve, if it's an inference object, convert it to a proper type, - otherwise - keep unchanged + data: models.VectorStruct - any vector struct data, if inference object types instances in `data` - add them + to the accumulator, otherwise - do nothing. `InferenceObject` instances are converted to proper types. Returns: - models.VectorStruct: resolved data + None """ + if isinstance(data, dict): + for value in data.values(): + self._accumulate(value) + return None - if not isinstance(data, models.InferenceObject): - return data + if isinstance(data, list): + for value in data: + if not isinstance(value, INFERENCE_OBJECT_TYPES): # if value is a vector + return None + self._accumulate(value) - model_name = data.model - value = data.object - options = data.options - if model_name in ( - *SUPPORTED_EMBEDDING_MODELS.keys(), - *SUPPORTED_SPARSE_EMBEDDING_MODELS.keys(), - *_LATE_INTERACTION_EMBEDDING_MODELS.keys(), - ): - return models.Document(model=model_name, text=value, options=options) - if model_name in _IMAGE_EMBEDDING_MODELS: - return models.Image(model=model_name, image=value, options=options) + if not isinstance(data, INFERENCE_OBJECT_TYPES): + return None - raise ValueError(f"{model_name} is not among supported models") + data = self._resolve_inference_object(data) + if data.model not in self._batch_accumulator: + self._batch_accumulator[data.model] = [] + self._batch_accumulator[data.model].append(data) - def _embed_raw_data( - self, - data: models.VectorStruct, - is_query: bool = False, - ) -> NumericVectorStruct: - """Iterates over the data and calls inference on the fields requiring it + def _drain_accumulator(self, data: models.VectorStruct) -> models.VectorStruct: + """Drain accumulator and replaces inference objects with computed embeddings + It is assumed objects are traversed in the same order as they were added to the accumulator Args: - data: models.VectorStruct - data to embed, if it's not a field which requires inference, leave it as is - is_query: Flag to determine which embed method to use. Defaults to False. + data: models.VectorStruct - any vector struct data, if inference object types instances in `data` - replace + them with computed embeddings. If embeddings haven't yet been computed - compute them and then replace + inference objects. Returns: - NumericVectorStruct: Embedded data + models.VectorStruct: data with replaced inference objects """ - data = self._resolve_inference_object(data) + if isinstance(data, dict): + for key, value in data.items(): + data[key] = self._drain_accumulator(value) + return data + + if isinstance(data, list): + for i, value in enumerate(data): + if not isinstance(value, INFERENCE_OBJECT_TYPES): # if value is vector + return data + + data[i] = self._drain_accumulator(value) + return data + + if not isinstance(data, INFERENCE_OBJECT_TYPES): + return data + + if not self._embed_storage or not self._embed_storage.get(data.model, None): + self._embed_accumulator() - if isinstance(data, models.Document): - return self._embed_document(data, is_query=is_query) - elif isinstance(data, models.Image): - return self._embed_image(data) - elif isinstance(data, dict): - return { - key: self._embed_raw_data(value, is_query=is_query) for key, value in data.items() - } - elif isinstance(data, list): - # we don't want to iterate over a vector - if data and isinstance(data[0], float): - return data - return [self._embed_raw_data(value, is_query=is_query) for value in data] - return data - - def _embed_document(self, document: models.Document, is_query: bool = False) -> NumericVector: - """Embed a document using the specified embedding model + return self._next_embed(data.model) + + def _embed_accumulator(self, is_query: bool = False) -> None: + """Embed all accumulated objects for all models Args: - document: Document to embed - is_query: Flag to determine which embed method to use. Defaults to False. + is_query: bool - flag to determine which embed method to use. Defaults to False. Returns: - NumericVector: Document's embedding - - Raises: - ValueError: If model is not supported + None """ - model_name = document.model - text = document.text - options = document.options or {} - if model_name in SUPPORTED_EMBEDDING_MODELS: - embedding_model_inst = self._get_or_init_model(model_name=model_name, **options) - if not is_query: - embedding = list(embedding_model_inst.embed(documents=[text]))[0].tolist() - else: - embedding = list(embedding_model_inst.query_embed(query=text))[0].tolist() - return embedding - elif model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS: - sparse_embedding_model_inst = self._get_or_init_sparse_model( - model_name=model_name, **options - ) - if not is_query: - sparse_embedding = list(sparse_embedding_model_inst.embed(documents=[text]))[0] - else: - sparse_embedding = list(sparse_embedding_model_inst.query_embed(query=text))[0] + for model_name, objects in self._batch_accumulator.items(): + if model_name not in ( + *SUPPORTED_EMBEDDING_MODELS.keys(), + *SUPPORTED_SPARSE_EMBEDDING_MODELS.keys(), + *_LATE_INTERACTION_EMBEDDING_MODELS.keys(), + *_IMAGE_EMBEDDING_MODELS, + ): + raise ValueError(f"{model_name} is not among supported models") + + options = next(iter(objects)).options + for obj in objects: + if options != obj.options: + raise ValueError( + f"Options for {model_name} model should be the same for all objects in one request" + ) - return models.SparseVector( - indices=sparse_embedding.indices.tolist(), values=sparse_embedding.values.tolist() - ) - elif model_name in _LATE_INTERACTION_EMBEDDING_MODELS: - li_embedding_model_inst = self._get_or_init_late_interaction_model( - model_name=model_name, **options - ) - if not is_query: - embedding = list(li_embedding_model_inst.embed(documents=[text]))[0].tolist() + for model_name, objects in self._batch_accumulator.items(): + options = next(iter(objects)).options or {} + + if model_name in SUPPORTED_EMBEDDING_MODELS.keys(): + texts = [obj.text for obj in objects] + embedding_model_inst = self._get_or_init_model(model_name=model_name, **options) + if not is_query: + embeddings = [ + embedding.tolist() + for embedding in embedding_model_inst.embed(documents=texts) + ] + else: + embeddings = [ + embedding.tolist() + for embedding in embedding_model_inst.query_embed(query=texts) + ] + + elif model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS.keys(): + texts = [obj.text for obj in objects] + embedding_model_inst = self._get_or_init_sparse_model( + model_name=model_name, **options + ) + if not is_query: + embeddings = [ + models.SparseVector( + indices=sparse_embedding.indices.tolist(), + values=sparse_embedding.values.tolist(), + ) + for sparse_embedding in embedding_model_inst.embed(documents=texts) + ] + else: + embeddings = [ + models.SparseVector( + indices=sparse_embedding.indices.tolist(), + values=sparse_embedding.values.tolist(), + ) + for sparse_embedding in embedding_model_inst.query_embed(query=texts) + ] + + elif model_name in _LATE_INTERACTION_EMBEDDING_MODELS.keys(): + texts = [obj.text for obj in objects] + embedding_model_inst = self._get_or_init_late_interaction_model( + model_name=model_name, **options + ) + if not is_query: + embeddings = [ + embedding.tolist() + for embedding in embedding_model_inst.embed(documents=texts) + ] + else: + embeddings = [ + embedding.tolist() + for embedding in embedding_model_inst.query_embed(query=texts) + ] else: - embedding = list(li_embedding_model_inst.query_embed(query=text))[0].tolist() - return embedding - else: - raise ValueError(f"{model_name} is not among supported models") + images = [obj.image for obj in objects] + embedding_model_inst = self._get_or_init_image_model( + model_name=model_name, **options + ) + embeddings = [ + embedding.tolist() for embedding in embedding_model_inst.embed(images=images) + ] + + self._embed_storage[model_name] = embeddings + self._batch_accumulator.clear() - def _embed_image(self, image: models.Image) -> NumericVector: - """Embed an image using the specified embedding model + def _next_embed(self, model_name: str) -> NumericVector: + """Get next computed embedding from embedded batch Args: - image: Image to embed + model_name: str - retrieve embedding from the storage by this model name Returns: - NumericVector: Image's embedding + NumericVector: computed embedding + """ + return self._embed_storage[model_name].pop(0) - Raises: - ValueError: If model is not supported + @staticmethod + def _resolve_inference_object(data: models.VectorStruct) -> models.VectorStruct: + """Resolve inference object into a model + + Args: + data: models.VectorStruct - data to resolve, if it's an inference object, convert it to a proper type, + otherwise - keep unchanged + + Returns: + models.VectorStruct: resolved data """ - model_name = image.model + + if not isinstance(data, models.InferenceObject): + return data + + model_name = data.model + value = data.object + options = data.options + if model_name in ( + *SUPPORTED_EMBEDDING_MODELS.keys(), + *SUPPORTED_SPARSE_EMBEDDING_MODELS.keys(), + *_LATE_INTERACTION_EMBEDDING_MODELS.keys(), + ): + return models.Document(model=model_name, text=value, options=options) if model_name in _IMAGE_EMBEDDING_MODELS: - embedding_model_inst = self._get_or_init_image_model( - model_name=model_name, **(image.options or {}) - ) - if not isinstance(image.image, (str, Path, PilImage.Image)): # type: ignore - # PilImage is None if PIL is not installed, - # but we'll fail earlier if it's not installed. - raise ValueError( - f"Unsupported image type: {type(image.image)}. Image: {image.image}" - ) - embedding = list(embedding_model_inst.embed(images=[image.image]))[0].tolist() - return embedding + return models.Image(model=model_name, image=value, options=options) raise ValueError(f"{model_name} is not among supported models")