Skip to content

Commit

Permalink
add auto backend for NN search
Browse files Browse the repository at this point in the history
  • Loading branch information
fjxmlzn committed Nov 25, 2024
1 parent 9a1a8f9 commit 1106f22
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 7 deletions.
8 changes: 8 additions & 0 deletions doc/source/api/pe.histogram.nearest_neighbor_backend.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ pe.histogram.nearest\_neighbor\_backend package
Submodules
----------

pe.histogram.nearest\_neighbor\_backend.auto module
---------------------------------------------------

.. automodule:: pe.histogram.nearest_neighbor_backend.auto
:members:
:undoc-members:
:show-inheritance:

pe.histogram.nearest\_neighbor\_backend.faiss module
----------------------------------------------------

Expand Down
1 change: 0 additions & 1 deletion example/image/camelyon17.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
lookahead_log_folder=os.path.join(exp_folder, "lookahead"),
voting_details_log_folder=os.path.join(exp_folder, "voting_details"),
api=api,
backend="faiss",
)
population = PEPopulation(api=api, histogram_threshold=4)

Expand Down
1 change: 0 additions & 1 deletion example/image/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
lookahead_log_folder=os.path.join(exp_folder, "lookahead"),
voting_details_log_folder=os.path.join(exp_folder, "voting_details"),
api=api,
backend="faiss",
)
population = PEPopulation(api=api, histogram_threshold=2)

Expand Down
1 change: 0 additions & 1 deletion example/image/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
lookahead_log_folder=os.path.join(exp_folder, "lookahead"),
voting_details_log_folder=os.path.join(exp_folder, "voting_details"),
api=api,
backend="faiss",
)
population = PEPopulation(api=api, histogram_threshold=10)

Expand Down
37 changes: 37 additions & 0 deletions pe/histogram/nearest_neighbor_backend/auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import traceback
from pe.logging import execution_logger


def search(syn_embedding, priv_embedding, num_nearest_neighbors, mode):
"""Compute the nearest neighbors of the private embedding in the synthetic embedding using Faiss. If Faiss is not
installed or an error occurs, fall back to the sklearn backend.
:param syn_embedding: The synthetic embedding
:type syn_embedding: np.ndarray
:param priv_embedding: The private embedding
:type priv_embedding: np.ndarray
:param num_nearest_neighbors: The number of nearest neighbors to search
:type num_nearest_neighbors: int
:param mode: The distance metric to use for finding the nearest neighbors. It should be one of the following:
"l2" (l2 distance), "cos_sim" (cosine similarity), "ip" (inner product, not supported by sklearn)
:type mode: str
:raises ValueError: If the mode is unknown
:return: The distances and indices of the nearest neighbors
:rtype: tuple[np.ndarray, np.ndarray]
"""
try:
execution_logger.info("Using faiss backend for nearest neighbor search")
from pe.histogram.nearest_neighbor_backend.faiss import search

return search(syn_embedding, priv_embedding, num_nearest_neighbors, mode)
except Exception as e:
execution_logger.error(f"Error using faiss backend for nearest neighbor search: {e}")
execution_logger.error(traceback.format_exc())
execution_logger.info(
"Please check the installation of the Faiss library: "
"https://microsoft.github.io/DPSDA/getting_started/installation.html#faiss"
)
execution_logger.info("Using sklearn backend for nearest neighbor search")
from pe.histogram.nearest_neighbor_backend.sklearn import search

return search(syn_embedding, priv_embedding, num_nearest_neighbors, mode)
13 changes: 9 additions & 4 deletions pe/histogram/nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
voting_details_log_folder=None,
api=None,
num_nearest_neighbors=1,
backend="sklearn",
backend="auto",
):
"""Constructor.
Expand Down Expand Up @@ -53,9 +53,10 @@ def __init__(
1
:type num_nearest_neighbors: int, optional
:param backend: The backend to use for finding the nearest neighbors. It should be one of the following:
"faiss" (FAISS), "sklearn" (scikit-learn). Defaults to "sklearn". FAISS supports GPU and is much faster
when the number of synthetic samples and/or private samples is large. It requires the installation of
`faiss-gpu` or `faiss-cpu` package. See https://faiss.ai/
"faiss" (FAISS), "sklearn" (scikit-learn), "auto" (using FAISS if available, otherwise scikit-learn).
Defaults to "auto". FAISS supports GPU and is much faster when the number of synthetic samples and/or
private samples is large. It requires the installation of `faiss-gpu` or `faiss-cpu` package. See
https://faiss.ai/
:type backend: str, optional
:raises ValueError: If the `api` is not provided when `lookahead_degree` is greater than 0
:raises ValueError: If the `backend` is unknown
Expand All @@ -77,6 +78,10 @@ def __init__(
elif backend.lower() == "sklearn":
from pe.histogram.nearest_neighbor_backend.sklearn import search

self._search = search
elif backend.lower() == "auto":
from pe.histogram.nearest_neighbor_backend.auto import search

self._search = search
else:
raise ValueError(f"Unknown backend: {backend}")
Expand Down

0 comments on commit 1106f22

Please sign in to comment.