Skip to content

Commit

Permalink
Fix for KMeans raw pointers API (#758)
Browse files Browse the repository at this point in the history
Authors:
  - Micka (https://github.com/lowener)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #758
  • Loading branch information
lowener authored Jul 27, 2022
1 parent 78fd5a7 commit 2e575ef
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions cpp/include/raft/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -901,14 +901,15 @@ void kmeans_fit(handle_t const& handle,
DataT& inertia,
IndexT& n_iter)
{
auto XView = raft::make_device_matrix_view<DataT, IndexT>(X, n_samples, n_features);
auto XView = raft::make_device_matrix_view<const DataT, IndexT>(X, n_samples, n_features);
auto centroidsView =
raft::make_device_matrix_view<DataT, IndexT>(centroids, params.n_clusters, n_features);
std::optional<raft::device_vector_view<const DataT>> sample_weightView = std::nullopt;
if (sample_weight)
sample_weightView = raft::make_device_vector_view<DataT, IndexT>(sample_weight, n_samples);
auto inertiaView = raft::make_host_scalar_view<DataT, IndexT>(&inertia);
auto n_iterView = raft::make_host_scalar_view<DataT, IndexT>(&n_iter);
sample_weightView =
raft::make_device_vector_view<const DataT, IndexT>(sample_weight, n_samples);
auto inertiaView = raft::make_host_scalar_view(&inertia);
auto n_iterView = raft::make_host_scalar_view(&n_iter);

detail::kmeans_fit<DataT, IndexT>(
handle, params, XView, sample_weightView, centroidsView, inertiaView, n_iterView);
Expand Down Expand Up @@ -1034,14 +1035,14 @@ void kmeans_predict(handle_t const& handle,
bool normalize_weight,
DataT& inertia)
{
auto XView = raft::make_device_matrix_view<DataT, IndexT>(X, n_samples, n_features);
auto XView = raft::make_device_matrix_view<const DataT, IndexT>(X, n_samples, n_features);
auto centroidsView =
raft::make_device_matrix_view<DataT, IndexT>(centroids, params.n_clusters, n_features);
raft::make_device_matrix_view<const DataT, IndexT>(centroids, params.n_clusters, n_features);
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weightView{std::nullopt};
if (sample_weight)
sample_weightView.emplace(
raft::make_device_vector_view<DataT, IndexT>(sample_weight, n_samples));
auto labelsView = raft::make_device_vector_view<DataT, IndexT>(labels, n_samples);
raft::make_device_vector_view<const DataT, IndexT>(sample_weight, n_samples));
auto labelsView = raft::make_device_vector_view<IndexT, IndexT>(labels, n_samples);
auto inertiaView = raft::make_host_scalar_view(&inertia);

detail::kmeans_predict<DataT, IndexT>(handle,
Expand Down Expand Up @@ -1092,18 +1093,18 @@ void kmeans_fit_predict(handle_t const& handle,
DataT& inertia,
IndexT& n_iter)
{
auto XView = raft::make_device_matrix_view<DataT, IndexT>(X, n_samples, n_features);
auto XView = raft::make_device_matrix_view<const DataT, IndexT>(X, n_samples, n_features);
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weightView{std::nullopt};
if (sample_weight)
sample_weightView.emplace(
raft::make_device_vector_view<DataT, IndexT>(sample_weight, n_samples));
raft::make_device_vector_view<const DataT, IndexT>(sample_weight, n_samples));
std::optional<raft::device_matrix_view<DataT, IndexT>> centroidsView{std::nullopt};
if (centroids)
centroidsView.emplace(
raft::make_device_matrix_view<DataT, IndexT>(centroids, params.n_clusters, n_features));
auto labelsView = raft::make_device_vector_view<DataT, IndexT>(labels, n_samples);
auto inertiaView = raft::make_host_scalar_view<DataT, IndexT>(&inertia);
auto n_iterView = raft::make_host_scalar_view<DataT, IndexT>(&n_iter);
auto labelsView = raft::make_device_vector_view<IndexT, IndexT>(labels, n_samples);
auto inertiaView = raft::make_host_scalar_view(&inertia);
auto n_iterView = raft::make_host_scalar_view(&n_iter);

detail::kmeans_fit_predict<DataT, IndexT>(
handle, params, XView, sample_weightView, centroidsView, labelsView, inertiaView, n_iterView);
Expand Down Expand Up @@ -1146,12 +1147,12 @@ void kmeans_transform(const raft::handle_t& handle,

// datasetView [ns x n_features] - view representing the current batch of
// input dataset
auto datasetView =
raft::make_device_matrix_view<DataT, IndexT>(X.data() + n_features * dIdx, ns, n_features);
auto datasetView = raft::make_device_matrix_view<const DataT, IndexT>(
X.data_handle() + n_features * dIdx, ns, n_features);

// pairwiseDistanceView [ns x n_clusters]
auto pairwiseDistanceView = raft::make_device_matrix_view<DataT, IndexT>(
X_new.data() + n_clusters * dIdx, ns, n_clusters);
X_new.data_handle() + n_clusters * dIdx, ns, n_clusters);

// calculate pairwise distance between cluster centroids and current batch
// of input dataset
Expand All @@ -1169,9 +1170,9 @@ void kmeans_transform(const raft::handle_t& handle,
IndexT n_features,
DataT* X_new)
{
auto XView = raft::make_device_matrix_view<DataT, IndexT>(X, n_samples, n_features);
auto XView = raft::make_device_matrix_view<const DataT, IndexT>(X, n_samples, n_features);
auto centroidsView =
raft::make_device_matrix_view<DataT, IndexT>(centroids, params.n_clusters, n_features);
raft::make_device_matrix_view<const DataT, IndexT>(centroids, params.n_clusters, n_features);
auto X_newView = raft::make_device_matrix_view<DataT, IndexT>(X_new, n_samples, n_features);

detail::kmeans_transform<DataT, IndexT>(handle, params, XView, centroidsView, X_newView);
Expand Down

0 comments on commit 2e575ef

Please sign in to comment.