Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
adapt for raft ivf
Browse files Browse the repository at this point in the history
Signed-off-by: Yusheng.Ma <[email protected]>
  • Loading branch information
Presburger committed Mar 13, 2023
1 parent d9d0f92 commit 391d9af
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 45 deletions.
20 changes: 0 additions & 20 deletions include/knowhere/gpu/gpu_res_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,6 @@ class GPUResMgr {
LOG_KNOWHERE_DEBUG_ << "InitDevice gpu_id " << gpu_id_ << ", resource count " << gpu_params_.res_num_
<< ", tmp_mem_sz " << gpu_params_.tmp_mem_sz_ / MB << "MB, pin_mem_sz "
<< gpu_params_.pin_mem_sz_ / MB << "MB";
#ifdef KNOWHERE_WITH_RAFT
if (gpu_id >= std::numeric_limits<int>::min() && gpu_id <= std::numeric_limits<int>::max()) {
auto rmm_id = rmm::cuda_device_id{int(gpu_id)};
rmm_memory_resources_.push_back(
std::make_unique<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>(
rmm::mr::get_per_device_resource(rmm_id)));
rmm::mr::set_per_device_resource(rmm_id, rmm_memory_resources_.back().get());
} else {
LOG_KNOWHERE_WARNING_ << "Could not init pool memory resource on GPU " << gpu_id_
<< ". ID is outside expected range.";
}
#endif
}

void
Expand Down Expand Up @@ -125,11 +113,6 @@ class GPUResMgr {
res_bq_.Take();
}
init_ = false;
#ifdef KNOWHERE_WITH_RAFT
for (auto&& rmm_res : rmm_memory_resources_) {
rmm_res.release();
}
#endif
}

ResPtr
Expand All @@ -156,9 +139,6 @@ class GPUResMgr {
int64_t gpu_id_ = 0;
GPUParams gpu_params_;
ResBQ res_bq_;
#ifdef KNOWHERE_WITH_RAFT
std::vector<std::unique_ptr<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>> rmm_memory_resources_;
#endif
};

class ResScope {
Expand Down
1 change: 1 addition & 0 deletions src/index/ivf_raft/ivf_flat_raft.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ KNOWHERE_REGISTER_GLOBAL(RAFT_IVF_FLAT, [](const Object& object) {
return Index<IndexNodeThreadPoolWrapper>::Create(
std::make_unique<RaftIvfIndexNode<detail::raft_ivf_flat_index>>(object));
});

} // namespace knowhere
210 changes: 185 additions & 25 deletions src/index/ivf_raft/ivf_raft.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@

namespace knowhere {

namespace memory_pool {
static rmm::mr::device_memory_resource*
resources() {
auto device_id = rmm::detail::current_device();
auto& map_ = rmm::mr::detail::get_map();
auto const found = map_.find(device_id.value());
if (found == map_.end())
map_[device_id.value()] =
new rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>(rmm::mr::detail::initial_resource());
return map_[device_id.value()];
}

}; // namespace memory_pool

namespace detail {
using raft_ivf_flat_index = raft::neighbors::ivf_flat::index<float, std::int64_t>;
using raft_ivf_pq_index = raft::neighbors::ivf_pq::index<std::int64_t>;
Expand Down Expand Up @@ -100,14 +114,16 @@ auto static constexpr const CUDA_R_8F_E5M2 = "CUDA_R_8F_E5M2";
inline expected<cudaDataType_t, Status>
str_to_cuda_dtype(std::string const& str) {
static const std::unordered_map<std::string, cudaDataType_t> name_map = {
{cuda_type::CUDA_R_16F, CUDA_R_16F}, {cuda_type::CUDA_C_16F, CUDA_C_16F},
{cuda_type::CUDA_R_16BF, CUDA_R_16BF}, {cuda_type::CUDA_C_16BF, CUDA_C_16BF},
{cuda_type::CUDA_R_32F, CUDA_R_32F}, {cuda_type::CUDA_C_32F, CUDA_C_32F},
{cuda_type::CUDA_R_64F, CUDA_R_64F}, {cuda_type::CUDA_C_64F, CUDA_C_64F},
{cuda_type::CUDA_R_8I, CUDA_R_8I}, {cuda_type::CUDA_C_8I, CUDA_C_8I},
{cuda_type::CUDA_R_8U, CUDA_R_8U}, {cuda_type::CUDA_C_8U, CUDA_C_8U},
{cuda_type::CUDA_R_32I, CUDA_R_32I}, {cuda_type::CUDA_C_32I, CUDA_C_32I},
{cuda_type::CUDA_R_8F_E4M3, CUDA_R_8F_E4M3}, {cuda_type::CUDA_R_8F_E5M2, CUDA_R_8F_E5M2},
{cuda_type::CUDA_R_16F, CUDA_R_16F}, {cuda_type::CUDA_C_16F, CUDA_C_16F},
{cuda_type::CUDA_R_16BF, CUDA_R_16BF}, {cuda_type::CUDA_C_16BF, CUDA_C_16BF},
{cuda_type::CUDA_R_32F, CUDA_R_32F}, {cuda_type::CUDA_C_32F, CUDA_C_32F},
{cuda_type::CUDA_R_64F, CUDA_R_64F}, {cuda_type::CUDA_C_64F, CUDA_C_64F},
{cuda_type::CUDA_R_8I, CUDA_R_8I}, {cuda_type::CUDA_C_8I, CUDA_C_8I},
{cuda_type::CUDA_R_8U, CUDA_R_8U}, {cuda_type::CUDA_C_8U, CUDA_C_8U},
{cuda_type::CUDA_R_32I, CUDA_R_32I}, {cuda_type::CUDA_C_32I, CUDA_C_32I},
// not support, when we use cuda 11.6
//{cuda_type::CUDA_R_8F_E4M3, CUDA_R_8F_E4M3}, {cuda_type::CUDA_R_8F_E5M2, CUDA_R_8F_E5M2},

};

auto it = name_map.find(str);
Expand All @@ -133,7 +149,7 @@ struct KnowhereConfigType<detail::raft_ivf_pq_index> {
template <typename T>
class RaftIvfIndexNode : public IndexNode {
public:
RaftIvfIndexNode(const Object& object) : devs_{}, res_{std::make_unique<raft::device_resources>()}, gpu_index_{} {
RaftIvfIndexNode(const Object& object) : devs_{}, gpu_index_{} {
}

virtual Status
Expand All @@ -158,15 +174,15 @@ class RaftIvfIndexNode : public IndexNode {
return metric.error();
}
if (metric.value() != raft::distance::DistanceType::L2Expanded &&
metric.value() != raft::distance::DistanceType::L2Unexpanded &&
metric.value() != raft::distance::DistanceType::InnerProduct) {
LOG_KNOWHERE_WARNING_ << "selected metric not supported in RAFT IVF indexes: "
<< ivf_raft_cfg.metric_type;
return Status::invalid_metric_type;
}

devs_.insert(devs_.begin(), ivf_raft_cfg.gpu_ids.begin(), ivf_raft_cfg.gpu_ids.end());
auto scoped_device = detail::device_setter{*ivf_raft_cfg.gpu_ids.begin()};
res_ = std::make_unique<raft::device_resources>();
auto res_ = std::make_unique<raft::device_resources>(rmm::cuda_stream_per_thread, nullptr,
memory_pool::resources());
auto rows = dataset.GetRows();
auto dim = dataset.GetDim();
auto* data = reinterpret_cast<float const*>(dataset.GetTensor());
Expand Down Expand Up @@ -204,6 +220,10 @@ class RaftIvfIndexNode : public IndexNode {
} else {
static_assert(std::is_same_v<detail::raft_ivf_flat_index, T>);
}
dim_ = dim;
counts_ = rows;
stream.synchronize();

} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "RAFT inner error, " << e.what();
return Status::raft_inner_error;
Expand All @@ -225,6 +245,10 @@ class RaftIvfIndexNode : public IndexNode {
auto rows = dataset.GetRows();
auto dim = dataset.GetDim();
auto* data = reinterpret_cast<float const*>(dataset.GetTensor());
auto scoped_device = detail::device_setter{devs_[0]};

auto res_ = std::make_unique<raft::device_resources>(rmm::cuda_stream_per_thread, nullptr,
memory_pool::resources());

auto stream = res_->get_stream();
// TODO(wphicks): Clean up transfer with raft
Expand All @@ -245,6 +269,8 @@ class RaftIvfIndexNode : public IndexNode {
} else {
static_assert(std::is_same_v<detail::raft_ivf_flat_index, T>);
}
dim_ = dim;
counts_ = rows;
} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "RAFT inner error, " << e.what();
result = Status::raft_inner_error;
Expand All @@ -265,6 +291,9 @@ class RaftIvfIndexNode : public IndexNode {
auto dis = std::unique_ptr<float[]>(new float[output_size]);

try {
auto scoped_device = detail::device_setter{devs_[0]};
auto res_ = std::make_unique<raft::device_resources>(rmm::cuda_stream_per_thread, nullptr,
memory_pool::resources());
auto stream = res_->get_stream();
// TODO(wphicks): Clean up transfer with raft
// buffer objects when available
Expand Down Expand Up @@ -344,12 +373,150 @@ class RaftIvfIndexNode : public IndexNode {

virtual Status
Serialize(BinarySet& binset) const override {
return Status::not_implemented;
if (!gpu_index_.has_value())
return Status::empty_index;
std::stringbuf buf;

std::ostream os(&buf);

os.write((char*)(&this->dim_), sizeof(this->dim_));
os.write((char*)(&this->counts_), sizeof(this->counts_));
os.write((char*)(&this->devs_[0]), sizeof(this->devs_[0]));

auto scoped_device = detail::device_setter{devs_[0]};

auto res_ =
std::make_unique<raft::device_resources>(rmm::cuda_stream_per_thread, nullptr, memory_pool::resources());
if constexpr (std::is_same_v<T, detail::raft_ivf_flat_index>) {
raft::serialize_scalar(*res_, os, gpu_index_->size());
raft::serialize_scalar(*res_, os, gpu_index_->dim());
raft::serialize_scalar(*res_, os, gpu_index_->n_lists());
raft::serialize_scalar(*res_, os, gpu_index_->metric());
raft::serialize_scalar(*res_, os, gpu_index_->veclen());
raft::serialize_scalar(*res_, os, gpu_index_->adaptive_centers());
raft::serialize_mdspan(*res_, os, gpu_index_->data());
raft::serialize_mdspan(*res_, os, gpu_index_->indices());
raft::serialize_mdspan(*res_, os, gpu_index_->list_sizes());
raft::serialize_mdspan(*res_, os, gpu_index_->list_offsets());
raft::serialize_mdspan(*res_, os, gpu_index_->centers());
if (gpu_index_->center_norms()) {
bool has_norms = true;
serialize_scalar(*res_, os, has_norms);
serialize_mdspan(*res_, os, *gpu_index_->center_norms());
} else {
bool has_norms = false;
serialize_scalar(*res_, os, has_norms);
}
}
if constexpr (std::is_same_v<T, detail::raft_ivf_pq_index>) {
raft::serialize_scalar(*res_, os, gpu_index_->size());
raft::serialize_scalar(*res_, os, gpu_index_->dim());
raft::serialize_scalar(*res_, os, gpu_index_->pq_bits());
raft::serialize_scalar(*res_, os, gpu_index_->pq_dim());

raft::serialize_scalar(*res_, os, gpu_index_->metric());
raft::serialize_scalar(*res_, os, gpu_index_->codebook_kind());
raft::serialize_scalar(*res_, os, gpu_index_->n_lists());
raft::serialize_scalar(*res_, os, gpu_index_->n_nonempty_lists());

raft::serialize_mdspan(*res_, os, gpu_index_->pq_centers());
raft::serialize_mdspan(*res_, os, gpu_index_->pq_dataset());
raft::serialize_mdspan(*res_, os, gpu_index_->indices());
raft::serialize_mdspan(*res_, os, gpu_index_->rotation_matrix());
raft::serialize_mdspan(*res_, os, gpu_index_->list_offsets());
raft::serialize_mdspan(*res_, os, gpu_index_->list_sizes());
raft::serialize_mdspan(*res_, os, gpu_index_->centers());
raft::serialize_mdspan(*res_, os, gpu_index_->centers_rot());
}

os.flush();
std::shared_ptr<uint8_t[]> index_binary(new (std::nothrow) uint8_t[buf.str().size()]);

memcpy(index_binary.get(), buf.str().c_str(), buf.str().size());
binset.Append(this->Type(), index_binary, buf.str().size());
return Status::success;
}

virtual Status
Deserialize(const BinarySet& binset) override {
return Status::not_implemented;
std::stringbuf buf;
auto binary = binset.GetByName(this->Type());
buf.sputn((char*)binary->data.get(), binary->size);
std::istream is(&buf);

is.read((char*)(&this->dim_), sizeof(this->dim_));
is.read((char*)(&this->counts_), sizeof(this->counts_));
this->devs_.resize(1);
is.read((char*)(&this->devs_[0]), sizeof(this->devs_[0]));
auto scoped_device = detail::device_setter{devs_[0]};

auto res_ =
std::make_unique<raft::device_resources>(rmm::cuda_stream_per_thread, nullptr, memory_pool::resources());

if constexpr (std::is_same_v<T, detail::raft_ivf_flat_index>) {
auto n_rows = raft::deserialize_scalar<std::int64_t>(*res_, is);
auto dim = raft::deserialize_scalar<std::uint32_t>(*res_, is);
auto n_lists = raft::deserialize_scalar<std::uint32_t>(*res_, is);
auto metric = raft::deserialize_scalar<raft::distance::DistanceType>(*res_, is);
auto veclen = raft::deserialize_scalar<std::uint32_t>(*res_, is);
bool adaptive_centers = raft::deserialize_scalar<bool>(*res_, is);

T index_ = T(*res_, metric, n_lists, adaptive_centers, dim);

index_.allocate(*res_, n_rows);
raft::deserialize_mdspan(*res_, is, index_.data());
raft::deserialize_mdspan(*res_, is, index_.indices());
raft::deserialize_mdspan(*res_, is, index_.list_sizes());
raft::deserialize_mdspan(*res_, is, index_.list_offsets());
raft::deserialize_mdspan(*res_, is, index_.centers());
bool has_norms = raft::deserialize_scalar<bool>(*res_, is);
if (has_norms) {
if (!index_.center_norms()) {
RAFT_FAIL("Error inconsistent center norms");
} else {
auto center_norms = *index_.center_norms();
raft::deserialize_mdspan(*res_, is, center_norms);
}
}
res_->sync_stream();
is.sync();
gpu_index_ = T(std::move(index_));
}
if constexpr (std::is_same_v<T, detail::raft_ivf_pq_index>) {
auto n_rows = raft::deserialize_scalar<std::int64_t>(*res_, is);
auto dim = raft::deserialize_scalar<std::uint32_t>(*res_, is);
auto pq_bits = raft::deserialize_scalar<std::uint32_t>(*res_, is);
auto pq_dim = raft::deserialize_scalar<std::uint32_t>(*res_, is);

auto metric = raft::deserialize_scalar<raft::distance::DistanceType>(*res_, is);
auto codebook_kind = raft::deserialize_scalar<raft::neighbors::ivf_pq::codebook_gen>(*res_, is);
auto n_lists = raft::deserialize_scalar<std::uint32_t>(*res_, is);
auto n_nonempty_lists = raft::deserialize_scalar<std::uint32_t>(*res_, is);

T index_ = T(*res_, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, n_nonempty_lists);
index_.allocate(*res_, n_rows);

raft::deserialize_mdspan(*res_, is, index_.pq_centers());
raft::deserialize_mdspan(*res_, is, index_.pq_dataset());
raft::deserialize_mdspan(*res_, is, index_.indices());
raft::deserialize_mdspan(*res_, is, index_.rotation_matrix());
raft::deserialize_mdspan(*res_, is, index_.list_offsets());
raft::deserialize_mdspan(*res_, is, index_.list_sizes());
raft::deserialize_mdspan(*res_, is, index_.centers());
raft::deserialize_mdspan(*res_, is, index_.centers_rot());
res_->sync_stream();
is.sync();
gpu_index_ = T(std::move(index_));
}
// TODO(yusheng.ma):support no raw data mode
/*
#define RAW_DATA "RAW_DATA"
auto data = binset.GetByName(RAW_DATA);
raft_gpu::raw_data_copy(*this->index_, data->data.get(), data->size);
*/
is.sync();

return Status::success;
}

virtual std::unique_ptr<BaseConfig>
Expand All @@ -359,11 +526,7 @@ class RaftIvfIndexNode : public IndexNode {

virtual int64_t
Dim() const override {
auto result = std::int64_t{};
if (gpu_index_) {
result = gpu_index_->dim();
}
return result;
return dim_;
}

virtual int64_t
Expand All @@ -373,11 +536,7 @@ class RaftIvfIndexNode : public IndexNode {

virtual int64_t
Count() const override {
auto result = std::int64_t{};
if (gpu_index_) {
result = gpu_index_->size();
}
return result;
return counts_;
}

virtual std::string
Expand All @@ -392,7 +551,8 @@ class RaftIvfIndexNode : public IndexNode {

private:
std::vector<int32_t> devs_;
std::unique_ptr<raft::device_resources> res_;
int64_t dim_ = 0;
int64_t counts_ = 0;
std::optional<T> gpu_index_;
};
} // namespace knowhere
Expand Down

0 comments on commit 391d9af

Please sign in to comment.