-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from dleemiller/feature/cython-extensions
Feature/cython extensions
- Loading branch information
Showing
15 changed files
with
377 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
name: CI | ||
|
||
on: [push, pull_request] | ||
|
||
jobs: | ||
build: | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
matrix: | ||
os: [ubuntu-latest, macos-latest, windows-latest] | ||
python-version: ['3.8', '3.9', '3.10', '3.11'] | ||
|
||
steps: | ||
- uses: actions/checkout@v2 | ||
|
||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install setuptools wheel setuptools_scm Cython numpy | ||
- name: Build and install package | ||
run: | | ||
python setup.py build_ext --inplace | ||
python -m pip install . | ||
- name: Test installation | ||
run: | | ||
python -m unittest discover -s tests | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
name: Publish to PyPI | ||
|
||
on: | ||
push: | ||
tags: | ||
- 'v*.*.*' | ||
|
||
jobs: | ||
build-and-publish: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v2 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: '3.9' | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install setuptools wheel twine setuptools_scm Cython numpy | ||
- name: Build package | ||
run: | | ||
python setup.py sdist bdist_wheel | ||
- name: Publish package to PyPI | ||
env: | ||
TWINE_USERNAME: __token__ | ||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} | ||
run: | | ||
twine upload dist/* | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
include LICENSE | ||
include README.md | ||
recursive-include wordllama *.py *.toml *.json | ||
include wordllama/algorithms/*.pyx | ||
include wordllama/algorithms/*.pxd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from setuptools import setup, Extension | ||
from Cython.Build import cythonize | ||
import numpy as np | ||
import platform | ||
import sys | ||
|
||
numpy_include = np.get_include() | ||
|
||
extra_compile_args = [] | ||
extra_link_args = [] | ||
|
||
if platform.system() == "Darwin": | ||
if platform.machine() == "arm64": | ||
extra_compile_args.extend(["-arch", "arm64", "-O3", "-ffast-math"]) | ||
extra_link_args.extend(["-arch", "arm64"]) | ||
else: | ||
extra_compile_args.extend(["-arch", "x86_64", "-O3", "-ffast-math"]) | ||
extra_link_args.extend(["-arch", "x86_64"]) | ||
elif platform.system() == "Windows": | ||
extra_compile_args.extend(["/O2"]) | ||
else: # Linux and others | ||
if platform.machine().startswith("arm"): | ||
if platform.architecture()[0] == "32bit": | ||
extra_compile_args.extend(["-march=armv7-a", "-mfpu=neon"]) | ||
extra_link_args.extend(["-march=armv7-a", "-mfpu=neon"]) | ||
else: # 64-bit ARM | ||
extra_compile_args.extend(["-march=armv8-a"]) | ||
extra_link_args.extend(["-march=armv8-a"]) | ||
elif platform.machine() in ["x86_64", "AMD64"]: | ||
extra_compile_args.extend(["-march=native", "-mpopcnt"]) | ||
extra_link_args.extend(["-march=native", "-mpopcnt"]) | ||
|
||
extra_compile_args.extend(["-O3", "-ffast-math"]) | ||
|
||
extensions = [ | ||
Extension( | ||
"wordllama.algorithms.splitter", | ||
["wordllama/algorithms/splitter.pyx"], | ||
extra_compile_args=extra_compile_args, | ||
extra_link_args=extra_link_args, | ||
), | ||
Extension( | ||
"wordllama.algorithms.hamming_distance", | ||
["wordllama/algorithms/hamming_distance.pyx"], | ||
include_dirs=[numpy_include], | ||
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")], | ||
extra_compile_args=extra_compile_args, | ||
extra_link_args=extra_link_args, | ||
), | ||
Extension( | ||
"wordllama.algorithms.deduplicate_helpers", | ||
["wordllama/algorithms/deduplicate_helpers.pyx"], | ||
include_dirs=[numpy_include], | ||
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")], | ||
extra_compile_args=extra_compile_args, | ||
extra_link_args=extra_link_args, | ||
), | ||
Extension( | ||
"wordllama.algorithms.kmeans_helpers", | ||
["wordllama/algorithms/kmeans_helpers.pyx"], | ||
include_dirs=[numpy_include], | ||
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")], | ||
extra_compile_args=extra_compile_args, | ||
extra_link_args=extra_link_args, | ||
), | ||
] | ||
|
||
setup( | ||
name="Text Processing Tools", | ||
ext_modules=cythonize( | ||
extensions, | ||
compiler_directives={ | ||
"language_level": "3", | ||
"boundscheck": False, | ||
"wraparound": False, | ||
}, | ||
annotate=True, | ||
), | ||
zip_safe=False, | ||
install_requires=["numpy"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from .kmeans import kmeans_clustering | ||
from .hamming_distance import hamming_distance | ||
from .deduplicate_helpers import process_batches_cy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# cython: language_level=3, boundscheck=False, wraparound=False | ||
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION | ||
import numpy as np | ||
cimport numpy as np | ||
from numpy cimport PyArray_DIMS | ||
|
||
ctypedef fused embedding_dtype: | ||
np.uint32_t | ||
np.float32_t | ||
np.float64_t | ||
|
||
def process_batches_cy(np.ndarray[embedding_dtype, ndim=2] doc_embeddings, | ||
double threshold, int batch_size, vector_similarity): | ||
cdef int num_embeddings = PyArray_DIMS(doc_embeddings)[0] | ||
cdef set duplicate_indices = set() | ||
cdef set seen_docs = set() | ||
cdef int i, j, start_i, end_i, start_j, end_j | ||
cdef np.ndarray[embedding_dtype, ndim=2] batch_i, batch_j | ||
cdef np.ndarray[double, ndim=2] sim_matrix | ||
cdef np.ndarray[np.int64_t, ndim=2] sim_indices | ||
cdef int doc_idx_1, doc_idx_2 | ||
|
||
for i in range(0, num_embeddings, batch_size): | ||
start_i = i | ||
end_i = min(i + batch_size, num_embeddings) | ||
batch_i = doc_embeddings[start_i:end_i] | ||
for j in range(i, num_embeddings, batch_size): | ||
start_j = j | ||
end_j = min(j + batch_size, num_embeddings) | ||
batch_j = doc_embeddings[start_j:end_j] | ||
sim_matrix = vector_similarity(batch_i, batch_j) | ||
sim_indices = np.argwhere(sim_matrix > threshold) | ||
for idx in sim_indices: | ||
if idx[0] + start_i != idx[1] + start_j: | ||
doc_idx_1 = idx[0] + start_i | ||
doc_idx_2 = idx[1] + start_j | ||
if doc_idx_2 not in seen_docs: | ||
seen_docs.add(doc_idx_1) | ||
duplicate_indices.add(doc_idx_2) | ||
|
||
return duplicate_indices | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# cython: language_level=3, boundscheck=False, wraparound=False | ||
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION | ||
|
||
import numpy as np | ||
cimport numpy as np | ||
from numpy cimport int32_t, uint32_t, uint8_t, PyArrayObject, PyArray_DIMS | ||
from libc.stdint cimport uint32_t, uint8_t | ||
|
||
np.import_array() | ||
|
||
cdef extern from *: | ||
""" | ||
#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) | ||
#include <x86intrin.h> | ||
static inline int popcount(uint32_t x) { | ||
return __builtin_popcount(x); | ||
} | ||
#elif defined(__GNUC__) && (defined(__ARM_NEON) || defined(__aarch64__)) | ||
#include <arm_neon.h> | ||
static inline int popcount(uint32_t x) { | ||
return vaddv_u8(vcnt_u8(vcreate_u8(x))); | ||
} | ||
#else | ||
static inline int popcount(uint32_t x) { | ||
x = x - ((x >> 1) & 0x55555555); | ||
x = (x & 0x33333333) + ((x >> 2) & 0x33333333); | ||
x = (x + (x >> 4)) & 0x0F0F0F0F; | ||
x = x + (x >> 8); | ||
x = x + (x >> 16); | ||
return x & 0x0000003F; | ||
} | ||
#endif | ||
""" | ||
int popcount(uint32_t x) nogil | ||
|
||
cpdef np.ndarray[int32_t, ndim=2] hamming_distance(np.ndarray[uint32_t, ndim=2] a, np.ndarray[uint32_t, ndim=2] b): | ||
cdef Py_ssize_t i, j, k | ||
cdef int dist | ||
cdef Py_ssize_t n = PyArray_DIMS(a)[0] | ||
cdef Py_ssize_t m = PyArray_DIMS(b)[0] | ||
cdef Py_ssize_t width = PyArray_DIMS(a)[1] | ||
cdef np.ndarray[int32_t, ndim=2] distance = np.zeros((n, m), dtype=np.int32) | ||
|
||
# Calculate Hamming distance | ||
for i in range(n): | ||
for j in range(m): | ||
dist = 0 | ||
for k in range(width): | ||
dist += popcount(a[i, k] ^ b[j, k]) | ||
distance[i, j] = dist | ||
|
||
return distance | ||
|
Oops, something went wrong.