Skip to content

Commit

Permalink
Use KMeans from Raft (#4713)
Browse files Browse the repository at this point in the history
Authors:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #4713
  • Loading branch information
lowener authored Nov 8, 2022
1 parent 0f7e801 commit b4967bf
Show file tree
Hide file tree
Showing 19 changed files with 733 additions and 2,375 deletions.
3 changes: 3 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ if hasArg clean; then
CLEAN=1
fi

if hasArg cpp-mgtests; then
BUILD_CUML_MG_TESTS=ON
fi

# Long arguments
LONG_ARGUMENT_LIST=(
Expand Down
9 changes: 6 additions & 3 deletions cpp/bench/sg/kmeans.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#include "benchmark.cuh"
#include <cuml/cluster/kmeans.hpp>
#include <cuml/common/logger.hpp>
#include <raft/cluster/specializations.cuh>
#include <raft/distance/distance_type.hpp>
#include <raft/random/rng_state.hpp>
#include <utility>

namespace ML {
Expand Down Expand Up @@ -86,9 +89,9 @@ std::vector<Params> getInputs()
p.kmeans.init = ML::kmeans::KMeansParams::InitMethod(0);
p.kmeans.max_iter = 300;
p.kmeans.tol = 1e-4;
p.kmeans.verbosity = CUML_LEVEL_INFO;
p.kmeans.seed = int(p.blobs.seed);
p.kmeans.metric = 0; // L2
p.kmeans.verbosity = RAFT_LEVEL_INFO;
p.kmeans.metric = raft::distance::DistanceType::L2Expanded;
p.kmeans.rng_state = raft::random::RngState(p.blobs.seed);
p.kmeans.inertia_check = true;
std::vector<std::pair<int, int>> rowcols = {
{160000, 64},
Expand Down
5 changes: 2 additions & 3 deletions cpp/examples/kmeans/kmeans_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@

#include <cuda_runtime.h>

#include <raft/core/handle.hpp>

#include <cuml/cluster/kmeans.hpp>
#include <raft/core/handle.hpp>

#ifndef CUDA_RT_CALL
#define CUDA_RT_CALL(call) \
Expand Down Expand Up @@ -112,7 +111,7 @@ int main(int argc, char* argv[])
params.max_iter = 300;
params.tol = 0.05;
}
params.metric = 1;
params.metric = raft::distance::DistanceType::L2SqrtExpanded;
params.init = ML::kmeans::KMeansParams::InitMethod::Random;

// Inputs copied from kmeans_test.cu
Expand Down
56 changes: 3 additions & 53 deletions cpp/include/cuml/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
#pragma once

#include <cuml/common/log_levels.hpp>
#include <raft/cluster/kmeans_types.hpp>

namespace raft {
class handle_t;
Expand All @@ -26,54 +27,7 @@ namespace ML {

namespace kmeans {

struct KMeansParams {
enum InitMethod { KMeansPlusPlus, Random, Array };

// The number of clusters to form as well as the number of centroids to
// generate (default:8).
int n_clusters = 8;

/*
* Method for initialization, defaults to k-means++:
* - InitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm
* to select the initial cluster centers.
* - InitMethod::Random (random): Choose 'n_clusters' observations (rows) at
* random from the input data for the initial centroids.
* - InitMethod::Array (ndarray): Use 'centroids' as initial cluster centers.
*/
InitMethod init = KMeansPlusPlus;

// Maximum number of iterations of the k-means algorithm for a single run.
int max_iter = 300;

// Relative tolerance with regards to inertia to declare convergence.
double tol = 1e-4;

// verbosity level.
int verbosity = CUML_LEVEL_INFO;

// Seed to the random number generator.
int seed = 0;

// Metric to use for distance computation. Any metric from
// raft::distance::DistanceType can be used
int metric = 0;

// Number of instance k-means algorithm will be run with different seeds.
int n_init = 1;

// Oversampling factor for use in the k-means|| algorithm.
double oversampling_factor = 2.0;

// batch_samples and batch_centroids are used to tile 1NN computation which is
// useful to optimize/control the memory footprint
// Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0
// then don't tile the centroids
int batch_samples = 1 << 15;
int batch_centroids = 0; // if 0 then batch_centroids = n_clusters

bool inertia_check = false;
};
using KMeansParams = raft::cluster::KMeansParams;

/**
* @brief Compute k-means clustering and predicts cluster index for each sample
Expand Down Expand Up @@ -222,8 +176,6 @@ void predict(const raft::handle_t& handle,
* @param[in] n_features Number of features or the dimensions of each
* sample in 'X' (it should be same as the dimension for each cluster centers in
* 'centroids').
* @param[in] metric Metric to use for distance computation. Any
* metric from raft::distance::DistanceType can be used
* @param[out] X_new X transformed in the new space..
*/
void transform(const raft::handle_t& handle,
Expand All @@ -232,7 +184,6 @@ void transform(const raft::handle_t& handle,
const float* X,
int n_samples,
int n_features,
int metric,
float* X_new);

void transform(const raft::handle_t& handle,
Expand All @@ -241,7 +192,6 @@ void transform(const raft::handle_t& handle,
const double* X,
int n_samples,
int n_features,
int metric,
double* X_new);

}; // end namespace kmeans
Expand Down
5 changes: 3 additions & 2 deletions cpp/include/cuml/cluster/kmeans_mg.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,13 +16,14 @@

#pragma once

#include <cuml/cluster/kmeans.hpp>

namespace raft {
class handle_t;
}

namespace ML {
namespace kmeans {
struct KMeansParams;
namespace opg {

/**
Expand Down
185 changes: 0 additions & 185 deletions cpp/src/common/tensor.hpp

This file was deleted.

Loading

0 comments on commit b4967bf

Please sign in to comment.