Skip to content

Commit

Permalink
Merge pull request #35 from dleemiller/return-idx-deduplicate
Browse files Browse the repository at this point in the history
Return idx deduplicate
  • Loading branch information
dleemiller authored Oct 14, 2024
2 parents c2d6261 + c1cc604 commit ab66bdf
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ print(ranked_docs)
Remove duplicate texts based on a similarity threshold:

```python
deduplicated_docs = wl.deduplicate(candidates, threshold=0.5)
deduplicated_docs = wl.deduplicate(candidates, return_indices=False, threshold=0.5)
print(deduplicated_docs)
# Output:
# ['I went to the park',
Expand Down Expand Up @@ -294,7 +294,7 @@ If you use WordLlama in your research or project, please consider citing it as f
title = {WordLlama: Recycled Token Embeddings from Large Language Models},
year = {2024},
url = {https://github.com/dleemiller/wordllama},
version = {0.3.1}
version = {0.3.2}
}
```

Expand Down
18 changes: 18 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
TokenizerInferenceConfig,
)

np.random.seed(42)


class TestWordLlamaInference(unittest.TestCase):
@patch("wordllama.inference.Tokenizer.from_pretrained")
def setUp(self, mock_tokenizer):
np.random.seed(42)

# Mock the tokenizer
self.mock_tokenizer = MagicMock()

Expand Down Expand Up @@ -123,6 +127,20 @@ def test_deduplicate_all_duplicates(self, mock_embed):
self.assertEqual(len(deduplicated_docs), 1)
self.assertIn("doc1", deduplicated_docs)

@patch.object(
WordLlamaInference,
"embed",
return_value=np.array([[0.1] * 64, [0.1] * 64, [0.1] * 64], dtype=np.float32),
)
def test_deduplicate_return_indices(self, mock_embed):
docs = ["doc1", "doc1_dup", "doc1_dup2"]
duplicated_idx = self.model.deduplicate(
docs, return_indices=True, threshold=0.9
)
self.assertEqual(len(duplicated_idx), 2)
self.assertIn(1, duplicated_idx)
self.assertIn(2, duplicated_idx)

def test_tokenize(self):
tokens = self.model.tokenize("test string")
self.mock_tokenizer.encode_batch.assert_called_with(
Expand Down
13 changes: 11 additions & 2 deletions wordllama/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,18 @@ def rank(
return similarities

def deduplicate(
self, docs: List[str], threshold: float = 0.9, batch_size: Optional[int] = None
) -> List[str]:
self,
docs: List[str],
threshold: float = 0.9,
return_indices: bool = False,
batch_size: Optional[int] = None,
) -> List[Union[str, int]]:
"""Deduplicate documents based on a similarity threshold.
Args:
docs (List[str]): List of documents to deduplicate.
threshold (float, optional): Similarity threshold above which documents are considered duplicates. Defaults to 0.9.
return_indices (bool, optional): Return indices of duplicated documents, rather than deduplicated list of documents.
batch_size (Optional[int], optional): Batch size for processing embeddings. Defaults to None.
Returns:
Expand All @@ -226,6 +231,10 @@ def deduplicate(
duplicate_indices = deduplicate_embeddings(
doc_embeddings, threshold, batch_size
)
if return_indices:
# turn set of numpy int into sorted list of python int
duplicate_indices = list(map(lambda x: x.item(), duplicate_indices))
return sorted(duplicate_indices)

unique_docs = [
doc for idx, doc in enumerate(docs) if idx not in duplicate_indices
Expand Down

0 comments on commit ab66bdf

Please sign in to comment.