Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limiting workspace memory resource #1356

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
a939784
Wrap the workspace resource into a limiting_resource_adaptor
achirkin Mar 20, 2023
ac5762b
Set the pool memory resource by default and start the ivf-pq use case
achirkin Mar 20, 2023
c33d519
Refactor the resource to more rely on shared_ptr to manage lifetime
achirkin Mar 21, 2023
de5dd84
Preserve the semantics of not transfering the ownership of raw pointe…
achirkin Mar 21, 2023
36365b5
Merge branch 'branch-23.04' into fea-limited-workspace-resource
achirkin Mar 21, 2023
2eeb9e9
Merge branch 'branch-23.04' into fea-limited-workspace-resource
achirkin Mar 28, 2023
b8c5bc3
Merge branch 'branch-23.04' into fea-limited-workspace-resource
achirkin Mar 29, 2023
48f90fe
Merge branch 'branch-23.06' into fea-limited-workspace-resource
achirkin May 9, 2023
79c954e
Fix a missing merge change
achirkin May 9, 2023
6cf1103
Make the resource change not permanent
achirkin May 9, 2023
370b9ed
Don't force use the temp local workspace for all raft allocations
achirkin May 9, 2023
197106b
Merge remote-tracking branch 'rapidsai/branch-23.08' into fea-limited…
achirkin Jun 28, 2023
06cf4ff
Don't use device_resources
achirkin Jun 28, 2023
f27ba86
Using more of workspace memory resource
achirkin Jun 28, 2023
d435855
Let device_uvector_policy keep the memory resource when needed
achirkin Jun 28, 2023
1b62e3a
Make helper to query workspace size
achirkin Jun 28, 2023
5fed631
Tiny unrelated test fix: copy data in a stream.
achirkin Jun 28, 2023
a2e749d
Update the API to always use shared pointers to the resources
achirkin Jun 28, 2023
7736d76
Fix a typo
achirkin Jun 28, 2023
be63f73
Rename limited->limiting resource for consistency
achirkin Jun 29, 2023
c70728a
Add comments
achirkin Jun 29, 2023
be047d4
Remove repeated word in the comment
achirkin Jun 29, 2023
3e151d4
Fix a missing word in the comment
achirkin Jun 29, 2023
2649391
Merge branch 'branch-23.08' into fea-limited-workspace-resource
achirkin Jul 4, 2023
db27247
Merge branch 'branch-23.08' into fea-limited-workspace-resource
achirkin Jul 6, 2023
0900904
Merge branch 'branch-23.08' into fea-limited-workspace-resource
achirkin Jul 7, 2023
4cf6455
Add a deprecation comment to the mr argument
achirkin Jul 7, 2023
127907c
Add function deprecations
achirkin Jul 7, 2023
d6a27c5
Remove ANN reference
achirkin Jul 7, 2023
4530423
Use the plain workspace resource by default and print a warning if ne…
achirkin Jul 7, 2023
f87cbf2
Merge branch 'branch-23.08' into fea-limited-workspace-resource
achirkin Jul 14, 2023
d4f0c78
Merge branch 'branch-23.08' into fea-limited-workspace-resource
achirkin Jul 14, 2023
775f718
Add a note about no deleter
achirkin Jul 18, 2023
d7fcde9
Use the workspace resource size to determine the batch sizes for ivf-pq
achirkin Jul 18, 2023
eaafd3f
Merge branch 'branch-23.08' into fea-limited-workspace-resource
achirkin Jul 19, 2023
8eb9b80
Merge branch 'branch-23.08' into fea-limited-workspace-resource
achirkin Jul 19, 2023
463f409
Merge branch 'branch-23.08' into fea-limited-workspace-resource
achirkin Jul 20, 2023
e85033e
Merge branch 'branch-23.08' into fea-limited-workspace-resource
achirkin Jul 21, 2023
044b6ca
Merge branch 'branch-23.08' into fea-limited-workspace-resource
achirkin Jul 22, 2023
f3edcbc
Merge branch 'branch-23.08' into fea-limited-workspace-resource
achirkin Jul 25, 2023
ae0f469
Use get_workspace_free_bytes and debug-log the usage of the default p…
achirkin Jul 25, 2023
e082d09
Merge branch 'branch-23.08' into fea-limited-workspace-resource
achirkin Jul 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions cpp/include/raft/core/device_resources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <memory>
#include <mutex>
#include <optional>
#include <string>
#include <unordered_map>
#include <utility>
Expand Down Expand Up @@ -60,12 +61,12 @@ namespace raft {
class device_resources : public resources {
public:
device_resources(const device_resources& handle,
rmm::mr::device_memory_resource* workspace_resource)
rmm::mr::device_memory_resource* workspace_resource,
std::optional<std::size_t> allocation_limit = std::nullopt)
: resources{handle}
{
// replace the resource factory for the workspace_resources
resources::add_resource_factory(
std::make_shared<resource::workspace_resource_factory>(workspace_resource));
resource::set_workspace_resource(*this, workspace_resource, allocation_limit);
}

device_resources(const device_resources& handle) : resources{handle} {}
Expand All @@ -80,19 +81,21 @@ class device_resources : public resources {
* @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified)
* @param[in] workspace_resource an optional resource used by some functions for allocating
* temporary workspaces.
* @param[in] allocation_limit the total amount of memory in bytes available to the temporary
* workspace resources.
*/
device_resources(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread,
std::shared_ptr<rmm::cuda_stream_pool> stream_pool = {nullptr},
rmm::mr::device_memory_resource* workspace_resource = nullptr)
rmm::mr::device_memory_resource* workspace_resource = nullptr,
std::optional<std::size_t> allocation_limit = std::nullopt)
: resources{}
{
resources::add_resource_factory(std::make_shared<resource::device_id_resource_factory>());
resources::add_resource_factory(
std::make_shared<resource::cuda_stream_resource_factory>(stream_view));
resources::add_resource_factory(
std::make_shared<resource::cuda_stream_pool_resource_factory>(stream_pool));
resources::add_resource_factory(
std::make_shared<resource::workspace_resource_factory>(workspace_resource));
resource::set_workspace_resource(*this, workspace_resource, allocation_limit);
}

/** Destroys all held-up resources */
Expand Down Expand Up @@ -255,4 +258,4 @@ class stream_syncer {

} // namespace raft

#endif
#endif
132 changes: 115 additions & 17 deletions cpp/include/raft/core/resource/device_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,48 @@
*/
#pragma once

#include <raft/core/operators.hpp>
#include <raft/core/resource/resource_types.hpp>
#include <raft/core/resources.hpp>
#include <raft/util/cudart_utils.hpp>

#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/limiting_resource_adaptor.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <cstddef>
#include <optional>

namespace raft::resource {
class device_memory_resource : public resource {
class limited_memory_resource : public resource {
achirkin marked this conversation as resolved.
Show resolved Hide resolved
public:
device_memory_resource(rmm::mr::device_memory_resource* mr_ = nullptr) : mr(mr_)
limited_memory_resource(std::shared_ptr<rmm::mr::device_memory_resource> mr,
std::size_t allocation_limit,
std::optional<std::size_t> alignment)
: upstream_(mr), mr_(make_adaptor(mr, allocation_limit, alignment))
{
if (mr_ == nullptr) { mr = rmm::mr::get_current_device_resource(); }
}
void* get_resource() override { return mr; }

~device_memory_resource() override {}
auto get_resource() -> void* override { return &mr_; }

~limited_memory_resource() override = default;

private:
rmm::mr::device_memory_resource* mr;
std::shared_ptr<rmm::mr::device_memory_resource> upstream_;
rmm::mr::limiting_resource_adaptor<rmm::mr::device_memory_resource> mr_;

static inline auto make_adaptor(std::shared_ptr<rmm::mr::device_memory_resource> upstream,
std::size_t limit,
std::optional<std::size_t> alignment)
-> rmm::mr::limiting_resource_adaptor<rmm::mr::device_memory_resource>
{
auto p = upstream.get();
if (alignment.has_value()) {
return rmm::mr::limiting_resource_adaptor(p, limit, alignment.value());
} else {
return rmm::mr::limiting_resource_adaptor(p, limit);
}
}
};

/**
Expand All @@ -40,36 +65,109 @@ class device_memory_resource : public resource {
*/
class workspace_resource_factory : public resource_factory {
public:
workspace_resource_factory(rmm::mr::device_memory_resource* mr_ = nullptr) : mr(mr_) {}
resource_type get_resource_type() override { return resource_type::WORKSPACE_RESOURCE; }
resource* make_resource() override { return new device_memory_resource(mr); }
explicit workspace_resource_factory(
std::shared_ptr<rmm::mr::device_memory_resource> mr = {nullptr},
std::optional<std::size_t> allocation_limit = std::nullopt,
std::optional<std::size_t> alignment = std::nullopt)
: allocation_limit_(allocation_limit.value_or(default_allocation_limit())),
alignment_(alignment),
mr_(mr ? mr : default_memory_resource(allocation_limit_))
{
}

auto get_resource_type() -> resource_type override { return resource_type::WORKSPACE_RESOURCE; }
auto make_resource() -> resource* override
{
return new limited_memory_resource(mr_, allocation_limit_, alignment_);
}

private:
rmm::mr::device_memory_resource* mr;
std::size_t allocation_limit_;
std::optional<std::size_t> alignment_;
std::shared_ptr<rmm::mr::device_memory_resource> mr_;

// Create a pool memory resource by default
static inline auto default_memory_resource(std::size_t limit)
achirkin marked this conversation as resolved.
Show resolved Hide resolved
-> std::shared_ptr<rmm::mr::device_memory_resource>
{
constexpr std::size_t kOneGb = 1024lu * 1024lu * 1024lu;
auto min_size = std::min<std::size_t>(kOneGb, limit / 2);
auto max_size = limit * 3lu / 2lu;
achirkin marked this conversation as resolved.
Show resolved Hide resolved
auto upstream = rmm::mr::get_current_device_resource();
return std::make_shared<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>(
achirkin marked this conversation as resolved.
Show resolved Hide resolved
upstream, min_size, max_size);
}

// Allow a fraction of available memory by default.
static inline auto default_allocation_limit() -> std::size_t
{
std::size_t free_size{};
std::size_t total_size{};
RAFT_CUDA_TRY(cudaMemGetInfo(&free_size, &total_size));
return free_size / 2;
achirkin marked this conversation as resolved.
Show resolved Hide resolved
}
};

/**
* Load a temp workspace resource from a resources instance (and populate it on the res
* if needed).
*
* @param res raft resources object for managing resources
* @return device memory resource object
*/
inline rmm::mr::device_memory_resource* get_workspace_resource(resources const& res)
inline auto get_workspace_resource(resources const& res)
-> rmm::mr::limiting_resource_adaptor<rmm::mr::device_memory_resource>*
{
if (!res.has_resource_factory(resource_type::WORKSPACE_RESOURCE)) {
res.add_resource_factory(std::make_shared<workspace_resource_factory>());
}
return res.get_resource<rmm::mr::device_memory_resource>(resource_type::WORKSPACE_RESOURCE);
return res.get_resource<rmm::mr::limiting_resource_adaptor<rmm::mr::device_memory_resource>>(
resource_type::WORKSPACE_RESOURCE);
};

/**
* Set a temporary workspace resource on a resources instance.
*
* @param res raft resources object for managing resources
* @param mr an optional RMM device_memory_resource
* @param allocation_limit
* the total amount of memory in bytes available to the temporary workspace resources.
* @param alignment optional alignment requirements passed to RMM allocations
*
*/
inline void set_workspace_resource(resources const& res,
std::shared_ptr<rmm::mr::device_memory_resource> mr = {nullptr},
std::optional<std::size_t> allocation_limit = std::nullopt,
std::optional<std::size_t> alignment = std::nullopt)
{
res.add_resource_factory(
std::make_shared<workspace_resource_factory>(mr, allocation_limit, alignment));
};

/**
* Set a temp workspace resource on a resources instance.
* Set a temporary workspace resource on a resources instance.
*
* @param res raft resources object for managing resources
* @param mr a valid rmm device_memory_resource
* @param mr an optional RMM device_memory_resource;
* note, the ownership of the object is not transferred with this raw pointer interface.
* @param allocation_limit
* the total amount of memory in bytes available to the temporary workspace resources.
* @param alignment optional alignment requirements passed to RMM allocations
*
*/
inline void set_workspace_resource(resources const& res, rmm::mr::device_memory_resource* mr)
inline void set_workspace_resource(resources const& res,
rmm::mr::device_memory_resource* mr,
std::optional<std::size_t> allocation_limit = std::nullopt,
std::optional<std::size_t> alignment = std::nullopt)
{
res.add_resource_factory(std::make_shared<workspace_resource_factory>(mr));
// NB: to preserve the semantics of passing memory resource without transferring the ownership,
// we create a shared pointer with a dummy deleter (void_op).
achirkin marked this conversation as resolved.
Show resolved Hide resolved
set_workspace_resource(res,
mr != nullptr
? std::shared_ptr<rmm::mr::device_memory_resource>{mr, void_op{}}
: std::shared_ptr<rmm::mr::device_memory_resource>{nullptr},
allocation_limit,
alignment);
};
} // namespace raft::resource

} // namespace raft::resource
7 changes: 6 additions & 1 deletion cpp/include/raft/core/resources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ class resources {
RAFT_EXPECTS(rtype != resource::resource_type::LAST_KEY,
"LAST_KEY is a placeholder and not a valid resource factory type.");
factories_.at(rtype) = std::make_pair(rtype, factory);
// Clear the corresponding resource, so that on next `get_resource` the new factory is used
if (resources_.at(rtype).first != resource::resource_type::LAST_KEY) {
resources_.at(rtype) = std::make_pair(resource::resource_type::LAST_KEY,
std::make_shared<resource::empty_resource>());
}
}

/**
Expand Down Expand Up @@ -128,4 +133,4 @@ class resources {
mutable std::vector<pair_res_factory> factories_;
mutable std::vector<pair_resource> resources_;
};
} // namespace raft
} // namespace raft
10 changes: 2 additions & 8 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,6 @@ void train_per_subset(raft::device_resources const& handle,
index.pq_len(),
stream);

// clone the handle and attached the device memory resource to it
const device_resources new_handle(handle, device_memory);

// train PQ codebook for this subspace
auto sub_trainset_view =
raft::make_device_matrix_view<const float, IdxT>(sub_trainset.data(), n_rows, index.pq_len());
Expand All @@ -458,7 +455,7 @@ void train_per_subset(raft::device_resources const& handle,
raft::cluster::kmeans_balanced_params kmeans_params;
kmeans_params.n_iters = kmeans_n_iters;
kmeans_params.metric = raft::distance::DistanceType::L2Expanded;
raft::cluster::kmeans_balanced::helpers::build_clusters(new_handle,
raft::cluster::kmeans_balanced::helpers::build_clusters(handle,
kmeans_params,
sub_trainset_view,
centers_tmp_view,
Expand Down Expand Up @@ -523,9 +520,6 @@ void train_per_cluster(raft::device_resources const& handle,
indices + cluster_offsets[l],
device_memory);

// clone the handle and attached the device memory resource to it
const device_resources new_handle(handle, device_memory);

// limit the cluster size to bound the training time.
// [sic] we interpret the data as pq_len-dimensional
size_t big_enough = 256ul * std::max<size_t>(index.pq_book_size(), index.pq_dim());
Expand All @@ -546,7 +540,7 @@ void train_per_cluster(raft::device_resources const& handle,
raft::cluster::kmeans_balanced_params kmeans_params;
kmeans_params.n_iters = kmeans_n_iters;
kmeans_params.metric = raft::distance::DistanceType::L2Expanded;
raft::cluster::kmeans_balanced::helpers::build_clusters(new_handle,
raft::cluster::kmeans_balanced::helpers::build_clusters(handle,
kmeans_params,
rot_vectors_view,
centers_tmp_view,
Expand Down
16 changes: 5 additions & 11 deletions cpp/include/raft/neighbors/detail/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1262,10 +1262,10 @@ void ivfpq_search_worker(raft::device_resources const& handle,
IdxT* neighbors, // [n_queries, topK]
float* distances, // [n_queries, topK]
float scaling_factor,
double preferred_shmem_carveout,
rmm::mr::device_memory_resource* mr)
double preferred_shmem_carveout)
{
auto stream = handle.get_stream();
auto mr = handle.get_workspace_resource();

bool manage_local_topk = is_local_topk_feasible(topK, n_probes, n_queries);
auto topk_len = manage_local_topk ? n_probes * topK : max_samples;
Expand Down Expand Up @@ -1554,8 +1554,7 @@ inline void search(raft::device_resources const& handle,
uint32_t n_queries,
uint32_t k,
IdxT* neighbors,
float* distances,
rmm::mr::device_memory_resource* mr = nullptr)
float* distances)
achirkin marked this conversation as resolved.
Show resolved Hide resolved
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>,
"Unsupported element type.");
Expand Down Expand Up @@ -1601,11 +1600,7 @@ inline void search(raft::device_resources const& handle,
max_samples = ms;
}

auto pool_guard = raft::get_pool_memory_resource(mr, n_queries * n_probes * k * 16);
if (pool_guard) {
RAFT_LOG_DEBUG("ivf_pq::search: using pool memory resource with initial size %zu bytes",
pool_guard->pool_size());
}
auto mr = handle.get_workspace_resource();

// Maximum number of query vectors to search at the same time.
const auto max_queries = std::min<uint32_t>(std::max<uint32_t>(n_queries, 1), 4096);
achirkin marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -1669,8 +1664,7 @@ inline void search(raft::device_resources const& handle,
neighbors + uint64_t(k) * (offset_q + offset_b),
distances + uint64_t(k) * (offset_q + offset_b),
utils::config<T>::kDivisor / utils::config<float>::kDivisor,
params.preferred_shmem_carveout,
mr);
params.preferred_shmem_carveout);
}
}
}
Expand Down
9 changes: 6 additions & 3 deletions cpp/include/raft/neighbors/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@

#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/operators.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/per_device_resource.hpp>

#include <memory>

namespace raft::neighbors::ivf_pq {

/**
Expand Down Expand Up @@ -182,8 +185,7 @@ void search(raft::device_resources const& handle,
static_cast<std::uint32_t>(queries.extent(0)),
k,
neighbors.data_handle(),
distances.data_handle(),
handle.get_workspace_resource());
distances.data_handle());
}

/** @} */ // end group ivf_pq
Expand Down Expand Up @@ -349,7 +351,8 @@ void search(raft::device_resources const& handle,
float* distances,
rmm::mr::device_memory_resource* mr = nullptr)
{
return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr);
if (mr != nullptr) { resource::set_workspace_resource(handle, mr); }
return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances);
}

} // namespace raft::neighbors::ivf_pq