Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference.py: add batch_size argument to rank() #36

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions wordllama/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def embed(
norm (bool, optional): If True, normalize embeddings to unit vectors. Defaults to False.
return_np (bool, optional): If True, return embeddings as a NumPy array; otherwise, return as a list. Defaults to True.
pool_embeddings (bool, optional): If True, apply average pooling to token embeddings. Defaults to True.
batch_size (int, optional): Number of texts to process in each batch. Defaults to 32.
batch_size (int, optional): Number of texts to process in each batch. Defaults to 64.

Returns:
Union[np.ndarray, List]: Embeddings as a NumPy array or a list, depending on `return_np`.
Expand Down Expand Up @@ -178,7 +178,7 @@ def similarity(self, text1: str, text2: str) -> float:
return self.vector_similarity(embedding1[0], embedding2[0]).item()

def rank(
self, query: str, docs: List[str], sort: bool = True
self, query: str, docs: List[str], sort: bool = True, batch_size: int = 64
) -> List[Tuple[str, float]]:
"""Rank documents based on their similarity to a query.

Expand All @@ -188,6 +188,7 @@ def rank(
query (str): The query text.
docs (List[str]): The list of document texts to rank.
sort (bool): Sort documents by similarity, or not (respect the order in `docs`)
batch_size (int, optional): Number of texts to process in each batch. Defaults to 64.

Returns:
List[Tuple[str, float]]: A list of tuples `(doc, score)`.
Expand All @@ -197,7 +198,7 @@ def rank(
isinstance(docs, list) and len(docs) > 1
), "Docs must be a list of 2 more more strings."
query_embedding = self.embed(query)
doc_embeddings = self.embed(docs)
doc_embeddings = self.embed(docs, batch_size=batch_size)
scores = self.vector_similarity(query_embedding[0], doc_embeddings)

scores = np.atleast_1d(scores.squeeze())
Expand All @@ -208,10 +209,10 @@ def rank(

def deduplicate(
self,
docs: List[str],
threshold: float = 0.9,
return_indices: bool = False,
batch_size: Optional[int] = None,
docs: List[str],
threshold: float = 0.9,
return_indices: bool = False,
batch_size: Optional[int] = None
) -> List[Union[str, int]]:
"""Deduplicate documents based on a similarity threshold.

Expand Down
Loading