Skip to content

Commit

Permalink
Integrating RAFT handle updates (rapidsai#4313)
Browse files Browse the repository at this point in the history
Authors:
  - Divye Gala (https://github.com/divyegala)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#4313
  • Loading branch information
divyegala authored Dec 14, 2021
1 parent ea684a0 commit f16ff77
Show file tree
Hide file tree
Showing 50 changed files with 438 additions and 486 deletions.
3 changes: 1 addition & 2 deletions cpp/bench/prims/fused_l2_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ struct FusedL2NN : public Fixture {
alloc(out, params.m);
alloc(workspace, params.m);
raft::random::Rng r(123456ULL);
raft::handle_t handle;
handle.set_stream(stream);
raft::handle_t handle{stream};

r.uniform(x, params.m * params.k, T(-1.0), T(1.0), stream);
r.uniform(y, params.n * params.k, T(-1.0), T(1.0), stream);
Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/sg/benchmark.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class Fixture : public MLCommon::Bench::Fixture {

void SetUp(const ::benchmark::State& state) override
{
handle.reset(new raft::handle_t(NumStreams));
auto stream_pool = std::make_shared<rmm::cuda_stream_pool>(NumStreams);
handle.reset(new raft::handle_t{stream, stream_pool});
MLCommon::Bench::Fixture::SetUp(state);
handle->set_stream(stream);
}

void TearDown(const ::benchmark::State& state) override
Expand Down
4 changes: 1 addition & 3 deletions cpp/examples/dbscan/dbscan_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ int main(int argc, char* argv[])
}
}

raft::handle_t handle;

std::vector<float> h_inputData;

if (input == "") {
Expand Down Expand Up @@ -177,7 +175,7 @@ int main(int argc, char* argv[])

cudaStream_t stream;
CUDA_RT_CALL(cudaStreamCreate(&stream));
handle.set_stream(stream);
raft::handle_t handle{stream};

std::vector<int> h_labels(nRows);
int* d_labels = nullptr;
Expand Down
4 changes: 1 addition & 3 deletions cpp/examples/kmeans/kmeans_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,9 @@ int main(int argc, char* argv[])
std::cout << "Run KMeans with k=" << params.n_clusters << ", max_iterations=" << params.max_iter
<< std::endl;

raft::handle_t handle;

cudaStream_t stream;
CUDA_RT_CALL(cudaStreamCreate(&stream));
handle.set_stream(stream);
raft::handle_t handle{stream};

// srcdata size n_samples * n_features
double* d_srcdata = nullptr;
Expand Down
5 changes: 1 addition & 4 deletions cpp/examples/symreg/symreg_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,10 @@ int main(int argc, char* argv[])

/* ======================= Begin GPU memory allocation ======================= */
std::cout << "***************************************" << std::endl;
raft::handle_t handle;
std::shared_ptr<raft::mr::device::allocator> allocator(new raft::mr::device::default_allocator());

cudaStream_t stream;
CUDA_RT_CALL(cudaStreamCreate(&stream));
handle.set_stream(stream);
raft::handle_t handle{stream};

// Begin recording time
cudaEventRecord(start, stream);
Expand Down Expand Up @@ -342,6 +340,5 @@ int main(int argc, char* argv[])
raft::deallocate(d_finalprogs, stream);
CUDA_RT_CALL(cudaEventDestroy(start));
CUDA_RT_CALL(cudaEventDestroy(stop));
CUDA_RT_CALL(cudaStreamDestroy(stream));
return 0;
}
11 changes: 2 additions & 9 deletions cpp/include/cuml/cuml_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ const char* cumlGetErrorString(cumlError_t error);
* @brief Creates a cumlHandle_t
*
* @param[inout] handle pointer to the handle to create.
* @param[in] stream the stream to which cuML work should be ordered.
* @return CUML_SUCCESS on success, @todo: add more error codes
*/
cumlError_t cumlCreate(cumlHandle_t* handle);
cumlError_t cumlCreate(cumlHandle_t* handle, cudaStream_t stream);

/**
* @brief sets the stream to which all cuML work issued via the passed handle should be ordered.
Expand All @@ -64,14 +65,6 @@ cumlError_t cumlCreate(cumlHandle_t* handle);
* @param[in] stream the stream to which cuML work should be ordered.
* @return CUML_SUCCESS on success, @todo: add more error codes
*/
cumlError_t cumlSetStream(cumlHandle_t handle, cudaStream_t stream);
/**
* @brief gets the stream to which all cuML work issued via the passed handle should be ordered.
*
* @param[inout] handle handle to get the stream of.
* @param[out] stream pointer to the stream to which cuML work should be ordered.
* @return CUML_SUCCESS on success, @todo: add more error codes
*/
cumlError_t cumlGetStream(cumlHandle_t handle, cudaStream_t* stream);

/**
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/common/cumlHandle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ namespace ML {

HandleMap handleMap;

std::pair<cumlHandle_t, cumlError_t> HandleMap::createAndInsertHandle()
std::pair<cumlHandle_t, cumlError_t> HandleMap::createAndInsertHandle(cudaStream_t stream)
{
cumlError_t status = CUML_SUCCESS;
cumlHandle_t chosen_handle;
try {
auto handle_ptr = new raft::handle_t();
auto handle_ptr = new raft::handle_t{stream};
bool inserted;
{
std::lock_guard<std::mutex> guard(_mapMutex);
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/common/cumlHandle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ class HandleMap {
/**
* @brief Creates new handle object with associated handle ID and insert into map.
*
* @param[in] stream the stream to which cuML work should be ordered.
* @return std::pair with handle and error code. If error code is not CUML_SUCCESS
* the handle is INVALID_HANDLE.
*/
std::pair<cumlHandle_t, cumlError_t> createAndInsertHandle();
std::pair<cumlHandle_t, cumlError_t> createAndInsertHandle(cudaStream_t stream);

/**
* @brief Lookup pointer to handle object for handle ID in map.
Expand Down
26 changes: 2 additions & 24 deletions cpp/src/common/cuml_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,32 +89,10 @@ extern "C" const char* cumlGetErrorString(cumlError_t error)
}
}

extern "C" cumlError_t cumlCreate(cumlHandle_t* handle)
extern "C" cumlError_t cumlCreate(cumlHandle_t* handle, cudaStream_t stream)
{
cumlError_t status;
std::tie(*handle, status) = ML::handleMap.createAndInsertHandle();
return status;
}

extern "C" cumlError_t cumlSetStream(cumlHandle_t handle, cudaStream_t stream)
{
cumlError_t status;
raft::handle_t* handle_ptr;
std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle);
if (status == CUML_SUCCESS) {
try {
handle_ptr->set_stream(stream);
}
// TODO: Implement this
// catch (const MLCommon::Exception& e)
//{
// //log e.what()?
// status = e.getErrorCode();
//}
catch (...) {
status = CUML_ERROR_UNKNOWN;
}
}
std::tie(*handle, status) = ML::handleMap.createAndInsertHandle(stream);
return status;
}

Expand Down
7 changes: 3 additions & 4 deletions cpp/src/knn/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ void knn_classify(raft::handle_t& handle,
}

MLCommon::Selection::knn_classify(
out, knn_indices, y, n_index_rows, n_query_rows, k, uniq_labels, n_unique, stream);
handle, out, knn_indices, y, n_index_rows, n_query_rows, k, uniq_labels, n_unique);
}

void knn_regress(raft::handle_t& handle,
Expand All @@ -139,8 +139,7 @@ void knn_regress(raft::handle_t& handle,
size_t n_query_rows,
int k)
{
MLCommon::Selection::knn_regress(
out, knn_indices, y, n_index_rows, n_query_rows, k, handle.get_stream());
MLCommon::Selection::knn_regress(handle, out, knn_indices, y, n_index_rows, n_query_rows, k);
}

void knn_class_proba(raft::handle_t& handle,
Expand All @@ -164,7 +163,7 @@ void knn_class_proba(raft::handle_t& handle,
}

MLCommon::Selection::class_probs(
out, knn_indices, y, n_index_rows, n_query_rows, k, uniq_labels, n_unique, stream);
handle, out, knn_indices, y, n_index_rows, n_query_rows, k, uniq_labels, n_unique);
}

}; // END NAMESPACE ML
2 changes: 0 additions & 2 deletions cpp/src/knn/knn_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ cumlError_t knn_search(const cumlHandle_t handle,
raft::distance::DistanceType metric_distance_type =
static_cast<raft::distance::DistanceType>(metric_type);

std::vector<cudaStream_t> int_streams = handle_ptr->get_internal_streams();

std::vector<float*> input_vec(n_params);
std::vector<int> sizes_vec(n_params);
for (int i = 0; i < n_params; i++) {
Expand Down
27 changes: 8 additions & 19 deletions cpp/src/knn/knn_opg_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -910,15 +910,8 @@ void perform_local_operation(opg_knn_param<in_t, ind_t, dist_t, out_t>& params,
y[o] = reinterpret_cast<out_t*>(labels) + (o * n_labels);
}

MLCommon::Selection::knn_regress<float, 32, true>(outputs,
nullptr,
y,
n_labels,
batch_size,
params.k,
handle.get_stream(),
handle.get_internal_streams().data(),
handle.get_num_internal_streams());
MLCommon::Selection::knn_regress<float, 32, true>(
handle, outputs, nullptr, y, n_labels, batch_size, params.k);
}

/*!
Expand Down Expand Up @@ -952,30 +945,26 @@ void perform_local_operation(opg_knn_param<in_t, ind_t, dist_t, out_t>& params,

switch (params.knn_op) {
case knn_operation::classification:
MLCommon::Selection::knn_classify<32, true>(outputs,
MLCommon::Selection::knn_classify<32, true>(handle,
outputs,
nullptr,
y,
n_labels,
batch_size,
params.k,
*(params.uniq_labels),
*(params.n_unique),
handle.get_stream(),
handle.get_internal_streams().data(),
handle.get_num_internal_streams());
*(params.n_unique));
break;
case knn_operation::class_proba:
MLCommon::Selection::class_probs<32, true>(probas_with_offsets,
MLCommon::Selection::class_probs<32, true>(handle,
probas_with_offsets,
nullptr,
y,
n_labels,
batch_size,
params.k,
*(params.uniq_labels),
*(params.n_unique),
handle.get_stream(),
handle.get_internal_streams().data(),
handle.get_num_internal_streams());
*(params.n_unique));
break;
default: CUML_LOG_DEBUG("FAILURE!");
}
Expand Down
15 changes: 6 additions & 9 deletions cpp/src/randomforest/randomforest.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ class RandomForest {
n_sampled_rows = n_rows;
}
int n_streams = this->rf_params.n_streams;
ASSERT(n_streams <= handle.get_num_internal_streams(),
"rf_params.n_streams (=%d) should be <= raft::handle_t.n_streams (=%d)",
ASSERT(static_cast<std::size_t>(n_streams) <= handle.get_stream_pool_size(),
"rf_params.n_streams (=%d) should be <= raft::handle_t.n_streams (=%lu)",
n_streams,
handle.get_num_internal_streams());
handle.get_stream_pool_size());

// Select n_sampled_rows (with replacement) numbers from [0, n_rows) per tree.
// selected_rows: randomly generated IDs for bootstrapped samples (w/ replacement); a device
Expand All @@ -149,7 +149,7 @@ class RandomForest {
// constructor
std::deque<rmm::device_uvector<int>> selected_rows;
for (int i = 0; i < n_streams; i++) {
selected_rows.emplace_back(n_sampled_rows, handle.get_internal_stream(i));
selected_rows.emplace_back(n_sampled_rows, handle.get_stream_from_stream_pool(i));
}

auto global_quantiles =
Expand All @@ -159,7 +159,7 @@ class RandomForest {
#pragma omp parallel for num_threads(n_streams)
for (int i = 0; i < this->rf_params.n_trees; i++) {
int stream_id = omp_get_thread_num();
auto s = handle.get_internal_stream(stream_id);
auto s = handle.get_stream_from_stream_pool(i);

this->get_row_sample(i, n_rows, &selected_rows[stream_id], s);

Expand All @@ -186,10 +186,7 @@ class RandomForest {
i);
}
// Cleanup
for (int i = 0; i < n_streams; i++) {
auto s = handle.get_internal_stream(i);
CUDA_CHECK(cudaStreamSynchronize(s));
}
handle.sync_stream_pool();
CUDA_CHECK(cudaStreamSynchronize(handle.get_stream()));
ML::POP_RANGE();
}
Expand Down
16 changes: 8 additions & 8 deletions cpp/src/svm/linear.cu
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,10 @@ class WorkerHandle {
}

WorkerHandle(const raft::handle_t& h, int stream_id)
: handle_ptr(new raft::handle_t(h, stream_id, 0)),
: handle_ptr{new raft::handle_t{h.get_next_usable_stream(stream_id)}},
stream_id(stream_id),
handle(*handle_ptr),
stream(h.get_internal_stream(stream_id))
stream(h.get_next_usable_stream(stream_id))
{
}

Expand All @@ -322,7 +322,7 @@ LinearSVMModel<T> LinearSVMModel<T>::allocate(const raft::handle_t& handle,
const std::size_t nCols,
const std::size_t nClasses)
{
auto stream = handle.get_stream_view();
auto stream = handle.get_stream();
auto res = rmm::mr::get_current_device_resource();
const std::size_t coefRows = nCols + params.fit_intercept;
const std::size_t coefCols = nClasses <= 2 ? 1 : nClasses;
Expand All @@ -340,7 +340,7 @@ LinearSVMModel<T> LinearSVMModel<T>::allocate(const raft::handle_t& handle,
template <typename T>
void LinearSVMModel<T>::free(const raft::handle_t& handle, LinearSVMModel<T>& model)
{
auto stream = handle.get_stream_view();
auto stream = handle.get_stream();
auto res = rmm::mr::get_current_device_resource();
const std::size_t coefRows = model.coefRows;
const std::size_t coefCols = model.coefCols();
Expand Down Expand Up @@ -427,7 +427,7 @@ LinearSVMModel<T> LinearSVMModel<T>::fit(const raft::handle_t& handle,
// one-vs-rest logic goes over each class
std::vector<T> targets(coefCols);
std::vector<int> num_iters(coefCols);
const int n_streams = coefCols > 1 ? handle.get_num_internal_streams() : 1;
const int n_streams = coefCols > 1 ? handle.get_stream_pool_size() : 1;
bool parallel = n_streams > 1;
#pragma omp parallel for num_threads(n_streams) if (parallel)
for (int class_i = 0; class_i < coefCols; class_i++) {
Expand Down Expand Up @@ -496,7 +496,7 @@ LinearSVMModel<T> LinearSVMModel<T>::fit(const raft::handle_t& handle,
worker.stream,
(T*)sampleWeight);
}
if (parallel) handle.wait_on_internal_streams();
if (parallel) handle.sync_stream_pool();

if (coefCols > 1) {
raft::linalg::transpose(handle, w1, model.w, coefRows, coefCols, stream);
Expand All @@ -517,7 +517,7 @@ void LinearSVMModel<T>::predict(const raft::handle_t& handle,
const std::size_t nCols,
T* out)
{
auto stream = handle.get_stream_view();
auto stream = handle.get_stream();
const auto coefCols = model.coefCols();
if (isRegression(params.loss))
return predictLinear(
Expand Down Expand Up @@ -547,7 +547,7 @@ void LinearSVMModel<T>::predictProba(const raft::handle_t& handle,
ASSERT(model.probScale != nullptr,
"The model was not trained to output probabilities (model.probScale == nullptr).");

auto stream = handle.get_stream_view();
auto stream = handle.get_stream();
const auto coefCols = model.coefCols();
rmm::device_uvector<T> temp(nRows * coefCols, stream);

Expand Down
6 changes: 3 additions & 3 deletions cpp/src_prims/linalg/lstsq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,15 @@ void lstsqEig(const raft::handle_t& handle,
rmm::cuda_stream_view multAbStream = mainStream;
bool concurrent = false;
{
int sp_size = handle.get_num_internal_streams();
int sp_size = handle.get_stream_pool_size();
if (sp_size > 0) {
multAbStream = handle.get_internal_stream_view(0);
multAbStream = handle.get_stream_from_stream_pool(0);
// check if the two streams can run concurrently
if (!are_implicitly_synchronized(mainStream, multAbStream)) {
concurrent = true;
} else if (sp_size > 1) {
mainStream = multAbStream;
multAbStream = handle.get_internal_stream_view(1);
multAbStream = handle.get_stream_from_stream_pool(1);
concurrent = true;
}
}
Expand Down
Loading

0 comments on commit f16ff77

Please sign in to comment.