Skip to content

Commit

Permalink
Separating more namespaces into easier-to-consume sections (#1091)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Ben Frederickson (https://github.com/benfred)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #1091
  • Loading branch information
cjnolet authored Dec 22, 2022
1 parent 107f6e3 commit a67a1db
Show file tree
Hide file tree
Showing 74 changed files with 1,196 additions and 686 deletions.
19 changes: 16 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ While not exhaustive, the following general categories help summarize the accele
All of RAFT's C++ APIs can be accessed header-only and optional pre-compiled shared libraries can 1) speed up compile times and 2) enable the APIs to be used without CUDA-enabled compilers.

In addition to the C++ library, RAFT also provides 2 Python libraries:
- `pylibraft` - lightweight low-level Python wrappers around RAFT's host-accessible APIs.
- `pylibraft` - lightweight low-level Python wrappers around RAFT's host-accessible "runtime" APIs.
- `raft-dask` - multi-node multi-GPU communicator infrastructure for building distributed algorithms on the GPU with Dask.

## Getting started
Expand Down Expand Up @@ -142,7 +142,7 @@ in2 = cp.random.random_sample((n_samples, n_features), dtype=cp.float32)
output = pairwise_distance(in1, in2, metric="euclidean")
```

The `output` array supports [__cuda_array_interface__](https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html#cuda-array-interface-version-2) so it is interoperable with other libraries like CuPy, Numba, and PyTorch that also support it.
The `output` array in the above example is of type `raft.common.device_ndarray`, which supports [__cuda_array_interface__](https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html#cuda-array-interface-version-2) making it interoperable with other libraries like CuPy, Numba, and PyTorch that also support it. CuPy supports DLPack, which also enables zero-copy conversion from `raft.common.device_ndarray` to JAX and Tensorflow.

Below is an example of converting the output `pylibraft.device_ndarray` to a CuPy array:
```python
Expand All @@ -156,6 +156,18 @@ import torch
torch_tensor = torch.as_tensor(output, device='cuda')
```

When the corresponding library has been installed and available in your environment, this conversion can also be done automatically by all RAFT compute APIs by setting a global configuration option:
```python
import pylibraft.config
pylibraft.config.set_output_as("cupy") # All compute APIs will return cupy arrays
pylibraft.config.set_output_as("torch") # All compute APIs will return torch tensors
```

You can also specify a `callable` that accepts a `pylibraft.common.device_ndarray` and performs a custom conversion. The following example converts all output to `numpy` arrays:
```python
pylibraft.config.set_output_as(lambda device_ndarray: return device_ndarray.copy_to_host())
```

`pylibraft` also supports writing to a pre-allocated output array so any `__cuda_array_interface__` supported array can be written to in-place:

```python
Expand Down Expand Up @@ -257,7 +269,8 @@ Several CMake targets can be made available by adding components in the table be
| --- | --- | --- | --- |
| n/a | `raft::raft` | Full RAFT header library | CUDA toolkit library, RMM, Thrust (optional), NVTools (optional) |
| distance | `raft::distance` | Pre-compiled template specializations for raft::distance | raft::raft, cuCollections (optional) |
| nn | `raft::nn` | Pre-compiled template specializations for raft::spatial::knn | raft::raft, FAISS (optional) |
| nn | `raft::nn` | Pre-compiled template specializations for raft::neighbors | raft::raft, FAISS (optional) |
| distributed | `raft::distributed` | No specializations | raft::raft, UCX, NCCL |

### Source

Expand Down
102 changes: 0 additions & 102 deletions cpp/include/raft/comms/helper.hpp

This file was deleted.

35 changes: 35 additions & 0 deletions cpp/include/raft/comms/mpi_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,47 @@ namespace comms {

using mpi_comms = detail::mpi_comms;

/**
* @defgroup mpi_comms_factory MPI Comms Factory Functions
* @{
*/

/**
* Given a properly initialized MPI_Comm, construct an instance of RAFT's
* MPI Communicator and inject it into the given RAFT handle instance
* @param handle raft handle for managing expensive resources
* @param comm an initialized MPI communicator
*
* @code{.cpp}
* #include <raft/comms/mpi_comms.hpp>
* #include <raft/core/device_mdarray.hpp>
*
* MPI_Comm mpi_comm;
* raft::handle_t handle;
*
* initialize_mpi_comms(&handle, mpi_comm);
* ...
* const auto& comm = handle.get_comms();
* auto gather_data = raft::make_device_vector<float>(handle, comm.get_size());
* ...
* comm.allgather((gather_data.data_handle())[comm.get_rank()],
* gather_data.data_handle(),
* 1,
* handle.get_stream());
*
* comm.sync_stream(handle.get_stream());
* @endcode
*/
inline void initialize_mpi_comms(handle_t* handle, MPI_Comm comm)
{
auto communicator = std::make_shared<comms_t>(
std::unique_ptr<comms_iface>(new mpi_comms(comm, false, handle->get_stream())));
handle->set_comms(communicator);
};

/**
* @}
*/

}; // namespace comms
}; // end namespace raft
59 changes: 55 additions & 4 deletions cpp/include/raft/comms/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,38 @@ namespace comms {
using std_comms = detail::std_comms;

/**
* Function to construct comms_t and inject it on a handle_t. This
* is used for convenience in the Python layer.
* @defgroup std_comms_factory std_comms Factory functions
* @{
*/

/**
* Factory function to construct a RAFT NCCL communicator and inject it into a
* RAFT handle.
*
* @param handle raft::handle_t for injecting the comms
* @param nccl_comm initialized NCCL communicator to use for collectives
* @param num_ranks number of ranks in communicator clique
* @param rank rank of local instance
*
* @code{.cpp}
* #include <raft/comms/std_comms.hpp>
* #include <raft/core/device_mdarray.hpp>
*
* ncclComm_t nccl_comm;
* raft::handle_t handle;
*
* build_comms_nccl_only(&handle, nccl_comm, 5, 0);
* ...
* const auto& comm = handle.get_comms();
* auto gather_data = raft::make_device_vector<float>(handle, comm.get_size());
* ...
* comm.allgather((gather_data.data_handle())[comm.get_rank()],
* gather_data.data_handle(),
* 1,
* handle.get_stream());
*
* comm.sync_stream(handle.get_stream());
* @endcode
*/
void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks, int rank)
{
Expand All @@ -49,8 +74,8 @@ void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks
}

/**
* Function to construct comms_t and inject it on a handle_t. This
* is used for convenience in the Python layer.
* Factory function to construct a RAFT NCCL+UCX and inject it into a RAFT
* handle.
*
* @param handle raft::handle_t for injecting the comms
* @param nccl_comm initialized NCCL communicator to use for collectives
Expand All @@ -62,6 +87,28 @@ void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks
* the ucp_ep_h doesn't need to be exposed through the cython layer.
* @param num_ranks number of ranks in communicator clique
* @param rank rank of local instance
*
* @code{.cpp}
* #include <raft/comms/std_comms.hpp>
* #include <raft/core/device_mdarray.hpp>
*
* ncclComm_t nccl_comm;
* raft::handle_t handle;
* ucp_worker_h ucp_worker;
* ucp_ep_h *ucp_endpoints_arr;
*
* build_comms_nccl_ucx(&handle, nccl_comm, &ucp_worker, ucp_endpoints_arr, 5, 0);
* ...
* const auto& comm = handle.get_comms();
* auto gather_data = raft::make_device_vector<float>(handle, comm.get_size());
* ...
* comm.allgather((gather_data.data_handle())[comm.get_rank()],
* gather_data.data_handle(),
* 1,
* handle.get_stream());
*
* comm.sync_stream(handle.get_stream());
* @endcode
*/
void build_comms_nccl_ucx(
handle_t* handle, ncclComm_t nccl_comm, void* ucp_worker, void* eps, int num_ranks, int rank)
Expand Down Expand Up @@ -90,6 +137,10 @@ void build_comms_nccl_ucx(
handle->set_comms(communicator);
}

/**
* @}
*/

inline void nccl_unique_id_from_char(ncclUniqueId* id, char* uniqueId, int size)
{
memcpy(id->internal, uniqueId, size);
Expand Down
27 changes: 27 additions & 0 deletions cpp/include/raft/core/comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
namespace raft {
namespace comms {

/**
* @defgroup comms_types Common mnmg comms types
* @{
*/

typedef unsigned int request_t;
enum class datatype_t { CHAR, UINT8, INT32, UINT32, INT64, UINT64, FLOAT32, FLOAT64 };
enum class op_t { SUM, PROD, MIN, MAX };
Expand Down Expand Up @@ -105,6 +110,15 @@ get_type<double>()
return datatype_t::FLOAT64;
}

/**
* @}
*/

/**
* @defgroup comms_iface MNMG Communicator Interface
* @{
*/

class comms_iface {
public:
virtual ~comms_iface() {}
Expand Down Expand Up @@ -215,6 +229,15 @@ class comms_iface {
virtual void group_end() const = 0;
};

/**
* @}
*/

/**
* @defgroup comms_t Base Communicator Proxy
* @{
*/

class comms_t {
public:
comms_t(std::unique_ptr<comms_iface> impl) : impl_(impl.release())
Expand Down Expand Up @@ -647,5 +670,9 @@ class comms_t {
std::unique_ptr<comms_iface> impl_;
};

/**
* @}
*/

} // namespace comms
} // namespace raft
17 changes: 17 additions & 0 deletions cpp/include/raft/core/cublas_macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@

namespace raft {

/**
* @ingroup error_handling
* @{
*/

/**
* @brief Exception thrown when a cuBLAS error is encountered.
*/
Expand All @@ -40,6 +45,10 @@ struct cublas_error : public raft::exception {
explicit cublas_error(std::string const& message) : raft::exception(message) {}
};

/**
* @}
*/

namespace linalg {
namespace detail {

Expand All @@ -66,6 +75,11 @@ inline const char* cublas_error_to_string(cublasStatus_t err)

#undef _CUBLAS_ERR_TO_STR

/**
* @ingroup assertion
* @{
*/

/**
* @brief Error checking macro for cuBLAS runtime API functions.
*
Expand Down Expand Up @@ -108,6 +122,9 @@ inline const char* cublas_error_to_string(cublasStatus_t err)
} \
} while (0)

/**
* @}
*/
/** FIXME: remove after cuml rename */
#ifndef CUBLAS_CHECK
#define CUBLAS_CHECK(call) CUBLAS_TRY(call)
Expand Down
Loading

0 comments on commit a67a1db

Please sign in to comment.