Skip to content

Commit

Permalink
fix: Add input validation and retry mechanism to HuggingFaceEmbedding (
Browse files Browse the repository at this point in the history
  • Loading branch information
iamarunbrahma authored Dec 10, 2024
1 parent acea53c commit 278bde0
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
get_text_instruct_for_model_name,
)
from sentence_transformers import SentenceTransformer
from tenacity import retry, stop_after_attempt, wait_exponential

DEFAULT_HUGGINGFACE_LENGTH = 512
logger = logging.getLogger(__name__)


class HuggingFaceEmbedding(BaseEmbedding):
"""HuggingFace class for text embeddings.
"""
HuggingFace class for text embeddings.
Args:
model_name (str, optional): If it is a filepath on disc, it loads the model from that path.
Expand Down Expand Up @@ -183,46 +185,95 @@ def __init__(
def class_name(cls) -> str:
return "HuggingFaceEmbedding"

def _embed(
def _validate_input(self, text: str) -> None:
"""
Validate input text.
Args:
text: Input text to validate
Raises:
ValueError: If text is empty
"""
if not text.strip():
raise ValueError("Input text cannot be empty or whitespace")

@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
reraise=True,
)
def _embed_with_retry(
self,
sentences: List[str],
prompt_name: Optional[str] = None,
) -> List[List[float]]:
"""Generates Embeddings either multiprocess or single process.
"""
Generates embeddings with retry mechanism.
Args:
sentences (List[str]): Texts or Sentences to embed
prompt_name (Optional[str], optional): The name of the prompt to use for encoding. Must be a key in the `prompts` dictionary i.e. "query" or "text" If ``prompt`` is also set, this argument is ignored. Defaults to None.
sentences: List of texts to embed
prompt_name: Optional prompt type
Returns:
List[List[float]]: a 2d numpy array with shape [num_inputs, output_dimension] is returned.
If only one string input is provided, then the output is a 1d array with shape [output_dimension]
List of embedding vectors
Raises:
Exception: If embedding fails after retries
"""
if self._parallel_process:
pool = self._model.start_multi_process_pool(
target_devices=self._target_devices
)
emb = self._model.encode_multi_process(
sentences=sentences,
pool=pool,
batch_size=self.embed_batch_size,
prompt_name=prompt_name,
normalize_embeddings=self.normalize,
)
self._model.stop_multi_process_pool(pool=pool)
try:
if self._parallel_process:
pool = self._model.start_multi_process_pool(
target_devices=self._target_devices
)
emb = self._model.encode_multi_process(
sentences=sentences,
pool=pool,
batch_size=self.embed_batch_size,
prompt_name=prompt_name,
normalize_embeddings=self.normalize,
)
self._model.stop_multi_process_pool(pool=pool)
else:
emb = self._model.encode(
sentences,
batch_size=self.embed_batch_size,
prompt_name=prompt_name,
normalize_embeddings=self.normalize,
)
return emb.tolist()
except Exception as e:
logger.warning(f"Embedding attempt failed: {e!s}")
raise

else:
emb = self._model.encode(
sentences,
batch_size=self.embed_batch_size,
prompt_name=prompt_name,
normalize_embeddings=self.normalize,
)
def _embed(
self,
sentences: List[str],
prompt_name: Optional[str] = None,
) -> List[List[float]]:
"""
Generates Embeddings with input validation and retry mechanism.
Args:
sentences: Texts or Sentences to embed
prompt_name: The name of the prompt to use for encoding
Returns:
List of embedding vectors
Raises:
ValueError: If any input text is invalid
Exception: If embedding fails after retries
"""
# Validate all inputs
for text in sentences:
self._validate_input(text)

return emb.tolist()
return self._embed_with_retry(sentences, prompt_name)

def _get_query_embedding(self, query: str) -> List[float]:
"""Generates Embeddings for Query.
"""
Generates Embeddings for Query.
Args:
query (str): Query text/sentence
Expand All @@ -233,7 +284,8 @@ def _get_query_embedding(self, query: str) -> List[float]:
return self._embed(query, prompt_name="query")

async def _aget_query_embedding(self, query: str) -> List[float]:
"""Generates Embeddings for Query Asynchronously.
"""
Generates Embeddings for Query Asynchronously.
Args:
query (str): Query text/sentence
Expand All @@ -244,7 +296,8 @@ async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_query_embedding(query)

async def _aget_text_embedding(self, text: str) -> List[float]:
"""Generates Embeddings for text Asynchronously.
"""
Generates Embeddings for text Asynchronously.
Args:
text (str): Text/Sentence
Expand All @@ -255,7 +308,8 @@ async def _aget_text_embedding(self, text: str) -> List[float]:
return self._get_text_embedding(text)

def _get_text_embedding(self, text: str) -> List[float]:
"""Generates Embeddings for text.
"""
Generates Embeddings for text.
Args:
text (str): Text/sentences
Expand All @@ -266,7 +320,8 @@ def _get_text_embedding(self, text: str) -> List[float]:
return self._embed(text, prompt_name="text")

def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Generates Embeddings for text.
"""
Generates Embeddings for text.
Args:
texts (List[str]): Texts / Sentences
Expand Down Expand Up @@ -353,7 +408,8 @@ def _get_inference_client_kwargs(self) -> Dict[str, Any]:
}

def __init__(self, **kwargs: Any) -> None:
"""Initialize.
"""
Initialize.
Args:
kwargs: See the class-level Fields.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
HuggingFaceEmbedding,
HuggingFaceInferenceAPIEmbedding,
)
import pytest


def test_huggingfaceembedding_class():
Expand All @@ -15,3 +16,29 @@ def test_huggingfaceapiembedding_class():
b.__name__ for b in HuggingFaceInferenceAPIEmbedding.__mro__
]
assert BaseEmbedding.__name__ in names_of_base_classes


def test_input_validation():
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en")

# Test empty input
with pytest.raises(ValueError, match="Input text cannot be empty or whitespace"):
embed_model._validate_input("")

# Test whitespace input
with pytest.raises(ValueError, match="Input text cannot be empty or whitespace"):
embed_model._validate_input(" ")

# Test valid input
embed_model._validate_input("This is a valid input") # Should not raise


def test_embedding_retry():
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en")

# Test successful embedding
result = embed_model._embed(["This is a test sentence"])
assert isinstance(result, list)
assert len(result) == 1
assert isinstance(result[0], list)
assert all(isinstance(x, float) for x in result[0])

0 comments on commit 278bde0

Please sign in to comment.