Skip to content

Commit

Permalink
Revert back changes in kmeans balanced.
Browse files Browse the repository at this point in the history
  • Loading branch information
abc99lr committed May 8, 2024
1 parent e26ca70 commit bfd0677
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions cpp/include/raft/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resource/thrust_nosync_policy.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/distance/distance.cuh>
#include <raft/distance/distance_types.hpp>
Expand Down Expand Up @@ -103,7 +102,7 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(
auto minClusterAndDistance = raft::make_device_mdarray<raft::KeyValuePair<IdxT, MathT>, IdxT>(
handle, mr, make_extents<IdxT>(n_rows));
raft::KeyValuePair<IdxT, MathT> initial_value(0, std::numeric_limits<MathT>::max());
thrust::fill(resource::get_thrust_nosync_policy(handle),
thrust::fill(resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
initial_value);
Expand All @@ -129,7 +128,7 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(

// todo(lsugy): use KVP + iterator in caller.
// Copy keys to output labels
thrust::transform(resource::get_thrust_nosync_policy(handle),
thrust::transform(resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + n_rows,
labels,
Expand Down

0 comments on commit bfd0677

Please sign in to comment.