Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separating more namespaces into easier-to-consume sections #1091

Merged
merged 19 commits into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ While not exhaustive, the following general categories help summarize the accele
All of RAFT's C++ APIs can be accessed header-only and optional pre-compiled shared libraries can 1) speed up compile times and 2) enable the APIs to be used without CUDA-enabled compilers.

In addition to the C++ library, RAFT also provides 2 Python libraries:
- `pylibraft` - lightweight low-level Python wrappers around RAFT's host-accessible APIs.
- `pylibraft` - lightweight low-level Python wrappers around RAFT's host-accessible "runtime" APIs.
- `raft-dask` - multi-node multi-GPU communicator infrastructure for building distributed algorithms on the GPU with Dask.

## Getting started
Expand Down Expand Up @@ -142,7 +142,7 @@ in2 = cp.random.random_sample((n_samples, n_features), dtype=cp.float32)
output = pairwise_distance(in1, in2, metric="euclidean")
```

The `output` array supports [__cuda_array_interface__](https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html#cuda-array-interface-version-2) so it is interoperable with other libraries like CuPy, Numba, and PyTorch that also support it.
The `output` array in the above example is of type `raft.common.device_ndarray`, which supports [__cuda_array_interface__](https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html#cuda-array-interface-version-2) making it interoperable with other libraries like CuPy, Numba, and PyTorch that also support it. CuPy supports DLPack, which also enables zero-copy conversion from `raft.common.device_ndarray` to JAX and Tensorflow.

Below is an example of converting the output `pylibraft.device_ndarray` to a CuPy array:
```python
Expand All @@ -156,6 +156,18 @@ import torch
torch_tensor = torch.as_tensor(output, device='cuda')
```

When the corresponding library has been installed and available in your environment, this conversion can also be done automatically by all RAFT compute APIs by setting a global configuration option:
```python
import pylibraft.config
pylibraft.config.set_output_as("cupy") # All compute APIs will return cupy arrays
pylibraft.config.set_output_as("torch") # All compute APIs will return torch tensors
```

You can also specify a `callable` that accepts a `pylibraft.common.device_ndarray` and performs a custom conversion. The following example converts all output to `numpy` arrays:
```python
pylibraft.config.set_output_as(lambda device_ndarray: return device_ndarray.copy_to_host())
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
```

`pylibraft` also supports writing to a pre-allocated output array so any `__cuda_array_interface__` supported array can be written to in-place:

```python
Expand Down Expand Up @@ -257,7 +269,8 @@ Several CMake targets can be made available by adding components in the table be
| --- | --- | --- | --- |
| n/a | `raft::raft` | Full RAFT header library | CUDA toolkit library, RMM, Thrust (optional), NVTools (optional) |
| distance | `raft::distance` | Pre-compiled template specializations for raft::distance | raft::raft, cuCollections (optional) |
| nn | `raft::nn` | Pre-compiled template specializations for raft::spatial::knn | raft::raft, FAISS (optional) |
| nn | `raft::nn` | Pre-compiled template specializations for raft::neighbors | raft::raft, FAISS (optional) |
| distributed | `raft::distributed` | No specializations | raft::raft, UCX, NCCL |

### Source

Expand Down
102 changes: 0 additions & 102 deletions cpp/include/raft/comms/helper.hpp

This file was deleted.

35 changes: 35 additions & 0 deletions cpp/include/raft/comms/mpi_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,47 @@ namespace comms {

using mpi_comms = detail::mpi_comms;

/**
* @defgroup mpi_comms_factory MPI Comms Factory Functions
* @{
*/

/**
* Given a properly initialized MPI_Comm, construct an instance of RAFT's
* MPI Communicator and inject it into the given RAFT handle instance
* @param handle raft handle for managing expensive resources
* @param comm an initialized MPI communicator
*
* @code{.cpp}
* #include <raft/comms/mpi_comms.hpp>
* #include <raft/core/device_mdarray.hpp>
*
* MPI_Comm mpi_comm;
* raft::handle_t handle;
*
* initialize_mpi_comms(&handle, mpi_comm);
* ...
* const auto& comm = handle.get_comms();
* auto gather_data = raft::make_device_vector<float>(handle, comm.get_size());
* ...
* comm.allgather((gather_data.data_handle())[comm.get_rank()],
* gather_data.data_handle(),
* 1,
* handle.get_stream());
*
* comm.sync_stream(handle.get_stream());
* @endcode
*/
inline void initialize_mpi_comms(handle_t* handle, MPI_Comm comm)
{
auto communicator = std::make_shared<comms_t>(
std::unique_ptr<comms_iface>(new mpi_comms(comm, false, handle->get_stream())));
handle->set_comms(communicator);
};

/**
* @}
*/

}; // namespace comms
}; // end namespace raft
59 changes: 55 additions & 4 deletions cpp/include/raft/comms/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,38 @@ namespace comms {
using std_comms = detail::std_comms;

/**
* Function to construct comms_t and inject it on a handle_t. This
* is used for convenience in the Python layer.
* @defgroup std_comms_factory std_comms Factory functions
* @{
*/

/**
* Factory function to construct a RAFT NCCL communicator and inject it into a
* RAFT handle.
*
* @param handle raft::handle_t for injecting the comms
* @param nccl_comm initialized NCCL communicator to use for collectives
* @param num_ranks number of ranks in communicator clique
* @param rank rank of local instance
*
* @code{.cpp}
* #include <raft/comms/std_comms.hpp>
* #include <raft/core/device_mdarray.hpp>
*
* ncclComm_t nccl_comm;
* raft::handle_t handle;
*
* build_comms_nccl_only(&handle, nccl_comm, 5, 0);
* ...
* const auto& comm = handle.get_comms();
* auto gather_data = raft::make_device_vector<float>(handle, comm.get_size());
* ...
* comm.allgather((gather_data.data_handle())[comm.get_rank()],
* gather_data.data_handle(),
* 1,
* handle.get_stream());
*
* comm.sync_stream(handle.get_stream());
* @endcode
*/
void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks, int rank)
{
Expand All @@ -49,8 +74,8 @@ void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks
}

/**
* Function to construct comms_t and inject it on a handle_t. This
* is used for convenience in the Python layer.
* Factory function to construct a RAFT NCCL+UCX and inject it into a RAFT
* handle.
*
* @param handle raft::handle_t for injecting the comms
* @param nccl_comm initialized NCCL communicator to use for collectives
Expand All @@ -62,6 +87,28 @@ void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks
* the ucp_ep_h doesn't need to be exposed through the cython layer.
* @param num_ranks number of ranks in communicator clique
* @param rank rank of local instance
*
* @code{.cpp}
* #include <raft/comms/std_comms.hpp>
* #include <raft/core/device_mdarray.hpp>
*
* ncclComm_t nccl_comm;
* raft::handle_t handle;
* ucp_worker_h ucp_worker;
* ucp_ep_h *ucp_endpoints_arr;
*
* build_comms_nccl_ucx(&handle, nccl_comm, &ucp_worker, ucp_endpoints_arr, 5, 0);
* ...
* const auto& comm = handle.get_comms();
* auto gather_data = raft::make_device_vector<float>(handle, comm.get_size());
* ...
* comm.allgather((gather_data.data_handle())[comm.get_rank()],
* gather_data.data_handle(),
* 1,
* handle.get_stream());
*
* comm.sync_stream(handle.get_stream());
* @endcode
*/
void build_comms_nccl_ucx(
handle_t* handle, ncclComm_t nccl_comm, void* ucp_worker, void* eps, int num_ranks, int rank)
Expand Down Expand Up @@ -90,6 +137,10 @@ void build_comms_nccl_ucx(
handle->set_comms(communicator);
}

/**
* @}
*/

inline void nccl_unique_id_from_char(ncclUniqueId* id, char* uniqueId, int size)
{
memcpy(id->internal, uniqueId, size);
Expand Down
27 changes: 27 additions & 0 deletions cpp/include/raft/core/comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
namespace raft {
namespace comms {

/**
* @defgroup comms_types Common mnmg comms types
* @{
*/

typedef unsigned int request_t;
enum class datatype_t { CHAR, UINT8, INT32, UINT32, INT64, UINT64, FLOAT32, FLOAT64 };
enum class op_t { SUM, PROD, MIN, MAX };
Expand Down Expand Up @@ -105,6 +110,15 @@ get_type<double>()
return datatype_t::FLOAT64;
}

/**
* @}
*/

/**
* @defgroup comms_iface MNMG Communicator Interface
* @{
*/

class comms_iface {
public:
virtual ~comms_iface() {}
Expand Down Expand Up @@ -215,6 +229,15 @@ class comms_iface {
virtual void group_end() const = 0;
};

/**
* @}
*/

/**
* @defgroup comms_t Base Communicator Proxy
* @{
*/

class comms_t {
public:
comms_t(std::unique_ptr<comms_iface> impl) : impl_(impl.release())
Expand Down Expand Up @@ -647,5 +670,9 @@ class comms_t {
std::unique_ptr<comms_iface> impl_;
};

/**
* @}
*/

} // namespace comms
} // namespace raft
17 changes: 17 additions & 0 deletions cpp/include/raft/core/cublas_macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@

namespace raft {

/**
* @ingroup error_handling
* @{
*/

/**
* @brief Exception thrown when a cuBLAS error is encountered.
*/
Expand All @@ -40,6 +45,10 @@ struct cublas_error : public raft::exception {
explicit cublas_error(std::string const& message) : raft::exception(message) {}
};

/**
* @}
*/

namespace linalg {
namespace detail {

Expand All @@ -66,6 +75,11 @@ inline const char* cublas_error_to_string(cublasStatus_t err)

#undef _CUBLAS_ERR_TO_STR

/**
* @ingroup assertion
* @{
*/

/**
* @brief Error checking macro for cuBLAS runtime API functions.
*
Expand Down Expand Up @@ -108,6 +122,9 @@ inline const char* cublas_error_to_string(cublasStatus_t err)
} \
} while (0)

/**
* @}
*/
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
/** FIXME: remove after cuml rename */
#ifndef CUBLAS_CHECK
#define CUBLAS_CHECK(call) CUBLAS_TRY(call)
Expand Down
Loading