diff --git a/cpp/include/cuvs/core/c_api.h b/cpp/include/cuvs/core/c_api.h index d931d6c13..4db7fd12c 100644 --- a/cpp/include/cuvs/core/c_api.h +++ b/cpp/include/cuvs/core/c_api.h @@ -127,6 +127,27 @@ cuvsError_t cuvsRMMAlloc(cuvsResources_t res, void** ptr, size_t bytes); */ cuvsError_t cuvsRMMFree(cuvsResources_t res, void* ptr, size_t bytes); +/** + * @brief Switches the working memory resource to use the RMM pool memory resource, which will + * bypass unnecessary synchronizations by allocating a chunk of device memory up front and carving + * that up for temporary memory allocations within algorithms. Be aware that this function will + * change the memory resource for the whole process and the new memory resource will be used until + * explicitly changed. + * + * @param[in] initial_pool_size_percent The initial pool size as a percentage of the total + * available memory + * @param[in] max_pool_size_percent The maximum pool size as a percentage of the total + * available memory + * @return cuvsError_t + */ +cuvsError_t cuvsRMMPoolMemoryResourceEnable(int initial_pool_size_percent, + int max_pool_size_percent); +/** + * @brief Resets the memory resource to use the default memory resource (cuda_memory_resource) + * @return cuvsError_t + */ +cuvsError_t cuvsRMMMemoryResourceReset(); + /** @} */ #ifdef __cplusplus diff --git a/cpp/src/core/c_api.cpp b/cpp/src/core/c_api.cpp index 96504a2fe..a75e5a1dd 100644 --- a/cpp/src/core/c_api.cpp +++ b/cpp/src/core/c_api.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include extern "C" cuvsError_t cuvsResourcesCreate(cuvsResources_t* res) @@ -82,6 +83,34 @@ extern "C" cuvsError_t cuvsRMMFree(cuvsResources_t res, void* ptr, size_t bytes) }); } +thread_local std::unique_ptr> pool_mr; + +extern "C" cuvsError_t cuvsRMMPoolMemoryResourceEnable(int initial_pool_size_percent, + int max_pool_size_percent) +{ + return cuvs::core::translate_exceptions([=] { + // Upstream memory resource needs to be a cuda_memory_resource + auto cuda_mr = rmm::mr::get_current_device_resource(); + auto* cuda_mr_casted = dynamic_cast(cuda_mr); + if (cuda_mr_casted == nullptr) { + throw std::runtime_error("Current memory resource is not a cuda_memory_resource"); + } + auto initial_size = rmm::percent_of_free_device_memory(initial_pool_size_percent); + auto max_size = rmm::percent_of_free_device_memory(max_pool_size_percent); + pool_mr = std::make_unique>( + cuda_mr_casted, initial_size, max_size); + rmm::mr::set_current_device_resource(pool_mr.get()); + }); +} + +extern "C" cuvsError_t cuvsRMMMemoryResourceReset() +{ + return cuvs::core::translate_exceptions([=] { + rmm::mr::set_current_device_resource(nullptr); + pool_mr.reset(); + }); +} + thread_local std::string last_error_text = ""; extern "C" const char* cuvsGetLastErrorText() diff --git a/cpp/test/core/c_api.c b/cpp/test/core/c_api.c index a5b73d8fb..27973c2dd 100644 --- a/cpp/test/core/c_api.c +++ b/cpp/test/core/c_api.c @@ -31,6 +31,33 @@ int main() cuvsError_t stream_error = cuvsStreamSet(res, stream); if (stream_error == CUVS_ERROR) { exit(EXIT_FAILURE); } + // Allocate memory + void* ptr; + size_t bytes = 1024; + cuvsError_t alloc_error = cuvsRMMAlloc(res, &ptr, bytes); + if (alloc_error == CUVS_ERROR) { exit(EXIT_FAILURE); } + + // Free memory + cuvsError_t free_error = cuvsRMMFree(res, ptr, bytes); + if (free_error == CUVS_ERROR) { exit(EXIT_FAILURE); } + + // Enable pool memory resource + cuvsError_t pool_error = cuvsRMMPoolMemoryResourceEnable(10, 100); + if (pool_error == CUVS_ERROR) { exit(EXIT_FAILURE); } + + // Allocate memory again + void* ptr2; + cuvsError_t alloc_error_pool = cuvsRMMAlloc(res, &ptr2, 1024); + if (alloc_error_pool == CUVS_ERROR) { exit(EXIT_FAILURE); } + + // Free memory + cuvsError_t free_error_pool = cuvsRMMFree(res, ptr2, 1024); + if (free_error_pool == CUVS_ERROR) { exit(EXIT_FAILURE); } + + // Reset pool memory resource + cuvsError_t reset_error = cuvsRMMMemoryResourceReset(); + if (reset_error == CUVS_ERROR) { exit(EXIT_FAILURE); } + // Destroy resources cuvsError_t destroy_error = cuvsResourcesDestroy(res); if (destroy_error == CUVS_ERROR) { exit(EXIT_FAILURE); }