Skip to content

Commit

Permalink
Fix: JIT compilation in Python
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Mar 29, 2024
1 parent 4ac3509 commit a3287f1
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 89 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/prerelease.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ jobs:
- name: Build Python
run: |
python -m pip install --upgrade pip
pip install pytest numpy
pip install pytest numpy numba cppyy
pip install --upgrade git+https://github.com/Maratyszcza/PeachPy
python -m pip install .
- name: Test Python
run: pytest python/scripts/ -s -x
Expand Down Expand Up @@ -142,7 +143,8 @@ jobs:
- name: Build Python
run: |
python -m pip install --upgrade pip
pip install pytest numpy
pip install pytest numpy numba cppyy
pip install --upgrade git+https://github.com/Maratyszcza/PeachPy
python -m pip install .
- name: Test Python
run: pytest python/scripts/ -s -x
Expand Down
130 changes: 70 additions & 60 deletions include/usearch/index_plugins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1318,15 +1318,17 @@ class metric_punned_t {

private:
/// In the generalized function API all the are arguments are pointer-sized.
using punned_arg_t = std::size_t;
using uptr_t = std::size_t;
/// Distance function that takes two arrays and returns a scalar.
using metric_array_array_t = result_t (*)(uptr_t, uptr_t);
/// Distance function that takes two arrays and their length and returns a scalar.
using punned_ptr_t = result_t (*)(punned_arg_t, punned_arg_t, punned_arg_t);
/// Distance function callback, like `punned_ptr_t`, but depends on member variables.
using routed_ptr_t = result_t (metric_punned_t::*)(punned_arg_t, punned_arg_t, punned_arg_t) const;
using metric_array_array_size_t = result_t (*)(uptr_t, uptr_t, uptr_t);
/// Distance function callback, like `metric_array_array_size_t`, but depends on member variables.
using metric_rounted_t = result_t (metric_punned_t::*)(uptr_t, uptr_t) const;

routed_ptr_t routed_ptr_ = nullptr;
punned_ptr_t raw_ptr_ = nullptr;
punned_arg_t raw_size_ = 0;
metric_rounted_t metric_routed_ = nullptr;
uptr_t metric_ptr_ = 0;
uptr_t metric_size_arg_ = 0;

std::size_t dimensions_ = 0;
metric_kind_t metric_kind_ = metric_kind_t::unknown_k;
Expand All @@ -1343,7 +1345,7 @@ class metric_punned_t {
* ! This is the only relevant function in the object. Everything else is just dynamic dispatch logic.
*/
inline result_t operator()(byte_t const* a, byte_t const* b) const noexcept {
return (this->*routed_ptr_)(reinterpret_cast<punned_arg_t>(a), reinterpret_cast<punned_arg_t>(b), raw_size_);
return (this->*metric_routed_)(reinterpret_cast<uptr_t>(a), reinterpret_cast<uptr_t>(b));
}

inline metric_punned_t() noexcept = default;
Expand All @@ -1353,8 +1355,8 @@ class metric_punned_t {
std::size_t dimensions, //
metric_kind_t metric_kind = metric_kind_t::l2sq_k, //
scalar_kind_t scalar_kind = scalar_kind_t::f32_k) noexcept
: routed_ptr_(&metric_punned_t::invoke_autovec), raw_size_(dimensions), dimensions_(dimensions),
metric_kind_(metric_kind), scalar_kind_(scalar_kind) {
: metric_routed_(&metric_punned_t::invoke_array_array_size), metric_size_arg_(dimensions),
dimensions_(dimensions), metric_kind_(metric_kind), scalar_kind_(scalar_kind) {

#if USEARCH_USE_SIMSIMD
if (!configure_with_simsimd())
Expand All @@ -1364,19 +1366,22 @@ class metric_punned_t {
#endif

if (scalar_kind == scalar_kind_t::b1x8_k)
raw_size_ = divide_round_up<CHAR_BIT>(dimensions_);
metric_size_arg_ = divide_round_up<CHAR_BIT>(dimensions_);
}

inline metric_punned_t( //
std::size_t dimensions, //
std::uintptr_t metric_uintptr, metric_punned_signature_t signature, //
metric_kind_t metric_kind, //
scalar_kind_t scalar_kind) noexcept
: routed_ptr_(&metric_punned_t::invoke_autovec), raw_ptr_(reinterpret_cast<punned_ptr_t>(metric_uintptr)),
dimensions_(dimensions), metric_kind_(metric_kind), scalar_kind_(scalar_kind) {
: metric_routed_(signature == metric_punned_signature_t::array_array_k
? &metric_punned_t::invoke_array_array
: &metric_punned_t::invoke_array_array_size),
metric_ptr_(metric_uintptr), metric_size_arg_(dimensions), dimensions_(dimensions), metric_kind_(metric_kind),
scalar_kind_(scalar_kind) {

// We don't need to explicitly parse signature, as all of them are compatible.
(void)signature;
if (scalar_kind == scalar_kind_t::b1x8_k)
metric_size_arg_ = divide_round_up<CHAR_BIT>(dimensions_);
}

inline std::size_t dimensions() const noexcept { return dimensions_; }
Expand Down Expand Up @@ -1436,107 +1441,112 @@ class metric_punned_t {
if (simd_metric == nullptr)
return false;

std::memcpy(&raw_ptr_, &simd_metric, sizeof(simd_metric));
routed_ptr_ = metric_kind_ == metric_kind_t::ip_k
? reinterpret_cast<routed_ptr_t>(&metric_punned_t::invoke_simsimd_reverse)
: reinterpret_cast<routed_ptr_t>(&metric_punned_t::invoke_simsimd);
std::memcpy(&metric_ptr_, &simd_metric, sizeof(simd_metric));
metric_routed_ = metric_kind_ == metric_kind_t::ip_k
? reinterpret_cast<metric_rounted_t>(&metric_punned_t::invoke_simsimd_reverse)
: reinterpret_cast<metric_rounted_t>(&metric_punned_t::invoke_simsimd);
isa_kind_ = simd_kind;
return true;
}
bool configure_with_simsimd() noexcept {
static simsimd_capability_t static_capabilities = simsimd_capabilities();
return configure_with_simsimd(static_capabilities);
}
result_t invoke_simsimd(punned_arg_t a, punned_arg_t b, punned_arg_t size) const noexcept {
result_t invoke_simsimd(uptr_t a, uptr_t b) const noexcept {
simsimd_distance_t result;
// Here `reinterpret_cast` raises warning... we know what we are doing!
auto function_pointer = (simsimd_metric_punned_t)(void*)(raw_ptr_);
function_pointer(reinterpret_cast<void const*>(a), reinterpret_cast<void const*>(b), size, &result);
auto function_pointer = (simsimd_metric_punned_t)(metric_ptr_);
function_pointer(reinterpret_cast<void const*>(a), reinterpret_cast<void const*>(b), metric_size_arg_, &result);
return (result_t)result;
}
result_t invoke_simsimd_reverse(punned_arg_t a, punned_arg_t b, punned_arg_t size) const noexcept {
return 1 - invoke_simsimd(a, b, size);
}
result_t invoke_simsimd_reverse(uptr_t a, uptr_t b) const noexcept { return 1 - invoke_simsimd(a, b); }
#else
bool configure_with_simsimd() noexcept { return false; }
#endif
result_t invoke_autovec(punned_arg_t a, punned_arg_t b, punned_arg_t n) const noexcept { return raw_ptr_(a, b, n); }
result_t invoke_array_array_size(uptr_t a, uptr_t b) const noexcept {
auto function_pointer = (metric_array_array_size_t)(metric_ptr_);
result_t result = function_pointer(a, b, metric_size_arg_);
return result;
}
result_t invoke_array_array(uptr_t a, uptr_t b) const noexcept {
auto function_pointer = (metric_array_array_t)(metric_ptr_);
result_t result = function_pointer(a, b);
return result;
}
void configure_with_autovec() noexcept {
switch (metric_kind_) {
case metric_kind_t::ip_k: {
switch (scalar_kind_) {
case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_ip_gt<f32_t>>; break;
case scalar_kind_t::f16_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_ip_gt<f16_t, f32_t>>; break;
case scalar_kind_t::i8_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_ip_gt<i8_t, f32_t>>; break;
case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_ip_gt<f64_t>>; break;
default: raw_ptr_ = nullptr; break;
case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_ip_gt<f32_t>>; break;
case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_ip_gt<f16_t, f32_t>>; break;
case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_ip_gt<i8_t, f32_t>>; break;
case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_ip_gt<f64_t>>; break;
default: metric_ptr_ = 0; break;
}
break;
}
case metric_kind_t::cos_k: {
switch (scalar_kind_) {
case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_cos_gt<f32_t>>; break;
case scalar_kind_t::f16_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_cos_gt<f16_t, f32_t>>; break;
case scalar_kind_t::i8_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_cos_gt<i8_t, f32_t>>; break;
case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_cos_gt<f64_t>>; break;
default: raw_ptr_ = nullptr; break;
case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_cos_gt<f32_t>>; break;
case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_cos_gt<f16_t, f32_t>>; break;
case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_cos_gt<i8_t, f32_t>>; break;
case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_cos_gt<f64_t>>; break;
default: metric_ptr_ = 0; break;
}
break;
}
case metric_kind_t::l2sq_k: {
switch (scalar_kind_) {
case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_l2sq_gt<f32_t>>; break;
case scalar_kind_t::f16_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_l2sq_gt<f16_t, f32_t>>; break;
case scalar_kind_t::i8_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_l2sq_gt<i8_t, f32_t>>; break;
case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_l2sq_gt<f64_t>>; break;
default: raw_ptr_ = nullptr; break;
case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_l2sq_gt<f32_t>>; break;
case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_l2sq_gt<f16_t, f32_t>>; break;
case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_l2sq_gt<i8_t, f32_t>>; break;
case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_l2sq_gt<f64_t>>; break;
default: metric_ptr_ = 0; break;
}
break;
}
case metric_kind_t::pearson_k: {
switch (scalar_kind_) {
case scalar_kind_t::i8_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_pearson_gt<i8_t, f32_t>>; break;
case scalar_kind_t::f16_k:
raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_pearson_gt<f16_t, f32_t>>;
break;
case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_pearson_gt<f32_t>>; break;
case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_pearson_gt<f64_t>>; break;
default: raw_ptr_ = nullptr; break;
case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_pearson_gt<i8_t, f32_t>>; break;
case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_pearson_gt<f16_t, f32_t>>; break;
case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_pearson_gt<f32_t>>; break;
case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_pearson_gt<f64_t>>; break;
default: metric_ptr_ = 0; break;
}
break;
}
case metric_kind_t::haversine_k: {
switch (scalar_kind_) {
case scalar_kind_t::f16_k:
raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_haversine_gt<f16_t, f32_t>>;
metric_ptr_ = (uptr_t)&equidimensional_<metric_haversine_gt<f16_t, f32_t>>;
break;
case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_haversine_gt<f32_t>>; break;
case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_haversine_gt<f64_t>>; break;
default: raw_ptr_ = nullptr; break;
case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_haversine_gt<f32_t>>; break;
case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_haversine_gt<f64_t>>; break;
default: metric_ptr_ = 0; break;
}
break;
}
case metric_kind_t::divergence_k: {
switch (scalar_kind_) {
case scalar_kind_t::f16_k:
raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_divergence_gt<f16_t, f32_t>>;
metric_ptr_ = (uptr_t)&equidimensional_<metric_divergence_gt<f16_t, f32_t>>;
break;
case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_divergence_gt<f32_t>>; break;
case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_divergence_gt<f64_t>>; break;
default: raw_ptr_ = nullptr; break;
case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_divergence_gt<f32_t>>; break;
case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_divergence_gt<f64_t>>; break;
default: metric_ptr_ = 0; break;
}
break;
}
case metric_kind_t::jaccard_k: // Equivalent to Tanimoto
case metric_kind_t::tanimoto_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_tanimoto_gt<b1x8_t>>; break;
case metric_kind_t::hamming_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_hamming_gt<b1x8_t>>; break;
case metric_kind_t::sorensen_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_sorensen_gt<b1x8_t>>; break;
case metric_kind_t::tanimoto_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_tanimoto_gt<b1x8_t>>; break;
case metric_kind_t::hamming_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_hamming_gt<b1x8_t>>; break;
case metric_kind_t::sorensen_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_sorensen_gt<b1x8_t>>; break;
default: return;
}
}

template <typename typed_at>
inline static result_t equidimensional_(punned_arg_t a, punned_arg_t b, punned_arg_t a_dimensions) noexcept {
inline static result_t equidimensional_(uptr_t a, uptr_t b, uptr_t a_dimensions) noexcept {
using scalar_t = typename typed_at::scalar_t;
return typed_at{}((scalar_t const*)a, (scalar_t const*)b, a_dimensions);
}
Expand Down
49 changes: 22 additions & 27 deletions python/scripts/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_index_numba(ndim: int, batch_size: int):
try:
from numba import cfunc, types, carray
except ImportError:
pytest.skip("Numba is not installed.")
return

# Showcases how to use Numba to JIT-compile similarity measures for USearch.
Expand All @@ -33,12 +34,6 @@ def test_index_numba(ndim: int, batch_size: int):
types.CPointer(types.float32),
types.uint64,
)
signature_four_args = types.float32(
types.CPointer(types.float32),
types.uint64,
types.CPointer(types.float32),
types.uint64,
)

@cfunc(signature_two_args)
def python_inner_product_two_args(a, b):
Expand Down Expand Up @@ -72,23 +67,28 @@ def python_inner_product_three_args(a, b, ndim):
kind=MetricKind.IP,
signature=signature,
)
index = Index(ndim=ndim, metric=metric)
index = Index(ndim=ndim, metric=metric, dtype=np.float32)

keys = np.arange(batch_size)
vectors = random_vectors(count=batch_size, ndim=ndim)

index.add(keys, vectors)
matches = index.search(vectors, 10)
matches = index.search(vectors, 10, exact=True)
assert len(matches) == batch_size

matches_keys = [match[0].key for match in matches] if batch_size > 1 else [matches[0].key]
assert all(matches_keys[i] == keys[i] for i in range(batch_size)), f"Received {matches_keys}"


# Just one size for Cppyy to avoid redefining kernels in the global namespace
@pytest.mark.parametrize("ndim", dimensions[-1:])
@pytest.mark.parametrize("batch_size", batch_sizes[-1:])
def test_index_cppyy(ndim: int, batch_size: int):
try:
import cppyy
import cppyy.ll
except ImportError:
pytest.skip("cppyy is not installed.")
return

cppyy.cppdef(
Expand All @@ -107,13 +107,6 @@ def test_index_cppyy(ndim: int, batch_size: int):
result += a[i] * b[i];
return 1 - result;
}
float inner_product_four_args(float *a, size_t an, float *b, size_t) {
float result = 0;
for (size_t i = 0; i != an; ++i)
result += a[i] * b[i];
return 1 - result;
}
""".replace(
"ndim", str(ndim)
)
Expand All @@ -122,28 +115,29 @@ def test_index_cppyy(ndim: int, batch_size: int):
functions = [
cppyy.gbl.inner_product_two_args,
cppyy.gbl.inner_product_three_args,
# cppyy.gbl.inner_product_four_args,
]
signatures = [
MetricSignature.ArrayArray,
MetricSignature.ArrayArraySize,
# MetricSignature.ArraySizeArraySize,
]
for function, signature in zip(functions, signatures):
metric = CompiledMetric(
pointer=cppyy.ll.addressof(function),
kind=MetricKind.IP,
signature=signature,
)
index = Index(ndim=ndim, metric=metric)
index = Index(ndim=ndim, metric=metric, dtype=np.float32)

keys = np.arange(batch_size)
vectors = random_vectors(count=batch_size, ndim=ndim)
vectors = random_vectors(count=batch_size, ndim=ndim, dtype=np.float32)

index.add(keys, vectors)
matches = index.search(vectors, 10)
matches = index.search(vectors, 10, exact=True)
assert len(matches) == batch_size

matches_keys = [match[0].key for match in matches] if batch_size > 1 else [matches[0].key]
assert all(matches_keys[i] == keys[i] for i in range(batch_size)), f"Received {matches_keys}"


@pytest.mark.parametrize("ndim", [8])
@pytest.mark.parametrize("batch_size", batch_sizes)
Expand Down Expand Up @@ -173,14 +167,13 @@ def test_index_peachpy(ndim: int, batch_size: int):
RETURN,
)
except ImportError:
pytest.skip("PeachPy is not installed.")
return

a = Argument(ptr(const_float_), name="a")
b = Argument(ptr(const_float_), name="b")

with Function(
"InnerProduct", (a, b), float_, target=uarch.default + isa.avx + isa.avx2
) as asm_function:
with Function("InnerProduct", (a, b), float_, target=uarch.default + isa.avx + isa.avx2) as asm_function:
# Request two 64-bit general-purpose registers for addresses
reg_a, reg_b = GeneralPurposeRegister64(), GeneralPurposeRegister64()
LOAD.ARGUMENT(reg_a, a)
Expand Down Expand Up @@ -220,12 +213,14 @@ def test_index_peachpy(ndim: int, batch_size: int):
kind=MetricKind.IP,
signature=MetricSignature.ArrayArray,
)
index = Index(ndim=ndim, metric=metric)
index = Index(ndim=ndim, metric=metric, dtype=np.float32)

keys = np.arange(batch_size)
vectors = random_vectors(count=batch_size, ndim=ndim)

index.add(keys, vectors)
matches, distances, count = index.search(vectors, 10)
assert matches.shape[0] == distances.shape[0]
assert count.shape[0] == batch_size
matches = index.search(vectors, 10, exact=True)
assert len(matches) == batch_size

matches_keys = [match[0].key for match in matches] if batch_size > 1 else [matches[0].key]
assert all(matches_keys[i] == keys[i] for i in range(batch_size)), f"Received {matches_keys}"

0 comments on commit a3287f1

Please sign in to comment.