Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

raft index supports cosine similarity by normalizing the input data. #924

Merged
merged 1 commit into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 22 additions & 51 deletions src/common/raft/integration/raft_knowhere_index.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "common/raft/integration/raft_knowhere_index.hpp"
#include "common/raft/proto/raft_index.cuh"
#include "common/raft/proto/raft_index_kind.hpp"
#include "knowhere/comp/index_param.h"

namespace raft_knowhere {
namespace detail {
Expand Down Expand Up @@ -117,6 +118,8 @@ metric_string_to_raft_distance_type(std::string const& metric_string) {
auto result = raft::distance::DistanceType::L2Expanded;
if (metric_string == "L2") {
result = raft::distance::DistanceType::L2Expanded;
} else if (metric_string == "COSINE") {
result = raft::distance::DistanceType::InnerProduct;
} else if (metric_string == "L2SqrtExpanded") {
result = raft::distance::DistanceType::L2SqrtExpanded;
} else if (metric_string == "CosineExpanded") {
Expand Down Expand Up @@ -404,6 +407,17 @@ struct raft_knowhere_index<IndexKind>::impl {
}
auto const& res = get_device_resources_without_mempool();
auto host_data = raft::make_host_matrix_view(data, row_count, feature_count);
if (config.metric_type == knowhere::metric::COSINE) {
auto device_data = raft::make_device_matrix<data_type, input_indexing_type>(res, row_count, feature_count);
auto device_data_view = device_data.view();
raft::copy(res, device_data_view, host_data);
raft::linalg::row_normalize(res, raft::make_const_mdspan(device_data_view), device_data_view,
raft::linalg::NormType::L2Norm);
auto host_data_view = raft::make_host_matrix_view(const_cast<data_type*>(data), row_count, feature_count);
raft::copy(res, host_data_view, device_data_view);
res.sync_stream();
}

if (config.cache_dataset_on_device) {
device_dataset_storage =
raft::make_device_matrix<data_type, input_indexing_type>(res, row_count, feature_count);
Expand All @@ -417,51 +431,6 @@ struct raft_knowhere_index<IndexKind>::impl {
}
}

void
add(data_type const* data, knowhere_indexing_type row_count, knowhere_indexing_type feature_count,
knowhere_indexing_type const* new_ids) {
if constexpr (index_kind == raft_proto::raft_index_kind::brute_force) {
if (index_) {
RAFT_FAIL("RAFT brute force does not support adding vectors after training");
}
} else if constexpr (index_kind == raft_proto::raft_index_kind::cagra) {
if (index_) {
RAFT_FAIL("CAGRA does not support adding vectors after training");
}
} else if constexpr (index_kind == raft_proto::raft_index_kind::ivf_pq) {
if (index_) {
RAFT_FAIL("IVFPQ does not support adding vectors after training");
}
} else {
if (index_) {
auto const& res = get_device_resources_without_mempool();
auto host_data = raft::make_host_matrix_view(data, row_count, feature_count);
device_dataset_storage =
raft::make_device_matrix<data_type, input_indexing_type>(res, row_count, feature_count);
auto device_data = device_dataset_storage->view();
raft::copy(res, device_data, host_data);
auto device_ids_storage = std::optional<raft::device_vector<indexing_type, input_indexing_type>>{};
if (new_ids != nullptr) {
auto host_ids = raft::make_host_vector_view(new_ids, row_count);
device_ids_storage = raft::make_device_vector<indexing_type, input_indexing_type>(res, row_count);
raft::copy(res, device_ids_storage->view(), host_ids);
}

if (device_ids_storage) {
index_ = raft_index_type::extend(
res, raft::make_const_mdspan(device_data),
std::make_optional(raft::make_const_mdspan(device_ids_storage->view())), *index_);
} else {
index_ = raft_index_type::extend(
res, raft::make_const_mdspan(device_data),
std::optional<raft::device_vector_view<indexing_type const, input_indexing_type>>{}, *index_);
}
} else {
RAFT_FAIL("Index has not yet been trained");
}
}
}

auto
search(raft_knowhere_config const& config, data_type const* data, knowhere_indexing_type row_count,
knowhere_indexing_type feature_count, knowhere_bitset_data_type const* bitset_data,
Expand All @@ -475,6 +444,13 @@ struct raft_knowhere_index<IndexKind>::impl {
auto device_data_storage =
raft::make_device_matrix<data_type, input_indexing_type>(res, row_count, feature_count);
raft::copy(res, device_data_storage.view(), host_data);

if (config.metric_type == knowhere::metric::COSINE) {
auto device_data_view = device_data_storage.view();
raft::linalg::row_normalize(res, raft::make_const_mdspan(device_data_view), device_data_view,
raft::linalg::NormType::L2Norm);
}

auto device_bitset =
std::optional<raft::core::bitset<knowhere_bitset_data_type, knowhere_bitset_indexing_type>>{};
auto k_tmp = k;
Expand Down Expand Up @@ -714,12 +690,7 @@ raft_knowhere_index<IndexKind>::train(raft_knowhere_config const& config, data_t
knowhere_indexing_type row_count, knowhere_indexing_type feature_count) {
return pimpl->train(config, data, row_count, feature_count);
}
template <raft_proto::raft_index_kind IndexKind>
void
raft_knowhere_index<IndexKind>::add(data_type const* data, knowhere_indexing_type row_count,
knowhere_indexing_type feature_count, knowhere_indexing_type const* new_ids) {
return pimpl->add(data, row_count, feature_count, new_ids);
}

template <raft_proto::raft_index_kind IndexKind>
std::tuple<knowhere_indexing_type*, knowhere_data_type*>
raft_knowhere_index<IndexKind>::search(raft_knowhere_config const& config, data_type const* data,
Expand Down
3 changes: 0 additions & 3 deletions src/common/raft/integration/raft_knowhere_index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ struct raft_knowhere_index {
dim() const;
void
train(raft_knowhere_config const&, data_type const*, knowhere_indexing_type, knowhere_indexing_type);
void
add(data_type const* data, knowhere_indexing_type row_count, knowhere_indexing_type feature_count,
knowhere_indexing_type const* new_ids = nullptr);
std::tuple<knowhere_indexing_type*, knowhere_data_type*>
search(raft_knowhere_config const& config, data_type const* data, knowhere_indexing_type row_count,
knowhere_indexing_type feature_count, knowhere_bitset_data_type const* bitset_data = nullptr,
Expand Down
50 changes: 0 additions & 50 deletions src/common/raft_metric.h

This file was deleted.

2 changes: 1 addition & 1 deletion src/index/gpu_raft/gpu_raft_brute_force_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct GpuRaftBruteForceConfig : public BaseConfig {
Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
if (param_type == PARAM_TYPE::TRAIN) {
constexpr std::array<std::string_view, 2> legal_metric_list{"L2", "IP"};
constexpr std::array<std::string_view, 3> legal_metric_list{"L2", "IP", "COSINE"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
if (err_msg) {
Expand Down
11 changes: 11 additions & 0 deletions src/index/gpu_raft/gpu_raft_cagra_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ struct GpuRaftCagraConfig : public BaseConfig {

Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
if (param_type == PARAM_TYPE::TRAIN) {
constexpr std::array<std::string_view, 3> legal_metric_list{"L2", "IP", "COSINE"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
if (err_msg) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP COSINE]";
}
return Status::invalid_metric_type;
}
}

if (param_type == PARAM_TYPE::SEARCH) {
// auto align itopk_size
auto itopk_v = itopk_size.value_or(std::max(k.value(), kItopkSize));
Expand Down
2 changes: 1 addition & 1 deletion src/index/gpu_raft/gpu_raft_ivf_flat_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ struct GpuRaftIvfFlatConfig : public IvfFlatConfig {
Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
if (param_type == PARAM_TYPE::TRAIN) {
constexpr std::array<std::string_view, 2> legal_metric_list{"L2", "IP"};
constexpr std::array<std::string_view, 3> legal_metric_list{"L2", "IP", "COSINE"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
if (err_msg) {
Expand Down
4 changes: 2 additions & 2 deletions src/index/gpu_raft/gpu_raft_ivf_pq_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ struct GpuRaftIvfPqConfig : public IvfPqConfig {
Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
if (param_type == PARAM_TYPE::TRAIN) {
constexpr std::array<std::string_view, 2> legal_metric_list{"L2", "IP"};
constexpr std::array<std::string_view, 3> legal_metric_list{"L2", "IP", "COSINE"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
if (err_msg) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]";
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP COSINE]";
}
return Status::invalid_metric_type;
}
Expand Down
Loading