From 278bde0f116967f379b2f1e8f0bdbaec207133b1 Mon Sep 17 00:00:00 2001 From: Arun Brahma Date: Wed, 11 Dec 2024 04:28:35 +0530 Subject: [PATCH] fix: Add input validation and retry mechanism to HuggingFaceEmbedding (#17207) --- .../embeddings/huggingface/base.py | 122 +++++++++++++----- .../tests/test_embeddings_huggingface.py | 27 ++++ 2 files changed, 116 insertions(+), 33 deletions(-) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface/llama_index/embeddings/huggingface/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface/llama_index/embeddings/huggingface/base.py index 40eb31c41cf5c..1deaaf35fdb7d 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface/llama_index/embeddings/huggingface/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface/llama_index/embeddings/huggingface/base.py @@ -26,13 +26,15 @@ get_text_instruct_for_model_name, ) from sentence_transformers import SentenceTransformer +from tenacity import retry, stop_after_attempt, wait_exponential DEFAULT_HUGGINGFACE_LENGTH = 512 logger = logging.getLogger(__name__) class HuggingFaceEmbedding(BaseEmbedding): - """HuggingFace class for text embeddings. + """ + HuggingFace class for text embeddings. Args: model_name (str, optional): If it is a filepath on disc, it loads the model from that path. @@ -183,46 +185,95 @@ def __init__( def class_name(cls) -> str: return "HuggingFaceEmbedding" - def _embed( + def _validate_input(self, text: str) -> None: + """ + Validate input text. + + Args: + text: Input text to validate + + Raises: + ValueError: If text is empty + """ + if not text.strip(): + raise ValueError("Input text cannot be empty or whitespace") + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + reraise=True, + ) + def _embed_with_retry( self, sentences: List[str], prompt_name: Optional[str] = None, ) -> List[List[float]]: - """Generates Embeddings either multiprocess or single process. + """ + Generates embeddings with retry mechanism. Args: - sentences (List[str]): Texts or Sentences to embed - prompt_name (Optional[str], optional): The name of the prompt to use for encoding. Must be a key in the `prompts` dictionary i.e. "query" or "text" If ``prompt`` is also set, this argument is ignored. Defaults to None. + sentences: List of texts to embed + prompt_name: Optional prompt type Returns: - List[List[float]]: a 2d numpy array with shape [num_inputs, output_dimension] is returned. - If only one string input is provided, then the output is a 1d array with shape [output_dimension] + List of embedding vectors + + Raises: + Exception: If embedding fails after retries """ - if self._parallel_process: - pool = self._model.start_multi_process_pool( - target_devices=self._target_devices - ) - emb = self._model.encode_multi_process( - sentences=sentences, - pool=pool, - batch_size=self.embed_batch_size, - prompt_name=prompt_name, - normalize_embeddings=self.normalize, - ) - self._model.stop_multi_process_pool(pool=pool) + try: + if self._parallel_process: + pool = self._model.start_multi_process_pool( + target_devices=self._target_devices + ) + emb = self._model.encode_multi_process( + sentences=sentences, + pool=pool, + batch_size=self.embed_batch_size, + prompt_name=prompt_name, + normalize_embeddings=self.normalize, + ) + self._model.stop_multi_process_pool(pool=pool) + else: + emb = self._model.encode( + sentences, + batch_size=self.embed_batch_size, + prompt_name=prompt_name, + normalize_embeddings=self.normalize, + ) + return emb.tolist() + except Exception as e: + logger.warning(f"Embedding attempt failed: {e!s}") + raise - else: - emb = self._model.encode( - sentences, - batch_size=self.embed_batch_size, - prompt_name=prompt_name, - normalize_embeddings=self.normalize, - ) + def _embed( + self, + sentences: List[str], + prompt_name: Optional[str] = None, + ) -> List[List[float]]: + """ + Generates Embeddings with input validation and retry mechanism. + + Args: + sentences: Texts or Sentences to embed + prompt_name: The name of the prompt to use for encoding + + Returns: + List of embedding vectors + + Raises: + ValueError: If any input text is invalid + Exception: If embedding fails after retries + """ + # Validate all inputs + for text in sentences: + self._validate_input(text) - return emb.tolist() + return self._embed_with_retry(sentences, prompt_name) def _get_query_embedding(self, query: str) -> List[float]: - """Generates Embeddings for Query. + """ + Generates Embeddings for Query. Args: query (str): Query text/sentence @@ -233,7 +284,8 @@ def _get_query_embedding(self, query: str) -> List[float]: return self._embed(query, prompt_name="query") async def _aget_query_embedding(self, query: str) -> List[float]: - """Generates Embeddings for Query Asynchronously. + """ + Generates Embeddings for Query Asynchronously. Args: query (str): Query text/sentence @@ -244,7 +296,8 @@ async def _aget_query_embedding(self, query: str) -> List[float]: return self._get_query_embedding(query) async def _aget_text_embedding(self, text: str) -> List[float]: - """Generates Embeddings for text Asynchronously. + """ + Generates Embeddings for text Asynchronously. Args: text (str): Text/Sentence @@ -255,7 +308,8 @@ async def _aget_text_embedding(self, text: str) -> List[float]: return self._get_text_embedding(text) def _get_text_embedding(self, text: str) -> List[float]: - """Generates Embeddings for text. + """ + Generates Embeddings for text. Args: text (str): Text/sentences @@ -266,7 +320,8 @@ def _get_text_embedding(self, text: str) -> List[float]: return self._embed(text, prompt_name="text") def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: - """Generates Embeddings for text. + """ + Generates Embeddings for text. Args: texts (List[str]): Texts / Sentences @@ -353,7 +408,8 @@ def _get_inference_client_kwargs(self) -> Dict[str, Any]: } def __init__(self, **kwargs: Any) -> None: - """Initialize. + """ + Initialize. Args: kwargs: See the class-level Fields. diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface/tests/test_embeddings_huggingface.py b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface/tests/test_embeddings_huggingface.py index ca784f8fdbde6..4363c44a975a7 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface/tests/test_embeddings_huggingface.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface/tests/test_embeddings_huggingface.py @@ -3,6 +3,7 @@ HuggingFaceEmbedding, HuggingFaceInferenceAPIEmbedding, ) +import pytest def test_huggingfaceembedding_class(): @@ -15,3 +16,29 @@ def test_huggingfaceapiembedding_class(): b.__name__ for b in HuggingFaceInferenceAPIEmbedding.__mro__ ] assert BaseEmbedding.__name__ in names_of_base_classes + + +def test_input_validation(): + embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en") + + # Test empty input + with pytest.raises(ValueError, match="Input text cannot be empty or whitespace"): + embed_model._validate_input("") + + # Test whitespace input + with pytest.raises(ValueError, match="Input text cannot be empty or whitespace"): + embed_model._validate_input(" ") + + # Test valid input + embed_model._validate_input("This is a valid input") # Should not raise + + +def test_embedding_retry(): + embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en") + + # Test successful embedding + result = embed_model._embed(["This is a test sentence"]) + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], list) + assert all(isinstance(x, float) for x in result[0])