Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InnerProduct Distance Metric for CAGRA search #2260

Merged
merged 34 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b6d9980
apply updates to 24.06
tarang-jain Apr 5, 2024
fac65b8
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 5, 2024
ad48b03
remove build errors
tarang-jain Apr 5, 2024
2b8d898
search inputs
tarang-jain Apr 5, 2024
e44ab17
inner product in compute_distance_vpq.cuh
tarang-jain Apr 9, 2024
c506216
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 9, 2024
cb7fbba
inner product in index build; debug statements
tarang-jain Apr 10, 2024
0029bba
tests passing
tarang-jain Apr 10, 2024
293bc8f
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 10, 2024
c91d895
style
tarang-jain Apr 10, 2024
7a5f876
update testing
tarang-jain Apr 10, 2024
c37aa27
rm log statements
tarang-jain Apr 10, 2024
890372b
pass CagraSort
tarang-jain Apr 10, 2024
810ddd1
tests passing
tarang-jain Apr 11, 2024
f92e68b
remove dbg statements
tarang-jain Apr 11, 2024
5bbbc70
update docs
tarang-jain Apr 11, 2024
7febd73
metric assertions
tarang-jain Apr 11, 2024
7e19937
add metric as const arg
tarang-jain Apr 12, 2024
e40b967
make metric template:
tarang-jain Apr 12, 2024
b102393
clean up metric template
tarang-jain Apr 15, 2024
6da7d55
update assertion
tarang-jain Apr 15, 2024
1e666e3
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 17, 2024
c1f4dcd
metric runtime dispatch
tarang-jain Apr 17, 2024
2e9e3fe
Merge branch 'cagra-dists' of https://github.com/tarang-jain/raft int…
tarang-jain Apr 17, 2024
d8a4b39
address all PR reviews
tarang-jain Apr 18, 2024
c5cd0e7
update docs, passing gtests
tarang-jain Apr 19, 2024
eabb031
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 19, 2024
77ff0c2
add ivf_pq::index_params helper
tarang-jain Apr 23, 2024
b61c1c3
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 23, 2024
54061d0
tracking issue; styling
tarang-jain Apr 23, 2024
dda1810
make helper static
tarang-jain Apr 24, 2024
743d5bc
update from_dataset verbiage
tarang-jain Apr 24, 2024
5791743
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 24, 2024
f5f30b7
Merge branch 'branch-24.06' into cagra-dists
cjnolet Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,22 @@
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/dataset.hpp>

#include <rmm/cuda_stream_view.hpp>

namespace raft::neighbors::cagra {

template <typename DataT, typename accessor>
tarang-jain marked this conversation as resolved.
Show resolved Hide resolved
auto get_default_ivf_pq_build_params(
mdspan<const DataT, matrix_extent<int64_t>, row_major, accessor> dataset,
const raft::distance::DistanceType metric = raft::distance::L2Expanded) -> ivf_pq::index_params
{
return detail::get_default_ivf_pq_build_params(dataset, metric);
}

/**
* @defgroup cagra CUDA ANN Graph-based nearest neighbor search
* @{
Expand All @@ -48,12 +57,13 @@ namespace raft::neighbors::cagra {
*
* The following distance metrics are supported:
* - L2Expanded
* - InnerProduct
*
* Usage example:
* @code{.cpp}
* using namespace raft::neighbors;
* // use default index parameters
* ivf_pq::index_params build_params;
* ivf_pq::index_params = cagra::get_default_ivf_pq_build_params(dataset);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I like this approach as it doesn't follow the standard API design flow that we've been using in hte other algos. I wonder if instead of adding this additional helper function if we could instead just add an overloaded constructor to the index_params object that accepts an mdsan with the dataset and sets the defaults automatically. This would then follow the standard build flow but add an option for the user to establish the ressonable defaultsby passing in the dataset.

Copy link
Contributor Author

@tarang-jain tarang-jain Apr 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't add constructors to the ivf_pq::index_params struct because the aggregate assertion here. A helper is needed. One thing that can be done is that instead of a constructor within the index_params struct, we can have a helper get_default_index_params outside of the struct, but in ivf_pq_types.hpp

Copy link
Member

@cjnolet cjnolet Apr 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I'd just remove the static assertion there. It's not necessary. My problem with exposing helper functions like this is that they make the APIs harder to use as they aren't immediatley obvious and don't provide for a consistent experience. Keeping everything together within the index_params object itself allows there to be a single entrypoint (e.g. index_params) for construction. The other option would be to have factory functions which would always be used for construction, however I think that will also make things even more confusing unless we used that pattern across all the ANN APis.

Copy link
Contributor

@tfeher tfeher Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we are constructing CAGRA specific ivf_pq::index_params. I do not think the constructor of ivf_pq::index_params shall have the job to define parameters according to CAGRA's requirements.

For consistent experience, the user just need to fill the metric field of cagra::index_params, which is inherited from ann::index_params, and call cagra::build as usual.

Users of build_knn_graph are advanced users who are using build_knn_graph & optimize in a separate step. Other algos don't have these steps, this is custom for CAGRA. Before this PR, the user either passes a manually filled ivf_pq::index_params, or just uses a default. I think a helper function is an improvement.

* ivf_pq::search_params search_params
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* // create knn graph
Expand Down
40 changes: 31 additions & 9 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
#include "../../cagra_types.hpp"
#include "../../vpq_dataset.cuh"
#include "graph_core.cuh"
// #include "raft/core/mdspan_types.hpp"

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/error.hpp>
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
Expand All @@ -40,6 +42,23 @@

namespace raft::neighbors::cagra::detail {

template <typename DataT, typename accessor>
ivf_pq::index_params get_default_ivf_pq_build_params(
mdspan<const DataT, matrix_extent<int64_t>, row_major, accessor> dataset,
const raft::distance::DistanceType metric)
{
auto build_params = ivf_pq::index_params{};
build_params.n_lists = dataset.extent(0) < 4 * 2500 ? 4 : (uint32_t)(dataset.extent(0) / 2500);
tarang-jain marked this conversation as resolved.
Show resolved Hide resolved
build_params.pq_dim = raft::Pow2<8>::roundUp(dataset.extent(1) / 2);
tarang-jain marked this conversation as resolved.
Show resolved Hide resolved
build_params.pq_bits = 8;
build_params.kmeans_trainset_fraction = dataset.extent(0) < 10000 ? 1 : 10;
build_params.kmeans_n_iters = 25;
tarang-jain marked this conversation as resolved.
Show resolved Hide resolved
build_params.add_data_on_build = true;
build_params.metric = metric;

return build_params;
}

template <typename DataT, typename IdxT, typename accessor>
void build_knn_graph(raft::resources const& res,
mdspan<const DataT, matrix_extent<int64_t>, row_major, accessor> dataset,
Expand All @@ -48,8 +67,9 @@ void build_knn_graph(raft::resources const& res,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded,
"Currently only L2Expanded metric is supported");
RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded ||
build_params->metric == distance::DistanceType::InnerProduct,
"Currently only L2Expanded or InnerProduct metric are supported");

uint32_t node_degree = knn_graph.extent(1);
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::build_graph(%zu, %zu, %u)",
Expand All @@ -58,13 +78,7 @@ void build_knn_graph(raft::resources const& res,
node_degree);

if (!build_params) {
build_params = ivf_pq::index_params{};
build_params->n_lists = dataset.extent(0) < 4 * 2500 ? 4 : (uint32_t)(dataset.extent(0) / 2500);
build_params->pq_dim = raft::Pow2<8>::roundUp(dataset.extent(1) / 2);
build_params->pq_bits = 8;
build_params->kmeans_trainset_fraction = dataset.extent(0) < 10000 ? 1 : 10;
build_params->kmeans_n_iters = 25;
build_params->add_data_on_build = true;
build_params = get_default_ivf_pq_build_params(dataset, raft::distance::L2Expanded);
}

// Make model name
Expand Down Expand Up @@ -321,9 +335,15 @@ index<T, IdxT> build(
raft::make_host_matrix<IdxT, int64_t>(dataset.extent(0), intermediate_degree));

if (params.build_algo == graph_build_algo::IVF_PQ) {
if (!pq_build_params) {
pq_build_params = get_default_ivf_pq_build_params(dataset, params.metric);
}
build_knn_graph(res, dataset, knn_graph->view(), refine_rate, pq_build_params, search_params);

} else {
RAFT_EXPECTS(
params.metric == raft::distance::DistanceType::L2Expanded,
"L2Expanded is the only distance metrics supported for CAGRA build with nn_descent");
// Use nn-descent to build CAGRA knn graph
if (!nn_descent_params) {
nn_descent_params = experimental::nn_descent::index_params();
Expand All @@ -346,6 +366,8 @@ index<T, IdxT> build(
// Construct an index from dataset and optimized knn graph.
if (construct_index_with_dataset) {
if (params.compression.has_value()) {
RAFT_EXPECTS(params.metric == raft::distance::DistanceType::L2Expanded,
"VPQ compression is only supported with L2Expanded distance mertric");
index<T, IdxT> idx(res, params.metric);
idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view()));
idx.update_dataset(
Expand Down
35 changes: 26 additions & 9 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <raft/core/nvtx.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/detail/ivf_common.cuh>
#include <raft/neighbors/detail/ivf_pq_search.cuh>
Expand Down Expand Up @@ -87,7 +88,8 @@ void search_main_core(
raft::device_matrix_view<const typename DatasetDescriptorT::DATA_T, int64_t, row_major> queries,
raft::device_matrix_view<typename DatasetDescriptorT::INDEX_T, int64_t, row_major> neighbors,
raft::device_matrix_view<typename DatasetDescriptorT::DISTANCE_T, int64_t, row_major> distances,
CagraSampleFilterT sample_filter = CagraSampleFilterT())
CagraSampleFilterT sample_filter = CagraSampleFilterT(),
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded)
{
RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n",
static_cast<size_t>(dataset_desc.size),
Expand All @@ -112,7 +114,7 @@ void search_main_core(
using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
std::unique_ptr<search_plan_impl<DatasetDescriptorT, CagraSampleFilterT_s>> plan =
factory<DatasetDescriptorT, CagraSampleFilterT_s>::create(
res, params, dataset_desc.dim, graph.extent(1), topk);
res, params, dataset_desc.dim, graph.extent(1), topk, metric);

plan->check(topk);

Expand Down Expand Up @@ -163,7 +165,8 @@ void launch_vpq_search_main_core(
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<InternalIdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<DistanceT, int64_t, row_major> distances,
CagraSampleFilterT sample_filter)
CagraSampleFilterT sample_filter,
const raft::distance::DistanceType metric)
{
RAFT_EXPECTS(vpq_dset->pq_bits() == 8, "Only pq_bits = 8 is supported for now");
RAFT_EXPECTS(vpq_dset->pq_len() == 2 || vpq_dset->pq_len() == 4,
Expand Down Expand Up @@ -192,7 +195,7 @@ void launch_vpq_search_main_core(
size_t(vpq_dset->n_rows()),
vpq_dset->dim());
search_main_core(
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter);
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter, metric);
} else if (vpq_dset->pq_len() == 4) {
using dataset_desc_t = cagra_q_dataset_descriptor_t<T,
DatasetT,
Expand All @@ -210,7 +213,7 @@ void launch_vpq_search_main_core(
size_t(vpq_dset->n_rows()),
vpq_dset->dim());
search_main_core(
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter);
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter, metric);
} else {
RAFT_FAIL("Subspace dimension must be 2 or 4");
}
Expand Down Expand Up @@ -268,17 +271,31 @@ void search_main(raft::resources const& res,
strided_dset->n_rows(),
strided_dset->dim(),
strided_dset->stride());

search_main_core<dataset_desc_t, CagraSampleFilterT>(
res, params, dataset_desc, graph_internal, queries, neighbors, distances, sample_filter);
search_main_core<dataset_desc_t, CagraSampleFilterT>(res,
params,
dataset_desc,
graph_internal,
queries,
neighbors,
distances,
sample_filter,
index.metric());
} else if (auto* vpq_dset = dynamic_cast<const vpq_dataset<float, ds_idx_type>*>(&index.data());
vpq_dset != nullptr) {
// Search using a compressed dataset
RAFT_FAIL("FP32 VPQ dataset support is coming soon");
} else if (auto* vpq_dset = dynamic_cast<const vpq_dataset<half, ds_idx_type>*>(&index.data());
vpq_dset != nullptr) {
launch_vpq_search_main_core<T, half, ds_idx_type, InternalIdxT, DistanceT, CagraSampleFilterT>(
res, vpq_dset, params, graph_internal, queries, neighbors, distances, sample_filter);
res,
vpq_dset,
params,
graph_internal,
queries,
neighbors,
distances,
sample_filter,
index.metric());
} else if (auto* empty_dset = dynamic_cast<const empty_dataset<ds_idx_type>*>(&index.data());
empty_dset != nullptr) {
// Forgot to add a dataset.
Expand Down
65 changes: 56 additions & 9 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "hashmap.hpp"
#include "utils.hpp"

#include <raft/core/operators.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/util/vectorized.cuh>

Expand Down Expand Up @@ -54,6 +56,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
const uint32_t num_seeds,
INDEX_T* const visited_hash_ptr,
const uint32_t hash_bitlen,
const raft::distance::DistanceType metric,
const uint32_t block_id = 0,
const uint32_t num_blocks = 1)
{
Expand All @@ -78,8 +81,22 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
}
}

const auto norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM, TEAM_SIZE>(
query_buffer, seed_index, valid_i);
DISTANCE_T norm2;
switch (metric) {
case raft::distance::L2Expanded:
norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM,
TEAM_SIZE,
raft::distance::L2Expanded>(
query_buffer, seed_index, valid_i);
break;
case raft::distance::InnerProduct:
norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM,
TEAM_SIZE,
raft::distance::InnerProduct>(
query_buffer, seed_index, valid_i);
break;
default: break;
}

if (valid_i && (norm2 < best_norm2_team_local)) {
best_norm2_team_local = norm2;
Expand Down Expand Up @@ -121,7 +138,8 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(
const std::uint32_t hash_bitlen,
const INDEX_T* const parent_indices,
const INDEX_T* const internal_topk_list,
const std::uint32_t search_width)
const std::uint32_t search_width,
const raft::distance::DistanceType metric)
{
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;
const INDEX_T invalid_index = utils::get_max_value<INDEX_T>();
Expand Down Expand Up @@ -153,8 +171,22 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(
INDEX_T child_id = invalid_index;
if (valid_i) { child_id = result_child_indices_ptr[i]; }

const auto norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM, TEAM_SIZE>(
query_buffer, child_id, child_id != invalid_index);
DISTANCE_T norm2;
switch (metric) {
case raft::distance::L2Expanded:
norm2 =
dataset_desc
.template compute_similarity<DATASET_BLOCK_DIM, TEAM_SIZE, raft::distance::L2Expanded>(
query_buffer, child_id, child_id != invalid_index);
break;
case raft::distance::InnerProduct:
norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM,
TEAM_SIZE,
raft::distance::InnerProduct>(
query_buffer, child_id, child_id != invalid_index);
break;
default: break;
}

// Store the distance
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
Expand Down Expand Up @@ -220,7 +252,22 @@ struct standard_dataset_descriptor_t
}
}

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE>
template <typename T, raft::distance::DistanceType METRIC>
std::enable_if_t<METRIC == raft::distance::DistanceType::L2Expanded, T> __device__
dist_op(T a, T b) const
{
T diff = a - b;
return diff * diff;
}

template <typename T, raft::distance::DistanceType METRIC>
std::enable_if_t<METRIC == raft::distance::DistanceType::InnerProduct, T> __device__
dist_op(T a, T b) const
{
return -a * b;
}

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE, raft::distance::DistanceType METRIC>
__device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T dataset_i,
const bool valid) const
Expand Down Expand Up @@ -252,9 +299,9 @@ struct standard_dataset_descriptor_t
// because:
// - Above the last element (dataset_dim-1), the query array is filled with zeros.
// - The data buffer has to be also padded with zeros.
DISTANCE_T diff = query_ptr[device::swizzling(kv)];
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].val.data[v]);
norm2 += diff * diff;
DISTANCE_T d = query_ptr[device::swizzling(kv)];
norm2 += dist_op<DISTANCE_T, METRIC>(
d, spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].val.data[v]));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "compute_distance.hpp"

#include <raft/distance/distance_types.hpp>
#include <raft/util/integer_utils.hpp>

namespace raft::neighbors::cagra::detail {
Expand Down Expand Up @@ -112,7 +113,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
}
}

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE>
template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE, raft::distance::DistanceType METRIC>
__device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T node_id,
const bool valid) const
Expand Down Expand Up @@ -227,4 +228,4 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
}
};

} // namespace raft::neighbors::cagra::detail
} // namespace raft::neighbors::cagra::detail
11 changes: 6 additions & 5 deletions cpp/include/raft/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class factory {
search_params const& params,
int64_t dim,
int64_t graph_degree,
uint32_t topk)
uint32_t topk,
const raft::distance::DistanceType metric)
{
search_plan_impl_base plan(params, dim, graph_degree, topk);
search_plan_impl_base plan(params, dim, graph_degree, topk, metric);
switch (plan.dataset_block_dim) {
case 128:
switch (plan.team_size) {
Expand Down Expand Up @@ -77,17 +78,17 @@ class factory {
return std::unique_ptr<search_plan_impl<DATASET_DESCRIPTOR_T, CagraSampleFilterT>>(
new single_cta_search::
search<TEAM_SIZE, DATASET_BLOCK_DIM, DATASET_DESCRIPTOR_T, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric));
} else if (plan.algo == search_algo::MULTI_CTA) {
return std::unique_ptr<search_plan_impl<DATASET_DESCRIPTOR_T, CagraSampleFilterT>>(
new multi_cta_search::
search<TEAM_SIZE, DATASET_BLOCK_DIM, DATASET_DESCRIPTOR_T, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric));
} else {
return std::unique_ptr<search_plan_impl<DATASET_DESCRIPTOR_T, CagraSampleFilterT>>(
new multi_kernel_search::
search<TEAM_SIZE, DATASET_BLOCK_DIM, DATASET_DESCRIPTOR_T, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric));
}
}
};
Expand Down
Loading
Loading