Skip to content

Commit

Permalink
Merge pull request #39 from dleemiller/numpy-bitwise-count
Browse files Browse the repository at this point in the history
Numpy bitwise count
  • Loading branch information
dleemiller authored Oct 28, 2024
2 parents 0ae94a6 + aac96c1 commit dd4cad8
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 95 deletions.
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
[build-system]
requires = ["setuptools", "wheel", "setuptools_scm[toml]", "Cython", "numpy"]
requires = ["setuptools", "wheel", "setuptools_scm[toml]", "Cython", "numpy>=2"]
build-backend = "setuptools.build_meta"

[project]
name = "wordllama"
dynamic = ["version"]
description = "WordLlama Embedding Utility"
description = "WordLlama NLP Utility"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }
requires-python = ">=3.8"
requires-python = ">=3.9"
authors = [{ name = "Lee Miller", email = "[email protected]" }]
dependencies = [
"numpy",
"numpy>=2",
"safetensors",
"tokenizers",
"toml",
Expand Down
44 changes: 10 additions & 34 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,9 @@

numpy_include = np.get_include()

extra_compile_args = []
extra_compile_args = ["-O3", "-ffast-math"]
extra_link_args = []

if platform.system() == "Darwin":
if platform.machine() == "arm64":
extra_compile_args.extend(["-arch", "arm64", "-O3", "-ffast-math"])
extra_link_args.extend(["-arch", "arm64"])
else:
extra_compile_args.extend(["-arch", "x86_64", "-O3", "-ffast-math"])
extra_link_args.extend(["-arch", "x86_64"])
elif platform.system() == "Windows":
extra_compile_args.extend(["/O2"])
else: # Linux and others
if platform.machine().startswith("arm"):
if platform.architecture()[0] == "32bit":
extra_compile_args.extend(["-march=armv7-a", "-mfpu=neon"])
extra_link_args.extend(["-march=armv7-a", "-mfpu=neon"])
else: # 64-bit ARM
extra_compile_args.extend(["-march=armv8-a"])
extra_link_args.extend(["-march=armv8-a"])
elif platform.machine() in ["x86_64", "AMD64"]:
extra_compile_args.extend(["-march=native", "-mpopcnt"])
extra_link_args.extend(["-march=native", "-mpopcnt"])

extra_compile_args.extend(["-O3", "-ffast-math"])
define_macros = [("NPY_NO_DEPRECATED_API", "NPY_2_0_API_VERSION")]

extensions = [
Extension(
Expand All @@ -42,15 +20,15 @@
"wordllama.algorithms.deduplicate_helpers",
["wordllama/algorithms/deduplicate_helpers.pyx"],
include_dirs=[numpy_include],
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
),
Extension(
"wordllama.algorithms.kmeans",
["wordllama/algorithms/kmeans.pyx"],
include_dirs=[numpy_include],
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
),
Expand All @@ -61,33 +39,31 @@
define_macros=[],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
language="c++"
language="c++",
),
Extension(
"wordllama.algorithms.find_local_minima",
["wordllama/algorithms/find_local_minima.pyx"],
include_dirs=[numpy_include],
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
language="c++"
language="c++",
),
Extension(
"wordllama.algorithms.vector_similarity",
["wordllama/algorithms/vector_similarity.pyx"],
include_dirs=[numpy_include],
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)


),
]

setup(
name="Embedding and lightweight NLP utility.",
use_scm_version=True,
setup_requires=['setuptools_scm'],
setup_requires=["setuptools_scm"],
ext_modules=cythonize(
extensions,
compiler_directives={
Expand Down
10 changes: 9 additions & 1 deletion tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@


class TestFunctional(unittest.TestCase):

def test_function_clustering(self):
wl = WordLlama.load()
wl.cluster(["a", "b"], k=2)

def test_function_similarity(self):
wl = WordLlama.load()
wl.similarity("a", "b")

def test_function_similarity_binary(self):
wl = WordLlama.load()
wl.binary = True
wl.similarity("a", "b")
1 change: 0 additions & 1 deletion wordllama/algorithms/deduplicate_helpers.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# cython: language_level=3, boundscheck=False, wraparound=False, cdivision=True
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION

import numpy as np
cimport numpy as np
Expand Down
1 change: 0 additions & 1 deletion wordllama/algorithms/kmeans.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# cython: language_level=3, boundscheck=False, wraparound=False, cdivision=True, fastmath=True
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION

import numpy as np
from numpy.random import RandomState
Expand Down
69 changes: 20 additions & 49 deletions wordllama/algorithms/vector_similarity.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# cython: language_level=3, boundscheck=False, wraparound=False, cdivision=True, nonecheck=False
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION

import numpy as np
cimport numpy as np
Expand All @@ -12,36 +11,8 @@ from numpy cimport (

np.import_array()

cdef extern from *:
"""
#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
#include <x86intrin.h>
static inline int popcount(uint64_t x) {
return __builtin_popcountll(x);
}
#elif defined(__GNUC__) && (defined(__ARM_NEON) || defined(__aarch64__))
#include <arm_neon.h>
static inline int popcount(uint64_t x) {
// No direct 64-bit popcount in NEON, need to split into two 32-bit parts
uint32_t lo = (uint32_t)x;
uint32_t hi = (uint32_t)(x >> 32);
return vaddv_u8(vcnt_u8(vcreate_u8(lo))) + vaddv_u8(vcnt_u8(vcreate_u8(hi)));
}
#else
static inline int popcount(uint64_t x) {
x = x - ((x >> 1) & 0x5555555555555555);
x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333);
x = (x + (x >> 4)) & 0x0F0F0F0F0F0F0F0F;
x = x + (x >> 8);
x = x + (x >> 16);
x = x + (x >> 32);
return x & 0x0000007F;
}
#endif
"""
int popcount(uint64_t x) nogil

cpdef object hamming_distance(object a, object b):
cpdef object hamming_distance(np.ndarray[np.uint64_t, ndim=2, mode='c'] a,
np.ndarray[np.uint64_t, ndim=2, mode='c'] b):
"""
Compute the Hamming distance between two arrays of binary vectors.
Expand All @@ -52,32 +23,32 @@ cpdef object hamming_distance(object a, object b):
Returns:
np.ndarray: A 2D array containing the Hamming distances.
"""
cdef Py_ssize_t i, j, k
cdef int dist
cdef Py_ssize_t i
cdef Py_ssize_t n = a.shape[0]
cdef Py_ssize_t m = b.shape[0]
cdef Py_ssize_t width = a.shape[1]

# Allocate distance array
distance = np.zeros((n, m), dtype=np.uint32)

# Create a typed memoryview
cdef uint32_t[:, :] distance_view = distance

# Ensure contiguous

if not a.flags.c_contiguous or not b.flags.c_contiguous:
raise ValueError("Input arrays must be C-contiguous")

# Create typed memoryviews
cdef uint64_t[:, :] a_view = a
cdef uint64_t[:, :] b_view = b
cdef np.ndarray[np.uint32_t, ndim=2, mode='c'] distance = np.zeros((n, m), dtype=np.uint32)
cdef np.ndarray[np.uint64_t, ndim=1] a_row
cdef np.ndarray[np.uint64_t, ndim=2] xor_result
cdef np.ndarray[np.uint8_t, ndim=2] popcounts
cdef np.ndarray[np.uint32_t, ndim=1] distances_i

for i in range(n):
for j in range(m):
dist = 0
for k in range(width):
dist += popcount(a_view[i, k] ^ b_view[j, k])
distance_view[i, j] = dist
a_row = a[i, :]

# XOR 'a_row' and all rows in 'b'
xor_result = np.bitwise_xor(a_row[np.newaxis, :], b)

# Compute popcounts
popcounts = np.bitwise_count(xor_result)

# Sum to get Hamming distance
distances_i = np.sum(popcounts, axis=1, dtype=np.uint32)
distance[i, :] = distances_i

return distance

Expand Down
13 changes: 8 additions & 5 deletions wordllama/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def embed(
num_texts = len(texts)
embedding_dim = self.embedding.shape[1]
np_type = np.float32 if not self.binary else np.uint64
pooled_embeddings = np.empty((num_texts, embedding_dim), dtype=np_type)
pooled_embeddings = np.empty(
(num_texts, embedding_dim if not self.binary else embedding_dim // 64),
dtype=np_type,
)

for i in range(0, num_texts, batch_size):
chunk = texts[i : i + batch_size]
Expand Down Expand Up @@ -209,10 +212,10 @@ def rank(

def deduplicate(
self,
docs: List[str],
threshold: float = 0.9,
return_indices: bool = False,
batch_size: Optional[int] = None
docs: List[str],
threshold: float = 0.9,
return_indices: bool = False,
batch_size: Optional[int] = None,
) -> List[Union[str, int]]:
"""Deduplicate documents based on a similarity threshold.
Expand Down

0 comments on commit dd4cad8

Please sign in to comment.