Skip to content

Commit

Permalink
Make ucx linkage explicit and add a new CMake target for it (#1032)
Browse files Browse the repository at this point in the history
This PR removes the dlopen logic for libucp in ucp_helper.hpp in favor of calling the relevant APIs directly. It also adds a new CMake component `raft::distributed` that can be used by dependent libraries to indicate the dependency on parts of raft that require UCX.

While it does not change any public APIs, I have marked this PR as breaking since it does mean that any C++ code linking to UCX must now ensure that UCX is available at link time. It is no longer sufficient to make the library available at runtime.

Resolves #1031.

Authors:
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1032
  • Loading branch information
vyasr authored Nov 18, 2022
1 parent d6df557 commit c743916
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 104 deletions.
53 changes: 47 additions & 6 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,21 @@ target_link_libraries(
raft_nn INTERFACE raft::raft $<TARGET_NAME_IF_EXISTS:raft::raft_nn_lib> nvidia::cutlass::cutlass
)

# ##################################################################################################
# * raft_distributed -------------------------------------------------------------------------------
add_library(raft_distributed INTERFACE)

if(TARGET raft_distributed AND (NOT TARGET raft::distributed))
add_library(raft::distributed ALIAS raft_distributed)
endif()

set_target_properties(raft_distributed PROPERTIES EXPORT_NAME distributed)

rapids_export_package(BUILD ucx raft-distributed-exports)
rapids_export_package(INSTALL ucx raft-distributed-exports)

target_link_libraries(raft_distributed INTERFACE ucx::ucp)

# ##################################################################################################
# * install targets-----------------------------------------------------------
rapids_cmake_install_lib_dir(lib_dir)
Expand Down Expand Up @@ -518,6 +533,13 @@ if(TARGET raft_nn_lib)
)
endif()

install(
TARGETS raft_distributed
DESTINATION ${lib_dir}
COMPONENT distributed
EXPORT raft-distributed-exports
)

install(
DIRECTORY include/raft
COMPONENT raft
Expand All @@ -542,8 +564,8 @@ install(

include("${rapids-cmake-dir}/export/write_dependencies.cmake")

set(raft_components distance nn)
set(raft_install_comp raft raft)
set(raft_components distance nn distributed)
set(raft_install_comp raft raft raft)
if(TARGET raft_distance_lib)
list(APPEND raft_components distance-lib)
list(APPEND raft_install_comp distance)
Expand Down Expand Up @@ -588,11 +610,13 @@ for data science and machine learning.
Optional Components:
- nn
- distance
- distributed

Imported Targets:
- raft::raft
- raft::nn brought in by the `nn` optional component
- raft::distance brought in by the `distance` optional component
- raft::distributed brought in by the `distributed` optional component

]=]
)
Expand Down Expand Up @@ -634,15 +658,32 @@ endif()
# Use `rapids_export` for 22.04 as it will have COMPONENT support
include(cmake/modules/raft_export.cmake)
raft_export(
INSTALL raft COMPONENTS nn distance EXPORT_SET raft-exports GLOBAL_TARGETS raft nn distance
NAMESPACE raft:: DOCUMENTATION doc_string FINAL_CODE_BLOCK code_string
INSTALL raft COMPONENTS nn distance distributed EXPORT_SET raft-exports GLOBAL_TARGETS raft nn
distance distributed NAMESPACE raft:: DOCUMENTATION doc_string FINAL_CODE_BLOCK code_string
)

# ##################################################################################################
# * build export -------------------------------------------------------------
raft_export(
BUILD raft EXPORT_SET raft-exports COMPONENTS nn distance GLOBAL_TARGETS raft raft_distance
raft_nn DOCUMENTATION doc_string NAMESPACE raft:: FINAL_CODE_BLOCK code_string
BUILD
raft
EXPORT_SET
raft-exports
COMPONENTS
nn
distance
distributed
GLOBAL_TARGETS
raft
raft_distance
distributed
raft_nn
DOCUMENTATION
doc_string
NAMESPACE
raft::
FINAL_CODE_BLOCK
code_string
)

# ##################################################################################################
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/comms/detail/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ class std_comms : public comms_iface {
bool restart = false; // resets the timeout when any progress was made

// Causes UCP to progress through the send/recv message queue
while (ucp_handler_.ucp_progress(ucp_worker_) != 0) {
while (ucp_worker_progress(ucp_worker_) != 0) {
restart = true;
}

Expand Down
97 changes: 4 additions & 93 deletions cpp/include/raft/comms/detail/ucp_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#pragma once

#include <dlfcn.h>
#include <raft/util/cudart_utils.hpp>
#include <stdio.h>
#include <ucp/api/ucp.h>
Expand All @@ -26,23 +25,6 @@ namespace raft {
namespace comms {
namespace detail {

typedef void (*dlsym_print_info)(ucp_ep_h, FILE*);

typedef void (*dlsym_rec_free)(void*);

typedef int (*dlsym_worker_progress)(ucp_worker_h);

typedef ucs_status_ptr_t (*dlsym_send)(
ucp_ep_h, const void*, size_t, ucp_datatype_t, ucp_tag_t, ucp_send_callback_t);

typedef ucs_status_ptr_t (*dlsym_recv)(ucp_worker_h,
void*,
size_t count,
ucp_datatype_t datatype,
ucp_tag_t,
ucp_tag_t,
ucp_tag_recv_callback_t);

/**
* Standard UCX request object that will be passed
* around asynchronously. This object is really
Expand Down Expand Up @@ -90,96 +72,25 @@ static void recv_callback(void* request, ucs_status_t status, ucp_tag_recv_info_
}

/**
* Helper class for managing `dlopen` state and
* interacting with ucp.
* Helper class for interacting with ucp.
*/
class comms_ucp_handler {
public:
comms_ucp_handler()
{
load_ucp_handle();
load_send_func();
load_recv_func();
load_free_req_func();
load_print_info_func();
load_worker_progress_func();
}

~comms_ucp_handler() { dlclose(ucp_handle); }

private:
void* ucp_handle;

dlsym_print_info print_info_func;
dlsym_rec_free req_free_func;
dlsym_worker_progress worker_progress_func;
dlsym_send send_func;
dlsym_recv recv_func;

void load_ucp_handle()
{
ucp_handle = dlopen("libucp.so", RTLD_LAZY | RTLD_NOLOAD | RTLD_NODELETE);
if (!ucp_handle) {
ucp_handle = dlopen("libucp.so", RTLD_LAZY | RTLD_NODELETE);
ASSERT(ucp_handle, "Cannot open UCX library: %s\n", dlerror());
}
// Reset any potential error
dlerror();
}

void assert_dlerror()
{
char* error = dlerror();
ASSERT(error == NULL, "Error loading function symbol: %s\n", error);
}

void load_send_func()
{
send_func = (dlsym_send)dlsym(ucp_handle, "ucp_tag_send_nb");
assert_dlerror();
}

void load_free_req_func()
{
req_free_func = (dlsym_rec_free)dlsym(ucp_handle, "ucp_request_free");
assert_dlerror();
}

void load_print_info_func()
{
print_info_func = (dlsym_print_info)dlsym(ucp_handle, "ucp_ep_print_info");
assert_dlerror();
}

void load_worker_progress_func()
{
worker_progress_func = (dlsym_worker_progress)dlsym(ucp_handle, "ucp_worker_progress");
assert_dlerror();
}

void load_recv_func()
{
recv_func = (dlsym_recv)dlsym(ucp_handle, "ucp_tag_recv_nb");
assert_dlerror();
}

ucp_tag_t build_message_tag(int rank, int tag) const
{
// keeping the rank in the lower bits enables debugging.
return ((uint32_t)tag << 31) | (uint32_t)rank;
}

public:
int ucp_progress(ucp_worker_h worker) const { return (*(worker_progress_func))(worker); }

/**
* @brief Frees any memory underlying the given ucp request object
*/
void free_ucp_request(ucp_request* request) const
{
if (request->needs_release) {
request->req->completed = 0;
(*(req_free_func))(request->req);
ucp_request_free(request->req);
}
free(request);
}
Expand All @@ -198,7 +109,7 @@ class comms_ucp_handler {
ucp_tag_t ucp_tag = build_message_tag(rank, tag);

ucs_status_ptr_t send_result =
(*(send_func))(ep_ptr, buf, size, ucp_dt_make_contig(1), ucp_tag, send_callback);
ucp_tag_send_nb(ep_ptr, buf, size, ucp_dt_make_contig(1), ucp_tag, send_callback);
struct ucx_context* ucp_req = (struct ucx_context*)send_result;

if (UCS_PTR_IS_ERR(send_result)) {
Expand Down Expand Up @@ -240,7 +151,7 @@ class comms_ucp_handler {
ucp_tag_t ucp_tag = build_message_tag(sender_rank, tag);

ucs_status_ptr_t recv_result =
(*(recv_func))(worker, buf, size, ucp_dt_make_contig(1), ucp_tag, tag_mask, recv_callback);
ucp_tag_recv_nb(worker, buf, size, ucp_dt_make_contig(1), ucp_tag, tag_mask, recv_callback);

struct ucx_context* ucp_req = (struct ucx_context*)recv_result;

Expand Down
5 changes: 3 additions & 2 deletions python/raft-dask/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ option(FIND_RAFT_CPP "Search for existing RAFT C++ installations before defaulti

# If the user requested it we attempt to find RAFT.
if(FIND_RAFT_CPP)
find_package(raft ${raft_dask_version} REQUIRED)
find_package(raft ${raft_dask_version} REQUIRED COMPONENTS distributed)
else()
set(raft_FOUND OFF)
endif()
Expand All @@ -47,7 +47,8 @@ if(NOT raft_FOUND)
enable_language(CUDA)
# Since raft-dask only enables CUDA optionally we need to manually include the file that
# rapids_cuda_init_architectures relies on `project` including.
include("${CMAKE_PROJECT_raft_dask_INCLUDE}")
include("${CMAKE_PROJECT_raft-dask_INCLUDE}")
find_package(ucx REQUIRED)

# raft-dask doesn't actually use raft libraries, it just needs the headers, so we can turn off all
# library compilation and we don't need to install anything here.
Expand Down
3 changes: 1 addition & 2 deletions python/raft-dask/raft_dask/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
# =============================================================================

include(${raft-dask-python_SOURCE_DIR}/cmake/thirdparty/get_nccl.cmake)
find_package(ucx REQUIRED)

set(cython_sources comms_utils.pyx nccl.pyx)
set(linked_libraries raft::raft NCCL::NCCL ucx::ucp)
set(linked_libraries raft::raft raft::distributed NCCL::NCCL)
rapids_cython_create_modules(
SOURCE_FILES "${cython_sources}" LINKED_LIBRARIES "${linked_libraries}" CXX
)

0 comments on commit c743916

Please sign in to comment.