From 46cd699b0b423ea288f547d3a244f59dbb76f0ab Mon Sep 17 00:00:00 2001 From: "Jesus M. Gonzalez-Barahona" Date: Mon, 14 Oct 2024 17:51:54 +0200 Subject: [PATCH 1/2] inference.py: add batch_size argument to rank() rank() calls embed(), to embed the list of documents to rank according to similarity. Allowing rank() to have a batch_size argument allows it to call embed() with the same argument, which helps a lot when working with not too much RAM. --- wordllama/inference.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/wordllama/inference.py b/wordllama/inference.py index 4efae36..0ac7003 100644 --- a/wordllama/inference.py +++ b/wordllama/inference.py @@ -178,7 +178,7 @@ def similarity(self, text1: str, text2: str) -> float: return self.vector_similarity(embedding1[0], embedding2[0]).item() def rank( - self, query: str, docs: List[str], sort: bool = True + self, query: str, docs: List[str], sort: bool = True, batch_size: int = 64 ) -> List[Tuple[str, float]]: """Rank documents based on their similarity to a query. @@ -188,6 +188,7 @@ def rank( query (str): The query text. docs (List[str]): The list of document texts to rank. sort (bool): Sort documents by similarity, or not (respect the order in `docs`) + batch_size (int, optional): Number of texts to process in each batch. Defaults to 32. Returns: List[Tuple[str, float]]: A list of tuples `(doc, score)`. @@ -197,7 +198,7 @@ def rank( isinstance(docs, list) and len(docs) > 1 ), "Docs must be a list of 2 more more strings." query_embedding = self.embed(query) - doc_embeddings = self.embed(docs) + doc_embeddings = self.embed(docs, batch_size=batch_size) scores = self.vector_similarity(query_embedding[0], doc_embeddings) scores = np.atleast_1d(scores.squeeze()) @@ -208,10 +209,10 @@ def rank( def deduplicate( self, - docs: List[str], - threshold: float = 0.9, - return_indices: bool = False, - batch_size: Optional[int] = None, + docs: List[str], + threshold: float = 0.9, + return_indices: bool = False, + batch_size: Optional[int] = None ) -> List[Union[str, int]]: """Deduplicate documents based on a similarity threshold. From f699a5963e59f6176065f755abb043be9d7b1122 Mon Sep 17 00:00:00 2001 From: Lee Miller <80222060+dleemiller@users.noreply.github.com> Date: Mon, 14 Oct 2024 14:00:54 -0600 Subject: [PATCH 2/2] Update inference.py --- wordllama/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wordllama/inference.py b/wordllama/inference.py index 0ac7003..bf4ac3c 100644 --- a/wordllama/inference.py +++ b/wordllama/inference.py @@ -77,7 +77,7 @@ def embed( norm (bool, optional): If True, normalize embeddings to unit vectors. Defaults to False. return_np (bool, optional): If True, return embeddings as a NumPy array; otherwise, return as a list. Defaults to True. pool_embeddings (bool, optional): If True, apply average pooling to token embeddings. Defaults to True. - batch_size (int, optional): Number of texts to process in each batch. Defaults to 32. + batch_size (int, optional): Number of texts to process in each batch. Defaults to 64. Returns: Union[np.ndarray, List]: Embeddings as a NumPy array or a list, depending on `return_np`. @@ -188,7 +188,7 @@ def rank( query (str): The query text. docs (List[str]): The list of document texts to rank. sort (bool): Sort documents by similarity, or not (respect the order in `docs`) - batch_size (int, optional): Number of texts to process in each batch. Defaults to 32. + batch_size (int, optional): Number of texts to process in each batch. Defaults to 64. Returns: List[Tuple[str, float]]: A list of tuples `(doc, score)`.