Skip to content

Commit

Permalink
inference.py: add batch_size argument to rank()
Browse files Browse the repository at this point in the history
rank() calls embed(), to embed the list of documents to rank according to similarity. Allowing rank() to have a batch_size argument allows it to call embed() with the same argument, which helps a lot when working with not too much RAM.
  • Loading branch information
jgbarah authored Oct 14, 2024
1 parent ab66bdf commit 67b1db9
Showing 1 changed file with 5 additions and 13 deletions.
18 changes: 5 additions & 13 deletions wordllama/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def similarity(self, text1: str, text2: str) -> float:
return self.vector_similarity(embedding1[0], embedding2[0]).item()

def rank(
self, query: str, docs: List[str], sort: bool = True
self, query: str, docs: List[str], sort: bool = True, batch_size: int = 64
) -> List[Tuple[str, float]]:
"""Rank documents based on their similarity to a query.
Expand All @@ -188,6 +188,7 @@ def rank(
query (str): The query text.
docs (List[str]): The list of document texts to rank.
sort (bool): Sort documents by similarity, or not (respect the order in `docs`)
batch_size (int, optional): Number of texts to process in each batch. Defaults to 32.
Returns:
List[Tuple[str, float]]: A list of tuples `(doc, score)`.
Expand All @@ -197,7 +198,7 @@ def rank(
isinstance(docs, list) and len(docs) > 1
), "Docs must be a list of 2 more more strings."
query_embedding = self.embed(query)
doc_embeddings = self.embed(docs)
doc_embeddings = self.embed(docs, batch_size=batch_size)
scores = self.vector_similarity(query_embedding[0], doc_embeddings)

scores = np.atleast_1d(scores.squeeze())
Expand All @@ -207,18 +208,13 @@ def rank(
return similarities

def deduplicate(
self,
docs: List[str],
threshold: float = 0.9,
return_indices: bool = False,
batch_size: Optional[int] = None,
) -> List[Union[str, int]]:
self, docs: List[str], threshold: float = 0.9, batch_size: Optional[int] = None
) -> List[str]:
"""Deduplicate documents based on a similarity threshold.
Args:
docs (List[str]): List of documents to deduplicate.
threshold (float, optional): Similarity threshold above which documents are considered duplicates. Defaults to 0.9.
return_indices (bool, optional): Return indices of duplicated documents, rather than deduplicated list of documents.
batch_size (Optional[int], optional): Batch size for processing embeddings. Defaults to None.
Returns:
Expand All @@ -231,10 +227,6 @@ def deduplicate(
duplicate_indices = deduplicate_embeddings(
doc_embeddings, threshold, batch_size
)
if return_indices:
# turn set of numpy int into sorted list of python int
duplicate_indices = list(map(lambda x: x.item(), duplicate_indices))
return sorted(duplicate_indices)

unique_docs = [
doc for idx, doc in enumerate(docs) if idx not in duplicate_indices
Expand Down

0 comments on commit 67b1db9

Please sign in to comment.