Skip to content

Commit

Permalink
Fix: Exceptions for missing distances
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Mar 31, 2024
1 parent 267fbb5 commit 349c02d
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 2 deletions.
4 changes: 4 additions & 0 deletions c/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ USEARCH_EXPORT usearch_index_t usearch_init(usearch_init_options_t* options, use
reinterpret_cast<std::uintptr_t>(options->metric), //
metric_punned_signature_t::array_array_k, //
metric_kind, scalar_kind);
if (!metric) {
*error = "Unknown metric kind!";
return NULL;
}

index_dense_t index = index_dense_t::make(metric, config);
index_dense_t* result_ptr = new index_dense_t(std::move(index));
Expand Down
4 changes: 4 additions & 0 deletions include/usearch/index_plugins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1401,8 +1401,12 @@ class metric_punned_t {
inline std::size_t dimensions() const noexcept { return dimensions_; }
inline metric_kind_t metric_kind() const noexcept { return metric_kind_; }
inline scalar_kind_t scalar_kind() const noexcept { return scalar_kind_; }
explicit inline operator bool() const noexcept { return metric_routed_ && metric_ptr_; }

inline char const* isa_name() const noexcept {
if (!*this)
return "uninitialized";

#if USEARCH_USE_SIMSIMD
switch (isa_kind_) {
case simsimd_cap_serial_k: return "serial";
Expand Down
8 changes: 8 additions & 0 deletions java/cloud/unum/usearch/cloud_unum_usearch_Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ JNIEXPORT jlong JNICALL Java_cloud_unum_usearch_Index_c_1create( //
index_dense_config_t config(static_cast<std::size_t>(connectivity), static_cast<std::size_t>(expansion_add),
static_cast<std::size_t>(expansion_search));
metric_punned_t metric(static_cast<std::size_t>(dimensions), metric_kind, quantization);
if (!metric) {
jclass jc = (*env).FindClass("java/lang/Error");
if (jc)
(*env).ThrowNew(jc, "Failed to initialize the metric!");
goto cleanup;
}

index_dense_t index = index_dense_t::make(metric, config);
if (!index.reserve(static_cast<std::size_t>(capacity))) {
jclass jc = (*env).FindClass("java/lang/Error");
Expand All @@ -47,6 +54,7 @@ JNIEXPORT jlong JNICALL Java_cloud_unum_usearch_Index_c_1create( //
(*env).ThrowNew(jc, "Failed to initialize the vector index!");
}

cleanup:
(*env).ReleaseStringUTFChars(metric, metric_cstr);
(*env).ReleaseStringUTFChars(quantization, quantization_cstr);
return result;
Expand Down
11 changes: 10 additions & 1 deletion javascript/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,13 @@ CompiledIndex::CompiledIndex(Napi::CallbackInfo const& ctx) : Napi::ObjectWrap<C
bool multi = ctx[6].As<Napi::Boolean>().Value();

metric_punned_t metric(dimensions, metric_kind, quantization);
if (!metric) {
Napi::TypeError::New(ctx.Env(), "Failed to initialize the metric!").ThrowAsJavaScriptException();
return;
}

index_dense_config_t config(connectivity, expansion_add, expansion_search);
config.multi = multi;

native_.reset(new index_dense_t(index_dense_t::make(metric, config)));
if (!native_)
Napi::Error::New(ctx.Env(), "Out of memory!").ThrowAsJavaScriptException();
Expand Down Expand Up @@ -289,6 +293,11 @@ Napi::Value exactSearch(Napi::CallbackInfo const& ctx) {
}

metric_punned_t metric(dimensions, metric_kind, quantization);
if (!metric) {
Napi::TypeError::New(env, "Failed to initialize the metric!").ThrowAsJavaScriptException();
return;
}

executor_default_t executor;
exact_search_t search;

Expand Down
6 changes: 6 additions & 0 deletions objc/USearchObjective.mm
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ + (instancetype)make:(USearchMetric)metricKind dimensions:(UInt32)dimensions con

index_config_t config(static_cast<std::size_t>(connectivity));
metric_punned_t metric(dims, to_native_metric(metricKind), to_native_scalar(quantization));
if (!metric) {
@throw [NSException exceptionWithName:@"Can't create an index"
reason:@"The metric is not supported"
userInfo:nil];
}

shared_index_dense_t ptr = std::make_shared<index_dense_t>(index_dense_t::make(metric, config));
return [[USearchIndex alloc] initWithIndex:ptr];
}
Expand Down
8 changes: 7 additions & 1 deletion python/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ static dense_index_py_t make_index( //
metric_uintptr //
? metric_t(dimensions, metric_uintptr, metric_signature, metric_kind, scalar_kind)
: metric_t(dimensions, metric_kind, scalar_kind);
if (!metric)
throw std::invalid_argument("Unsupported metric!");

return index_dense_t::make(metric, config);
}

Expand Down Expand Up @@ -485,6 +488,8 @@ static py::tuple search_many_brute_force( //
metric_uintptr //
? metric_t(dimensions, metric_uintptr, metric_signature, metric_kind, queries_kind)
: metric_t(dimensions, metric_kind, queries_kind);
if (!metric)
throw std::invalid_argument("Unsupported metric!");

py::array_t<dense_key_t> keys_py({static_cast<Py_ssize_t>(queries_count), static_cast<Py_ssize_t>(wanted)});
py::array_t<distance_t> distances_py({static_cast<Py_ssize_t>(queries_count), static_cast<Py_ssize_t>(wanted)});
Expand Down Expand Up @@ -1098,6 +1103,8 @@ PYBIND11_MODULE(compiled, m) {
metric_uintptr //
? metric_t(dimensions, metric_uintptr, metric_signature, metric_kind, scalar_kind)
: metric_t(dimensions, metric_kind, scalar_kind);
if (!metric)
throw std::invalid_argument("Unsupported metric kind!");
index.change_metric(std::move(metric));
},
py::arg("metric_kind") = metric_kind_t::cos_k, //
Expand Down Expand Up @@ -1231,4 +1238,3 @@ PYBIND11_MODULE(compiled, m) {
py::arg("progress") = nullptr //
);
}

89 changes: 89 additions & 0 deletions python/scripts/test_distances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest
import numpy as np

from usearch.eval import random_vectors
from usearch.index import search

from usearch.index import (
Index,
MetricKind,
ScalarKind,
)


@pytest.mark.parametrize(
"metric",
[
MetricKind.Cos,
MetricKind.L2sq,
MetricKind.Divergence,
MetricKind.Pearson,
],
)
@pytest.mark.parametrize(
"quantization",
[
ScalarKind.F32,
ScalarKind.F16,
ScalarKind.I8,
],
)
@pytest.mark.parametrize(
"dtype",
[
np.float32,
np.float64,
np.float16,
np.int8,
],
)
def test_distances_continuous(metric, quantization, dtype):
ndim = 1024
try:
index = Index(ndim=ndim, metric=metric, dtype=quantization)
vectors = random_vectors(count=2, ndim=ndim, dtype=dtype)
keys = np.arange(2)
index.add(keys, vectors)
except ValueError:
pytest.skip(f"Unsupported metric `{metric}`, quantization `{quantization}`, dtype `{dtype}`")
return

rtol = 1e-2
atol = 1e-2

distance_itself_first = index.pairwise_distance([0], [0])
distance_itself_second = index.pairwise_distance([1], [1])
distance_different = index.pairwise_distance([0], [1])

assert not np.allclose(distance_different, 0)
assert np.allclose(distance_itself_first, 0, rtol=rtol, atol=atol) and np.allclose(
distance_itself_second, 0, rtol=rtol, atol=atol
)


@pytest.mark.parametrize(
"metric",
[
MetricKind.Hamming,
MetricKind.Tanimoto,
MetricKind.Sorensen,
],
)
def test_distances_sparse(metric):
ndim = 1024
index = Index(ndim=ndim, metric=metric, dtype=ScalarKind.B1)
vectors = random_vectors(count=2, ndim=ndim, dtype=ScalarKind.B1)
keys = np.arange(2)
index.add(keys, vectors)

rtol = 1e-2
atol = 1e-2

distance_itself_first = index.pairwise_distance([0], [0])
distance_itself_second = index.pairwise_distance([1], [1])
distance_different = index.pairwise_distance([0], [1])

assert not np.allclose(distance_different, 0)
assert np.allclose(distance_itself_first, 0, rtol=rtol, atol=atol) and np.allclose(
distance_itself_second, 0, rtol=rtol, atol=atol
)
2 changes: 2 additions & 0 deletions rust/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ std::unique_ptr<NativeIndex> new_native_index(IndexOptions const& options) {
metric_kind_t metric_kind = rust_to_cpp_metric(options.metric);
scalar_kind_t scalar_kind = rust_to_cpp_scalar(options.quantization);
metric_punned_t metric(options.dimensions, metric_kind, scalar_kind);
if (!metric)
throw std::invalid_argument("Unsupported metric or scalar type");
index_dense_config_t config(options.connectivity, options.expansion_add, options.expansion_search);
config.multi = options.multi;
return wrap(index_t::make(metric, config));
Expand Down

0 comments on commit 349c02d

Please sign in to comment.