Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InnerProduct Distance Metric for CAGRA search #2260

Merged
merged 34 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b6d9980
apply updates to 24.06
tarang-jain Apr 5, 2024
fac65b8
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 5, 2024
ad48b03
remove build errors
tarang-jain Apr 5, 2024
2b8d898
search inputs
tarang-jain Apr 5, 2024
e44ab17
inner product in compute_distance_vpq.cuh
tarang-jain Apr 9, 2024
c506216
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 9, 2024
cb7fbba
inner product in index build; debug statements
tarang-jain Apr 10, 2024
0029bba
tests passing
tarang-jain Apr 10, 2024
293bc8f
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 10, 2024
c91d895
style
tarang-jain Apr 10, 2024
7a5f876
update testing
tarang-jain Apr 10, 2024
c37aa27
rm log statements
tarang-jain Apr 10, 2024
890372b
pass CagraSort
tarang-jain Apr 10, 2024
810ddd1
tests passing
tarang-jain Apr 11, 2024
f92e68b
remove dbg statements
tarang-jain Apr 11, 2024
5bbbc70
update docs
tarang-jain Apr 11, 2024
7febd73
metric assertions
tarang-jain Apr 11, 2024
7e19937
add metric as const arg
tarang-jain Apr 12, 2024
e40b967
make metric template:
tarang-jain Apr 12, 2024
b102393
clean up metric template
tarang-jain Apr 15, 2024
6da7d55
update assertion
tarang-jain Apr 15, 2024
1e666e3
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 17, 2024
c1f4dcd
metric runtime dispatch
tarang-jain Apr 17, 2024
2e9e3fe
Merge branch 'cagra-dists' of https://github.com/tarang-jain/raft int…
tarang-jain Apr 17, 2024
d8a4b39
address all PR reviews
tarang-jain Apr 18, 2024
c5cd0e7
update docs, passing gtests
tarang-jain Apr 19, 2024
eabb031
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 19, 2024
77ff0c2
add ivf_pq::index_params helper
tarang-jain Apr 23, 2024
b61c1c3
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 23, 2024
54061d0
tracking issue; styling
tarang-jain Apr 23, 2024
dda1810
make helper static
tarang-jain Apr 24, 2024
743d5bc
update from_dataset verbiage
tarang-jain Apr 24, 2024
5791743
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain Apr 24, 2024
f5f30b7
Merge branch 'branch-24.06' into cagra-dists
cjnolet Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/dataset.hpp>

Expand All @@ -48,13 +49,14 @@ namespace raft::neighbors::cagra {
*
* The following distance metrics are supported:
* - L2Expanded
* - InnerProduct
*
* Usage example:
* @code{.cpp}
* using namespace raft::neighbors;
* // use default index parameters
* ivf_pq::index_params build_params;
* ivf_pq::search_params search_params
* // use default index parameters based on shape of the dataset
* ivf_pq::index_params build_params = ivf_pq::index_params::from_dataset(dataset);
* ivf_pq::search_params search_params;
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* // create knn graph
* cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params);
Expand Down
22 changes: 10 additions & 12 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/error.hpp>
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
Expand Down Expand Up @@ -50,24 +51,17 @@ void build_knn_graph(raft::resources const& res,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded,
"Currently only L2Expanded metric is supported");
RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded ||
build_params->metric == distance::DistanceType::InnerProduct,
"Currently only L2Expanded or InnerProduct metric are supported");

uint32_t node_degree = knn_graph.extent(1);
common::nvtx::range<common::nvtx::domain::raft> fun_scope("cagra::build_graph(%zu, %zu, %u)",
size_t(dataset.extent(0)),
size_t(dataset.extent(1)),
node_degree);

if (!build_params) {
build_params = ivf_pq::index_params{};
build_params->n_lists = dataset.extent(0) < 4 * 2500 ? 4 : (uint32_t)(dataset.extent(0) / 2500);
build_params->pq_dim = raft::Pow2<8>::roundUp(dataset.extent(1) / 2);
build_params->pq_bits = 8;
build_params->kmeans_trainset_fraction = dataset.extent(0) < 10000 ? 1 : 10;
build_params->kmeans_n_iters = 25;
build_params->add_data_on_build = true;
}
if (!build_params) { build_params = ivf_pq::index_params::from_dataset(dataset); }

// Make model name
const std::string model_name = [&]() {
Expand Down Expand Up @@ -324,8 +318,10 @@ index<T, IdxT> build(

if (params.build_algo == graph_build_algo::IVF_PQ) {
build_knn_graph(res, dataset, knn_graph->view(), refine_rate, pq_build_params, search_params);

} else {
RAFT_EXPECTS(
params.metric == raft::distance::DistanceType::L2Expanded,
"L2Expanded is the only distance metrics supported for CAGRA build with nn_descent");
// Use nn-descent to build CAGRA knn graph
if (!nn_descent_params) {
nn_descent_params = experimental::nn_descent::index_params();
Expand All @@ -348,6 +344,8 @@ index<T, IdxT> build(
// Construct an index from dataset and optimized knn graph.
if (construct_index_with_dataset) {
if (params.compression.has_value()) {
RAFT_EXPECTS(params.metric == raft::distance::DistanceType::L2Expanded,
"VPQ compression is only supported with L2Expanded distance mertric");
index<T, IdxT> idx(res, params.metric);
idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view()));
idx.update_dataset(
Expand Down
35 changes: 26 additions & 9 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <raft/core/nvtx.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/detail/ivf_common.cuh>
#include <raft/neighbors/detail/ivf_pq_search.cuh>
Expand Down Expand Up @@ -87,7 +88,8 @@ void search_main_core(
raft::device_matrix_view<const typename DatasetDescriptorT::DATA_T, int64_t, row_major> queries,
raft::device_matrix_view<typename DatasetDescriptorT::INDEX_T, int64_t, row_major> neighbors,
raft::device_matrix_view<typename DatasetDescriptorT::DISTANCE_T, int64_t, row_major> distances,
CagraSampleFilterT sample_filter = CagraSampleFilterT())
CagraSampleFilterT sample_filter = CagraSampleFilterT(),
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded)
{
RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n",
static_cast<size_t>(dataset_desc.size),
Expand All @@ -112,7 +114,7 @@ void search_main_core(
using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector<CagraSampleFilterT>::type;
std::unique_ptr<search_plan_impl<DatasetDescriptorT, CagraSampleFilterT_s>> plan =
factory<DatasetDescriptorT, CagraSampleFilterT_s>::create(
res, params, dataset_desc.dim, graph.extent(1), topk);
res, params, dataset_desc.dim, graph.extent(1), topk, metric);

plan->check(topk);

Expand Down Expand Up @@ -163,7 +165,8 @@ void launch_vpq_search_main_core(
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<InternalIdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<DistanceT, int64_t, row_major> distances,
CagraSampleFilterT sample_filter)
CagraSampleFilterT sample_filter,
const raft::distance::DistanceType metric)
{
RAFT_EXPECTS(vpq_dset->pq_bits() == 8, "Only pq_bits = 8 is supported for now");
RAFT_EXPECTS(vpq_dset->pq_len() == 2 || vpq_dset->pq_len() == 4,
Expand Down Expand Up @@ -192,7 +195,7 @@ void launch_vpq_search_main_core(
size_t(vpq_dset->n_rows()),
vpq_dset->dim());
search_main_core(
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter);
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter, metric);
} else if (vpq_dset->pq_len() == 4) {
using dataset_desc_t = cagra_q_dataset_descriptor_t<T,
DatasetT,
Expand All @@ -210,7 +213,7 @@ void launch_vpq_search_main_core(
size_t(vpq_dset->n_rows()),
vpq_dset->dim());
search_main_core(
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter);
res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter, metric);
} else {
RAFT_FAIL("Subspace dimension must be 2 or 4");
}
Expand Down Expand Up @@ -268,17 +271,31 @@ void search_main(raft::resources const& res,
strided_dset->n_rows(),
strided_dset->dim(),
strided_dset->stride());

search_main_core<dataset_desc_t, CagraSampleFilterT>(
res, params, dataset_desc, graph_internal, queries, neighbors, distances, sample_filter);
search_main_core<dataset_desc_t, CagraSampleFilterT>(res,
params,
dataset_desc,
graph_internal,
queries,
neighbors,
distances,
sample_filter,
index.metric());
} else if (auto* vpq_dset = dynamic_cast<const vpq_dataset<float, ds_idx_type>*>(&index.data());
vpq_dset != nullptr) {
// Search using a compressed dataset
RAFT_FAIL("FP32 VPQ dataset support is coming soon");
} else if (auto* vpq_dset = dynamic_cast<const vpq_dataset<half, ds_idx_type>*>(&index.data());
vpq_dset != nullptr) {
launch_vpq_search_main_core<T, half, ds_idx_type, InternalIdxT, DistanceT, CagraSampleFilterT>(
res, vpq_dset, params, graph_internal, queries, neighbors, distances, sample_filter);
res,
vpq_dset,
params,
graph_internal,
queries,
neighbors,
distances,
sample_filter,
index.metric());
} else if (auto* empty_dset = dynamic_cast<const empty_dataset<ds_idx_type>*>(&index.data());
empty_dset != nullptr) {
// Forgot to add a dataset.
Expand Down
65 changes: 56 additions & 9 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "hashmap.hpp"
#include "utils.hpp"

#include <raft/core/operators.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/util/vectorized.cuh>

Expand Down Expand Up @@ -54,6 +56,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
const uint32_t num_seeds,
INDEX_T* const visited_hash_ptr,
const uint32_t hash_bitlen,
const raft::distance::DistanceType metric,
const uint32_t block_id = 0,
const uint32_t num_blocks = 1)
{
Expand All @@ -78,8 +81,22 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
}
}

const auto norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM, TEAM_SIZE>(
query_buffer, seed_index, valid_i);
DISTANCE_T norm2;
switch (metric) {
case raft::distance::L2Expanded:
norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM,
TEAM_SIZE,
raft::distance::L2Expanded>(
query_buffer, seed_index, valid_i);
break;
case raft::distance::InnerProduct:
norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM,
TEAM_SIZE,
raft::distance::InnerProduct>(
query_buffer, seed_index, valid_i);
break;
default: break;
}

if (valid_i && (norm2 < best_norm2_team_local)) {
best_norm2_team_local = norm2;
Expand Down Expand Up @@ -121,7 +138,8 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(
const std::uint32_t hash_bitlen,
const INDEX_T* const parent_indices,
const INDEX_T* const internal_topk_list,
const std::uint32_t search_width)
const std::uint32_t search_width,
const raft::distance::DistanceType metric)
{
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;
const INDEX_T invalid_index = utils::get_max_value<INDEX_T>();
Expand Down Expand Up @@ -153,8 +171,22 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(
INDEX_T child_id = invalid_index;
if (valid_i) { child_id = result_child_indices_ptr[i]; }

const auto norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM, TEAM_SIZE>(
query_buffer, child_id, child_id != invalid_index);
DISTANCE_T norm2;
switch (metric) {
case raft::distance::L2Expanded:
norm2 =
dataset_desc
.template compute_similarity<DATASET_BLOCK_DIM, TEAM_SIZE, raft::distance::L2Expanded>(
query_buffer, child_id, child_id != invalid_index);
break;
case raft::distance::InnerProduct:
norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM,
TEAM_SIZE,
raft::distance::InnerProduct>(
query_buffer, child_id, child_id != invalid_index);
break;
default: break;
}

// Store the distance
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
Expand Down Expand Up @@ -220,7 +252,22 @@ struct standard_dataset_descriptor_t
}
}

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE>
template <typename T, raft::distance::DistanceType METRIC>
std::enable_if_t<METRIC == raft::distance::DistanceType::L2Expanded, T> __device__
dist_op(T a, T b) const
{
T diff = a - b;
return diff * diff;
}

template <typename T, raft::distance::DistanceType METRIC>
std::enable_if_t<METRIC == raft::distance::DistanceType::InnerProduct, T> __device__
dist_op(T a, T b) const
{
return -a * b;
}

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE, raft::distance::DistanceType METRIC>
__device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T dataset_i,
const bool valid) const
Expand Down Expand Up @@ -252,9 +299,9 @@ struct standard_dataset_descriptor_t
// because:
// - Above the last element (dataset_dim-1), the query array is filled with zeros.
// - The data buffer has to be also padded with zeros.
DISTANCE_T diff = query_ptr[device::swizzling(kv)];
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].val.data[v]);
norm2 += diff * diff;
DISTANCE_T d = query_ptr[device::swizzling(kv)];
norm2 += dist_op<DISTANCE_T, METRIC>(
d, spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].val.data[v]));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "compute_distance.hpp"

#include <raft/distance/distance_types.hpp>
#include <raft/util/integer_utils.hpp>

namespace raft::neighbors::cagra::detail {
Expand Down Expand Up @@ -112,7 +113,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
}
}

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE>
template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE, raft::distance::DistanceType METRIC>
__device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T node_id,
const bool valid) const
Expand Down Expand Up @@ -227,4 +228,4 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
}
};

} // namespace raft::neighbors::cagra::detail
} // namespace raft::neighbors::cagra::detail
11 changes: 6 additions & 5 deletions cpp/include/raft/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class factory {
search_params const& params,
int64_t dim,
int64_t graph_degree,
uint32_t topk)
uint32_t topk,
const raft::distance::DistanceType metric)
{
search_plan_impl_base plan(params, dim, graph_degree, topk);
search_plan_impl_base plan(params, dim, graph_degree, topk, metric);
switch (plan.dataset_block_dim) {
case 128:
switch (plan.team_size) {
Expand Down Expand Up @@ -77,17 +78,17 @@ class factory {
return std::unique_ptr<search_plan_impl<DATASET_DESCRIPTOR_T, CagraSampleFilterT>>(
new single_cta_search::
search<TEAM_SIZE, DATASET_BLOCK_DIM, DATASET_DESCRIPTOR_T, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric));
} else if (plan.algo == search_algo::MULTI_CTA) {
return std::unique_ptr<search_plan_impl<DATASET_DESCRIPTOR_T, CagraSampleFilterT>>(
new multi_cta_search::
search<TEAM_SIZE, DATASET_BLOCK_DIM, DATASET_DESCRIPTOR_T, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric));
} else {
return std::unique_ptr<search_plan_impl<DATASET_DESCRIPTOR_T, CagraSampleFilterT>>(
new multi_kernel_search::
search<TEAM_SIZE, DATASET_BLOCK_DIM, DATASET_DESCRIPTOR_T, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric));
}
}
};
Expand Down
10 changes: 8 additions & 2 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@
#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk if possible
#include "utils.hpp"

#include <raft/core/detail/macros.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_properties.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/map.cuh>
#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/util/cuda_rt_essentials.hpp>
#include <raft/util/cudart_utils.hpp> // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp
Expand Down Expand Up @@ -96,8 +99,10 @@ struct search : public search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
search_params params,
int64_t dim,
int64_t graph_degree,
uint32_t topk)
: search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T>(res, params, dim, graph_degree, topk),
uint32_t topk,
raft::distance::DistanceType metric)
: search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T>(
res, params, dim, graph_degree, topk, metric),
intermediate_indices(0, resource::get_cuda_stream(res)),
intermediate_distances(0, resource::get_cuda_stream(res)),
topk_workspace(0, resource::get_cuda_stream(res))
Expand Down Expand Up @@ -235,6 +240,7 @@ struct search : public search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
min_iterations,
max_iterations,
sample_filter,
this->metric,
stream);
RAFT_CUDA_TRY(cudaPeekAtLastError());

Expand Down
Loading
Loading