Skip to content

Commit

Permalink
Interruptible execution (rapidsai#4463)
Browse files Browse the repository at this point in the history
### Cooperative-style interruptible C++ threads.

This proposal attempts to make cuml experience more responsive by allowing easier way to interrupt/cancel long running cuml tasks. It replaces calls `cudaStreamSynchronize` with `raft::interruptible::synchronize`, which serve as a cancellation points in the algorithms. With a small extra hook on the python side, Ctrl+C requests now can interrupt the execution (almost) immediately. At this moment, I adapted just a few models as a proof-of-concept.

Example:
```python
import sklearn.datasets
import cuml.svm

X, y = sklearn.datasets.fetch_olivetti_faces(return_X_y=True)
model = cuml.svm.SVC()
print("Data loaded; fitting... (try Ctrl+C now)")
try:
    model.fit(X, y)
    print("Done! Score:", model.score(X, y))
except Exception as e:
    print("Canceled!")
    print(e)
```
#### Implementation details
rapidsai/raft#433

#### Adoption costs
From the changeset in this PR you can see that I introduce two types of changes:
  1. Change `cudaStreamSynchronize` to either `handle.sync_thread` or `raft::interruptible::synchronize`
  2. Wrap the cython calls with  [`cuda_interruptible`](https://github.com/rapidsai/raft/blob/36e8de5f73e9ec7e604b38a4290ac82bc35be4b7/python/raft/common/interruptible.pyx#L28) and `nogil`

Change (1) is straightforward and can mostly be automated.

Change (2) is a bit more involved. You definitely have to wrap a C++ call with `interruptibleCpp` to make `Ctrl+C` work, but that is also rather simple. The tricky part is adding `nogil`, because you have to make sure there is no python objects within `with nogil` block. However, `nogil` does not seem to be strictly required for the signal handler to successfully interrupt the C++ thread. It worked in my tests without `nogil` as well. Yet, I chose to add `nogil` in the code where possible, because in theory it should reduce the interrupt latency and enable more multithreading.

#### Motivation
In general, this proposal makes executing threads (and thus algos/models) more controllable. The main use cases I see:

  1. Being able to Ctrl+C the running model using signal handlers.
  2. Stopping the thread programmatically, e.g. we can create the tests of sort "if running for more than n seconds, stop and fail".

Resolves rapidsai#4384

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#4463
  • Loading branch information
achirkin authored Feb 8, 2022
1 parent 0e0bcac commit 39a4b24
Show file tree
Hide file tree
Showing 105 changed files with 304 additions and 294 deletions.
2 changes: 1 addition & 1 deletion cpp/bench/sg/arima_loglikelihood.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ArimaLoglikelihood : public TsFixtureRandom<DataT> {
counting + this->params.batch_size,
[=] __device__(int bid) { x[(bid + 1) * N - 1] = 1.0; });

RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);

// Benchmark loop
this->loopOnState(state, [this]() {
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/sg/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class FIL : public RegressionFixture<float> {
auto* mPtr = &rf_model;
size_t train_nrows = std::min(params.nrows, 1000);
fit(*handle, mPtr, data.X.data(), train_nrows, params.ncols, data.y.data(), p_rest.rf);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle->sync_stream(stream);

ML::build_treelite_forest(&model, &rf_model, params.ncols);
ML::fil::treelite_params_t tl_params = {
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/sg/rf_classifier.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class RFClassifier : public BlobsFixture<D> {
this->data.y.data(),
this->params.nclasses,
rfParams);
RAFT_CUDA_TRY(cudaStreamSynchronize(this->stream));
this->handle->sync_stream(this->stream);
});
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/sg/rf_regressor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class RFRegressor : public RegressionFixture<D> {
this->params.ncols,
this->data.y,
rfParams);
RAFT_CUDA_TRY(cudaStreamSynchronize(this->stream));
handle->sync_stream(this->stream);
});
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/sg/svc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class SVC : public BlobsFixture<D, D> {
this->kernel,
this->model,
static_cast<D*>(nullptr));
RAFT_CUDA_TRY(cudaStreamSynchronize(this->stream));
this->handle->sync_stream(this->stream);
ML::SVM::svmFreeBuffers(*this->handle, this->model);
});
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/sg/svr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class SVR : public RegressionFixture<D> {
this->svm_param,
this->kernel,
*(this->model));
RAFT_CUDA_TRY(cudaStreamSynchronize(this->stream));
this->handle->sync_stream(this->stream);
ML::SVM::svmFreeBuffers(*this->handle, *(this->model));
});
}
Expand Down Expand Up @@ -130,4 +130,4 @@ ML_BENCH_REGISTER(SvrParams<double>, SVR<double>, "regression", getInputs<double

} // namespace SVM
} // namespace Bench
} // end namespace ML
} // end namespace ML
4 changes: 2 additions & 2 deletions cpp/src/dbscan/adjgraph/naive.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void launcher(const raft::handle_t& handle,
ML::pinned_host_vector<Index_> host_ex_scan(batch_size);
raft::update_host((bool*)host_adj.data(), data.adj, batch_size * N, stream);
raft::update_host(host_vd.data(), data.vd, batch_size + 1, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
size_t adjgraph_size = size_t(host_vd[batch_size]);
ML::pinned_host_vector<Index_> host_adj_graph(adjgraph_size);
for (Index_ i = 0; i < batch_size; i++) {
Expand All @@ -62,4 +62,4 @@ void launcher(const raft::handle_t& handle,
} // namespace Naive
} // namespace AdjGraph
} // namespace Dbscan
} // namespace ML
} // namespace ML
2 changes: 1 addition & 1 deletion cpp/src/dbscan/runner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ std::size_t run(const raft::handle_t& handle,
raft::common::nvtx::pop_range();
}
raft::update_host(&curradjlen, vd + n_points, 1, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);

CUML_LOG_DEBUG("--> Computing adjacency graph with %ld nnz.", (unsigned long)curradjlen);
raft::common::nvtx::push_range("Trace::Dbscan::AdjGraph");
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/decisiontree/batched-levelalgo/builder.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ struct Builder {
RAFT_CUDA_TRY(cudaGetLastError());
raft::common::nvtx::pop_range();
raft::update_host(h_splits, splits, work_items.size(), builder_stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(builder_stream));
handle.sync_stream(builder_stream);
return std::make_tuple(h_splits, work_items.size());
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ struct dense_forest : forest {
dispatch_on_fil_template_params(opt_into_arch_dependent_shmem<dense_storage>(max_shm_),
static_cast<predict_params>(class_ssp_));
// copy must be finished before freeing the host data
RAFT_CUDA_TRY(cudaStreamSynchronize(h.get_stream()));
h.sync_stream();
h_nodes_.clear();
h_nodes_.shrink_to_fit();
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/fil/treelite_import.cu
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ struct tl2fil_t {
handle, pforest, cat_sets_.accessor(), vector_leaf_, roots_.data(), nodes_.data(), &params_);
// sync is necessary as nodes_ are used in init(),
// but destructed at the end of this function
RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream()));
handle.sync_stream(handle.get_stream());
if (tl_params_.pforest_shape_str) {
*tl_params_.pforest_shape_str = sprintf_shape(model_, nodes_, roots_, cat_sets_);
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/genetic/genetic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ void parallel_evolve(const raft::handle_t& h,
RAFT_CUDA_TRY(cudaPeekAtLastError());

// Make sure tournaments have finished running before copying win indices
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
h.sync_stream(stream);

// Perform host mutations

Expand Down Expand Up @@ -242,7 +242,7 @@ void parallel_evolve(const raft::handle_t& h,
}

// Make sure all copying is done
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
h.sync_stream(stream);

// Update raw fitness for all programs
set_batched_fitness(
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/glm/ols_mg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ void fit_impl(raft::handle_t& handle,
verbose);

for (int i = 0; i < n_streams; i++) {
RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i]));
handle.sync_stream(streams[i]);
}

for (int i = 0; i < n_streams; i++) {
Expand Down Expand Up @@ -227,7 +227,7 @@ void predict_impl(raft::handle_t& handle,
handle, input_data, input_desc, coef, intercept, preds_data, streams, n_streams, verbose);

for (int i = 0; i < n_streams; i++) {
RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i]));
handle.sync_stream(streams[i]);
}

for (int i = 0; i < n_streams; i++) {
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/glm/qn/glm_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ struct GLMWithData : GLMDims {
objective->loss_grad(dev_scalar, G, W, *X, *y, *Z, stream);
T loss_host;
raft::update_host(&loss_host, dev_scalar, 1, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
raft::interruptible::synchronize(stream);
return loss_host;
}
};
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/glm/qn/glm_regularizer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct RegularizedGLM : GLMDims {
loss->loss_grad(lossVal.data, G, W, Xb, yb, Zb, stream, false);
raft::update_host(&loss_host, lossVal.data, 1, stream);

RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
raft::interruptible::synchronize(stream);

lossVal.fill(loss_host + reg_host, stream);
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/glm/qn/glm_softmax.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void launchLogsoftmax(
T* loss_val, T* dldZ, const T* Z, const T* labels, int C, int N, cudaStream_t stream)
{
RAFT_CUDA_TRY(cudaMemsetAsync(loss_val, 0, sizeof(T), stream));
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
raft::interruptible::synchronize(stream);
if (C <= 4) {
dim3 bs(4, 64);
dim3 gs(ceildiv(N, 64));
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/glm/qn/simple_mat/base.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@

#include <raft/cuda_utils.cuh>
#include <raft/handle.hpp>
#include <raft/interruptible.hpp>

namespace ML {

Expand Down
11 changes: 6 additions & 5 deletions cpp/src/glm/qn/simple_mat/dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ inline T dot(const SimpleVec<T>& u, const SimpleVec<T>& v, T* tmp_dev, cudaStrea
raft::linalg::mapThenSumReduce(tmp_dev, u.len, f, stream, u.data, v.data);
T tmp_host;
raft::update_host(&tmp_host, tmp_dev, 1, stream);
cudaStreamSynchronize(stream);

raft::interruptible::synchronize(stream);
return tmp_host;
}

Expand All @@ -307,7 +308,7 @@ inline T nrmMax(const SimpleVec<T>& u, T* tmp_dev, cudaStream_t stream)
raft::linalg::mapThenReduce(tmp_dev, u.len, T(0), f, r, stream, u.data);
T tmp_host;
raft::update_host(&tmp_host, tmp_dev, 1, stream);
cudaStreamSynchronize(stream);
raft::interruptible::synchronize(stream);
return tmp_host;
}

Expand All @@ -324,7 +325,7 @@ inline T nrm1(const SimpleVec<T>& u, T* tmp_dev, cudaStream_t stream)
tmp_dev, u.data, u.len, 1, raft::linalg::L1Norm, true, stream, raft::Nop<T>());
T tmp_host;
raft::update_host(&tmp_host, tmp_dev, 1, stream);
cudaStreamSynchronize(stream);
raft::interruptible::synchronize(stream);
return tmp_host;
}

Expand All @@ -333,7 +334,7 @@ std::ostream& operator<<(std::ostream& os, const SimpleVec<T>& v)
{
std::vector<T> out(v.len);
raft::update_host(&out[0], v.data, v.len, 0);
RAFT_CUDA_TRY(cudaStreamSynchronize(0));
raft::interruptible::synchronize(rmm::cuda_stream_view());
int it = 0;
for (; it < v.len - 1;) {
os << out[it] << " ";
Expand All @@ -349,7 +350,7 @@ std::ostream& operator<<(std::ostream& os, const SimpleDenseMat<T>& mat)
os << "ord=" << (mat.ord == COL_MAJOR ? "CM" : "RM") << "\n";
std::vector<T> out(mat.len);
raft::update_host(&out[0], mat.data, mat.len, rmm::cuda_stream_default);
RAFT_CUDA_TRY(cudaStreamSynchronize(0));
raft::interruptible::synchronize(rmm::cuda_stream_view());
if (mat.ord == COL_MAJOR) {
for (int r = 0; r < mat.m; r++) {
int idx = r;
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/glm/qn/simple_mat/sparse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ struct SimpleSparseMat : SimpleMat<T> {
&bufferSize,
stream));

RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
raft::interruptible::synchronize(stream);
rmm::device_uvector<T> tmp(bufferSize, stream);

RAFT_CUSPARSE_TRY(raft::sparse::cusparsespmm(handle.get_cusparse_handle(),
Expand All @@ -170,7 +170,7 @@ inline void check_csr(const SimpleSparseMat<T>& mat, cudaStream_t stream)
{
int row_ids_nnz;
raft::update_host(&row_ids_nnz, &mat.row_ids[mat.m], 1, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
raft::interruptible::synchronize(stream);
ASSERT(row_ids_nnz == mat.nnz,
"SimpleSparseMat: the size of CSR row_ids array must be `m + 1`, and "
"the last element must be equal nnz.");
Expand All @@ -188,7 +188,7 @@ std::ostream& operator<<(std::ostream& os, const SimpleSparseMat<T>& mat)
raft::update_host(&values[0], mat.values, mat.nnz, rmm::cuda_stream_default);
raft::update_host(&cols[0], mat.cols, mat.nnz, rmm::cuda_stream_default);
raft::update_host(&row_ids[0], mat.row_ids, mat.m + 1, rmm::cuda_stream_default);
RAFT_CUDA_TRY(cudaStreamSynchronize(0));
raft::interruptible::synchronize(rmm::cuda_stream_view());

int i, row_end = 0;
for (int row = 0; row < mat.m; row++) {
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/glm/ridge_mg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ void fit_impl(raft::handle_t& handle,
verbose);

for (int i = 0; i < n_streams; i++) {
RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i]));
handle.sync_stream(streams[i]);
}

for (int i = 0; i < n_streams; i++) {
Expand Down Expand Up @@ -341,7 +341,7 @@ void predict_impl(raft::handle_t& handle,
handle, input_data, input_desc, coef, intercept, preds_data, streams, n_streams, verbose);

for (int i = 0; i < n_streams; i++) {
RAFT_CUDA_TRY(cudaStreamSynchronize(streams[i]));
handle.sync_stream(streams[i]);
}

for (int i = 0; i < n_streams; i++) {
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/hdbscan/detail/condense.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ void build_condensed_hierarchy(const raft::handle_t& handle,
n_elements_to_traverse =
thrust::reduce(exec_policy, frontier.data(), frontier.data() + root + 1, 0);

RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}

condensed_tree.condense(out_parent.data(), out_child.data(), out_lambda.data(), out_size.data());
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/hdbscan/detail/extract.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void do_labelling_on_host(const raft::handle_t& handle,
raft::update_host(
lambda_h.data(), condensed_tree.get_lambdas(), condensed_tree.get_n_edges(), stream);

RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);

auto parents = thrust::device_pointer_cast(condensed_tree.get_parents());
auto thrust_policy = handle.get_thrust_policy();
Expand Down Expand Up @@ -230,7 +230,7 @@ value_idx extract_clusters(const raft::handle_t& handle,

std::vector<int> is_cluster_h(is_cluster.size());
raft::update_host(is_cluster_h.data(), is_cluster.data(), is_cluster_h.size(), stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);

std::set<value_idx> clusters;
for (std::size_t i = 0; i < is_cluster_h.size(); i++) {
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/hdbscan/detail/select.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void perform_bfs(const raft::handle_t& handle,
thrust::fill(thrust_policy, next_frontier.begin(), next_frontier.end(), 0);

n_elements_to_traverse = thrust::reduce(thrust_policy, frontier, frontier + n_clusters, 0);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
}
}

Expand Down Expand Up @@ -200,7 +200,7 @@ void excess_of_mass(const raft::handle_t& handle,
std::vector<value_idx> indptr_h(indptr.size(), 0);
if (cluster_tree_edges > 0)
raft::update_host(indptr_h.data(), indptr.data(), indptr.size(), stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);

// Loop through stabilities in "reverse topological order" (e.g. reverse sorted order)
value_idx tree_top = allow_single_cluster ? 0 : 1;
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/kmeans/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ Tensor<DataT, 2, IndexT> sampleCentroids(const raft::handle_t& handle,

int nPtsSampledInRank = 0;
raft::copy(&nPtsSampledInRank, nSelected.data(), nSelected.numElements(), stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);

int* rawPtr_isSampleCentroid = isSampleCentroid.data();
thrust::for_each_n(handle.get_thrust_policy(),
Expand Down Expand Up @@ -769,7 +769,7 @@ void kmeansPlusPlus(const raft::handle_t& handle,
// Choose 'n_trials' centroid candidates from X with probability proportional to the squared
// distance to the nearest existing cluster
raft::copy(h_wt.data(), minClusterDistance.data(), minClusterDistance.numElements(), stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);

// Note - n_trials is relative small here, we don't need MLCommon::gather call
std::discrete_distribution<> d(h_wt.begin(), h_wt.end());
Expand Down Expand Up @@ -880,7 +880,7 @@ void checkWeights(const raft::handle_t& handle,

DataT wt_sum = 0;
raft::copy(&wt_sum, wt_aggr.data(), 1, stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);

if (wt_sum != n_samples) {
LOG(handle,
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/kmeans/kmeans_mg_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ void checkWeights(const raft::handle_t& handle,
raft::comms::op_t::SUM,
stream);
DataT wt_sum = wt_aggr.value(stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);

if (wt_sum != n_samples) {
LOG(handle,
Expand Down Expand Up @@ -662,7 +662,7 @@ void fit(const raft::handle_t& handle,
priorClusteringCost = curClusteringCost;
}

RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
if (sqrdNormError < params.tol) done = true;

if (done) {
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/kmeans/sg_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ void fit(const raft::handle_t& handle,
DataT curClusteringCost = 0;
raft::copy(&curClusteringCost, &(clusterCostD.data()->value), 1, stream);

RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
ASSERT(curClusteringCost != (DataT)0.0,
"Too few points and centriods being found is getting 0 cost from "
"centers");
Expand All @@ -244,7 +244,7 @@ void fit(const raft::handle_t& handle,
priorClusteringCost = curClusteringCost;
}

RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
if (sqrdNormError < params.tol) done = true;

if (done) {
Expand Down Expand Up @@ -425,7 +425,7 @@ void initScalableKMeansPlusPlus(const raft::handle_t& handle,
// <<< End of Step-2 >>>

// Scalable kmeans++ paper claims 8 rounds is sufficient
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
handle.sync_stream(stream);
int niter = std::min(8, (int)ceil(log(psi)));
LOG(handle, "KMeans||: psi = %g, log(psi) = %g, niter = %d ", psi, log(psi), niter);

Expand Down
Loading

0 comments on commit 39a4b24

Please sign in to comment.