diff --git a/include/knowhere/gpu/gpu_res_mgr.h b/include/knowhere/gpu/gpu_res_mgr.h index 155167bfb..db5363d87 100644 --- a/include/knowhere/gpu/gpu_res_mgr.h +++ b/include/knowhere/gpu/gpu_res_mgr.h @@ -82,18 +82,6 @@ class GPUResMgr { LOG_KNOWHERE_DEBUG_ << "InitDevice gpu_id " << gpu_id_ << ", resource count " << gpu_params_.res_num_ << ", tmp_mem_sz " << gpu_params_.tmp_mem_sz_ / MB << "MB, pin_mem_sz " << gpu_params_.pin_mem_sz_ / MB << "MB"; -#ifdef KNOWHERE_WITH_RAFT - if (gpu_id >= std::numeric_limits::min() && gpu_id <= std::numeric_limits::max()) { - auto rmm_id = rmm::cuda_device_id{int(gpu_id)}; - rmm_memory_resources_.push_back( - std::make_unique>( - rmm::mr::get_per_device_resource(rmm_id))); - rmm::mr::set_per_device_resource(rmm_id, rmm_memory_resources_.back().get()); - } else { - LOG_KNOWHERE_WARNING_ << "Could not init pool memory resource on GPU " << gpu_id_ - << ". ID is outside expected range."; - } -#endif } void @@ -125,11 +113,6 @@ class GPUResMgr { res_bq_.Take(); } init_ = false; -#ifdef KNOWHERE_WITH_RAFT - for (auto&& rmm_res : rmm_memory_resources_) { - rmm_res.release(); - } -#endif } ResPtr @@ -156,9 +139,6 @@ class GPUResMgr { int64_t gpu_id_ = 0; GPUParams gpu_params_; ResBQ res_bq_; -#ifdef KNOWHERE_WITH_RAFT - std::vector>> rmm_memory_resources_; -#endif }; class ResScope { diff --git a/src/index/ivf_raft/ivf_flat_raft.cu b/src/index/ivf_raft/ivf_flat_raft.cu index 006e2b479..7c26cf226 100644 --- a/src/index/ivf_raft/ivf_flat_raft.cu +++ b/src/index/ivf_raft/ivf_flat_raft.cu @@ -23,4 +23,5 @@ KNOWHERE_REGISTER_GLOBAL(RAFT_IVF_FLAT, [](const Object& object) { return Index::Create( std::make_unique>(object)); }); + } // namespace knowhere diff --git a/src/index/ivf_raft/ivf_raft.cuh b/src/index/ivf_raft/ivf_raft.cuh index 9c9f4703e..daa767ddc 100644 --- a/src/index/ivf_raft/ivf_raft.cuh +++ b/src/index/ivf_raft/ivf_raft.cuh @@ -37,6 +37,20 @@ namespace knowhere { +namespace memory_pool { +static rmm::mr::device_memory_resource* +resources() { + auto device_id = rmm::detail::current_device(); + auto& map_ = rmm::mr::detail::get_map(); + auto const found = map_.find(device_id.value()); + if (found == map_.end()) + map_[device_id.value()] = + new rmm::mr::pool_memory_resource(rmm::mr::detail::initial_resource()); + return map_[device_id.value()]; +} + +}; // namespace memory_pool + namespace detail { using raft_ivf_flat_index = raft::neighbors::ivf_flat::index; using raft_ivf_pq_index = raft::neighbors::ivf_pq::index; @@ -100,14 +114,16 @@ auto static constexpr const CUDA_R_8F_E5M2 = "CUDA_R_8F_E5M2"; inline expected str_to_cuda_dtype(std::string const& str) { static const std::unordered_map name_map = { - {cuda_type::CUDA_R_16F, CUDA_R_16F}, {cuda_type::CUDA_C_16F, CUDA_C_16F}, - {cuda_type::CUDA_R_16BF, CUDA_R_16BF}, {cuda_type::CUDA_C_16BF, CUDA_C_16BF}, - {cuda_type::CUDA_R_32F, CUDA_R_32F}, {cuda_type::CUDA_C_32F, CUDA_C_32F}, - {cuda_type::CUDA_R_64F, CUDA_R_64F}, {cuda_type::CUDA_C_64F, CUDA_C_64F}, - {cuda_type::CUDA_R_8I, CUDA_R_8I}, {cuda_type::CUDA_C_8I, CUDA_C_8I}, - {cuda_type::CUDA_R_8U, CUDA_R_8U}, {cuda_type::CUDA_C_8U, CUDA_C_8U}, - {cuda_type::CUDA_R_32I, CUDA_R_32I}, {cuda_type::CUDA_C_32I, CUDA_C_32I}, - {cuda_type::CUDA_R_8F_E4M3, CUDA_R_8F_E4M3}, {cuda_type::CUDA_R_8F_E5M2, CUDA_R_8F_E5M2}, + {cuda_type::CUDA_R_16F, CUDA_R_16F}, {cuda_type::CUDA_C_16F, CUDA_C_16F}, + {cuda_type::CUDA_R_16BF, CUDA_R_16BF}, {cuda_type::CUDA_C_16BF, CUDA_C_16BF}, + {cuda_type::CUDA_R_32F, CUDA_R_32F}, {cuda_type::CUDA_C_32F, CUDA_C_32F}, + {cuda_type::CUDA_R_64F, CUDA_R_64F}, {cuda_type::CUDA_C_64F, CUDA_C_64F}, + {cuda_type::CUDA_R_8I, CUDA_R_8I}, {cuda_type::CUDA_C_8I, CUDA_C_8I}, + {cuda_type::CUDA_R_8U, CUDA_R_8U}, {cuda_type::CUDA_C_8U, CUDA_C_8U}, + {cuda_type::CUDA_R_32I, CUDA_R_32I}, {cuda_type::CUDA_C_32I, CUDA_C_32I}, + // not support, when we use cuda 11.6 + //{cuda_type::CUDA_R_8F_E4M3, CUDA_R_8F_E4M3}, {cuda_type::CUDA_R_8F_E5M2, CUDA_R_8F_E5M2}, + }; auto it = name_map.find(str); @@ -133,7 +149,7 @@ struct KnowhereConfigType { template class RaftIvfIndexNode : public IndexNode { public: - RaftIvfIndexNode(const Object& object) : devs_{}, res_{std::make_unique()}, gpu_index_{} { + RaftIvfIndexNode(const Object& object) : devs_{}, gpu_index_{} { } virtual Status @@ -158,15 +174,15 @@ class RaftIvfIndexNode : public IndexNode { return metric.error(); } if (metric.value() != raft::distance::DistanceType::L2Expanded && - metric.value() != raft::distance::DistanceType::L2Unexpanded && metric.value() != raft::distance::DistanceType::InnerProduct) { LOG_KNOWHERE_WARNING_ << "selected metric not supported in RAFT IVF indexes: " << ivf_raft_cfg.metric_type; return Status::invalid_metric_type; } - + devs_.insert(devs_.begin(), ivf_raft_cfg.gpu_ids.begin(), ivf_raft_cfg.gpu_ids.end()); auto scoped_device = detail::device_setter{*ivf_raft_cfg.gpu_ids.begin()}; - res_ = std::make_unique(); + auto res_ = std::make_unique(rmm::cuda_stream_per_thread, nullptr, + memory_pool::resources()); auto rows = dataset.GetRows(); auto dim = dataset.GetDim(); auto* data = reinterpret_cast(dataset.GetTensor()); @@ -204,6 +220,10 @@ class RaftIvfIndexNode : public IndexNode { } else { static_assert(std::is_same_v); } + dim_ = dim; + counts_ = rows; + stream.synchronize(); + } catch (std::exception& e) { LOG_KNOWHERE_WARNING_ << "RAFT inner error, " << e.what(); return Status::raft_inner_error; @@ -225,6 +245,10 @@ class RaftIvfIndexNode : public IndexNode { auto rows = dataset.GetRows(); auto dim = dataset.GetDim(); auto* data = reinterpret_cast(dataset.GetTensor()); + auto scoped_device = detail::device_setter{devs_[0]}; + + auto res_ = std::make_unique(rmm::cuda_stream_per_thread, nullptr, + memory_pool::resources()); auto stream = res_->get_stream(); // TODO(wphicks): Clean up transfer with raft @@ -245,6 +269,8 @@ class RaftIvfIndexNode : public IndexNode { } else { static_assert(std::is_same_v); } + dim_ = dim; + counts_ = rows; } catch (std::exception& e) { LOG_KNOWHERE_WARNING_ << "RAFT inner error, " << e.what(); result = Status::raft_inner_error; @@ -265,6 +291,9 @@ class RaftIvfIndexNode : public IndexNode { auto dis = std::unique_ptr(new float[output_size]); try { + auto scoped_device = detail::device_setter{devs_[0]}; + auto res_ = std::make_unique(rmm::cuda_stream_per_thread, nullptr, + memory_pool::resources()); auto stream = res_->get_stream(); // TODO(wphicks): Clean up transfer with raft // buffer objects when available @@ -344,12 +373,150 @@ class RaftIvfIndexNode : public IndexNode { virtual Status Serialize(BinarySet& binset) const override { - return Status::not_implemented; + if (!gpu_index_.has_value()) + return Status::empty_index; + std::stringbuf buf; + + std::ostream os(&buf); + + os.write((char*)(&this->dim_), sizeof(this->dim_)); + os.write((char*)(&this->counts_), sizeof(this->counts_)); + os.write((char*)(&this->devs_[0]), sizeof(this->devs_[0])); + + auto scoped_device = detail::device_setter{devs_[0]}; + + auto res_ = + std::make_unique(rmm::cuda_stream_per_thread, nullptr, memory_pool::resources()); + if constexpr (std::is_same_v) { + raft::serialize_scalar(*res_, os, gpu_index_->size()); + raft::serialize_scalar(*res_, os, gpu_index_->dim()); + raft::serialize_scalar(*res_, os, gpu_index_->n_lists()); + raft::serialize_scalar(*res_, os, gpu_index_->metric()); + raft::serialize_scalar(*res_, os, gpu_index_->veclen()); + raft::serialize_scalar(*res_, os, gpu_index_->adaptive_centers()); + raft::serialize_mdspan(*res_, os, gpu_index_->data()); + raft::serialize_mdspan(*res_, os, gpu_index_->indices()); + raft::serialize_mdspan(*res_, os, gpu_index_->list_sizes()); + raft::serialize_mdspan(*res_, os, gpu_index_->list_offsets()); + raft::serialize_mdspan(*res_, os, gpu_index_->centers()); + if (gpu_index_->center_norms()) { + bool has_norms = true; + serialize_scalar(*res_, os, has_norms); + serialize_mdspan(*res_, os, *gpu_index_->center_norms()); + } else { + bool has_norms = false; + serialize_scalar(*res_, os, has_norms); + } + } + if constexpr (std::is_same_v) { + raft::serialize_scalar(*res_, os, gpu_index_->size()); + raft::serialize_scalar(*res_, os, gpu_index_->dim()); + raft::serialize_scalar(*res_, os, gpu_index_->pq_bits()); + raft::serialize_scalar(*res_, os, gpu_index_->pq_dim()); + + raft::serialize_scalar(*res_, os, gpu_index_->metric()); + raft::serialize_scalar(*res_, os, gpu_index_->codebook_kind()); + raft::serialize_scalar(*res_, os, gpu_index_->n_lists()); + raft::serialize_scalar(*res_, os, gpu_index_->n_nonempty_lists()); + + raft::serialize_mdspan(*res_, os, gpu_index_->pq_centers()); + raft::serialize_mdspan(*res_, os, gpu_index_->pq_dataset()); + raft::serialize_mdspan(*res_, os, gpu_index_->indices()); + raft::serialize_mdspan(*res_, os, gpu_index_->rotation_matrix()); + raft::serialize_mdspan(*res_, os, gpu_index_->list_offsets()); + raft::serialize_mdspan(*res_, os, gpu_index_->list_sizes()); + raft::serialize_mdspan(*res_, os, gpu_index_->centers()); + raft::serialize_mdspan(*res_, os, gpu_index_->centers_rot()); + } + + os.flush(); + std::shared_ptr index_binary(new (std::nothrow) uint8_t[buf.str().size()]); + + memcpy(index_binary.get(), buf.str().c_str(), buf.str().size()); + binset.Append(this->Type(), index_binary, buf.str().size()); + return Status::success; } virtual Status Deserialize(const BinarySet& binset) override { - return Status::not_implemented; + std::stringbuf buf; + auto binary = binset.GetByName(this->Type()); + buf.sputn((char*)binary->data.get(), binary->size); + std::istream is(&buf); + + is.read((char*)(&this->dim_), sizeof(this->dim_)); + is.read((char*)(&this->counts_), sizeof(this->counts_)); + this->devs_.resize(1); + is.read((char*)(&this->devs_[0]), sizeof(this->devs_[0])); + auto scoped_device = detail::device_setter{devs_[0]}; + + auto res_ = + std::make_unique(rmm::cuda_stream_per_thread, nullptr, memory_pool::resources()); + + if constexpr (std::is_same_v) { + auto n_rows = raft::deserialize_scalar(*res_, is); + auto dim = raft::deserialize_scalar(*res_, is); + auto n_lists = raft::deserialize_scalar(*res_, is); + auto metric = raft::deserialize_scalar(*res_, is); + auto veclen = raft::deserialize_scalar(*res_, is); + bool adaptive_centers = raft::deserialize_scalar(*res_, is); + + T index_ = T(*res_, metric, n_lists, adaptive_centers, dim); + + index_.allocate(*res_, n_rows); + raft::deserialize_mdspan(*res_, is, index_.data()); + raft::deserialize_mdspan(*res_, is, index_.indices()); + raft::deserialize_mdspan(*res_, is, index_.list_sizes()); + raft::deserialize_mdspan(*res_, is, index_.list_offsets()); + raft::deserialize_mdspan(*res_, is, index_.centers()); + bool has_norms = raft::deserialize_scalar(*res_, is); + if (has_norms) { + if (!index_.center_norms()) { + RAFT_FAIL("Error inconsistent center norms"); + } else { + auto center_norms = *index_.center_norms(); + raft::deserialize_mdspan(*res_, is, center_norms); + } + } + res_->sync_stream(); + is.sync(); + gpu_index_ = T(std::move(index_)); + } + if constexpr (std::is_same_v) { + auto n_rows = raft::deserialize_scalar(*res_, is); + auto dim = raft::deserialize_scalar(*res_, is); + auto pq_bits = raft::deserialize_scalar(*res_, is); + auto pq_dim = raft::deserialize_scalar(*res_, is); + + auto metric = raft::deserialize_scalar(*res_, is); + auto codebook_kind = raft::deserialize_scalar(*res_, is); + auto n_lists = raft::deserialize_scalar(*res_, is); + auto n_nonempty_lists = raft::deserialize_scalar(*res_, is); + + T index_ = T(*res_, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, n_nonempty_lists); + index_.allocate(*res_, n_rows); + + raft::deserialize_mdspan(*res_, is, index_.pq_centers()); + raft::deserialize_mdspan(*res_, is, index_.pq_dataset()); + raft::deserialize_mdspan(*res_, is, index_.indices()); + raft::deserialize_mdspan(*res_, is, index_.rotation_matrix()); + raft::deserialize_mdspan(*res_, is, index_.list_offsets()); + raft::deserialize_mdspan(*res_, is, index_.list_sizes()); + raft::deserialize_mdspan(*res_, is, index_.centers()); + raft::deserialize_mdspan(*res_, is, index_.centers_rot()); + res_->sync_stream(); + is.sync(); + gpu_index_ = T(std::move(index_)); + } + // TODO(yusheng.ma):support no raw data mode + /* +#define RAW_DATA "RAW_DATA" + auto data = binset.GetByName(RAW_DATA); + raft_gpu::raw_data_copy(*this->index_, data->data.get(), data->size); + */ + is.sync(); + + return Status::success; } virtual std::unique_ptr @@ -359,11 +526,7 @@ class RaftIvfIndexNode : public IndexNode { virtual int64_t Dim() const override { - auto result = std::int64_t{}; - if (gpu_index_) { - result = gpu_index_->dim(); - } - return result; + return dim_; } virtual int64_t @@ -373,11 +536,7 @@ class RaftIvfIndexNode : public IndexNode { virtual int64_t Count() const override { - auto result = std::int64_t{}; - if (gpu_index_) { - result = gpu_index_->size(); - } - return result; + return counts_; } virtual std::string @@ -392,7 +551,8 @@ class RaftIvfIndexNode : public IndexNode { private: std::vector devs_; - std::unique_ptr res_; + int64_t dim_ = 0; + int64_t counts_ = 0; std::optional gpu_index_; }; } // namespace knowhere