Skip to content

Commit

Permalink
Scaling workspace resources (#2322)
Browse files Browse the repository at this point in the history
### Brief

Add another workspace memory resource that does not have the explicit memory limit. That is, after the change we have the following:

1. `rmm::mr::get_current_device_resource()` is default for all allocations, as before. It is used for the allocations with unlimited lifetime, e.g. returned to the user.
2. `raft::get_workspace_resource()` is for temporary allocations and forced to have fixed size, as before. However, it becomes smaller and should be used only for allocations, which do not scale with problem size. It defaults to a thin layer on top of the `current_device_resource`.
3. `raft::get_large_workspace_resource()` _(new)_  is for temporary allocations, which can scale with the problem size. Unlike `workspace_resource`, its size is not fixed. By default, it points to the `current_device_resource`, but the user can set it to something backed by the host memory (e.g. managed memory) to avoid OOM exceptions when there's not enough device memory left.

## Problem

We have a list of issues/preference/requirements, some of which contradict others

1. We rely on RMM to handle all allocations and we often use [`rmm::mr::pool_memory_resource`](https://github.com/rapidsai/raft/blob/9fb05a2ab3d72760a09f1b7051e711d773682ef1/cpp/bench/ann/src/raft/raft_ann_bench_utils.h#L73) for performance reasons (to avoid lots of cudaMalloc calls in the loops)
2. Historically, we've used managed memory allocators as a workaround to [avoid OOM errors](https://github.com/rapidsai/raft/blob/5e80c1d2159e00a204ab5db0f5ca3f9ec43187c7/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh#L1788-L1795) or [improve speed (by increasing batch sizes)](https://github.com/rapidsai/raft/blob/5e80c1d2159e00a204ab5db0f5ca3f9ec43187c7/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh#L1596-L1603).
3. However, the design goal is to avoid setting allocators on our own and to give the full control to the user (hence the workaround in 2 [was removed](addb059#diff-f7f070424d71da5321d470416d1a4ca3605c4290c34c4a1c1d8b2240747000d2)).
4. We introduced the [workspace resource](#1356) earlier to allow querying the available memory reliably and maximize the batch sizes accordingly (see also issue [#1310](#1310)). Without this, some of our batched algorithms either fail with OOM or severely underperform due to small batch sizes.
5. However, we cannot just put all of RAFT temporary allocations into the limited `workspace_resource`, because some of them scale with the problem size and would inevitably fail with OOM at some point.
6. Setting the workspace resource to the managed memory is not advisable as well for performance reasons: we have lots of small allocations in performance critical sections, so we need a pool, but a pool in the managed memory inevitably outgrows the device memory and makes the whole program slow. 

## Solution
I propose to split the workspace memory into two:

1. small, fixed-size workspace for small, frequent allocations
2. large workspace for the allocations that scale with the problem size

Notes:
- We still leave the full control over the allocator types to the user. 
- Neither of the workspace resource should have unlimited lifetime / returned to the user. As a result, if the user sets the managed memory as the large workspace resource, the memory is guaranteed to be released after the function call.
- We have the option to use the slow managed memory without a pool for large allocations, while still using a fast pool for small allocations.
- We have more flexible control over which allocations are "large" and which are "small", so hopefully using the managed memory is not so bad for performance.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2322
  • Loading branch information
achirkin authored May 21, 2024
1 parent 5a8224c commit efcd11f
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 23 deletions.
20 changes: 17 additions & 3 deletions cpp/bench/ann/src/raft/raft_ann_bench_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/failure_callback_resource_adaptor.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <memory>
Expand Down Expand Up @@ -74,13 +75,14 @@ inline auto rmm_oom_callback(std::size_t bytes, void*) -> bool
*/
class shared_raft_resources {
public:
using pool_mr_type = rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>;
using mr_type = rmm::mr::failure_callback_resource_adaptor<pool_mr_type>;
using pool_mr_type = rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>;
using mr_type = rmm::mr::failure_callback_resource_adaptor<pool_mr_type>;
using large_mr_type = rmm::mr::managed_memory_resource;

shared_raft_resources()
try : orig_resource_{rmm::mr::get_current_device_resource()},
pool_resource_(orig_resource_, 1024 * 1024 * 1024ull),
resource_(&pool_resource_, rmm_oom_callback, nullptr) {
resource_(&pool_resource_, rmm_oom_callback, nullptr), large_mr_() {
rmm::mr::set_current_device_resource(&resource_);
} catch (const std::exception& e) {
auto cuda_status = cudaGetLastError();
Expand All @@ -103,10 +105,16 @@ class shared_raft_resources {

~shared_raft_resources() noexcept { rmm::mr::set_current_device_resource(orig_resource_); }

auto get_large_memory_resource() noexcept
{
return static_cast<rmm::mr::device_memory_resource*>(&large_mr_);
}

private:
rmm::mr::device_memory_resource* orig_resource_;
pool_mr_type pool_resource_;
mr_type resource_;
large_mr_type large_mr_;
};

/**
Expand All @@ -129,6 +137,12 @@ class configured_raft_resources {
res_{std::make_unique<raft::device_resources>(
rmm::cuda_stream_view(get_stream_from_global_pool()))}
{
// set the large workspace resource to the raft handle, but without the deleter
// (this resource is managed by the shared_res).
raft::resource::set_large_workspace_resource(
*res_,
std::shared_ptr<rmm::mr::device_memory_resource>(shared_res_->get_large_memory_resource(),
raft::void_op{}));
}

/** Default constructor creates all resources anew. */
Expand Down
52 changes: 50 additions & 2 deletions cpp/include/raft/core/resource/device_memory_resource.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,16 @@ namespace raft::resource {
* @{
*/

class device_memory_resource : public resource {
public:
explicit device_memory_resource(std::shared_ptr<rmm::mr::device_memory_resource> mr) : mr_(mr) {}
~device_memory_resource() override = default;
auto get_resource() -> void* override { return mr_.get(); }

private:
std::shared_ptr<rmm::mr::device_memory_resource> mr_;
};

class limiting_memory_resource : public resource {
public:
limiting_memory_resource(std::shared_ptr<rmm::mr::device_memory_resource> mr,
Expand Down Expand Up @@ -66,6 +76,29 @@ class limiting_memory_resource : public resource {
}
};

/**
* Factory that knows how to construct a specific raft::resource to populate
* the resources instance.
*/
class large_workspace_resource_factory : public resource_factory {
public:
explicit large_workspace_resource_factory(
std::shared_ptr<rmm::mr::device_memory_resource> mr = {nullptr})
: mr_{mr ? mr
: std::shared_ptr<rmm::mr::device_memory_resource>{
rmm::mr::get_current_device_resource(), void_op{}}}
{
}
auto get_resource_type() -> resource_type override
{
return resource_type::LARGE_WORKSPACE_RESOURCE;
}
auto make_resource() -> resource* override { return new device_memory_resource(mr_); }

private:
std::shared_ptr<rmm::mr::device_memory_resource> mr_;
};

/**
* Factory that knows how to construct a specific raft::resource to populate
* the resources instance.
Expand Down Expand Up @@ -144,7 +177,7 @@ class workspace_resource_factory : public resource_factory {
// Note, the workspace does not claim all this memory from the start, so it's still usable by
// the main resource as well.
// This limit is merely an order for algorithm internals to plan the batching accordingly.
return total_size / 2;
return total_size / 4;
}
};

Expand Down Expand Up @@ -241,6 +274,21 @@ inline void set_workspace_to_global_resource(
workspace_resource_factory::default_plain_resource(), allocation_limit, std::nullopt));
};

inline auto get_large_workspace_resource(resources const& res) -> rmm::mr::device_memory_resource*
{
if (!res.has_resource_factory(resource_type::LARGE_WORKSPACE_RESOURCE)) {
res.add_resource_factory(std::make_shared<large_workspace_resource_factory>());
}
return res.get_resource<rmm::mr::device_memory_resource>(resource_type::LARGE_WORKSPACE_RESOURCE);
};

inline void set_large_workspace_resource(resources const& res,
std::shared_ptr<rmm::mr::device_memory_resource> mr = {
nullptr})
{
res.add_resource_factory(std::make_shared<large_workspace_resource_factory>(mr));
};

/** @} */

} // namespace raft::resource
35 changes: 18 additions & 17 deletions cpp/include/raft/core/resource/resource_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,24 @@ namespace raft::resource {
*/
enum resource_type {
// device-specific resource types
CUBLAS_HANDLE = 0, // cublas handle
CUSOLVER_DN_HANDLE, // cusolver dn handle
CUSOLVER_SP_HANDLE, // cusolver sp handle
CUSPARSE_HANDLE, // cusparse handle
CUDA_STREAM_VIEW, // view of a cuda stream
CUDA_STREAM_POOL, // cuda stream pool
CUDA_STREAM_SYNC_EVENT, // cuda event for syncing streams
COMMUNICATOR, // raft communicator
SUB_COMMUNICATOR, // raft sub communicator
DEVICE_PROPERTIES, // cuda device properties
DEVICE_ID, // cuda device id
STREAM_VIEW, // view of a cuda stream or a placeholder in
// CUDA-free builds
THRUST_POLICY, // thrust execution policy
WORKSPACE_RESOURCE, // rmm device memory resource
CUBLASLT_HANDLE, // cublasLt handle
CUSTOM, // runtime-shared default-constructible resource
CUBLAS_HANDLE = 0, // cublas handle
CUSOLVER_DN_HANDLE, // cusolver dn handle
CUSOLVER_SP_HANDLE, // cusolver sp handle
CUSPARSE_HANDLE, // cusparse handle
CUDA_STREAM_VIEW, // view of a cuda stream
CUDA_STREAM_POOL, // cuda stream pool
CUDA_STREAM_SYNC_EVENT, // cuda event for syncing streams
COMMUNICATOR, // raft communicator
SUB_COMMUNICATOR, // raft sub communicator
DEVICE_PROPERTIES, // cuda device properties
DEVICE_ID, // cuda device id
STREAM_VIEW, // view of a cuda stream or a placeholder in
// CUDA-free builds
THRUST_POLICY, // thrust execution policy
WORKSPACE_RESOURCE, // rmm device memory resource for small temporary allocations
CUBLASLT_HANDLE, // cublasLt handle
CUSTOM, // runtime-shared default-constructible resource
LARGE_WORKSPACE_RESOURCE, // rmm device memory resource for somewhat large temporary allocations

LAST_KEY // reserved for the last key
};
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/matrix/detail/select_radix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ void radix_topk(const T* in,
unsigned grid_dim,
int sm_cnt,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource())
rmm::device_async_resource_ref mr)
{
// TODO: is it possible to relax this restriction?
static_assert(calc_num_passes<T, BitsPerPass>() > 1);
Expand Down

0 comments on commit efcd11f

Please sign in to comment.