Skip to content

Commit

Permalink
Moving kmeans from cuml to Raft (#605)
Browse files Browse the repository at this point in the history
This PR replace the current KMeans of Raft with cuml's implementation.
Closes #28.
It is using the new `device_*_view` for the API.

Authors:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #605
  • Loading branch information
lowener authored Jul 26, 2022
1 parent af4f35c commit 563bb3f
Show file tree
Hide file tree
Showing 16 changed files with 3,155 additions and 1,021 deletions.
1,940 changes: 1,041 additions & 899 deletions cpp/include/raft/cluster/detail/kmeans.cuh

Large diffs are not rendered by default.

683 changes: 683 additions & 0 deletions cpp/include/raft/cluster/detail/kmeans_common.cuh

Large diffs are not rendered by default.

503 changes: 465 additions & 38 deletions cpp/include/raft/cluster/kmeans.cuh

Large diffs are not rendered by default.

73 changes: 73 additions & 0 deletions cpp/include/raft/cluster/kmeans_params.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <raft/core/logger.hpp>
#include <raft/distance/distance_type.hpp>
#include <raft/random/rng_state.hpp>

namespace raft {
namespace cluster {

struct KMeansParams {
enum InitMethod { KMeansPlusPlus, Random, Array };

// The number of clusters to form as well as the number of centroids to
// generate (default:8).
int n_clusters = 8;

/*
* Method for initialization, defaults to k-means++:
* - InitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm
* to select the initial cluster centers.
* - InitMethod::Random (random): Choose 'n_clusters' observations (rows) at
* random from the input data for the initial centroids.
* - InitMethod::Array (ndarray): Use 'centroids' as initial cluster centers.
*/
InitMethod init = KMeansPlusPlus;

// Maximum number of iterations of the k-means algorithm for a single run.
int max_iter = 300;

// Relative tolerance with regards to inertia to declare convergence.
double tol = 1e-4;

// verbosity level.
int verbosity = RAFT_LEVEL_INFO;

// Seed to the random number generator.
raft::random::RngState rng_state =
raft::random::RngState(0, raft::random::GeneratorType::GenPhilox);

// Metric to use for distance computation.
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded;

// Number of instance k-means algorithm will be run with different seeds.
int n_init = 1;

// Oversampling factor for use in the k-means|| algorithm.
double oversampling_factor = 2.0;

// batch_samples and batch_centroids are used to tile 1NN computation which is
// useful to optimize/control the memory footprint
// Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0
// then don't tile the centroids
int batch_samples = 1 << 15;
int batch_centroids = 0; // if 0 then batch_centroids = n_clusters

bool inertia_check = false;
};
} // namespace cluster
} // namespace raft
2 changes: 1 addition & 1 deletion cpp/include/raft/comms/detail/ucp_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ucp_request {
};

// by default, match the whole tag
static const ucp_tag_t default_tag_mask = -1;
static const ucp_tag_t default_tag_mask = (ucp_tag_t)-1;

/**
* @brief Asynchronous send callback sets request to completed
Expand Down
5 changes: 5 additions & 0 deletions cpp/include/raft/core/logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
*/
#pragma once

#ifndef __RAFT_RT_LOGGER
#define __RAFT_RT_LOGGER

#include <stdarg.h>

#include <algorithm>
Expand Down Expand Up @@ -315,3 +318,5 @@ class logger {
#define RAFT_LOG_CRITICAL(fmt, ...) void(0)
#endif
/** @} */

#endif
Loading

0 comments on commit 563bb3f

Please sign in to comment.