Skip to content

Commit

Permalink
Add feature in IndexIDMap.cpp range_search with Parameters. (#3213)
Browse files Browse the repository at this point in the history
Summary:
for example:
#include <cstdio>
#include <cstdlib>
#include <iostream>
#include <random>

#include <faiss/IndexFlat.h>
#include <faiss/IndexIDMap.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/IDSelector.h>

// 64-bit int
using idx_t = faiss::idx_t;

int main() {
    int d = 64;     // dimension
    int nb = 10000; // database size
    int nq = 5;     // nb of queries

    std::mt19937 rng;
    std::uniform_real_distribution<> distrib;

    float* xb = new float[d * nb];
    float* xq = new float[d * nq];

    for (int i = 0; i < nb; i++) {
        for (int j = 0; j < d; j++)
            xb[d * i + j] = distrib(rng);
        xb[d * i] += i / 1000.;
    }

    for (int i = 0; i < nq; i++) {
        for (int j = 0; j < d; j++)
            xq[d * i + j] = distrib(rng);
        xq[d * i] += i / 1000.;
    }

    faiss::IndexFlatL2 index(d);
    faiss::IndexIDMap2 index_id_map2(&index);

    idx_t* xids = new idx_t[nb]();
    for (int i = 0; i < nb; i++) {
        xids[i] = i + nb;           // add ids
    }

    index_id_map2.add_with_ids(nb, xb, xids);

    faiss::SearchParameters params;

    std::vector<faiss::idx_t> ids;
    ids.reserve(nb / 2);
    for (faiss::idx_t i = 0; i < nb / 2; i++) {
        ids.push_back(i + nb);   // search ids
    }

    faiss::IDSelectorArray id_selector_array(ids.size(), ids.data());

    params.sel = &id_selector_array;

    // range search with param
    {
        float radius = 7.0f;
        faiss::RangeSearchResult* result = new faiss::RangeSearchResult(nq);

        index_id_map2.range_search(nq, xb, radius, result, &params);

        size_t off = 0;
        for (size_t i = 0; i < result->nq; i++) {
            size_t n = (result->lims[i + 1] - result->lims[i]);
            std::cout << "i : " << i << std::endl;
            for (size_t j = 0; j < n; j++) {
                std::cout << "\t label : " << result->labels[off + j]
                          << " distance : " << result->distances[off + j]
                          << std::endl;
            }
            off += n;
        }

        delete result;
    }

    delete[] xb;
    delete[] xq;
    delete[] xids;

    return 0;
}

outputs:
server@dingo11 cpp [main] $ ./6-Range-Search
i : 0
         label : 10000 distance : 0
i : 1
         label : 10001 distance : 0
         label : 10136 distance : 6.72638
         label : 10183 distance : 6.73293
         label : 10223 distance : 6.76569
         label : 10555 distance : 6.93339
         label : 10995 distance : 5.78548
i : 2
         label : 10002 distance : 0
         label : 10253 distance : 6.84876
         label : 10312 distance : 5.07469
i : 3
         label : 10003 distance : 0
         label : 10983 distance : 6.77275
i : 4
         label : 10004 distance : 0
         label : 10112 distance : 6.89793
         label : 10403 distance : 6.84196

Pull Request resolved: #3213

Reviewed By: mdouze

Differential Revision: D53704072

Pulled By: algoriddle

fbshipit-source-id: ca7f03f5a474a59089ebdf9685fb83e54ae198b0
  • Loading branch information
yuhaijun999 authored and facebook-github-bot committed Feb 13, 2024
1 parent ebb5f84 commit 8898eab
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions faiss/IndexIDMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,16 @@ void IndexIDMapTemplate<IndexT>::range_search(
typename IndexT::distance_t radius,
RangeSearchResult* result,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
index->range_search(n, x, radius, result);
if (params) {
SearchParameters internal_search_parameters;
IDSelectorTranslated id_selector_translated(id_map, params->sel);
internal_search_parameters.sel = &id_selector_translated;

index->range_search(n, x, radius, result, &internal_search_parameters);
} else {
index->range_search(n, x, radius, result);
}

#pragma omp parallel for
for (idx_t i = 0; i < result->lims[result->nq]; i++) {
result->labels[i] = result->labels[i] < 0 ? result->labels[i]
Expand Down

0 comments on commit 8898eab

Please sign in to comment.