Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for KMeans raw pointers API #758

Merged
merged 4 commits into from
Jul 27, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions cpp/include/raft/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -901,14 +901,15 @@ void kmeans_fit(handle_t const& handle,
DataT& inertia,
IndexT& n_iter)
{
auto XView = raft::make_device_matrix_view<DataT, IndexT>(X, n_samples, n_features);
auto XView = raft::make_device_matrix_view<const DataT, IndexT>(X, n_samples, n_features);
auto centroidsView =
raft::make_device_matrix_view<DataT, IndexT>(centroids, params.n_clusters, n_features);
std::optional<raft::device_vector_view<const DataT>> sample_weightView = std::nullopt;
if (sample_weight)
sample_weightView = raft::make_device_vector_view<DataT, IndexT>(sample_weight, n_samples);
auto inertiaView = raft::make_host_scalar_view<DataT, IndexT>(&inertia);
auto n_iterView = raft::make_host_scalar_view<DataT, IndexT>(&n_iter);
sample_weightView =
raft::make_device_vector_view<const DataT, IndexT>(sample_weight, n_samples);
auto inertiaView = raft::make_host_scalar_view(&inertia);
auto n_iterView = raft::make_host_scalar_view(&n_iter);

detail::kmeans_fit<DataT, IndexT>(
handle, params, XView, sample_weightView, centroidsView, inertiaView, n_iterView);
Expand Down Expand Up @@ -1034,14 +1035,14 @@ void kmeans_predict(handle_t const& handle,
bool normalize_weight,
DataT& inertia)
{
auto XView = raft::make_device_matrix_view<DataT, IndexT>(X, n_samples, n_features);
auto XView = raft::make_device_matrix_view<const DataT, IndexT>(X, n_samples, n_features);
auto centroidsView =
raft::make_device_matrix_view<DataT, IndexT>(centroids, params.n_clusters, n_features);
raft::make_device_matrix_view<const DataT, IndexT>(centroids, params.n_clusters, n_features);
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weightView{std::nullopt};
if (sample_weight)
sample_weightView.emplace(
raft::make_device_vector_view<DataT, IndexT>(sample_weight, n_samples));
auto labelsView = raft::make_device_vector_view<DataT, IndexT>(labels, n_samples);
raft::make_device_vector_view<const DataT, IndexT>(sample_weight, n_samples));
auto labelsView = raft::make_device_vector_view<IndexT, IndexT>(labels, n_samples);
auto inertiaView = raft::make_host_scalar_view(&inertia);

detail::kmeans_predict<DataT, IndexT>(handle,
Expand Down Expand Up @@ -1092,18 +1093,18 @@ void kmeans_fit_predict(handle_t const& handle,
DataT& inertia,
IndexT& n_iter)
{
auto XView = raft::make_device_matrix_view<DataT, IndexT>(X, n_samples, n_features);
auto XView = raft::make_device_matrix_view<const DataT, IndexT>(X, n_samples, n_features);
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weightView{std::nullopt};
if (sample_weight)
sample_weightView.emplace(
raft::make_device_vector_view<DataT, IndexT>(sample_weight, n_samples));
raft::make_device_vector_view<const DataT, IndexT>(sample_weight, n_samples));
std::optional<raft::device_matrix_view<DataT, IndexT>> centroidsView{std::nullopt};
if (centroids)
centroidsView.emplace(
raft::make_device_matrix_view<DataT, IndexT>(centroids, params.n_clusters, n_features));
auto labelsView = raft::make_device_vector_view<DataT, IndexT>(labels, n_samples);
auto inertiaView = raft::make_host_scalar_view<DataT, IndexT>(&inertia);
auto n_iterView = raft::make_host_scalar_view<DataT, IndexT>(&n_iter);
auto labelsView = raft::make_device_vector_view<IndexT, IndexT>(labels, n_samples);
auto inertiaView = raft::make_host_scalar_view(&inertia);
auto n_iterView = raft::make_host_scalar_view(&n_iter);

detail::kmeans_fit_predict<DataT, IndexT>(
handle, params, XView, sample_weightView, centroidsView, labelsView, inertiaView, n_iterView);
Expand Down Expand Up @@ -1146,12 +1147,12 @@ void kmeans_transform(const raft::handle_t& handle,

// datasetView [ns x n_features] - view representing the current batch of
// input dataset
auto datasetView =
raft::make_device_matrix_view<DataT, IndexT>(X.data() + n_features * dIdx, ns, n_features);
auto datasetView = raft::make_device_matrix_view<const DataT, IndexT>(
X.data_handle() + n_features * dIdx, ns, n_features);

// pairwiseDistanceView [ns x n_clusters]
auto pairwiseDistanceView = raft::make_device_matrix_view<DataT, IndexT>(
X_new.data() + n_clusters * dIdx, ns, n_clusters);
X_new.data_handle() + n_clusters * dIdx, ns, n_clusters);

// calculate pairwise distance between cluster centroids and current batch
// of input dataset
Expand All @@ -1169,9 +1170,9 @@ void kmeans_transform(const raft::handle_t& handle,
IndexT n_features,
DataT* X_new)
{
auto XView = raft::make_device_matrix_view<DataT, IndexT>(X, n_samples, n_features);
auto XView = raft::make_device_matrix_view<const DataT, IndexT>(X, n_samples, n_features);
auto centroidsView =
raft::make_device_matrix_view<DataT, IndexT>(centroids, params.n_clusters, n_features);
raft::make_device_matrix_view<const DataT, IndexT>(centroids, params.n_clusters, n_features);
auto X_newView = raft::make_device_matrix_view<DataT, IndexT>(X_new, n_samples, n_features);

detail::kmeans_transform<DataT, IndexT>(handle, params, XView, centroidsView, X_newView);
Expand Down