Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add input validation and retry mechanism to HuggingFaceEmbedding #17207

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
get_text_instruct_for_model_name,
)
from sentence_transformers import SentenceTransformer
from tenacity import retry, stop_after_attempt, wait_exponential

DEFAULT_HUGGINGFACE_LENGTH = 512
logger = logging.getLogger(__name__)


class HuggingFaceEmbedding(BaseEmbedding):
"""HuggingFace class for text embeddings.
"""
HuggingFace class for text embeddings.

Args:
model_name (str, optional): If it is a filepath on disc, it loads the model from that path.
Expand Down Expand Up @@ -183,46 +185,95 @@ def __init__(
def class_name(cls) -> str:
return "HuggingFaceEmbedding"

def _embed(
def _validate_input(self, text: str) -> None:
"""
Validate input text.

Args:
text: Input text to validate

Raises:
ValueError: If text is empty
"""
if not text.strip():
raise ValueError("Input text cannot be empty or whitespace")

@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
reraise=True,
)
def _embed_with_retry(
self,
sentences: List[str],
prompt_name: Optional[str] = None,
) -> List[List[float]]:
"""Generates Embeddings either multiprocess or single process.
"""
Generates embeddings with retry mechanism.

Args:
sentences (List[str]): Texts or Sentences to embed
prompt_name (Optional[str], optional): The name of the prompt to use for encoding. Must be a key in the `prompts` dictionary i.e. "query" or "text" If ``prompt`` is also set, this argument is ignored. Defaults to None.
sentences: List of texts to embed
prompt_name: Optional prompt type

Returns:
List[List[float]]: a 2d numpy array with shape [num_inputs, output_dimension] is returned.
If only one string input is provided, then the output is a 1d array with shape [output_dimension]
List of embedding vectors

Raises:
Exception: If embedding fails after retries
"""
if self._parallel_process:
pool = self._model.start_multi_process_pool(
target_devices=self._target_devices
)
emb = self._model.encode_multi_process(
sentences=sentences,
pool=pool,
batch_size=self.embed_batch_size,
prompt_name=prompt_name,
normalize_embeddings=self.normalize,
)
self._model.stop_multi_process_pool(pool=pool)
try:
if self._parallel_process:
pool = self._model.start_multi_process_pool(
target_devices=self._target_devices
)
emb = self._model.encode_multi_process(
sentences=sentences,
pool=pool,
batch_size=self.embed_batch_size,
prompt_name=prompt_name,
normalize_embeddings=self.normalize,
)
self._model.stop_multi_process_pool(pool=pool)
else:
emb = self._model.encode(
sentences,
batch_size=self.embed_batch_size,
prompt_name=prompt_name,
normalize_embeddings=self.normalize,
)
return emb.tolist()
except Exception as e:
logger.warning(f"Embedding attempt failed: {e!s}")
raise

else:
emb = self._model.encode(
sentences,
batch_size=self.embed_batch_size,
prompt_name=prompt_name,
normalize_embeddings=self.normalize,
)
def _embed(
self,
sentences: List[str],
prompt_name: Optional[str] = None,
) -> List[List[float]]:
"""
Generates Embeddings with input validation and retry mechanism.

Args:
sentences: Texts or Sentences to embed
prompt_name: The name of the prompt to use for encoding

Returns:
List of embedding vectors

Raises:
ValueError: If any input text is invalid
Exception: If embedding fails after retries
"""
# Validate all inputs
for text in sentences:
self._validate_input(text)

return emb.tolist()
return self._embed_with_retry(sentences, prompt_name)

def _get_query_embedding(self, query: str) -> List[float]:
"""Generates Embeddings for Query.
"""
Generates Embeddings for Query.

Args:
query (str): Query text/sentence
Expand All @@ -233,7 +284,8 @@ def _get_query_embedding(self, query: str) -> List[float]:
return self._embed(query, prompt_name="query")

async def _aget_query_embedding(self, query: str) -> List[float]:
"""Generates Embeddings for Query Asynchronously.
"""
Generates Embeddings for Query Asynchronously.

Args:
query (str): Query text/sentence
Expand All @@ -244,7 +296,8 @@ async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_query_embedding(query)

async def _aget_text_embedding(self, text: str) -> List[float]:
"""Generates Embeddings for text Asynchronously.
"""
Generates Embeddings for text Asynchronously.

Args:
text (str): Text/Sentence
Expand All @@ -255,7 +308,8 @@ async def _aget_text_embedding(self, text: str) -> List[float]:
return self._get_text_embedding(text)

def _get_text_embedding(self, text: str) -> List[float]:
"""Generates Embeddings for text.
"""
Generates Embeddings for text.

Args:
text (str): Text/sentences
Expand All @@ -266,7 +320,8 @@ def _get_text_embedding(self, text: str) -> List[float]:
return self._embed(text, prompt_name="text")

def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Generates Embeddings for text.
"""
Generates Embeddings for text.

Args:
texts (List[str]): Texts / Sentences
Expand Down Expand Up @@ -353,7 +408,8 @@ def _get_inference_client_kwargs(self) -> Dict[str, Any]:
}

def __init__(self, **kwargs: Any) -> None:
"""Initialize.
"""
Initialize.

Args:
kwargs: See the class-level Fields.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
HuggingFaceEmbedding,
HuggingFaceInferenceAPIEmbedding,
)
import pytest


def test_huggingfaceembedding_class():
Expand All @@ -15,3 +16,29 @@ def test_huggingfaceapiembedding_class():
b.__name__ for b in HuggingFaceInferenceAPIEmbedding.__mro__
]
assert BaseEmbedding.__name__ in names_of_base_classes


def test_input_validation():
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en")

# Test empty input
with pytest.raises(ValueError, match="Input text cannot be empty or whitespace"):
embed_model._validate_input("")

# Test whitespace input
with pytest.raises(ValueError, match="Input text cannot be empty or whitespace"):
embed_model._validate_input(" ")

# Test valid input
embed_model._validate_input("This is a valid input") # Should not raise


def test_embedding_retry():
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en")

# Test successful embedding
result = embed_model._embed(["This is a test sentence"])
assert isinstance(result, list)
assert len(result) == 1
assert isinstance(result[0], list)
assert all(isinstance(x, float) for x in result[0])
Loading