Skip to content

Commit

Permalink
Merge pull request #6 from dleemiller/feature/cython-extensions
Browse files Browse the repository at this point in the history
Feature/cython extensions
  • Loading branch information
dleemiller authored Jul 28, 2024
2 parents 194b1a3 + 54526c9 commit 9046060
Show file tree
Hide file tree
Showing 15 changed files with 377 additions and 92 deletions.
34 changes: 34 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: CI

on: [push, pull_request]

jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.8', '3.9', '3.10', '3.11']

steps:
- uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel setuptools_scm Cython numpy
- name: Build and install package
run: |
python setup.py build_ext --inplace
python -m pip install .
- name: Test installation
run: |
python -m unittest discover -s tests
35 changes: 35 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Publish to PyPI

on:
push:
tags:
- 'v*.*.*'

jobs:
build-and-publish:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.9'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine setuptools_scm Cython numpy
- name: Build package
run: |
python setup.py sdist bdist_wheel
- name: Publish package to PyPI
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
twine upload dist/*
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
include LICENSE
include README.md
recursive-include wordllama *.py *.toml *.json
include wordllama/algorithms/*.pyx
include wordllama/algorithms/*.pxd
17 changes: 9 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

# Word Llama
# WordLlama

The power of 15 trillion tokens of training, extracted, flogged and minimized into a cute little package for word embedding.

Expand Down Expand Up @@ -59,7 +59,7 @@ wl.topk(query, candidates, k=3) # return topk strings based on query

## What is it?

WordLlama is a word embedding model that recycles components from large language models (LLMs) to create efficient and compact word representations (such as GloVe, Word2Vec or FastText).
WordLlama is a utility for NLP and word embedding model that recycles components from large language models (LLMs) to create efficient and compact word representations (such as GloVe, Word2Vec or FastText).
WordLlama begins by extracting the token embedding codebook from a state-of-the-art LLM (e.g., LLama3 70B), and training a small context-less model in a general purpose embedding framework.

WordLlama improves on all MTEB benchmarks above word models like GloVe 300d, while being substantially smaller in size (**16MB default model** @ 256-dim vs >2GB).
Expand Down Expand Up @@ -96,8 +96,6 @@ Because of its fast and portable size, it makes a good "Swiss-Army Knife" utilit
The [l2_supercat](https://huggingface.co/dleemiller/word-llama-l2-supercat) is a Llama2-vocabulary model. To train this model, I concatenated codebooks from several models, including Llama2 70B and phi3 medium (after removing additional special tokens).
Because several models have used the Llama2 tokenizer, their codebooks can be concatenated and trained together. Performance of the resulting model is comparable to training the Llama3 70B codebook, while being 4x smaller (32k vs 128k vocabulary).

I anticipate the best results will come from training using the Llama3 405B codebook, when released.

## Embed Text

Here’s how you can load pre-trained embeddings and use them to embed text:
Expand Down Expand Up @@ -134,7 +132,7 @@ ranked_docs = wl.rank("i went to the car", ["van", "truck"])
wl.binary = False # turn off hamming and use cosine

# load a different model class
wl = WordLlama.load(config="llama3_400B", dim=1024) # downloads model from HF
wl = WordLlama.load(config="l3_supercat", dim=1024) # downloads model from HF
```

## Training Notes
Expand All @@ -145,8 +143,11 @@ L2 Supercat was trained using a batch size of 512 on a single A100 for 12 hours.

## Roadmap

- Test distillation training from a larger embedding model
- Retrain on llama3 405B (waiting on release...), concat with llama guard 2, llama3 70B
- Working on adding inference features:
- Semantic text splitting
- Add example notebooks
- DSPy evaluators
- RAG pipelines

## Extracting Token Embeddings

Expand Down Expand Up @@ -180,7 +181,7 @@ If you use WordLlama in your research or project, please consider citing it as f
title = {WordLlama: Recycled Token Embeddings from Large Language Models},
year = {2024},
url = {https://github.com/dleemiller/wordllama},
version = {0.2.1}
version = {0.2.3}
}
```

Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools>=42", "wheel", "setuptools_scm"]
requires = ["setuptools>=42", "wheel", "setuptools_scm", "Cython", "numpy"]
build-backend = "setuptools.build_meta"

[project]
Expand Down Expand Up @@ -37,7 +37,7 @@ Repository = "https://github.com/dleemiller/WordLlama"
packages = ["wordllama"]

[tool.setuptools.package-data]
wordllama = ["**/*.toml", "tokenizers/*.json", "weights/*.safetensors"]
wordllama = ["algorithms/*.so", "algorithms/*.pyd", "**/*.pyx", "**/*.pyd", "**/*.toml", "tokenizers/*.json", "weights/*.safetensors"]

[tool.setuptools.dynamic]
classifiers = { file = "classifiers.txt" }
Expand All @@ -46,3 +46,4 @@ classifiers = { file = "classifiers.txt" }
write_to = "wordllama/_version.py"
version_scheme = "post-release"
local_scheme = "no-local-version"

81 changes: 81 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from setuptools import setup, Extension
from Cython.Build import cythonize
import numpy as np
import platform
import sys

numpy_include = np.get_include()

extra_compile_args = []
extra_link_args = []

if platform.system() == "Darwin":
if platform.machine() == "arm64":
extra_compile_args.extend(["-arch", "arm64", "-O3", "-ffast-math"])
extra_link_args.extend(["-arch", "arm64"])
else:
extra_compile_args.extend(["-arch", "x86_64", "-O3", "-ffast-math"])
extra_link_args.extend(["-arch", "x86_64"])
elif platform.system() == "Windows":
extra_compile_args.extend(["/O2"])
else: # Linux and others
if platform.machine().startswith("arm"):
if platform.architecture()[0] == "32bit":
extra_compile_args.extend(["-march=armv7-a", "-mfpu=neon"])
extra_link_args.extend(["-march=armv7-a", "-mfpu=neon"])
else: # 64-bit ARM
extra_compile_args.extend(["-march=armv8-a"])
extra_link_args.extend(["-march=armv8-a"])
elif platform.machine() in ["x86_64", "AMD64"]:
extra_compile_args.extend(["-march=native", "-mpopcnt"])
extra_link_args.extend(["-march=native", "-mpopcnt"])

extra_compile_args.extend(["-O3", "-ffast-math"])

extensions = [
Extension(
"wordllama.algorithms.splitter",
["wordllama/algorithms/splitter.pyx"],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
),
Extension(
"wordllama.algorithms.hamming_distance",
["wordllama/algorithms/hamming_distance.pyx"],
include_dirs=[numpy_include],
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
),
Extension(
"wordllama.algorithms.deduplicate_helpers",
["wordllama/algorithms/deduplicate_helpers.pyx"],
include_dirs=[numpy_include],
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
),
Extension(
"wordllama.algorithms.kmeans_helpers",
["wordllama/algorithms/kmeans_helpers.pyx"],
include_dirs=[numpy_include],
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
),
]

setup(
name="Text Processing Tools",
ext_modules=cythonize(
extensions,
compiler_directives={
"language_level": "3",
"boundscheck": False,
"wraparound": False,
},
annotate=True,
),
zip_safe=False,
install_requires=["numpy"],
)
4 changes: 2 additions & 2 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def setUp(self, mock_tokenizer):
return_value=np.array([[0.1] * 64, [0.1] * 64, np.random.rand(64), [0.1] * 64]),
)
def test_deduplicate_cosine(self, mock_embed):
docs = ["doc1", "doc1_dup", "doc2", "doc1_dup2"]
docs = ["doc1", "doc1_dup", "a second document that is different", "doc1_dup2"]
deduplicated_docs = self.model.deduplicate(docs, threshold=0.9)
self.assertEqual(len(deduplicated_docs), 2)
self.assertIn("doc1", deduplicated_docs)
self.assertIn("doc2", deduplicated_docs)
self.assertIn("a second document that is different", deduplicated_docs)

@patch.object(
WordLlamaInference,
Expand Down
1 change: 0 additions & 1 deletion wordllama/adapters/binarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def approximate_function(x, o):


class Binarizer(nn.Module):

def __init__(self, ste="tanh"):
super().__init__()
assert ste in ["ste", "reste", "stochastic", "tanh"]
Expand Down
2 changes: 2 additions & 0 deletions wordllama/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .kmeans import kmeans_clustering
from .hamming_distance import hamming_distance
from .deduplicate_helpers import process_batches_cy
42 changes: 42 additions & 0 deletions wordllama/algorithms/deduplicate_helpers.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# cython: language_level=3, boundscheck=False, wraparound=False
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
import numpy as np
cimport numpy as np
from numpy cimport PyArray_DIMS

ctypedef fused embedding_dtype:
np.uint32_t
np.float32_t
np.float64_t

def process_batches_cy(np.ndarray[embedding_dtype, ndim=2] doc_embeddings,
double threshold, int batch_size, vector_similarity):
cdef int num_embeddings = PyArray_DIMS(doc_embeddings)[0]
cdef set duplicate_indices = set()
cdef set seen_docs = set()
cdef int i, j, start_i, end_i, start_j, end_j
cdef np.ndarray[embedding_dtype, ndim=2] batch_i, batch_j
cdef np.ndarray[double, ndim=2] sim_matrix
cdef np.ndarray[np.int64_t, ndim=2] sim_indices
cdef int doc_idx_1, doc_idx_2

for i in range(0, num_embeddings, batch_size):
start_i = i
end_i = min(i + batch_size, num_embeddings)
batch_i = doc_embeddings[start_i:end_i]
for j in range(i, num_embeddings, batch_size):
start_j = j
end_j = min(j + batch_size, num_embeddings)
batch_j = doc_embeddings[start_j:end_j]
sim_matrix = vector_similarity(batch_i, batch_j)
sim_indices = np.argwhere(sim_matrix > threshold)
for idx in sim_indices:
if idx[0] + start_i != idx[1] + start_j:
doc_idx_1 = idx[0] + start_i
doc_idx_2 = idx[1] + start_j
if doc_idx_2 not in seen_docs:
seen_docs.add(doc_idx_1)
duplicate_indices.add(doc_idx_2)

return duplicate_indices

53 changes: 53 additions & 0 deletions wordllama/algorithms/hamming_distance.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# cython: language_level=3, boundscheck=False, wraparound=False
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION

import numpy as np
cimport numpy as np
from numpy cimport int32_t, uint32_t, uint8_t, PyArrayObject, PyArray_DIMS
from libc.stdint cimport uint32_t, uint8_t

np.import_array()

cdef extern from *:
"""
#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
#include <x86intrin.h>
static inline int popcount(uint32_t x) {
return __builtin_popcount(x);
}
#elif defined(__GNUC__) && (defined(__ARM_NEON) || defined(__aarch64__))
#include <arm_neon.h>
static inline int popcount(uint32_t x) {
return vaddv_u8(vcnt_u8(vcreate_u8(x)));
}
#else
static inline int popcount(uint32_t x) {
x = x - ((x >> 1) & 0x55555555);
x = (x & 0x33333333) + ((x >> 2) & 0x33333333);
x = (x + (x >> 4)) & 0x0F0F0F0F;
x = x + (x >> 8);
x = x + (x >> 16);
return x & 0x0000003F;
}
#endif
"""
int popcount(uint32_t x) nogil

cpdef np.ndarray[int32_t, ndim=2] hamming_distance(np.ndarray[uint32_t, ndim=2] a, np.ndarray[uint32_t, ndim=2] b):
cdef Py_ssize_t i, j, k
cdef int dist
cdef Py_ssize_t n = PyArray_DIMS(a)[0]
cdef Py_ssize_t m = PyArray_DIMS(b)[0]
cdef Py_ssize_t width = PyArray_DIMS(a)[1]
cdef np.ndarray[int32_t, ndim=2] distance = np.zeros((n, m), dtype=np.int32)

# Calculate Hamming distance
for i in range(n):
for j in range(m):
dist = 0
for k in range(width):
dist += popcount(a[i, k] ^ b[j, k])
distance[i, j] = dist

return distance

Loading

0 comments on commit 9046060

Please sign in to comment.