diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index 33e1849568..9cb1bb6d32 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -51,6 +51,7 @@ set(FAISS_SRC VectorTransform.cpp clone_index.cpp index_factory.cpp + FaissHook.cpp impl/AuxIndexStructures.cpp impl/CodePacker.cpp impl/IDSelector.cpp @@ -145,6 +146,7 @@ set(FAISS_HEADERS clone_index.h index_factory.h index_io.h + FaissHook.h impl/AdditiveQuantizer.h impl/AuxIndexStructures.h impl/CodePacker.h diff --git a/faiss/FaissHook.cpp b/faiss/FaissHook.cpp new file mode 100644 index 0000000000..47c5bcfc3f --- /dev/null +++ b/faiss/FaissHook.cpp @@ -0,0 +1,40 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include "FaissHook.h" + +namespace faiss { + +extern float fvec_L2sqr_default(const float* x, const float* y, size_t d); + +extern float fvec_inner_product_default( + const float* x, + const float* y, + size_t d); + +FVEC_L2SQR_HOOK fvec_L2sqr_hook = fvec_L2sqr_default; +FVEC_INNER_PRODUCT_HOOK fvec_inner_product_hook = fvec_inner_product_default; + +void set_fvec_L2sqr_hook(FVEC_L2SQR_HOOK_C hook) { + if (nullptr != hook) + fvec_L2sqr_hook = hook; +} +FVEC_L2SQR_HOOK_C get_fvec_L2sqr_hook() { + return fvec_L2sqr_hook; +} + +void set_fvec_inner_product_hook(FVEC_INNER_PRODUCT_HOOK_C hook) { + if (nullptr != hook) + fvec_inner_product_hook = hook; +} +FVEC_INNER_PRODUCT_HOOK_C get_fvec_inner_product_hook() { + return fvec_inner_product_hook; +} + +} // namespace faiss diff --git a/faiss/FaissHook.h b/faiss/FaissHook.h new file mode 100644 index 0000000000..0ae9c41020 --- /dev/null +++ b/faiss/FaissHook.h @@ -0,0 +1,41 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#pragma once + +#include +#include "faiss/impl/platform_macros.h" + +namespace faiss { + +using FVEC_L2SQR_HOOK = float (*)(const float*, const float*, size_t); + +using FVEC_INNER_PRODUCT_HOOK = float (*)(const float*, const float*, size_t); + +extern FVEC_L2SQR_HOOK fvec_L2sqr_hook; +extern FVEC_INNER_PRODUCT_HOOK fvec_inner_product_hook; + +#ifdef __cplusplus +extern "C" { +#endif + +typedef float (*FVEC_L2SQR_HOOK_C)(const float*, const float*, size_t); +typedef float (*FVEC_INNER_PRODUCT_HOOK_C)(const float*, const float*, size_t); + +FAISS_API void set_fvec_L2sqr_hook(FVEC_L2SQR_HOOK_C hook); +FAISS_API FVEC_L2SQR_HOOK_C get_fvec_L2sqr_hook(); + +FAISS_API void set_fvec_inner_product_hook(FVEC_INNER_PRODUCT_HOOK_C hook); +FAISS_API FVEC_INNER_PRODUCT_HOOK_C get_fvec_inner_product_hook(); + +#ifdef __cplusplus +} +#endif + +} // namespace faiss diff --git a/faiss/utils/distances_simd.cpp b/faiss/utils/distances_simd.cpp index 323859f43b..02b107d73f 100644 --- a/faiss/utils/distances_simd.cpp +++ b/faiss/utils/distances_simd.cpp @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -186,7 +187,7 @@ void fvec_inner_products_ny_ref( */ FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN -float fvec_inner_product(const float* x, const float* y, size_t d) { +float fvec_inner_product_default(const float* x, const float* y, size_t d) { float res = 0.F; FAISS_PRAGMA_IMPRECISE_LOOP for (size_t i = 0; i != d; ++i) { @@ -196,6 +197,10 @@ float fvec_inner_product(const float* x, const float* y, size_t d) { } FAISS_PRAGMA_IMPRECISE_FUNCTION_END +float fvec_inner_product(const float* x, const float* y, size_t d) { + return fvec_inner_product_hook(x, y, d); +} + FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN float fvec_norm_L2sqr(const float* x, size_t d) { // the double in the _ref is suspected to be a typo. Some of the manual @@ -210,8 +215,12 @@ float fvec_norm_L2sqr(const float* x, size_t d) { } FAISS_PRAGMA_IMPRECISE_FUNCTION_END -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN float fvec_L2sqr(const float* x, const float* y, size_t d) { + return fvec_L2sqr_hook(x, y, d); +} + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +float fvec_L2sqr_default(const float* x, const float* y, size_t d) { size_t i; float res = 0; FAISS_PRAGMA_IMPRECISE_LOOP