diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..5c06779 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,34 @@ +name: CI + +on: [push, pull_request] + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ['3.8', '3.9', '3.10', '3.11'] + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install setuptools wheel setuptools_scm Cython numpy + + - name: Build and install package + run: | + python setup.py build_ext --inplace + python -m pip install . + + - name: Test installation + run: | + python -m unittest discover -s tests + diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..60564e7 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,35 @@ +name: Publish to PyPI + +on: + push: + tags: + - 'v*.*.*' + +jobs: + build-and-publish: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install setuptools wheel twine setuptools_scm Cython numpy + + - name: Build package + run: | + python setup.py sdist bdist_wheel + + - name: Publish package to PyPI + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} + run: | + twine upload dist/* + diff --git a/MANIFEST.in b/MANIFEST.in index 3e5661d..86a19fa 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,5 @@ include LICENSE include README.md recursive-include wordllama *.py *.toml *.json +include wordllama/algorithms/*.pyx +include wordllama/algorithms/*.pxd diff --git a/README.md b/README.md index 06139a4..1916316 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ -# Word Llama +# WordLlama The power of 15 trillion tokens of training, extracted, flogged and minimized into a cute little package for word embedding. @@ -59,7 +59,7 @@ wl.topk(query, candidates, k=3) # return topk strings based on query ## What is it? -WordLlama is a word embedding model that recycles components from large language models (LLMs) to create efficient and compact word representations (such as GloVe, Word2Vec or FastText). +WordLlama is a utility for NLP and word embedding model that recycles components from large language models (LLMs) to create efficient and compact word representations (such as GloVe, Word2Vec or FastText). WordLlama begins by extracting the token embedding codebook from a state-of-the-art LLM (e.g., LLama3 70B), and training a small context-less model in a general purpose embedding framework. WordLlama improves on all MTEB benchmarks above word models like GloVe 300d, while being substantially smaller in size (**16MB default model** @ 256-dim vs >2GB). @@ -96,8 +96,6 @@ Because of its fast and portable size, it makes a good "Swiss-Army Knife" utilit The [l2_supercat](https://huggingface.co/dleemiller/word-llama-l2-supercat) is a Llama2-vocabulary model. To train this model, I concatenated codebooks from several models, including Llama2 70B and phi3 medium (after removing additional special tokens). Because several models have used the Llama2 tokenizer, their codebooks can be concatenated and trained together. Performance of the resulting model is comparable to training the Llama3 70B codebook, while being 4x smaller (32k vs 128k vocabulary). -I anticipate the best results will come from training using the Llama3 405B codebook, when released. - ## Embed Text Here’s how you can load pre-trained embeddings and use them to embed text: @@ -134,7 +132,7 @@ ranked_docs = wl.rank("i went to the car", ["van", "truck"]) wl.binary = False # turn off hamming and use cosine # load a different model class -wl = WordLlama.load(config="llama3_400B", dim=1024) # downloads model from HF +wl = WordLlama.load(config="l3_supercat", dim=1024) # downloads model from HF ``` ## Training Notes @@ -145,8 +143,11 @@ L2 Supercat was trained using a batch size of 512 on a single A100 for 12 hours. ## Roadmap -- Test distillation training from a larger embedding model -- Retrain on llama3 405B (waiting on release...), concat with llama guard 2, llama3 70B +- Working on adding inference features: + - Semantic text splitting +- Add example notebooks + - DSPy evaluators + - RAG pipelines ## Extracting Token Embeddings @@ -180,7 +181,7 @@ If you use WordLlama in your research or project, please consider citing it as f title = {WordLlama: Recycled Token Embeddings from Large Language Models}, year = {2024}, url = {https://github.com/dleemiller/wordllama}, - version = {0.2.1} + version = {0.2.3} } ``` diff --git a/pyproject.toml b/pyproject.toml index 8eaba0e..8da1c6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=42", "wheel", "setuptools_scm"] +requires = ["setuptools>=42", "wheel", "setuptools_scm", "Cython", "numpy"] build-backend = "setuptools.build_meta" [project] @@ -37,7 +37,7 @@ Repository = "https://github.com/dleemiller/WordLlama" packages = ["wordllama"] [tool.setuptools.package-data] -wordllama = ["**/*.toml", "tokenizers/*.json", "weights/*.safetensors"] +wordllama = ["algorithms/*.so", "algorithms/*.pyd", "**/*.pyx", "**/*.pyd", "**/*.toml", "tokenizers/*.json", "weights/*.safetensors"] [tool.setuptools.dynamic] classifiers = { file = "classifiers.txt" } @@ -46,3 +46,4 @@ classifiers = { file = "classifiers.txt" } write_to = "wordllama/_version.py" version_scheme = "post-release" local_scheme = "no-local-version" + diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..666c7a9 --- /dev/null +++ b/setup.py @@ -0,0 +1,81 @@ +from setuptools import setup, Extension +from Cython.Build import cythonize +import numpy as np +import platform +import sys + +numpy_include = np.get_include() + +extra_compile_args = [] +extra_link_args = [] + +if platform.system() == "Darwin": + if platform.machine() == "arm64": + extra_compile_args.extend(["-arch", "arm64", "-O3", "-ffast-math"]) + extra_link_args.extend(["-arch", "arm64"]) + else: + extra_compile_args.extend(["-arch", "x86_64", "-O3", "-ffast-math"]) + extra_link_args.extend(["-arch", "x86_64"]) +elif platform.system() == "Windows": + extra_compile_args.extend(["/O2"]) +else: # Linux and others + if platform.machine().startswith("arm"): + if platform.architecture()[0] == "32bit": + extra_compile_args.extend(["-march=armv7-a", "-mfpu=neon"]) + extra_link_args.extend(["-march=armv7-a", "-mfpu=neon"]) + else: # 64-bit ARM + extra_compile_args.extend(["-march=armv8-a"]) + extra_link_args.extend(["-march=armv8-a"]) + elif platform.machine() in ["x86_64", "AMD64"]: + extra_compile_args.extend(["-march=native", "-mpopcnt"]) + extra_link_args.extend(["-march=native", "-mpopcnt"]) + +extra_compile_args.extend(["-O3", "-ffast-math"]) + +extensions = [ + Extension( + "wordllama.algorithms.splitter", + ["wordllama/algorithms/splitter.pyx"], + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ), + Extension( + "wordllama.algorithms.hamming_distance", + ["wordllama/algorithms/hamming_distance.pyx"], + include_dirs=[numpy_include], + define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")], + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ), + Extension( + "wordllama.algorithms.deduplicate_helpers", + ["wordllama/algorithms/deduplicate_helpers.pyx"], + include_dirs=[numpy_include], + define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")], + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ), + Extension( + "wordllama.algorithms.kmeans_helpers", + ["wordllama/algorithms/kmeans_helpers.pyx"], + include_dirs=[numpy_include], + define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")], + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ), +] + +setup( + name="Text Processing Tools", + ext_modules=cythonize( + extensions, + compiler_directives={ + "language_level": "3", + "boundscheck": False, + "wraparound": False, + }, + annotate=True, + ), + zip_safe=False, + install_requires=["numpy"], +) diff --git a/tests/test_inference.py b/tests/test_inference.py index d01da06..47714d8 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -82,11 +82,11 @@ def setUp(self, mock_tokenizer): return_value=np.array([[0.1] * 64, [0.1] * 64, np.random.rand(64), [0.1] * 64]), ) def test_deduplicate_cosine(self, mock_embed): - docs = ["doc1", "doc1_dup", "doc2", "doc1_dup2"] + docs = ["doc1", "doc1_dup", "a second document that is different", "doc1_dup2"] deduplicated_docs = self.model.deduplicate(docs, threshold=0.9) self.assertEqual(len(deduplicated_docs), 2) self.assertIn("doc1", deduplicated_docs) - self.assertIn("doc2", deduplicated_docs) + self.assertIn("a second document that is different", deduplicated_docs) @patch.object( WordLlamaInference, diff --git a/wordllama/adapters/binarizer.py b/wordllama/adapters/binarizer.py index 8267bc0..aa8b17b 100644 --- a/wordllama/adapters/binarizer.py +++ b/wordllama/adapters/binarizer.py @@ -76,7 +76,6 @@ def approximate_function(x, o): class Binarizer(nn.Module): - def __init__(self, ste="tanh"): super().__init__() assert ste in ["ste", "reste", "stochastic", "tanh"] diff --git a/wordllama/algorithms/__init__.py b/wordllama/algorithms/__init__.py index 5d88188..1832c65 100644 --- a/wordllama/algorithms/__init__.py +++ b/wordllama/algorithms/__init__.py @@ -1 +1,3 @@ from .kmeans import kmeans_clustering +from .hamming_distance import hamming_distance +from .deduplicate_helpers import process_batches_cy diff --git a/wordllama/algorithms/deduplicate_helpers.pyx b/wordllama/algorithms/deduplicate_helpers.pyx new file mode 100644 index 0000000..b29be40 --- /dev/null +++ b/wordllama/algorithms/deduplicate_helpers.pyx @@ -0,0 +1,42 @@ +# cython: language_level=3, boundscheck=False, wraparound=False +# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION +import numpy as np +cimport numpy as np +from numpy cimport PyArray_DIMS + +ctypedef fused embedding_dtype: + np.uint32_t + np.float32_t + np.float64_t + +def process_batches_cy(np.ndarray[embedding_dtype, ndim=2] doc_embeddings, + double threshold, int batch_size, vector_similarity): + cdef int num_embeddings = PyArray_DIMS(doc_embeddings)[0] + cdef set duplicate_indices = set() + cdef set seen_docs = set() + cdef int i, j, start_i, end_i, start_j, end_j + cdef np.ndarray[embedding_dtype, ndim=2] batch_i, batch_j + cdef np.ndarray[double, ndim=2] sim_matrix + cdef np.ndarray[np.int64_t, ndim=2] sim_indices + cdef int doc_idx_1, doc_idx_2 + + for i in range(0, num_embeddings, batch_size): + start_i = i + end_i = min(i + batch_size, num_embeddings) + batch_i = doc_embeddings[start_i:end_i] + for j in range(i, num_embeddings, batch_size): + start_j = j + end_j = min(j + batch_size, num_embeddings) + batch_j = doc_embeddings[start_j:end_j] + sim_matrix = vector_similarity(batch_i, batch_j) + sim_indices = np.argwhere(sim_matrix > threshold) + for idx in sim_indices: + if idx[0] + start_i != idx[1] + start_j: + doc_idx_1 = idx[0] + start_i + doc_idx_2 = idx[1] + start_j + if doc_idx_2 not in seen_docs: + seen_docs.add(doc_idx_1) + duplicate_indices.add(doc_idx_2) + + return duplicate_indices + diff --git a/wordllama/algorithms/hamming_distance.pyx b/wordllama/algorithms/hamming_distance.pyx new file mode 100644 index 0000000..8e4a9e6 --- /dev/null +++ b/wordllama/algorithms/hamming_distance.pyx @@ -0,0 +1,53 @@ +# cython: language_level=3, boundscheck=False, wraparound=False +# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION + +import numpy as np +cimport numpy as np +from numpy cimport int32_t, uint32_t, uint8_t, PyArrayObject, PyArray_DIMS +from libc.stdint cimport uint32_t, uint8_t + +np.import_array() + +cdef extern from *: + """ + #if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) + #include + static inline int popcount(uint32_t x) { + return __builtin_popcount(x); + } + #elif defined(__GNUC__) && (defined(__ARM_NEON) || defined(__aarch64__)) + #include + static inline int popcount(uint32_t x) { + return vaddv_u8(vcnt_u8(vcreate_u8(x))); + } + #else + static inline int popcount(uint32_t x) { + x = x - ((x >> 1) & 0x55555555); + x = (x & 0x33333333) + ((x >> 2) & 0x33333333); + x = (x + (x >> 4)) & 0x0F0F0F0F; + x = x + (x >> 8); + x = x + (x >> 16); + return x & 0x0000003F; + } + #endif + """ + int popcount(uint32_t x) nogil + +cpdef np.ndarray[int32_t, ndim=2] hamming_distance(np.ndarray[uint32_t, ndim=2] a, np.ndarray[uint32_t, ndim=2] b): + cdef Py_ssize_t i, j, k + cdef int dist + cdef Py_ssize_t n = PyArray_DIMS(a)[0] + cdef Py_ssize_t m = PyArray_DIMS(b)[0] + cdef Py_ssize_t width = PyArray_DIMS(a)[1] + cdef np.ndarray[int32_t, ndim=2] distance = np.zeros((n, m), dtype=np.int32) + + # Calculate Hamming distance + for i in range(n): + for j in range(m): + dist = 0 + for k in range(width): + dist += popcount(a[i, k] ^ b[j, k]) + distance[i, j] = dist + + return distance + diff --git a/wordllama/algorithms/kmeans.py b/wordllama/algorithms/kmeans.py index d37561c..3ce79e2 100644 --- a/wordllama/algorithms/kmeans.py +++ b/wordllama/algorithms/kmeans.py @@ -1,5 +1,6 @@ import numpy as np from typing import List, Tuple +from .kmeans_helpers import compute_distances, update_centroids def kmeans_plusplus_initialization( @@ -39,27 +40,6 @@ def kmeans_plusplus_initialization( return centroids -def calculate_inertia( - embeddings: np.ndarray, labels: np.ndarray, centroids: np.ndarray -) -> float: - """ - Calculate the inertia (sum of squared distances to the closest centroid). - - Parameters: - embeddings (np.ndarray): The input data points (embeddings) to cluster. - labels (np.ndarray): The cluster labels for each point. - centroids (np.ndarray): The cluster centroids. - - Returns: - float: The calculated inertia. - """ - inertia = 0.0 - for i, centroid in enumerate(centroids): - cluster_points = embeddings[labels == i] - inertia += np.sum((cluster_points - centroid) ** 2) - return inertia - - def kmeans_clustering( embeddings: np.ndarray, k: int, @@ -84,6 +64,7 @@ def kmeans_clustering( Returns: Tuple[List[int], List[float]]: A tuple containing the cluster labels and the list of loss values for each iteration. """ + if random_state is None: random_state = np.random.RandomState() elif isinstance(random_state, int): @@ -95,21 +76,16 @@ def kmeans_clustering( for init_run in range(n_init): centroids = kmeans_plusplus_initialization(embeddings, k, random_state) - prev_inertia = float("inf") losses = [] for iteration in range(max_iterations): - # Step 2: Assign each point to the nearest centroid - distances = np.sqrt( - ((embeddings[:, np.newaxis, :] - centroids[np.newaxis, :, :]) ** 2).sum( - axis=2 - ) - ) + # Step 2: Assign each point to the nearest centroid using the Cython optimized function + distances = compute_distances(embeddings, centroids) labels = np.argmin(distances, axis=1) - # Step 2: Calculate inertia - inertia = calculate_inertia(embeddings, labels, centroids) + # Calculate inertia using distances directly + inertia = np.sum(np.min(distances, axis=1) ** 2) losses.append(inertia) # Check for convergence based on inertia @@ -118,24 +94,17 @@ def kmeans_clustering( prev_inertia = inertia - # Step 3: Update centroids to the mean of the points in each cluster - new_centroids = np.array( - [ - embeddings[labels == i].mean(axis=0) - if np.sum(labels == i) > 0 - else centroids[i] - for i in range(k) - ] - ) + # Step 3: Update centroids using the Cython optimized function + centroids = update_centroids(embeddings, labels, k, embeddings.shape[1]) # Check for convergence based on centroids if iteration >= min_iterations and np.allclose( - centroids, new_centroids, atol=tolerance + centroids, + update_centroids(embeddings, labels, k, embeddings.shape[1]), + atol=tolerance, ): break - centroids = new_centroids - # Check if this initialization run has the best result if inertia < best_inertia: best_inertia = inertia diff --git a/wordllama/algorithms/kmeans_helpers.pyx b/wordllama/algorithms/kmeans_helpers.pyx new file mode 100644 index 0000000..7fcc647 --- /dev/null +++ b/wordllama/algorithms/kmeans_helpers.pyx @@ -0,0 +1,49 @@ +# cython: language_level=3, boundscheck=False, wraparound=False +# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION +import numpy as np +cimport numpy as np +from libc.math cimport sqrt + +ctypedef np.npy_intp DTYPE_t + +cdef inline double squared_euclidean_distance(const double[:] vec1, const double[:] vec2, Py_ssize_t dim) nogil: + cdef Py_ssize_t i + cdef double dist = 0.0 + for i in range(dim): + dist += (vec1[i] - vec2[i]) ** 2 + return dist + +def compute_distances(const double[:, :] embeddings, const double[:, :] centroids): + cdef Py_ssize_t num_points = embeddings.shape[0] + cdef Py_ssize_t num_centroids = centroids.shape[0] + cdef Py_ssize_t dim = embeddings.shape[1] + cdef double[:, :] distances = np.empty((num_points, num_centroids), dtype=np.float64) + cdef Py_ssize_t i, j + + for i in range(num_points): + for j in range(num_centroids): + distances[i, j] = sqrt(squared_euclidean_distance(embeddings[i], centroids[j], dim)) + + return np.asarray(distances) + +def update_centroids(const double[:, :] embeddings, const DTYPE_t[:] labels, Py_ssize_t num_clusters, Py_ssize_t dim): + cdef double[:, :] new_centroids = np.zeros((num_clusters, dim), dtype=np.float64) + cdef DTYPE_t[:] count = np.zeros(num_clusters, dtype=np.intp) + cdef Py_ssize_t i, j + cdef DTYPE_t label + + # Accumulate sums and counts for each cluster + for i in range(labels.shape[0]): + label = labels[i] + for j in range(dim): + new_centroids[label, j] += embeddings[i, j] + count[label] += 1 + + # Calculate the mean for each cluster + for i in range(num_clusters): + if count[i] > 0: + for j in range(dim): + new_centroids[i, j] /= count[i] + + return np.asarray(new_centroids) + diff --git a/wordllama/algorithms/splitter.pyx b/wordllama/algorithms/splitter.pyx new file mode 100644 index 0000000..97c6b80 --- /dev/null +++ b/wordllama/algorithms/splitter.pyx @@ -0,0 +1,43 @@ +# cython: language_level=3, infer_types=True, binding=True +import cython +from typing import List + +@cython.boundscheck(False) +@cython.wraparound(False) +def split_sentences(str text, set punct_chars=None) -> List[str]: + cdef int i, start = 0, text_len = len(text) + cdef list sentences = [] + cdef bint seen_period = False + cdef str current_char + cdef set punct_chars_c + + if punct_chars is None: + punct_chars = {'.', '!', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹', '।', '॥', '၊', '။', '።', '፧', '፨', + '᙮', '᜵', '᜶', '᠃', '᠉', '᥄', '᥅', '᪨', '᪩', '᪪', '᪫', '᭚', '᭛', '᭞', '᭟', + '᰻', '᰼', '᱾', '᱿', '‼', '‽', '⁇', '⁈', '⁉', '⸮', '⸼', '꓿', '꘎', '꘏', '꛳', + '꛷', '꡶', '꡷', '꣎', '꣏', '꤯', '꧈', '꧉', '꩝', '꩞', '꩟', '꫰', '꫱', '꯫', '﹒', + '﹖', '﹗', '!', '.', '?', '੖', '੗', '၇', '၈', 'Ⴞ', 'Ⴟ', 'Ⴠ', 'Ⴡ', 'ᅁ', + 'ᅂ', 'ᅃ', 'ᇅ', 'ᇆ', 'ᇍ', 'ᇞ', 'ᇟ', 'ሸ', 'ሹ', 'ሻ', 'ሼ', 'ኩ', 'ᑋ', + 'ᑌ', 'ᗂ', 'ᗃ', 'ᗉ', 'ᗊ', 'ᗋ', 'ᗌ', 'ᗍ', 'ᗎ', 'ᗏ', 'ᗐ', 'ᗑ', 'ᗒ', + 'ᗓ', 'ᗔ', 'ᗕ', 'ᗖ', 'ᗗ', '遁', '遂', '᜼', '᜽', '᜾', 'ᩂ', 'ᩃ', 'ꛝ', + 'ꛞ', '᱁', '᱂', '橮', '橯', '櫵', '欷', '欸', '歄', '벟', '?', '。', '。'} + + punct_chars_c = set(ord(c) for c in punct_chars) + + if not any(ord(char) in punct_chars_c for char in text): + return [text] + + for i in range(text_len): + current_char = text[i] + if ord(current_char) in punct_chars_c: + seen_period = True + elif seen_period and (current_char == ' ' or current_char == '\n'): + if i + 1 < text_len and (text[i+1].isupper() or text[i+1] == '\n'): + sentences.append(text[start:i+1].strip()) + start = i + 1 + seen_period = False + + if start < text_len: + sentences.append(text[start:].strip()) + + return sentences diff --git a/wordllama/inference.py b/wordllama/inference.py index 722f6f4..491cd33 100644 --- a/wordllama/inference.py +++ b/wordllama/inference.py @@ -1,9 +1,9 @@ import numpy as np from tokenizers import Tokenizer -from typing import Union, List, Tuple +from typing import Union, List, Tuple, Optional import logging -from .algorithms import kmeans_clustering +from .algorithms import kmeans_clustering, hamming_distance, process_batches_cy from .config import WordLlamaConfig # Set up logging @@ -145,9 +145,7 @@ def hamming_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray: max_dist = a.shape[1] * 32 # Calculate Hamming distance - xor_result = np.bitwise_xor(a[:, np.newaxis], b) - dist = np.sum(np.unpackbits(xor_result.view(np.uint8), axis=2), axis=2) - + dist = hamming_distance(a, b) return 1.0 - 2.0 * (dist / max_dist) @staticmethod @@ -230,7 +228,7 @@ def rank(self, query: str, docs: List[str]) -> List[tuple]: return similarities def deduplicate( - self, docs: List[str], threshold: float = 0.9, batch_size: int = 100 + self, docs: List[str], threshold: float = 0.9, batch_size: Optional[int] = None ) -> List[str]: """ Deduplicate a list of documents based on similarity threshold. @@ -238,43 +236,19 @@ def deduplicate( Args: docs (List[str]): List of document texts to deduplicate. threshold (float): Similarity threshold for deduplication. - batch_size (int): Batch size for processing embeddings. + batch_size (Optional[int]): Batch size for processing embeddings. Returns: List[str]: Deduplicated list of document texts. """ - # Embed all documents doc_embeddings = self.embed(docs, norm=not self.binary) - num_embeddings = doc_embeddings.shape[0] - duplicate_indices = set() - seen_docs = set() - - for i in range(0, num_embeddings, batch_size): - start_i = i - end_i = min(i + batch_size, num_embeddings) - batch_i = doc_embeddings[start_i:end_i] - - for j in range( - i, num_embeddings, batch_size - ): # Start from i to avoid redundant comparisons - start_j = j - end_j = min(j + batch_size, num_embeddings) - batch_j = doc_embeddings[start_j:end_j] - - sim_matrix = self.vector_similarity(batch_i, batch_j) - - # Find indices where similarity exceeds the threshold - sim_indices = np.argwhere(sim_matrix > threshold) - for idx in sim_indices: - if idx[0] + start_i != idx[1] + start_j: # Ignore self-comparison - doc_idx_1 = idx[0] + start_i - doc_idx_2 = idx[1] + start_j - if doc_idx_2 not in seen_docs: - seen_docs.add(doc_idx_1) - duplicate_indices.add(doc_idx_2) - - # Filter out embeddings that are not in duplicate_indices + if batch_size is None: + batch_size = 500 if self.binary else 5000 + duplicate_indices = process_batches_cy( + doc_embeddings, threshold, batch_size, self.vector_similarity + ) + unique_docs = [ doc for idx, doc in enumerate(docs) if idx not in duplicate_indices ]