From 1106f22f0de76db094ba78519fd80a1a09e2122f Mon Sep 17 00:00:00 2001 From: Zinan Lin Date: Mon, 25 Nov 2024 00:20:33 -0800 Subject: [PATCH] add auto backend for NN search --- .../pe.histogram.nearest_neighbor_backend.rst | 8 ++++ example/image/camelyon17.py | 1 - example/image/cat.py | 1 - example/image/cifar10.py | 1 - pe/histogram/nearest_neighbor_backend/auto.py | 37 +++++++++++++++++++ pe/histogram/nearest_neighbors.py | 13 +++++-- 6 files changed, 54 insertions(+), 7 deletions(-) create mode 100644 pe/histogram/nearest_neighbor_backend/auto.py diff --git a/doc/source/api/pe.histogram.nearest_neighbor_backend.rst b/doc/source/api/pe.histogram.nearest_neighbor_backend.rst index 6377013..ac831ec 100644 --- a/doc/source/api/pe.histogram.nearest_neighbor_backend.rst +++ b/doc/source/api/pe.histogram.nearest_neighbor_backend.rst @@ -9,6 +9,14 @@ pe.histogram.nearest\_neighbor\_backend package Submodules ---------- +pe.histogram.nearest\_neighbor\_backend.auto module +--------------------------------------------------- + +.. automodule:: pe.histogram.nearest_neighbor_backend.auto + :members: + :undoc-members: + :show-inheritance: + pe.histogram.nearest\_neighbor\_backend.faiss module ---------------------------------------------------- diff --git a/example/image/camelyon17.py b/example/image/camelyon17.py index 79f6674..a04970e 100644 --- a/example/image/camelyon17.py +++ b/example/image/camelyon17.py @@ -37,7 +37,6 @@ lookahead_log_folder=os.path.join(exp_folder, "lookahead"), voting_details_log_folder=os.path.join(exp_folder, "voting_details"), api=api, - backend="faiss", ) population = PEPopulation(api=api, histogram_threshold=4) diff --git a/example/image/cat.py b/example/image/cat.py index 0516419..3e51e4c 100644 --- a/example/image/cat.py +++ b/example/image/cat.py @@ -38,7 +38,6 @@ lookahead_log_folder=os.path.join(exp_folder, "lookahead"), voting_details_log_folder=os.path.join(exp_folder, "voting_details"), api=api, - backend="faiss", ) population = PEPopulation(api=api, histogram_threshold=2) diff --git a/example/image/cifar10.py b/example/image/cifar10.py index fbe38c9..4daafd7 100644 --- a/example/image/cifar10.py +++ b/example/image/cifar10.py @@ -37,7 +37,6 @@ lookahead_log_folder=os.path.join(exp_folder, "lookahead"), voting_details_log_folder=os.path.join(exp_folder, "voting_details"), api=api, - backend="faiss", ) population = PEPopulation(api=api, histogram_threshold=10) diff --git a/pe/histogram/nearest_neighbor_backend/auto.py b/pe/histogram/nearest_neighbor_backend/auto.py new file mode 100644 index 0000000..139cddb --- /dev/null +++ b/pe/histogram/nearest_neighbor_backend/auto.py @@ -0,0 +1,37 @@ +import traceback +from pe.logging import execution_logger + + +def search(syn_embedding, priv_embedding, num_nearest_neighbors, mode): + """Compute the nearest neighbors of the private embedding in the synthetic embedding using Faiss. If Faiss is not + installed or an error occurs, fall back to the sklearn backend. + + :param syn_embedding: The synthetic embedding + :type syn_embedding: np.ndarray + :param priv_embedding: The private embedding + :type priv_embedding: np.ndarray + :param num_nearest_neighbors: The number of nearest neighbors to search + :type num_nearest_neighbors: int + :param mode: The distance metric to use for finding the nearest neighbors. It should be one of the following: + "l2" (l2 distance), "cos_sim" (cosine similarity), "ip" (inner product, not supported by sklearn) + :type mode: str + :raises ValueError: If the mode is unknown + :return: The distances and indices of the nearest neighbors + :rtype: tuple[np.ndarray, np.ndarray] + """ + try: + execution_logger.info("Using faiss backend for nearest neighbor search") + from pe.histogram.nearest_neighbor_backend.faiss import search + + return search(syn_embedding, priv_embedding, num_nearest_neighbors, mode) + except Exception as e: + execution_logger.error(f"Error using faiss backend for nearest neighbor search: {e}") + execution_logger.error(traceback.format_exc()) + execution_logger.info( + "Please check the installation of the Faiss library: " + "https://microsoft.github.io/DPSDA/getting_started/installation.html#faiss" + ) + execution_logger.info("Using sklearn backend for nearest neighbor search") + from pe.histogram.nearest_neighbor_backend.sklearn import search + + return search(syn_embedding, priv_embedding, num_nearest_neighbors, mode) diff --git a/pe/histogram/nearest_neighbors.py b/pe/histogram/nearest_neighbors.py index d8182c3..67e3dc8 100644 --- a/pe/histogram/nearest_neighbors.py +++ b/pe/histogram/nearest_neighbors.py @@ -25,7 +25,7 @@ def __init__( voting_details_log_folder=None, api=None, num_nearest_neighbors=1, - backend="sklearn", + backend="auto", ): """Constructor. @@ -53,9 +53,10 @@ def __init__( 1 :type num_nearest_neighbors: int, optional :param backend: The backend to use for finding the nearest neighbors. It should be one of the following: - "faiss" (FAISS), "sklearn" (scikit-learn). Defaults to "sklearn". FAISS supports GPU and is much faster - when the number of synthetic samples and/or private samples is large. It requires the installation of - `faiss-gpu` or `faiss-cpu` package. See https://faiss.ai/ + "faiss" (FAISS), "sklearn" (scikit-learn), "auto" (using FAISS if available, otherwise scikit-learn). + Defaults to "auto". FAISS supports GPU and is much faster when the number of synthetic samples and/or + private samples is large. It requires the installation of `faiss-gpu` or `faiss-cpu` package. See + https://faiss.ai/ :type backend: str, optional :raises ValueError: If the `api` is not provided when `lookahead_degree` is greater than 0 :raises ValueError: If the `backend` is unknown @@ -77,6 +78,10 @@ def __init__( elif backend.lower() == "sklearn": from pe.histogram.nearest_neighbor_backend.sklearn import search + self._search = search + elif backend.lower() == "auto": + from pe.histogram.nearest_neighbor_backend.auto import search + self._search = search else: raise ValueError(f"Unknown backend: {backend}")