Skip to content

Commit

Permalink
Add RMM Pool memory resource to C API (#285)
Browse files Browse the repository at this point in the history
Enables switching to the pool allocator for better performance

Authors:
  - Ajit Mistry (https://github.com/ajit283)
  - Ray Douglass (https://github.com/raydouglass)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #285
  • Loading branch information
ajit283 authored Aug 22, 2024
1 parent d113e7c commit 108ee25
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 0 deletions.
21 changes: 21 additions & 0 deletions cpp/include/cuvs/core/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,27 @@ cuvsError_t cuvsRMMAlloc(cuvsResources_t res, void** ptr, size_t bytes);
*/
cuvsError_t cuvsRMMFree(cuvsResources_t res, void* ptr, size_t bytes);

/**
* @brief Switches the working memory resource to use the RMM pool memory resource, which will
* bypass unnecessary synchronizations by allocating a chunk of device memory up front and carving
* that up for temporary memory allocations within algorithms. Be aware that this function will
* change the memory resource for the whole process and the new memory resource will be used until
* explicitly changed.
*
* @param[in] initial_pool_size_percent The initial pool size as a percentage of the total
* available memory
* @param[in] max_pool_size_percent The maximum pool size as a percentage of the total
* available memory
* @return cuvsError_t
*/
cuvsError_t cuvsRMMPoolMemoryResourceEnable(int initial_pool_size_percent,
int max_pool_size_percent);
/**
* @brief Resets the memory resource to use the default memory resource (cuda_memory_resource)
* @return cuvsError_t
*/
cuvsError_t cuvsRMMMemoryResourceReset();

/** @} */

#ifdef __cplusplus
Expand Down
29 changes: 29 additions & 0 deletions cpp/src/core/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <raft/core/resources.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <thread>

extern "C" cuvsError_t cuvsResourcesCreate(cuvsResources_t* res)
Expand Down Expand Up @@ -82,6 +83,34 @@ extern "C" cuvsError_t cuvsRMMFree(cuvsResources_t res, void* ptr, size_t bytes)
});
}

thread_local std::unique_ptr<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>> pool_mr;

extern "C" cuvsError_t cuvsRMMPoolMemoryResourceEnable(int initial_pool_size_percent,
int max_pool_size_percent)
{
return cuvs::core::translate_exceptions([=] {
// Upstream memory resource needs to be a cuda_memory_resource
auto cuda_mr = rmm::mr::get_current_device_resource();
auto* cuda_mr_casted = dynamic_cast<rmm::mr::cuda_memory_resource*>(cuda_mr);
if (cuda_mr_casted == nullptr) {
throw std::runtime_error("Current memory resource is not a cuda_memory_resource");
}
auto initial_size = rmm::percent_of_free_device_memory(initial_pool_size_percent);
auto max_size = rmm::percent_of_free_device_memory(max_pool_size_percent);
pool_mr = std::make_unique<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>(
cuda_mr_casted, initial_size, max_size);
rmm::mr::set_current_device_resource(pool_mr.get());
});
}

extern "C" cuvsError_t cuvsRMMMemoryResourceReset()
{
return cuvs::core::translate_exceptions([=] {
rmm::mr::set_current_device_resource(nullptr);
pool_mr.reset();
});
}

thread_local std::string last_error_text = "";

extern "C" const char* cuvsGetLastErrorText()
Expand Down
27 changes: 27 additions & 0 deletions cpp/test/core/c_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,33 @@ int main()
cuvsError_t stream_error = cuvsStreamSet(res, stream);
if (stream_error == CUVS_ERROR) { exit(EXIT_FAILURE); }

// Allocate memory
void* ptr;
size_t bytes = 1024;
cuvsError_t alloc_error = cuvsRMMAlloc(res, &ptr, bytes);
if (alloc_error == CUVS_ERROR) { exit(EXIT_FAILURE); }

// Free memory
cuvsError_t free_error = cuvsRMMFree(res, ptr, bytes);
if (free_error == CUVS_ERROR) { exit(EXIT_FAILURE); }

// Enable pool memory resource
cuvsError_t pool_error = cuvsRMMPoolMemoryResourceEnable(10, 100);
if (pool_error == CUVS_ERROR) { exit(EXIT_FAILURE); }

// Allocate memory again
void* ptr2;
cuvsError_t alloc_error_pool = cuvsRMMAlloc(res, &ptr2, 1024);
if (alloc_error_pool == CUVS_ERROR) { exit(EXIT_FAILURE); }

// Free memory
cuvsError_t free_error_pool = cuvsRMMFree(res, ptr2, 1024);
if (free_error_pool == CUVS_ERROR) { exit(EXIT_FAILURE); }

// Reset pool memory resource
cuvsError_t reset_error = cuvsRMMMemoryResourceReset();
if (reset_error == CUVS_ERROR) { exit(EXIT_FAILURE); }

// Destroy resources
cuvsError_t destroy_error = cuvsResourcesDestroy(res);
if (destroy_error == CUVS_ERROR) { exit(EXIT_FAILURE); }
Expand Down

0 comments on commit 108ee25

Please sign in to comment.