From 7c5f49f50e14e964c4adb0af0b540aaec9ef3cf6 Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Fri, 19 Oct 2018 13:06:01 -0400 Subject: [PATCH 01/10] kmeans added. --- cuML/CMakeLists.txt | 1 + cuML/src/kmeans/kmeans.cu | 1427 +++++++++++++++++++++++++ cuML/src/kmeans/kmeans.h | 91 ++ cuML/src/kmeans/kmeans_c.h | 76 ++ cuML/src/kmeans/kmeans_centroids.h | 187 ++++ cuML/src/kmeans/kmeans_general.h | 26 + cuML/src/kmeans/kmeans_impl.h | 256 +++++ cuML/src/kmeans/kmeans_labels.h | 906 ++++++++++++++++ cuML/src/kmeans/logger.h | 51 + cuML/src/kmeans/timer.h | 18 + cuML/src/kmeans/utils.h | 59 + cuML/test/kmeans_test.cu | 109 ++ python/cuML/cuml.pyx | 1 + python/cuML/kmeans/c_kmeans.pxd | 59 + python/cuML/kmeans/kmeans_test.py | 55 + python/cuML/kmeans/kmeans_wrapper.pyx | 230 ++++ setup.py | 2 +- 17 files changed, 3553 insertions(+), 1 deletion(-) create mode 100644 cuML/src/kmeans/kmeans.cu create mode 100644 cuML/src/kmeans/kmeans.h create mode 100644 cuML/src/kmeans/kmeans_c.h create mode 100644 cuML/src/kmeans/kmeans_centroids.h create mode 100644 cuML/src/kmeans/kmeans_general.h create mode 100644 cuML/src/kmeans/kmeans_impl.h create mode 100644 cuML/src/kmeans/kmeans_labels.h create mode 100644 cuML/src/kmeans/logger.h create mode 100644 cuML/src/kmeans/timer.h create mode 100644 cuML/src/kmeans/utils.h create mode 100644 cuML/test/kmeans_test.cu create mode 100644 python/cuML/kmeans/c_kmeans.pxd create mode 100644 python/cuML/kmeans/kmeans_test.py create mode 100644 python/cuML/kmeans/kmeans_wrapper.pyx diff --git a/cuML/CMakeLists.txt b/cuML/CMakeLists.txt index f64ff11657..e57b7ab600 100644 --- a/cuML/CMakeLists.txt +++ b/cuML/CMakeLists.txt @@ -94,6 +94,7 @@ add_executable(ml_test test/pca_test.cu test/tsvd_test.cu test/dbscan_test.cu + test/kmeans_test.cu ) target_link_libraries(ml_test diff --git a/cuML/src/kmeans/kmeans.cu b/cuML/src/kmeans/kmeans.cu new file mode 100644 index 0000000000..5eac1c2eeb --- /dev/null +++ b/cuML/src/kmeans/kmeans.cu @@ -0,0 +1,1427 @@ +/*! + * Copyright 2017-2018 H2O.ai, Inc. + * License Apache License Version 2.0 (see LICENSE for details) + */ +#include +#include +#include +#include +#include +#include +#include "cuda.h" +#include +#include +#include "kmeans_c.h" +#include "kmeans_impl.h" +#include "kmeans_general.h" +#include "kmeans.h" +#include +#include +#include +#include +#include +#include "utils.h" +#include + +cudaStream_t cuda_stream[MAX_NGPUS]; + +/** + * METHODS FOR DATA COPYING AND GENERATION + */ + +template +void random_data(int verbose, thrust::device_vector &array, int m, int n) { + thrust::host_vector host_array(m * n); + for (int i = 0; i < m * n; i++) { + host_array[i] = (T) rand() / (T) RAND_MAX; + } + array = host_array; +} + +/** + * Copies data from srcdata to array + * @tparam T + * @param verbose Logging level + * @param ord Column on row order of data + * @param array Destination array + * @param srcdata Source data + * @param q Shard number (from 0 to n_gpu) + * @param n + * @param npergpu + * @param d + */ +template +void copy_data(int verbose, const char ord, thrust::device_vector &array, + const T *srcdata, int q, int n, size_t npergpu, int d) { + if (ord == 'c') { + thrust::host_vector host_array(npergpu * d); + log_debug(verbose, "Copy data COL ORDER -> ROW ORDER"); + + for (size_t i = 0; i < npergpu * d; i++) { + size_t indexi = i % d; // col + size_t indexj = i / d + q * npergpu; // row (shifted by which gpu) + host_array[i] = srcdata[indexi * n + indexj]; + } + array = host_array; + } else { + log_debug(verbose, "Copy data ROW ORDER not changed"); + thrust::host_vector host_array(srcdata + q * npergpu * d, + srcdata + q * npergpu * d + npergpu * d); + array = host_array; + } +} + +/** + * Like copy_data but shuffles the data according to mapping from v + * @tparam T + * @param verbose + * @param v + * @param ord + * @param array + * @param srcdata + * @param q + * @param n + * @param npergpu + * @param d + */ +template +void copy_data_shuffled(int verbose, std::vector v, const char ord, + thrust::device_vector &array, const T *srcdata, int q, int n, + int npergpu, int d) { + thrust::host_vector host_array(npergpu * d); + if (ord == 'c') { + log_debug(verbose, "Copy data shuffle COL ORDER -> ROW ORDER"); + + for (int i = 0; i < npergpu; i++) { + for (size_t j = 0; j < d; j++) { + host_array[i * d + j] = srcdata[v[q * npergpu + i] + j * n]; // shift by which gpu + } + } + } else { + log_debug(verbose, "Copy data shuffle ROW ORDER not changed"); + + for (int i = 0; i < npergpu; i++) { + for (size_t j = 0; j < d; j++) { + host_array[i * d + j] = srcdata[v[q * npergpu + i] * d + j]; // shift by which gpu + } + } + } + array = host_array; +} + +template +void copy_centroids_shuffled(int verbose, std::vector v, const char ord, + thrust::device_vector &array, const T *srcdata, int n, int k, + int d) { + copy_data_shuffled(verbose, v, ord, array, srcdata, 0, n, k, d); +} + +/** + * Copies centroids from initial training set randomly. + * @tparam T + * @param verbose + * @param seed + * @param ord + * @param array + * @param srcdata + * @param q + * @param n + * @param npergpu + * @param d + * @param k + */ +template +void random_centroids(int verbose, int seed, const char ord, + thrust::device_vector &array, const T *srcdata, int q, int n, + int npergpu, int d, int k) { + thrust::host_vector host_array(k * d); + if (seed < 0) { + std::random_device rd; //Will be used to obtain a seed for the random number engine + seed = rd(); + } + std::mt19937 gen(seed); + std::uniform_int_distribution<> dis(0, n - 1); // random i in range from 0..n-1 (i.e. only 1 gpu gets centroids) + + if (ord == 'c') { + log_debug(verbose, "Random centroids COL ORDER -> ROW ORDER"); + for (int i = 0; i < k; i++) { // clusters + size_t reali = dis(gen); // + q*npergpu; // row sampled (called indexj above) + for (size_t j = 0; j < d; j++) { // cols + host_array[i * d + j] = srcdata[reali + j * n]; + } + } + } else { + log_debug(verbose, "Random centroids ROW ORDER not changed"); + for (int i = 0; i < k; i++) { // rows + size_t reali = dis(gen); // + q*npergpu ; // row sampled + for (size_t j = 0; j < d; j++) { // cols + host_array[i * d + j] = srcdata[reali * d + j]; + } + } + } + array = host_array; +} + +/** + * KMEANS METHODS FIT, PREDICT, TRANSFORM + */ + +#define __HBAR__ \ + "----------------------------------------------------------------------------\n" + +namespace h2o4gpukmeans { + +template +int kmeans_find_clusters(int verbose, const char ord, int seed, + thrust::device_vector **data, thrust::device_vector **labels, + thrust::device_vector **d_centroids, + thrust::device_vector **data_dots, size_t rows, size_t cols, + int init_from_data, int k, int k_max, T threshold, const T *srcdata, + int n_gpu, std::vector dList, T &residual, std::vector v, + int max_iterations); + +template +int kmeans_fit(int verbose, int seed, int gpu_idtry, int n_gputry, size_t rows, + size_t cols, const char ord, int k, int k_max, int max_iterations, + int init_from_data, T threshold, const T *srcdata, T **pred_centroids, + int **pred_labels); + +template +int pick_point_idx_weighted(int seed, std::vector *data, + thrust::host_vector weights) { + T weighted_sum = 0; + + for (int i = 0; i < weights.size(); i++) { + if (data) { + weighted_sum += (data->data()[i] * weights.data()[i]); + } else { + weighted_sum += weights.data()[i]; + } + } + + T best_prob = 0.0; + int best_prob_idx = 0; + + std::mt19937 mt(seed); + std::uniform_real_distribution<> dist(0.0, 1.0); + + int i = 0; + for (i = 0; i <= weights.size(); i++) { + if (weights.size() == i) { + break; + } + + T prob_threshold = (T) dist(mt); + + T data_val = weights.data()[i]; + if (data) { + data_val *= data->data()[i]; + } + + T prob_x = (data_val / weighted_sum); + + if (prob_x > prob_threshold) { + break; + } + + if (prob_x >= best_prob) { + best_prob = prob_x; + best_prob_idx = i; + } + } + + return weights.size() == i ? best_prob_idx : i; +} + +/** + * Copies cols records, starting at position idx*cols from data to centroids. Removes them afterwards from data. + * Removes record from weights at position idx. + * @tparam T + * @param idx + * @param cols + * @param data + * @param weights + * @param centroids + */ +template +void add_centroid(int idx, int cols, thrust::host_vector &data, + thrust::host_vector &weights, std::vector ¢roids) { + for (int i = 0; i < cols; i++) { + centroids.push_back(data[idx * cols + i]); + } + weights[idx] = 0; +} + +struct square_root: public thrust::unary_function { + __host__ __device__ + float operator()(float x) const { + return sqrtf(x); + } +}; + +template +void filterByDot(int d, int k, int *numChosen, thrust::device_vector &dists, + thrust::device_vector ¢roids, + thrust::device_vector ¢roid_dots) { + + float alpha = 1.0f; + float beta = 0.0f; + + CUDACHECK(cudaSetDevice(0)); + kmeans::detail::make_self_dots(k, d, centroids, centroid_dots); + + thrust::transform(centroid_dots.begin(), centroid_dots.begin() + k, + centroid_dots.begin(), square_root()); + + cublasStatus_t stat = + safe_cublas( + cublasSgemm(kmeans::detail::cublas_handle[0], CUBLAS_OP_T, CUBLAS_OP_N, k, k, d, &alpha, thrust::raw_pointer_cast(centroids.data()), d, thrust::raw_pointer_cast(centroids.data()), d, &beta, thrust::raw_pointer_cast(dists.data()), k)) + ; //Has to be k or k + + //Check cosine angle between two vectors, must be < .9 + kmeans::detail::checkCosine(d, k, numChosen, dists, centroids, + centroid_dots); +} + +template +struct min_calc_functor { + T* all_costs_ptr; + T* min_costs_ptr; + T max = std::numeric_limits::max(); + int potential_k_rows; + int rows_per_run; + + min_calc_functor(T* _all_costs_ptr, T* _min_costs_ptr, + int _potential_k_rows, int _rows_per_run) { + all_costs_ptr = _all_costs_ptr; + min_costs_ptr = _min_costs_ptr; + potential_k_rows = _potential_k_rows; + rows_per_run = _rows_per_run; + } + + __host__ __device__ + void operator()(int idx) const { + T best = max; + for (int j = 0; j < potential_k_rows; j++) { + best = min(best, std::abs(all_costs_ptr[j * rows_per_run + idx])); + } + min_costs_ptr[idx] = min(min_costs_ptr[idx], best); + } +}; + +/** + * K-Means|| initialization method implementation as described in "Scalable K-Means++". + * + * This is a probabilistic method, which tries to choose points as much spread out as possible as centroids. + * + * In case it finds more than k centroids a K-Means++ algorithm is ran on potential centroids to pick k best suited ones. + * + * http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf + * + * @tparam T + * @param verbose + * @param seed + * @param ord + * @param data + * @param data_dots + * @param centroids + * @param rows + * @param cols + * @param k + * @param num_gpu + * @param threshold + */ +template +thrust::host_vector kmeans_parallel(int verbose, int seed, const char ord, + thrust::device_vector **data, thrust::device_vector **data_dots, + size_t rows, int cols, int k, int num_gpu, T threshold) { + if (seed < 0) { + std::random_device rd; + int seed = rd(); + } + + size_t rows_per_gpu = rows / num_gpu; + + std::mt19937 gen(seed); + std::uniform_int_distribution<> dis(0, rows - 1); + + // Find the position (GPU idx and idx on that GPU) of the initial centroid + int first_center = dis(gen); + int first_center_idx = first_center % rows_per_gpu; + int first_center_gpu = first_center / rows_per_gpu; + + log_verbose(verbose, "KMeans|| - Initial centroid %d on GPU %d.", + first_center_idx, first_center_gpu); + + // Copies the initial centroid to potential centroids vector. That vector will store all potential centroids found + // in the previous iteration. + thrust::host_vector h_potential_centroids(cols); + std::vector> h_potential_centroids_per_gpu(num_gpu); + + CUDACHECK(cudaSetDevice(first_center_gpu)); + + thrust::copy((*data[first_center_gpu]).begin() + first_center_idx * cols, + (*data[first_center_gpu]).begin() + (first_center_idx + 1) * cols, + h_potential_centroids.begin()); + + thrust::host_vector h_all_potential_centroids = h_potential_centroids; + + // Initial the cost-to-potential-centroids and cost-to-closest-potential-centroid matrices. Initial cost is +infinity + std::vector> d_min_costs(num_gpu); + for (int q = 0; q < num_gpu; q++) { + CUDACHECK(cudaSetDevice(q)); + d_min_costs[q].resize(rows_per_gpu); + thrust::fill(d_min_costs[q].begin(), d_min_costs[q].end(), + std::numeric_limits::max()); + } + + double t0 = timer(); + + int curr_k = h_potential_centroids.size() / cols; + int max_k = k; + + while (curr_k < max_k) { + T total_min_cost = 0.0; + + int new_potential_centroids = 0; +#pragma omp parallel for + for (int i = 0; i < num_gpu; i++) { + CUDACHECK(cudaSetDevice(i)); + + thrust::device_vector d_potential_centroids = + h_potential_centroids; + + int potential_k_rows = d_potential_centroids.size() / cols; + + // Compute all the costs to each potential centroid from previous iteration + thrust::device_vector centroid_dots(potential_k_rows); + + kmeans::detail::batch_calculate_distances(verbose, 0, rows_per_gpu, + cols, potential_k_rows, *data[i], d_potential_centroids, + *data_dots[i], centroid_dots, + [&](int rows_per_run, size_t offset, thrust::device_vector &pairwise_distances) { + // Find the closest potential center cost for each row + auto min_cost_counter = thrust::make_counting_iterator(0); + auto all_costs_ptr = thrust::raw_pointer_cast(pairwise_distances.data()); + auto min_costs_ptr = thrust::raw_pointer_cast(d_min_costs[i].data() + offset); + thrust::for_each(min_cost_counter, + min_cost_counter + rows_per_run, + // Functor instead of a lambda b/c nvcc is complaining about + // nesting a __device__ lambda inside a regular lambda + min_calc_functor(all_costs_ptr, min_costs_ptr, potential_k_rows, rows_per_run)); + }); + } + + for (int i = 0; i < num_gpu; i++) { + CUDACHECK(cudaSetDevice(i)); + total_min_cost += thrust::reduce(d_min_costs[i].begin(), + d_min_costs[i].end()); + } + + log_verbose(verbose, "KMeans|| - Total min cost from centers %g.", + total_min_cost); + + if (total_min_cost == (T) 0.0) { + thrust::host_vector final_centroids(0); + if (verbose) { + fprintf(stderr, + "Too few points and centriods being found is getting 0 cost from centers\n"); + fflush(stderr); + } + + return final_centroids; + } + + std::set copy_from_gpus; +#pragma omp parallel for + for (int i = 0; i < num_gpu; i++) { + CUDACHECK(cudaSetDevice(i)); + + // Count how many potential centroids there are using probabilities + // The further the row is from the closest cluster center the higher the probability + auto pot_cent_filter_counter = thrust::make_counting_iterator(0); + auto min_costs_ptr = thrust::raw_pointer_cast( + d_min_costs[i].data()); +int pot_cent_num = thrust::count_if( + pot_cent_filter_counter, + pot_cent_filter_counter + rows_per_gpu, [=]__device__(int idx){ + thrust::default_random_engine rng(seed); + thrust::uniform_real_distribution<> dist(0.0, 1.0); + int device; + cudaGetDevice(&device); + rng.discard(idx + device * rows_per_gpu); + T prob_threshold = (T) dist(rng); + + T prob_x = (( 2.0 * k * min_costs_ptr[idx]) / total_min_cost); + + return prob_x > prob_threshold; + } + ); + + log_debug(verbose, "KMeans|| - Potential centroids on GPU %d = %d.", + i, pot_cent_num); + + if (pot_cent_num > 0) { + copy_from_gpus.insert(i); + + // Copy all potential cluster centers + thrust::device_vector d_new_potential_centroids( + pot_cent_num * cols); + + auto range = thrust::make_counting_iterator(0); + thrust::copy_if( + (*data[i]).begin(), (*data[i]).end(), range, + d_new_potential_centroids.begin(), [=] __device__(int idx) { + int row = idx / cols; + thrust::default_random_engine rng(seed); + thrust::uniform_real_distribution<> dist(0.0, 1.0); + int device; + cudaGetDevice(&device); + rng.discard(row + device * rows_per_gpu); + T prob_threshold = (T) dist(rng); + + T prob_x = (( 2.0 * k * min_costs_ptr[row]) / total_min_cost); + + return prob_x > prob_threshold; + }); + + h_potential_centroids_per_gpu[i].clear(); + h_potential_centroids_per_gpu[i].resize( + d_new_potential_centroids.size()); + + new_potential_centroids += d_new_potential_centroids.size(); + + thrust::copy(d_new_potential_centroids.begin(), + d_new_potential_centroids.end(), + h_potential_centroids_per_gpu[i].begin()); + + } + + } + + log_verbose(verbose, "KMeans|| - New potential centroids %d.", + new_potential_centroids); + + // Gather potential cluster centers from all GPUs + if (new_potential_centroids > 0) { + h_potential_centroids.clear(); + h_potential_centroids.resize(new_potential_centroids); + + int old_pot_centroids_size = h_all_potential_centroids.size(); + h_all_potential_centroids.resize( + old_pot_centroids_size + new_potential_centroids); + + int offset = 0; + for (int i = 0; i < num_gpu; i++) { + if (copy_from_gpus.find(i) != copy_from_gpus.end()) { + thrust::copy(h_potential_centroids_per_gpu[i].begin(), + h_potential_centroids_per_gpu[i].end(), + h_potential_centroids.begin() + offset); + offset += h_potential_centroids_per_gpu[i].size(); + } + } + + CUDACHECK(cudaSetDevice(0)); + thrust::device_vector new_centroids = h_potential_centroids; + + thrust::device_vector new_centroids_dist( + (new_potential_centroids / cols) + * (new_potential_centroids / cols)); + thrust::device_vector new_centroids_dot( + new_potential_centroids / cols); + + int numChosen = new_potential_centroids / cols; + filterByDot(cols, numChosen, &numChosen, new_centroids_dist, + new_centroids, new_centroids_dot); + + thrust::host_vector h_new_centroids = new_centroids; + h_all_potential_centroids.resize( + old_pot_centroids_size + (numChosen * cols)); + thrust::copy(h_new_centroids.begin(), + h_new_centroids.begin() + (numChosen * cols), + h_all_potential_centroids.begin() + old_pot_centroids_size); + curr_k = curr_k + numChosen; + + } else { + thrust::host_vector final_centroids(0); + if (verbose) { + fprintf(stderr, + "Too few points , not able to find centroid candidate \n"); + fflush(stderr); + } + + return final_centroids; + } + } + + thrust::host_vector final_centroids(0); + int potential_centroids_num = h_all_potential_centroids.size() / cols; + + final_centroids.resize(k * cols); + thrust::copy(h_all_potential_centroids.begin(), + h_all_potential_centroids.begin() + (max_k * cols), + final_centroids.begin()); + + return final_centroids; +} + +volatile std::atomic_int flaggpu(0); + +inline void my_function_gpu(int sig) { // can be called asynchronously + fprintf(stderr, "Caught signal %d. Terminating shortly.\n", sig); + flaggpu = 1; +} + +std::vector kmeans_init(int verbose, int *final_n_gpu, int n_gputry, + int gpu_idtry, int rows) { + if (rows > std::numeric_limits::max()) { + fprintf(stderr, "rows > %d not implemented\n", + std::numeric_limits::max()); + fflush(stderr); + exit(0); + } + + std::signal(SIGINT, my_function_gpu); + std::signal(SIGTERM, my_function_gpu); + + // no more gpus than visible gpus + int n_gpuvis; + cudaGetDeviceCount(&n_gpuvis); + int n_gpu = std::min(n_gpuvis, n_gputry); + + // no more than rows + n_gpu = std::min(n_gpu, rows); + + if (verbose) { + std::cout << n_gpu << " gpus." << std::endl; + } + + int gpu_id = gpu_idtry % n_gpuvis; + + // setup GPU list to use + std::vector dList(n_gpu); + for (int idx = 0; idx < n_gpu; idx++) { + int device_idx = (gpu_id + idx) % n_gpuvis; + dList[idx] = device_idx; + } + + *final_n_gpu = n_gpu; + return dList; +} + +template +H2O4GPUKMeans::H2O4GPUKMeans(const T *A, int k, int n, int d) { + _A = A; + _k = k; + _n = n; + _d = d; +} + +template +int kmeans_find_clusters(int verbose, const char ord, int seed, + thrust::device_vector **data, thrust::device_vector **labels, + thrust::device_vector **d_centroids, + thrust::device_vector **data_dots, size_t rows, size_t cols, + int init_from_data, int k, int k_max, T threshold, const T *srcdata, + int n_gpu, std::vector dList, T &residual, std::vector v, + int max_iterations) { + int bytecount = cols * k * sizeof(T); + if (0 == init_from_data) { + + log_debug(verbose, "KMeans - Using random initialization."); + + int masterq = 0; + CUDACHECK(cudaSetDevice(dList[masterq])); + // DM: simply copies first k rows data into GPU_0 + copy_centroids_shuffled(verbose, v, ord, *d_centroids[masterq], + &srcdata[0], rows, k, cols); + + // DM: can remove all of this + // Copy centroids to all devices + std::vector streams; + streams.resize(n_gpu); +#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + if (q == masterq) + continue; + + CUDACHECK(cudaSetDevice(dList[q])); + if (verbose > 0) { + std::cout << "Copying centroid data to device: " << dList[q] + << std::endl; + } + + streams[q] = reinterpret_cast(malloc( + sizeof(cudaStream_t))); + cudaStreamCreate(streams[q]); + cudaMemcpyPeerAsync(thrust::raw_pointer_cast(&(*d_centroids[q])[0]), + dList[q], + thrust::raw_pointer_cast(&(*d_centroids[masterq])[0]), + dList[masterq], bytecount, *(streams[q])); + } +//#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + if (q == masterq) + continue; + cudaSetDevice(dList[q]); + cudaStreamDestroy(*(streams[q])); +#if(DEBUGKMEANS) + thrust::host_vector h_centroidq=*d_centroids[q]; + for(int ii=0;ii final_centroids = kmeans_parallel(verbose, seed, + ord, data, data_dots, rows, cols, k, n_gpu, threshold); + if (final_centroids.size() == 0) { + if (verbose) { + fprintf(stderr, + "kmeans || failed to find %d number of cluster points \n", + k); + fflush(stderr); + } + + residual = 0.0; + return 0; + } + +#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + CUDACHECK(cudaSetDevice(dList[q])); + cudaMemcpy(thrust::raw_pointer_cast(&(*d_centroids[q])[0]), + thrust::raw_pointer_cast(&final_centroids[0]), bytecount, + cudaMemcpyHostToDevice); + } + + } + +#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + CUDACHECK(cudaSetDevice(dList[q])); + labels[q] = new thrust::device_vector(rows / n_gpu); + } + + double t0 = timer(); + + int iter = kmeans::kmeans(verbose, &flaggpu, rows, cols, k, k_max, data, + labels, d_centroids, data_dots, dList, n_gpu, max_iterations, + threshold, true); + + if (iter < 0) { + log_error(verbose, "KMeans algorithm failed."); + return iter; + } + + // Calculate the residual + size_t rows_per_gpu = rows / n_gpu; + std::vector> d_min_costs(n_gpu); + for (int q = 0; q < n_gpu; q++) { + CUDACHECK(cudaSetDevice(q)); + d_min_costs[q].resize(rows_per_gpu); + thrust::fill(d_min_costs[q].begin(), d_min_costs[q].end(), + std::numeric_limits::max()); + } +#pragma omp parallel for + for (int i = 0; i < n_gpu; i++) { + CUDACHECK(cudaSetDevice(i)); + + int potential_k_rows = k; + // Compute all the costs to each potential centroid from previous iteration + thrust::device_vector centroid_dots(potential_k_rows); + + kmeans::detail::batch_calculate_distances(verbose, 0, rows_per_gpu, + cols, k, *data[i], *d_centroids[i], *data_dots[i], + centroid_dots, + [&](int rows_per_run, size_t offset, thrust::device_vector &pairwise_distances) { + // Find the closest potential center cost for each row + auto min_cost_counter = thrust::make_counting_iterator(0); + auto all_costs_ptr = thrust::raw_pointer_cast(pairwise_distances.data()); + auto min_costs_ptr = thrust::raw_pointer_cast(d_min_costs[i].data() + offset); + thrust::for_each(min_cost_counter, + min_cost_counter + rows_per_run, + // Functor instead of a lambda b/c nvcc is complaining about + // nesting a __device__ lambda inside a regular lambda + min_calc_functor(all_costs_ptr, min_costs_ptr, potential_k_rows, rows_per_run)); + }); + } + + residual = 0.0; + for (int i = 0; i < n_gpu; i++) { + CUDACHECK(cudaSetDevice(i)); + residual += thrust::reduce(d_min_costs[i].begin(), + d_min_costs[i].end()); + } + + double timefit = static_cast(timer() - t0); + + if (verbose) { + std::cout << " Time fit: " << timefit << " s" << std::endl; + fprintf(stderr, "Time fir: %g \n", timefit); + fflush(stderr); + } + + return iter; +} + +template +int kmeans_fit(int verbose, int seed, int gpu_idtry, int n_gputry, size_t rows, + size_t cols, const char ord, int k, int k_max, int max_iterations, + int init_from_data, T threshold, const T *srcdata, T **pred_centroids, + int **pred_labels) { + // init random seed if use the C function rand() + if (seed >= 0) { + srand(seed); + } else { + srand(unsigned(time(NULL))); + } + + // no more clusters than rows + if (k_max > rows) { + k_max = static_cast(rows); + fprintf(stderr, + "Number of clusters adjusted to be equal to number of rows.\n"); + fflush(stderr); + } + + int n_gpu; + // only creates a list of GPUs to use. can be removed for single GPU + std::vector dList = kmeans_init(verbose, &n_gpu, n_gputry, gpu_idtry, + rows); + + double t0t = timer(); + thrust::device_vector *data[n_gpu]; + thrust::device_vector *labels[n_gpu]; + thrust::device_vector *d_centroids[n_gpu]; + thrust::device_vector *data_dots[n_gpu]; +#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + CUDACHECK(cudaSetDevice(dList[q])); + data[q] = new thrust::device_vector(rows / n_gpu * cols); + d_centroids[q] = new thrust::device_vector(k_max * cols); + data_dots[q] = new thrust::device_vector(rows / n_gpu); + + kmeans::detail::labels_init(); + } + + log_debug(verbose, "Number of points: %d", rows); + log_debug(verbose, "Number of dimensions: %d", cols); + log_debug(verbose, "Number of clusters: %d", k); + log_debug(verbose, "Max number of clusters: %d", k_max); + log_debug(verbose, "Max. number of iterations: %d", max_iterations); + log_debug(verbose, "Stopping threshold: %d", threshold); + + std::vector v(rows); + std::iota(std::begin(v), std::end(v), 0); // Fill with 0, 1, ..., rows. + + if (seed >= 0) { + std::shuffle(v.begin(), v.end(), std::default_random_engine(seed)); + } else { + std::random_shuffle(v.begin(), v.end()); + } + + // Copy the data to devices +#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + CUDACHECK(cudaSetDevice(dList[q])); + if (verbose) { + std::cout << "Copying data to device: " << dList[q] << std::endl; + } + + copy_data(verbose, ord, *data[q], &srcdata[0], q, rows, rows / n_gpu, + cols); + + // Pre-compute the data matrix norms + kmeans::detail::make_self_dots(rows / n_gpu, cols, *data[q], + *data_dots[q]); + } + + // Host memory + thrust::host_vector results(k_max + 1, (T) (1e20)); + + // Loop to find *best* k + // Perform k-means in binary search + int left = k; //must be at least 2 + int right = k_max; //int(floor(len(data)/2)) #assumption of clusters of size 2 at least + int mid = int(floor((right + left) / 2)); + int oldmid = mid; + int tests; + int iter = 0; + T objective[2]; // 0= left of mid, 1= right of mid + T minres = 0; + T residual = 0.0; + + if (left == 1) + left = 2; // at least do 2 clusters + // eval left edge + + iter = kmeans_find_clusters(verbose, ord, seed, data, labels, d_centroids, + data_dots, rows, cols, init_from_data, left, k_max, threshold, + srcdata, n_gpu, dList, residual, v, max_iterations); + results[left] = residual; + + if (left != right) { + //eval right edge + residual = 0.0; + iter = kmeans_find_clusters(verbose, ord, seed, data, labels, + d_centroids, data_dots, rows, cols, init_from_data, right, + k_max, threshold, srcdata, n_gpu, dList, residual, v, + max_iterations); + int tmp_left = left; + int tmp_right = right; + T tmp_residual = 0.0; + + while ((residual == 0.0) && (right > 0)) { + right = (tmp_left + tmp_right) / 2; + // This k is already explored and need not be explored again + if (right == tmp_left) { + residual = tmp_residual; + right = tmp_left; + break; + } + iter = kmeans_find_clusters(verbose, ord, seed, data, labels, + d_centroids, data_dots, rows, cols, init_from_data, right, + k_max, threshold, srcdata, n_gpu, dList, residual, v, + max_iterations); + results[right] = residual; + + if (residual == 0.0) { + tmp_right = right; + } else { + tmp_left = right; + tmp_residual = residual; + + if (abs(tmp_left - tmp_right) == 1) { + break; + } + } + // Escape from an infinite loop if we come across + if (tmp_left == tmp_right) { + residual = tmp_residual; + right = tmp_left; + break; + } + + residual = 0.0; + } + results[right] = residual; + minres = residual * 0.9; + mid = int(floor((right + left) / 2)); + oldmid = mid; + } + + // binary search + while (left < right - 1) { + tests = 0; + while (results[mid] > results[left] && tests < 3) { + + iter = kmeans_find_clusters(verbose, ord, seed, data, labels, + d_centroids, data_dots, rows, cols, init_from_data, mid, + k_max, threshold, srcdata, n_gpu, dList, residual, v, + max_iterations); + results[mid] = residual; + if (results[mid] > results[left] && (mid + 1) < right) { + mid += 1; + results[mid] = 1e20; + } else if (results[mid] > results[left] && (mid - 1) > left) { + mid -= 1; + results[mid] = 1e20; + } + tests += 1; + } + objective[0] = abs(results[left] - results[mid]) + / (results[left] - minres); + objective[0] /= mid - left; + objective[1] = abs(results[mid] - results[right]) + / (results[mid] - minres); + objective[1] /= right - mid; + if (objective[0] > 1.2 * objective[1]) { //abs(resid_reduction[left]-resid_reduction[mid])/(mid-left)) { + // our point is in the left-of-mid side + right = mid; + } else { + left = mid; + } + oldmid = mid; + mid = int(floor((right + left) / 2)); + } + + int k_final = 0; + k_final = right; + if (results[left] < results[oldmid]) + k_final = left; + + // if k_star isn't what we just ran, re-run to get correct centroids and dist data on return-> this saves memory + if (k_final != oldmid) { + iter = kmeans_find_clusters(verbose, ord, seed, data, labels, + d_centroids, data_dots, rows, cols, init_from_data, k_final, + k_max, threshold, srcdata, n_gpu, dList, residual, v, + max_iterations); + } + + double timetransfer = static_cast(timer() - t0t); + + double t1 = timer(); + + // copy result of centroids (sitting entirely on each device) back to host + // TODO FIXME: When do delete ctr and h_labels memory??? + thrust::host_vector *ctr = new thrust::host_vector(*d_centroids[0]); + *pred_centroids = ctr->data(); + + // copy assigned labels + thrust::host_vector *h_labels = new thrust::host_vector(rows); +//#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + int offset = labels[q]->size() * q; + h_labels->insert(h_labels->begin() + offset, labels[q]->begin(), + labels[q]->end()); + } + + *pred_labels = h_labels->data(); + + // debug + if (verbose >= H2O4GPU_LOG_VERBOSE) { + for (unsigned int ii = 0; ii < k; ii++) { + fprintf(stderr, "ii=%d of k=%d ", ii, k); + for (unsigned int jj = 0; jj < cols; jj++) { + fprintf(stderr, "%g ", (*pred_centroids)[cols * ii + jj]); + } + fprintf(stderr, "\n"); + fflush(stderr); + } + + printf("Number of iteration: %d\n", iter); + } + +#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + CUDACHECK(cudaSetDevice(dList[q])); + delete (data[q]); + delete (labels[q]); + delete (d_centroids[q]); + delete (data_dots[q]); + kmeans::detail::labels_close(); + } + + double timecleanup = static_cast(timer() - t1); + + if (verbose) { + fprintf(stderr, "Timetransfer: %g Timecleanup: %g\n", timetransfer, + timecleanup); + fflush(stderr); + } + + return k_final; +} + +template +int kmeans_predict(int verbose, int gpu_idtry, int n_gputry, size_t rows, + size_t cols, const char ord, int k, const T *srcdata, + const T *centroids, int **pred_labels) { + // Print centroids + if (verbose >= H2O4GPU_LOG_VERBOSE) { + std::cout << std::endl; + for (int i = 0; i < cols * k; i++) { + std::cout << centroids[i] << " "; + if (i % cols == 1) { + std::cout << std::endl; + } + } + } + + int n_gpu; + std::vector dList = kmeans_init(verbose, &n_gpu, n_gputry, gpu_idtry, + rows); + + thrust::device_vector *d_data[n_gpu]; + thrust::device_vector *d_centroids[n_gpu]; + thrust::device_vector *data_dots[n_gpu]; + thrust::device_vector *centroid_dots[n_gpu]; + thrust::host_vector *h_labels = new thrust::host_vector(0); + std::vector> d_labels(n_gpu); + +#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + CUDACHECK(cudaSetDevice(dList[q])); + kmeans::detail::labels_init(); + + data_dots[q] = new thrust::device_vector(rows / n_gpu); + centroid_dots[q] = new thrust::device_vector(k); + + d_centroids[q] = new thrust::device_vector(k * cols); + d_data[q] = new thrust::device_vector(rows / n_gpu * cols); + + copy_data(verbose, 'r', *d_centroids[q], ¢roids[0], 0, k, k, cols); + + copy_data(verbose, ord, *d_data[q], &srcdata[0], q, rows, rows / n_gpu, + cols); + + kmeans::detail::make_self_dots(rows / n_gpu, cols, *d_data[q], + *data_dots[q]); + + d_labels[q].resize(rows / n_gpu); + + kmeans::detail::batch_calculate_distances(verbose, q, rows / n_gpu, + cols, k, *d_data[q], *d_centroids[q], *data_dots[q], + *centroid_dots[q], + [&](int n, size_t offset, thrust::device_vector &pairwise_distances) { + kmeans::detail::relabel(n, k, pairwise_distances, d_labels[q], offset); + }); + + } + + for (int q = 0; q < n_gpu; q++) { + h_labels->insert(h_labels->end(), d_labels[q].begin(), + d_labels[q].end()); + } + + *pred_labels = h_labels->data(); + +#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + safe_cuda(cudaSetDevice(dList[q])); + kmeans::detail::labels_close(); + delete (data_dots[q]); + delete (centroid_dots[q]); + delete (d_centroids[q]); + delete (d_data[q]); + } + + return 0; +} + +template +int kmeans_transform(int verbose, int gpu_idtry, int n_gputry, size_t rows, + size_t cols, const char ord, int k, const T *srcdata, + const T *centroids, T **preds) { + // Print centroids + if (verbose >= H2O4GPU_LOG_VERBOSE) { + std::cout << std::endl; + for (int i = 0; i < cols * k; i++) { + std::cout << centroids[i] << " "; + if (i % cols == 1) { + std::cout << std::endl; + } + } + } + + int n_gpu; + std::vector dList = kmeans_init(verbose, &n_gpu, n_gputry, gpu_idtry, + rows); + + thrust::device_vector *d_data[n_gpu]; + thrust::device_vector *d_centroids[n_gpu]; + thrust::device_vector *d_pairwise_distances[n_gpu]; + thrust::device_vector *data_dots[n_gpu]; + thrust::device_vector *centroid_dots[n_gpu]; +#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + CUDACHECK(cudaSetDevice(dList[q])); + kmeans::detail::labels_init(); + + data_dots[q] = new thrust::device_vector(rows / n_gpu); + centroid_dots[q] = new thrust::device_vector(k); + d_pairwise_distances[q] = new thrust::device_vector( + rows / n_gpu * k); + + d_centroids[q] = new thrust::device_vector(k * cols); + d_data[q] = new thrust::device_vector(rows / n_gpu * cols); + + copy_data(verbose, 'r', *d_centroids[q], ¢roids[0], 0, k, k, cols); + + copy_data(verbose, ord, *d_data[q], &srcdata[0], q, rows, rows / n_gpu, + cols); + + kmeans::detail::make_self_dots(rows / n_gpu, cols, *d_data[q], + *data_dots[q]); + + // TODO batch this + kmeans::detail::calculate_distances(verbose, q, rows / n_gpu, cols, k, + *d_data[q], 0, *d_centroids[q], *data_dots[q], + *centroid_dots[q], *d_pairwise_distances[q]); + } + +#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + CUDACHECK(cudaSetDevice(dList[q])); + thrust::transform((*d_pairwise_distances[q]).begin(), + (*d_pairwise_distances[q]).end(), + (*d_pairwise_distances[q]).begin(), square_root()); + } + + // Move the resulting labels into host memory from all devices + thrust::host_vector *h_pairwise_distances = new thrust::host_vector( + 0); +#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + h_pairwise_distances->insert(h_pairwise_distances->end(), + d_pairwise_distances[q]->begin(), + d_pairwise_distances[q]->end()); + } + *preds = h_pairwise_distances->data(); + + // Print centroids + if (verbose >= H2O4GPU_LOG_VERBOSE) { + std::cout << std::endl; + for (int i = 0; i < rows * cols; i++) { + std::cout << h_pairwise_distances->data()[i] << " "; + if (i % cols == 1) { + std::cout << std::endl; + } + } + } + +#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + safe_cuda(cudaSetDevice(dList[q])); + kmeans::detail::labels_close(); + delete (d_pairwise_distances[q]); + delete (data_dots[q]); + delete (centroid_dots[q]); + delete (d_centroids[q]); + delete (d_data[q]); + } + + return 0; +} + +template +int makePtr_dense(int dopredict, int verbose, int seed, int gpu_idtry, + int n_gputry, size_t rows, size_t cols, const char ord, int k, + int k_max, int max_iterations, int init_from_data, T threshold, + const T *srcdata, const T *centroids, T **pred_centroids, + int **pred_labels) { + if (dopredict == 0) { + return kmeans_fit(verbose, seed, gpu_idtry, n_gputry, rows, cols, ord, + k, k_max, max_iterations, init_from_data, threshold, srcdata, + pred_centroids, pred_labels); + } else { + return kmeans_predict(verbose, gpu_idtry, n_gputry, rows, cols, ord, k, + srcdata, centroids, pred_labels); + } +} + +template int +makePtr_dense(int dopredict, int verbose, int seed, int gpu_id, + int n_gpu, size_t rows, size_t cols, const char ord, int k, int k_max, + int max_iterations, int init_from_data, float threshold, + const float *srcdata, const float *centroids, float **pred_centroids, + int **pred_labels); + +template int +makePtr_dense(int dopredict, int verbose, int seed, int gpu_id, + int n_gpu, size_t rows, size_t cols, const char ord, int k, int k_max, + int max_iterations, int init_from_data, double threshold, + const double *srcdata, const double *centroids, double **pred_centroids, + int **pred_labels); + +template int kmeans_fit(int verbose, int seed, int gpu_idtry, + int n_gputry, size_t rows, size_t cols, const char ord, int k, + int k_max, int max_iterations, int init_from_data, float threshold, + const float *srcdata, float **pred_centroids, int **pred_labels); + +template int kmeans_fit(int verbose, int seed, int gpu_idtry, + int n_gputry, size_t rows, size_t cols, const char ord, int k, + int k_max, int max_iterations, int init_from_data, double threshold, + const double *srcdata, double **pred_centroids, int **pred_labels); + +template int kmeans_find_clusters(int verbose, const char ord, int seed, + thrust::device_vector **data, + thrust::device_vector **labels, + thrust::device_vector **d_centroids, + thrust::device_vector **data_dots, size_t rows, size_t cols, + int init_from_data, int k, int k_max, float threshold, + const float *srcdata, int n_gpu, std::vector dList, + float &residual, std::vector v, int max_iterations); + +template int kmeans_find_clusters(int verbose, const char ord, int seed, + thrust::device_vector **data, + thrust::device_vector **labels, + thrust::device_vector **d_centroids, + thrust::device_vector **data_dots, size_t rows, size_t cols, + int init_from_data, int k, int k_max, double threshold, + const double *srcdata, int n_gpu, std::vector dList, + double &residual, std::vector v, int max_iterations); + +template int kmeans_predict(int verbose, int gpu_idtry, int n_gputry, + size_t rows, size_t cols, const char ord, int k, const float *srcdata, + const float *centroids, int **pred_labels); + +template int kmeans_predict(int verbose, int gpu_idtry, int n_gputry, + size_t rows, size_t cols, const char ord, int k, const double *srcdata, + const double *centroids, int **pred_labels); + +template int kmeans_transform(int verbose, int gpu_id, int n_gpu, + size_t m, size_t n, const char ord, int k, const float *src_data, + const float *centroids, float **preds); + +template int kmeans_transform(int verbose, int gpu_id, int n_gpu, + size_t m, size_t n, const char ord, int k, const double *src_data, + const double *centroids, double **preds); + +// Explicit template instantiation. +#if !defined(H2O4GPU_DOUBLE) || H2O4GPU_DOUBLE == 1 + +template +class H2O4GPUKMeans ; + +#endif + +#if !defined(H2O4GPU_SINGLE) || H2O4GPU_SINGLE == 1 + +template +class H2O4GPUKMeans ; + +#endif + +int get_n_gpus(int n_gputry) { + int nDevices; + cudaGetDeviceCount(&nDevices); + + if (n_gputry < 0) { // get all available GPUs + return nDevices; + } else if (n_gputry > nDevices) { + return nDevices; + } else { + return n_gputry; + } +} + +} // namespace h2o4gpukmeans + +namespace ML { +/* + * Interface for other languages + */ + +// Fit and Predict +void make_ptr_kmeans(int dopredict, int verbose, int seed, int gpu_id, + int n_gpu, size_t mTrain, size_t n, const char ord, int k, int k_max, + int max_iterations, int init_from_data, float threshold, + const float *srcdata, const float *centroids, float *pred_centroids, + int *pred_labels) { + //float *h_srcdata = (float*) malloc(mTrain * n * sizeof(float)); + //cudaMemcpy((void*)h_srcdata, (void*)srcdata, mTrain*n * sizeof(float), cudaMemcpyDeviceToHost); + + const float *h_srcdata = srcdata; + + float *h_centroids = nullptr; + if (dopredict) { + h_centroids = (float*) malloc(k * n * sizeof(float)); + cudaMemcpy((void*) h_centroids, (void*) centroids, + k * n * sizeof(float), cudaMemcpyDeviceToHost); + } + + int *h_pred_labels = nullptr; + float *h_pred_centroids = nullptr; + int actual_n_gpu = h2o4gpukmeans::get_n_gpus(n_gpu); + int actual_k = h2o4gpukmeans::makePtr_dense(dopredict, verbose, seed, + gpu_id, actual_n_gpu, mTrain, n, ord, k, k_max, max_iterations, + init_from_data, threshold, h_srcdata, h_centroids, + &h_pred_centroids, &h_pred_labels); + + if (dopredict == 0) { + cudaMemcpy(pred_centroids, h_pred_centroids, k * n * sizeof(float), + cudaMemcpyHostToDevice); + } + + cudaMemcpy(pred_labels, h_pred_labels, mTrain * sizeof(int), + cudaMemcpyHostToDevice); + + //free(h_srcdata); + if (dopredict) { + free(h_centroids); + } +} + +void make_ptr_kmeans(int dopredict, int verbose, int seed, int gpu_id, + int n_gpu, size_t mTrain, size_t n, const char ord, int k, int k_max, + int max_iterations, int init_from_data, double threshold, + const double *srcdata, const double *centroids, double *pred_centroids, + int *pred_labels) { + + //double *h_srcdata = (double*) malloc(mTrain * n * sizeof(double)); + //cudaMemcpy((void*)h_srcdata, (void*)srcdata, mTrain*n * sizeof(double), cudaMemcpyDeviceToHost); + + const double *h_srcdata = srcdata; + + double *h_centroids = nullptr; + if (dopredict) { + h_centroids = (double*) malloc(k * n * sizeof(double)); + cudaMemcpy((void*) h_centroids, (void*) centroids, + k * n * sizeof(double), cudaMemcpyDeviceToHost); + } + + int *h_pred_labels = nullptr; + double *h_pred_centroids = nullptr; + int actual_n_gpu = h2o4gpukmeans::get_n_gpus(n_gpu); + int actual_k = h2o4gpukmeans::makePtr_dense(dopredict, verbose, + seed, gpu_id, actual_n_gpu, mTrain, n, ord, k, k_max, + max_iterations, init_from_data, threshold, h_srcdata, h_centroids, + &h_pred_centroids, &h_pred_labels); + + if (dopredict == 0) { + cudaMemcpy(pred_centroids, h_pred_centroids, k * n * sizeof(double), + cudaMemcpyHostToDevice); + } + + cudaMemcpy(pred_labels, h_pred_labels, mTrain * sizeof(int), + cudaMemcpyHostToDevice); + + //free(h_srcdata); + if (dopredict) { + free(h_centroids); + } + +} + +// Transform +void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, + const char ord, int k, const float *src_data, const float *centroids, + float *preds) { + //float *h_srcdata = (float*) malloc(m * n * sizeof(float)); + //cudaMemcpy((void*)h_srcdata, (void*)src_data, m*n * sizeof(float), cudaMemcpyDeviceToHost); + + const float *h_srcdata = src_data; + + float *h_centroids = (float*) malloc(k * n * sizeof(float)); + cudaMemcpy((void*) h_centroids, (void*) centroids, k * n * sizeof(float), + cudaMemcpyDeviceToHost); + + float *h_preds = nullptr; + int actual_n_gpu = h2o4gpukmeans::get_n_gpus(n_gpu); + h2o4gpukmeans::kmeans_transform(verbose, gpu_id, actual_n_gpu, m, n, + ord, k, h_srcdata, h_centroids, &h_preds); + + cudaMemcpy(preds, h_preds, m * k * sizeof(float), cudaMemcpyHostToDevice); + + //free(h_srcdata); + free(h_centroids); +} + +void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, + const char ord, int k, const double *src_data, const double *centroids, + double *preds) { + //double *h_srcdata = (double*) malloc(m * n * sizeof(double)); + //cudaMemcpy((void*)h_srcdata, (void*)src_data, m*n * sizeof(double), cudaMemcpyDeviceToHost); + + const double *h_srcdata = src_data; + + double *h_centroids = (double*) malloc(k * n * sizeof(double)); + cudaMemcpy((void*) h_centroids, (void*) centroids, k * n * sizeof(double), + cudaMemcpyDeviceToHost); + + double *h_preds = nullptr; + int actual_n_gpu = h2o4gpukmeans::get_n_gpus(n_gpu); + h2o4gpukmeans::kmeans_transform(verbose, gpu_id, actual_n_gpu, m, n, + ord, k, h_srcdata, h_centroids, &h_preds); + + cudaMemcpy(preds, h_preds, m * k * sizeof(double), cudaMemcpyHostToDevice); + + //free(h_srcdata); + free(h_centroids); +} + +} // end namespace ML diff --git a/cuML/src/kmeans/kmeans.h b/cuML/src/kmeans/kmeans.h new file mode 100644 index 0000000000..e653791285 --- /dev/null +++ b/cuML/src/kmeans/kmeans.h @@ -0,0 +1,91 @@ +/*! + * Copyright 2017-2018 H2O.ai, Inc. + * License Apache License Version 2.0 (see LICENSE for details) + */ +#pragma once +#include +#include +#include +#include "kmeans_labels.h" +#include "kmeans_centroids.h" + +template +struct count_functor { + T* pairwise_distances_ptr; + int* counts_ptr; + int k; + int rows_per_run; + + count_functor(T* _pairwise_distances_ptr, int* _counts_ptr, int _k, + int _rows_per_run) { + pairwise_distances_ptr = _pairwise_distances_ptr; + counts_ptr = _counts_ptr; + k = _k; + rows_per_run = _rows_per_run; + } + + __device__ + void operator()(int idx) const { + int closest_centroid_idx = 0; + T best_distance = pairwise_distances_ptr[idx]; + // FIXME potentially slow due to striding + for (int i = 1; i < k; i++) { + T distance = pairwise_distances_ptr[idx + i * rows_per_run]; + + if (distance < best_distance) { + best_distance = distance; + closest_centroid_idx = i; + } + } + atomicAdd(&counts_ptr[closest_centroid_idx], 1); + } +}; + +/** + * Calculates closest centroid for each record and counts how many points are assigned to each centroid. + * @tparam T + * @param verbose + * @param num_gpu + * @param rows_per_gpu + * @param cols + * @param data + * @param data_dots + * @param centroids + * @param weights + * @param pairwise_distances + * @param labels + */ +template +void count_pts_per_centroid(int verbose, int num_gpu, int rows_per_gpu, + int cols, thrust::device_vector **data, + thrust::device_vector **data_dots, thrust::host_vector centroids, + thrust::host_vector &weights) { + int k = centroids.size() / cols; +#pragma omp parallel for + for (int i = 0; i < num_gpu; i++) { + thrust::host_vector weights_tmp(weights.size()); + + CUDACHECK(cudaSetDevice(i)); + thrust::device_vector centroid_dots(k); + thrust::device_vector d_centroids = centroids; + thrust::device_vector counts(k); + + kmeans::detail::batch_calculate_distances(verbose, 0, rows_per_gpu, + cols, k, *data[i], d_centroids, *data_dots[i], centroid_dots, + [&](int rows_per_run, size_t offset, thrust::device_vector &pairwise_distances) { + auto counting = thrust::make_counting_iterator(0); + auto counts_ptr = thrust::raw_pointer_cast(counts.data()); + auto pairwise_distances_ptr = thrust::raw_pointer_cast(pairwise_distances.data()); + thrust::for_each(counting, + counting + rows_per_run, + count_functor(pairwise_distances_ptr, counts_ptr, k, rows_per_run) + ); + }); + + kmeans::detail::memcpy(weights_tmp, counts); + kmeans::detail::streamsync(i); + for (int p = 0; p < k; p++) { + weights[p] += weights_tmp[p]; + } + } +} diff --git a/cuML/src/kmeans/kmeans_c.h b/cuML/src/kmeans/kmeans_c.h new file mode 100644 index 0000000000..13d10d1599 --- /dev/null +++ b/cuML/src/kmeans/kmeans_c.h @@ -0,0 +1,76 @@ +/*! + * Copyright 2017-2018 H2O.ai, Inc. + * License Apache License Version 2.0 (see LICENSE for details) + */ + +#pragma once +#ifdef __JETBRAINS_IDE__ +#define __host__ +#define __device__ +#endif + +#include +#include +#include "timer.h" + +namespace h2o4gpukmeans { + +template +class H2O4GPUKMeans { +private: + // Data + const M *_A; + int _k; + int _n; + int _d; +public: + H2O4GPUKMeans(const M *A, int k, int n, int d); +}; + +template +class H2O4GPUKMeansCPU { +private: + // Data + const M *_A; + int _k; + int _n; + int _d; +public: + H2O4GPUKMeansCPU(const M *A, int k, int n, int d); +}; + +template +int makePtr_dense(int dopredict, int verbose, int seed, int gpu_id, int n_gpu, + size_t rows, size_t cols, const char ord, int k, int max_iterations, + int init_from_data, T threshold, const T *srcdata, const T *centroids, + T **pred_centroids, int **pred_labels); + +template +int kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, + const char ord, int k, const T *srcdata, const T *centroids, T **preds); + +} // namespace h2o4gpukmeans + +namespace ML { + +void make_ptr_kmeans(int dopredict, int verbose, int seed, int gpu_id, + int n_gpu, size_t mTrain, size_t n, const char ord, int k, int k_max, + int max_iterations, int init_from_data, float threshold, + const float *srcdata, const float *centroids, float *pred_centroids, + int *pred_labels); + +void make_ptr_kmeans(int dopredict, int verbose, int seed, int gpu_id, + int n_gpu, size_t mTrain, size_t n, const char ord, int k, int k_max, + int max_iterations, int init_from_data, double threshold, + const double *srcdata, const double *centroids, double *pred_centroids, + int *pred_labels); + +void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, + const char ord, int k, const float *srcdata, const float *centroids, + float *preds); + +void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, + const char ord, int k, const double *srcdata, const double *centroids, + double *preds); + +} diff --git a/cuML/src/kmeans/kmeans_centroids.h b/cuML/src/kmeans/kmeans_centroids.h new file mode 100644 index 0000000000..4fecdbdcaf --- /dev/null +++ b/cuML/src/kmeans/kmeans_centroids.h @@ -0,0 +1,187 @@ +/*! + * Modifications Copyright 2017-2018 H2O.ai, Inc. + */ +// original code from https://github.com/NVIDIA/kmeans (Apache V2.0 License) +#pragma once +#include +#include +#include "kmeans_labels.h" + +inline __device__ double my_atomic_add(double *address, double val) { + unsigned long long int *address_as_ull = (unsigned long long int *) address; + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + __longlong_as_double(assumed))); + } while (assumed != old); + return __longlong_as_double(old); +} + +inline __device__ float my_atomic_add(float *address, float val) { + return (atomicAdd(address, val)); +} + +namespace kmeans { +namespace detail { + +template +__device__ __forceinline__ +void update_centroid(int label, int dimension, int d, T accumulator, + T *centroids, int count, int *counts) { + int index = label * d + dimension; + T *target = centroids + index; + my_atomic_add(target, accumulator); + if (dimension == 0) { + atomicAdd(counts + label, count); + } +} + +template +__global__ void calculate_centroids(int n, int d, int k, T *data, + int *ordered_labels, int *ordered_indices, T *centroids, int *counts) { + int in_flight = blockDim.y * gridDim.y; + int labels_per_row = (n - 1) / in_flight + 1; + for (int dimension = threadIdx.x; dimension < d; dimension += blockDim.x) { + T accumulator = 0; + int count = 0; + int global_id = threadIdx.y + blockIdx.y * blockDim.y; + int start = global_id * labels_per_row; + int end = (global_id + 1) * labels_per_row; + end = (end > n) ? n : end; + int prior_label; + if (start < n) { + prior_label = ordered_labels[start]; + + for (int label_number = start; label_number < end; label_number++) { + int label = ordered_labels[label_number]; + if (label != prior_label) { + update_centroid(prior_label, dimension, d, accumulator, + centroids, count, counts); + accumulator = 0; + count = 0; + } + + T value = data[dimension + ordered_indices[label_number] * d]; + accumulator += value; + prior_label = label; + count++; + } + update_centroid(prior_label, dimension, d, accumulator, centroids, + count, counts); + } + } +} + +template +__global__ void revert_zeroed_centroids(int d, int k, T *tmp_centroids, + T *centroids, int *counts) { + int global_id_x = threadIdx.x + blockIdx.x * blockDim.x; + int global_id_y = threadIdx.y + blockIdx.y * blockDim.y; + if ((global_id_x < d) && (global_id_y < k)) { + if (counts[global_id_y] < 1) { + centroids[global_id_x + d * global_id_y] = tmp_centroids[global_id_x + + d * global_id_y]; + } + } +} + +template +__global__ void scale_centroids(int d, int k, int *counts, T *centroids) { + int global_id_x = threadIdx.x + blockIdx.x * blockDim.x; + int global_id_y = threadIdx.y + blockIdx.y * blockDim.y; + if ((global_id_x < d) && (global_id_y < k)) { + int count = counts[global_id_y]; + //To avoid introducing divide by zero errors + //If a centroid has no weight, we'll do no normalization + //This will keep its coordinates defined. + if (count < 1) { + count = 1; + } + T scale = 1.0 / T(count); + centroids[global_id_x + d * global_id_y] *= scale; + } +} + +// Scale - should be true when running on a single GPU +template +void find_centroids(int q, int n, int d, int k, int k_max, + thrust::device_vector &data, thrust::device_vector &labels, + thrust::device_vector ¢roids, thrust::device_vector &range, + thrust::device_vector &indices, thrust::device_vector &counts, + bool scale) { + int dev_num; + cudaGetDevice(&dev_num); + + // If no scaling then this step will be handled on the host + // when aggregating centroids from all GPUs + // Cache original centroids in case some centroids are not present in labels + // and would get zeroed + thrust::device_vector tmp_centroids; + if (scale) { + tmp_centroids = thrust::device_vector(k_max * d); + memcpy(tmp_centroids, centroids); + } + + memcpy(indices, range); + // TODO the rest of the algorithm doesn't necessarily require labels/data to be sorted + // but *might* make if faster due to less atomic updates + thrust::sort_by_key(labels.begin(), labels.end(), indices.begin()); + // TODO cub is faster but sort_by_key_int isn't sorting, possibly a bug +// mycub::sort_by_key_int(labels, indices); + +#if(CHECK) + gpuErrchk(cudaGetLastError()); +#endif + + // Need to zero this - the algo uses this array to accumulate values for each centroid + // which are then averaged to get a new centroid + memzero(centroids); + memzero(counts); + + //Calculate centroids + int n_threads_x = 64; // TODO FIXME + int n_threads_y = 16; // TODO FIXME + //XXX Number of blocks here is hard coded at 30 + //This should be taken care of more thoughtfully. + calculate_centroids<<>>(n, d, k, thrust::raw_pointer_cast(data.data()), + thrust::raw_pointer_cast(labels.data()), + thrust::raw_pointer_cast(indices.data()), + thrust::raw_pointer_cast(centroids.data()), + thrust::raw_pointer_cast(counts.data())); + +#if(CHECK) + gpuErrchk(cudaGetLastError()); +#endif + + // Scaling should take place on the GPU if n_gpus=1 so we don't + // move centroids and counts between host and device all the time for nothing + if (scale) { + // Revert only if `scale`, otherwise this will be taken care of in the host + // Revert reverts centroids for which count is equal 0 + revert_zeroed_centroids<<>>(d, k, + thrust::raw_pointer_cast(tmp_centroids.data()), + thrust::raw_pointer_cast(centroids.data()), + thrust::raw_pointer_cast(counts.data())); + +#if(CHECK) + gpuErrchk(cudaGetLastError()); +#endif + + //Averages the centroids + scale_centroids<<>>(d, k, + thrust::raw_pointer_cast(counts.data()), + thrust::raw_pointer_cast(centroids.data())); +#if(CHECK) + gpuErrchk(cudaGetLastError()); +#endif + } +} + +} +} diff --git a/cuML/src/kmeans/kmeans_general.h b/cuML/src/kmeans/kmeans_general.h new file mode 100644 index 0000000000..4c5a558338 --- /dev/null +++ b/cuML/src/kmeans/kmeans_general.h @@ -0,0 +1,26 @@ +/*! + * Copyright 2017-2018 H2O.ai, Inc. + * License Apache License Version 2.0 (see LICENSE for details) + */ +#pragma once +#include "logger.h" +#define MAX_NGPUS 16 + +#define VERBOSE 0 +#define CHECK 1 +#define DEBUGKMEANS 0 + +// TODO(pseudotensor): Avoid throw for python exception handling. Need to avoid all exit's and return exit code all the way back. +#define gpuErrchk(ans) { gpu_assert((ans), __FILE__, __LINE__); } +#define safe_cuda(ans) throw_on_cuda_error((ans), __FILE__, __LINE__); +#define safe_cublas(ans) throw_on_cublas_error((ans), __FILE__, __LINE__); + +#define CUDACHECK(cmd) do { \ + cudaError_t e = cmd; \ + if( e != cudaSuccess ) { \ + printf("Cuda failure %s:%d '%s'\n", \ + __FILE__,__LINE__,cudaGetErrorString(e));\ + fflush( stdout ); \ + exit(EXIT_FAILURE); \ + } \ + } while(0) diff --git a/cuML/src/kmeans/kmeans_impl.h b/cuML/src/kmeans/kmeans_impl.h new file mode 100644 index 0000000000..9b3b943188 --- /dev/null +++ b/cuML/src/kmeans/kmeans_impl.h @@ -0,0 +1,256 @@ +/*! + * Modifications Copyright 2017-2018 H2O.ai, Inc. + */ +// original code from https://github.com/NVIDIA/kmeans (Apache V2.0 License) +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "kmeans_centroids.h" +#include "kmeans_labels.h" +#include "kmeans_general.h" + +namespace kmeans { + +//! kmeans clusters data into k groups +/*! + + \param n Number of data points + \param d Number of dimensions + \param k Number of clusters + \param data Data points, in row-major order. This vector must have + size n * d, and since it's in row-major order, data point x occupies + positions [x * d, (x + 1) * d) in the vector. The vector is passed + by reference since it is shared with the caller and not copied. + \param labels Cluster labels. This vector has size n. + The vector is passed by reference since it is shared with the caller + and not copied. + \param centroids Centroid locations, in row-major order. This + vector must have size k * d, and since it's in row-major order, + centroid x occupies positions [x * d, (x + 1) * d) in the + vector. The vector is passed by reference since it is shared + with the caller and not copied. + \param threshold This controls early termination of the kmeans + iterations. If the ratio of points being reassigned to a different + centroid is less than the threshold, than the iterations are + terminated. Defaults to 1e-3. + \param max_iterations Maximum number of iterations to run + \return The number of iterations actually performed. + */ + +template +int kmeans(int verbose, volatile std::atomic_int *flag, int n, int d, int k, + int k_max, thrust::device_vector **data, + thrust::device_vector **labels, + thrust::device_vector **centroids, + thrust::device_vector **data_dots, std::vector dList, int n_gpu, + int max_iterations, double threshold = 1e-3, bool do_per_iter_check = + true) { + + thrust::device_vector *centroid_dots[n_gpu]; + thrust::device_vector *labels_copy[n_gpu]; + thrust::device_vector *range[n_gpu]; + thrust::device_vector *indices[n_gpu]; + thrust::device_vector *counts[n_gpu]; + thrust::device_vector d_old_centroids; + + thrust::host_vector h_counts(k); + thrust::host_vector h_counts_tmp(k); + thrust::host_vector h_centroids(k * d); + h_centroids = *centroids[0]; // all should be equal + thrust::host_vector h_centroids_tmp(k_max * d); + + T *d_distance_sum[n_gpu]; + + bool unable_alloc = false; +#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + log_debug(verbose, "Before kmeans() Allocation: gpu: %d", q); + + safe_cuda(cudaSetDevice(dList[q])); + safe_cuda(cudaMalloc(&d_distance_sum[q], sizeof(T))); + + try { + centroid_dots[q] = new thrust::device_vector(k); + labels_copy[q] = new thrust::device_vector(n / n_gpu); + range[q] = new thrust::device_vector(n / n_gpu); + counts[q] = new thrust::device_vector(k); + indices[q] = new thrust::device_vector(n / n_gpu); + } catch (thrust::system_error &e) { + log_error(verbose, + "Unable to allocate memory for gpu: %d | n/n_gpu: %d | k: %d | d: %d | error: %s", + q, n / n_gpu, k, d, e.what()); + unable_alloc = true; + // throw std::runtime_error(ss.str()); + } catch (std::bad_alloc &e) { + log_error(verbose, + "Unable to allocate memory for gpu: %d | n/n_gpu: %d | k: %d | d: %d | error: %s", + q, n / n_gpu, k, d, e.what()); + unable_alloc = true; + //throw std::runtime_error(ss.str()); + } + + if (!unable_alloc) { + //Create and save "range" for initializing labels + thrust::copy(thrust::counting_iterator(0), + thrust::counting_iterator(n / n_gpu), + (*range[q]).begin()); + } + } + + if (unable_alloc) + return (-1); + + log_debug(verbose, "Before kmeans() Iterations"); + + int i = 0; + bool done = false; + for (; i < max_iterations; i++) { + log_verbose(verbose, "KMeans - Iteration %d/%d", i, max_iterations); + + if (*flag) + continue; + + safe_cuda(cudaSetDevice(dList[0])); + d_old_centroids = *centroids[dList[0]]; + +#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + safe_cuda(cudaSetDevice(dList[q])); + + detail::batch_calculate_distances(verbose, q, n / n_gpu, d, k, + *data[q], *centroids[q], *data_dots[q], *centroid_dots[q], + [&](int n, size_t offset, thrust::device_vector &pairwise_distances) { + detail::relabel(n, k, pairwise_distances, *labels[q], offset); + }); + + log_verbose(verbose, "KMeans - Relabeled."); + + detail::memcpy(*labels_copy[q], *labels[q]); + detail::find_centroids(q, n / n_gpu, d, k, k_max, *data[q], + *labels_copy[q], *centroids[q], *range[q], *indices[q], + *counts[q], n_gpu <= 1); + } + + // Scale the centroids on host + if (n_gpu > 1) { + //Average the centroids from each device + for (int p = 0; p < k; p++) + h_counts[p] = 0.0; + for (int q = 0; q < n_gpu; q++) { + safe_cuda(cudaSetDevice(dList[q])); + detail::memcpy(h_counts_tmp, *counts[q]); + detail::streamsync(dList[q]); + for (int p = 0; p < k; p++) + h_counts[p] += h_counts_tmp[p]; + } + + // Zero the centroids only if any of the GPUs actually updated them + for (int p = 0; p < k; p++) { + for (int r = 0; r < d; r++) { + if (h_counts[p] != 0) { + h_centroids[p * d + r] = 0.0; + } + } + } + + for (int q = 0; q < n_gpu; q++) { + safe_cuda(cudaSetDevice(dList[q])); + detail::memcpy(h_centroids_tmp, *centroids[q]); + detail::streamsync(dList[q]); + for (int p = 0; p < k; p++) { + for (int r = 0; r < d; r++) { + if (h_counts[p] != 0) { + h_centroids[p * d + r] += + h_centroids_tmp[p * d + r]; + } + } + } + } + + for (int p = 0; p < k; p++) { + for (int r = 0; r < d; r++) { + // If 0 counts that means we leave the original centroids + if (h_counts[p] == 0) { + h_counts[p] = 1; + } + h_centroids[p * d + r] /= h_counts[p]; + } + } + + //Copy the averaged centroids to each device +#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + safe_cuda(cudaSetDevice(dList[q])); + detail::memcpy(*centroids[q], h_centroids); + } + } + + // whether to perform per iteration check + if (do_per_iter_check) { + safe_cuda(cudaSetDevice(dList[0])); + + T + squared_norm = thrust::inner_product( + d_old_centroids.begin(), d_old_centroids.end(), + (*centroids[0]).begin(), + (T) 0.0, + thrust::plus(), + [=]__device__(T left, T right) { + T diff = left - right; + return diff * diff; + } + ); + + if (squared_norm < threshold) { + if (verbose) { + std::cout << "Threshold triggered. Terminating early." + << std::endl; + } + done = true; + } + } + + if (*flag) { + fprintf(stderr, "Signal caught. Terminated early.\n"); + fflush(stderr); + *flag = 0; // set flag + done = true; + } + + if (done || i == max_iterations - 1) { + // Final relabeling - uses final centroids +#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + safe_cuda(cudaSetDevice(dList[q])); + detail::batch_calculate_distances(verbose, q, n / n_gpu, d, k, + *data[q], *centroids[q], *data_dots[q], + *centroid_dots[q], + [&](int n, size_t offset, thrust::device_vector &pairwise_distances) { + detail::relabel(n, k, pairwise_distances, *labels[q], offset); + }); + } + break; + } + } + +//#pragma omp parallel for + for (int q = 0; q < n_gpu; q++) { + safe_cuda(cudaSetDevice(dList[q])); + delete (centroid_dots[q]); + delete (labels_copy[q]); + delete (range[q]); + delete (counts[q]); + delete (indices[q]); + } + + log_debug(verbose, "Iterations: %d", i); + + return i; +} + +} diff --git a/cuML/src/kmeans/kmeans_labels.h b/cuML/src/kmeans/kmeans_labels.h new file mode 100644 index 0000000000..0a829ed76e --- /dev/null +++ b/cuML/src/kmeans/kmeans_labels.h @@ -0,0 +1,906 @@ +/*! + * Modifications Copyright 2017-2018 H2O.ai, Inc. + */ +// original code from https://github.com/NVIDIA/kmeans (Apache V2.0 License) +#pragma once +#include +#include "cub/cub.cuh" +#include +#include +#include +#include +#include +#include "kmeans_general.h" +#include +#include + +inline void gpu_assert(cudaError_t code, const char *file, int line, + bool abort = true) { + if (code != cudaSuccess) { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, + line); + std::stringstream ss; + ss << file << "(" << line << ")"; + std::string file_and_line; + ss >> file_and_line; + thrust::system_error(code, thrust::cuda_category(), file_and_line); + } +} + +inline cudaError_t throw_on_cuda_error(cudaError_t code, const char *file, + int line) { + if (code != cudaSuccess) { + std::stringstream ss; + ss << file << "(" << line << ")"; + std::string file_and_line; + ss >> file_and_line; + thrust::system_error(code, thrust::cuda_category(), file_and_line); + } + + return code; +} + +#ifdef CUBLAS_API_H_ +// cuBLAS API errors +static const char *cudaGetErrorEnum(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + } + + return ""; +} +#endif + +inline cublasStatus_t throw_on_cublas_error(cublasStatus_t code, + const char *file, int line) { + + if (code != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, "cublas error: %s %s %d\n", cudaGetErrorEnum(code), + file, line); + std::stringstream ss; + ss << file << "(" << line << ")"; + std::string file_and_line; + ss >> file_and_line; + thrust::system_error(code, thrust::cuda_category(), file_and_line); + } + + return code; +} + +extern cudaStream_t cuda_stream[MAX_NGPUS]; + +template +extern __global__ void debugMark() { +} +; + +namespace kmeans { +namespace detail { + +void labels_init(); +void labels_close(); + +extern cublasHandle_t cublas_handle[MAX_NGPUS]; +template +void memcpy(thrust::host_vector > &H, + thrust::device_vector > &D) { + int dev_num; + safe_cuda(cudaGetDevice(&dev_num)); + safe_cuda( + cudaMemcpyAsync(thrust::raw_pointer_cast(H.data()), + thrust::raw_pointer_cast(D.data()), sizeof(T) * D.size(), + cudaMemcpyDeviceToHost, cuda_stream[dev_num])); +} + +template +void memcpy(thrust::device_vector > &D, + thrust::host_vector > &H) { + int dev_num; + safe_cuda(cudaGetDevice(&dev_num)); + safe_cuda( + cudaMemcpyAsync(thrust::raw_pointer_cast(D.data()), + thrust::raw_pointer_cast(H.data()), sizeof(T) * H.size(), + cudaMemcpyHostToDevice, cuda_stream[dev_num])); +} +template +void memcpy(thrust::device_vector > &Do, + thrust::device_vector > &Di) { + int dev_num; + safe_cuda(cudaGetDevice(&dev_num)); + safe_cuda( + cudaMemcpyAsync(thrust::raw_pointer_cast(Do.data()), + thrust::raw_pointer_cast(Di.data()), sizeof(T) * Di.size(), + cudaMemcpyDeviceToDevice, cuda_stream[dev_num])); +} +template +void memzero(thrust::device_vector >& D) { + int dev_num; + safe_cuda(cudaGetDevice(&dev_num)); + safe_cuda( + cudaMemsetAsync(thrust::raw_pointer_cast(D.data()), 0, + sizeof(T) * D.size(), cuda_stream[dev_num])); +} +void streamsync(int dev_num); + +//n: number of points +//d: dimensionality of points +//data: points, laid out in row-major order (n rows, d cols) +//dots: result vector (n rows) +// NOTE: +//Memory accesses in this function are uncoalesced!! +//This is because data is in row major order +//However, in k-means, it's called outside the optimization loop +//on the large data array, and inside the optimization loop it's +//called only on a small array, so it doesn't really matter. +//If this becomes a performance limiter, transpose the data somewhere +template +__global__ void self_dots(int n, int d, T* data, T* dots) { + T accumulator = 0; + int global_id = blockDim.x * blockIdx.x + threadIdx.x; + + if (global_id < n) { + for (int i = 0; i < d; i++) { + T value = data[i + global_id * d]; + accumulator += value * value; + } + dots[global_id] = accumulator; + } +} + +template +void make_self_dots(int n, int d, thrust::device_vector& data, + thrust::device_vector& dots) { + int dev_num; +#define MAX_BLOCK_THREADS0 256 + const int GRID_SIZE = (n - 1) / MAX_BLOCK_THREADS0 + 1; + safe_cuda(cudaGetDevice(&dev_num)); + self_dots<<>>(n, d, + thrust::raw_pointer_cast(data.data()), + thrust::raw_pointer_cast(dots.data())); +#if(CHECK) + gpuErrchk(cudaGetLastError()); +#endif + +} + +#define MAX_BLOCK_THREADS 32 +template +__global__ void all_dots(int n, int k, T* data_dots, T* centroid_dots, + T* dots) { + __shared__ T local_data_dots[MAX_BLOCK_THREADS]; + __shared__ T local_centroid_dots[MAX_BLOCK_THREADS]; + // if(threadIdx.x==0 && threadIdx.y==0 && blockIdx.x==0) printf("inside %d %d %d\n",threadIdx.x,blockIdx.x,blockDim.x); + + int data_index = threadIdx.x + blockIdx.x * blockDim.x; + if ((data_index < n) && (threadIdx.y == 0)) { + local_data_dots[threadIdx.x] = data_dots[data_index]; + } + + int centroid_index = threadIdx.x + blockIdx.y * blockDim.y; + if ((centroid_index < k) && (threadIdx.y == 1)) { + local_centroid_dots[threadIdx.x] = centroid_dots[centroid_index]; + } + + __syncthreads(); + + centroid_index = threadIdx.y + blockIdx.y * blockDim.y; + // printf("data_index=%d centroid_index=%d\n",data_index,centroid_index); + if ((data_index < n) && (centroid_index < k)) { + dots[data_index + centroid_index * n] = local_data_dots[threadIdx.x] + + local_centroid_dots[threadIdx.y]; + } +} + +template +void make_all_dots(int n, int k, size_t offset, + thrust::device_vector& data_dots, + thrust::device_vector& centroid_dots, + thrust::device_vector& dots) { + int dev_num; + safe_cuda(cudaGetDevice(&dev_num)); + const int BLOCK_THREADSX = MAX_BLOCK_THREADS; // BLOCK_THREADSX*BLOCK_THREADSY<=1024 on modern arch's (sm_61) + const int BLOCK_THREADSY = MAX_BLOCK_THREADS; + const int GRID_SIZEX = (n - 1) / BLOCK_THREADSX + 1; // on old arch's this has to be less than 2^16=65536 + const int GRID_SIZEY = (k - 1) / BLOCK_THREADSY + 1; // this has to be less than 2^16=65536 + // printf("pre all_dots: %d %d %d %d\n",GRID_SIZEX,GRID_SIZEY,BLOCK_THREADSX,BLOCK_THREADSY); fflush(stdout); + all_dots<<>>(n, + k, thrust::raw_pointer_cast(data_dots.data() + offset), + thrust::raw_pointer_cast(centroid_dots.data()), + thrust::raw_pointer_cast(dots.data())); +#if(CHECK) + gpuErrchk(cudaGetLastError()); +#endif +} +; + +#define WARP_SIZE 32 +#define BLOCK_SIZE 1024 +#define BSIZE_DIV_WSIZE (BLOCK_SIZE/WARP_SIZE) +#define IDX(i,j,lda) ((i)+(j)*(lda)) +template +__global__ void rejectByCosines(int k, int *accept, T *dists, + T *centroid_dots) { + + // Global indices + int gidx, gidy; + + // Lengths from cosine_dots + float lenA, lenB; + // Threshold + float thresh = 0.9; + + // Observation vector is determined by global y-index + gidy = threadIdx.y + blockIdx.y * blockDim.y; + while (gidy < k) { + // Get lengths from global memory, stored in centroid_dots + lenA = centroid_dots[gidy]; + + gidx = threadIdx.x + blockIdx.x * blockDim.x; + while (gidx < gidy) { + lenB = centroid_dots[gidx]; + if (lenA > 1e-8 && lenB > 1e-8) + dists[IDX(gidx, gidy, k)] /= lenA * lenB; + if (dists[IDX(gidx, gidy, k)] > thresh + && ((lenA < 2.0 * lenB) && (lenB < 2.0 * lenA))) + accept[gidy] = 0; + gidx += blockDim.x * gridDim.x; + } + // Move to another centroid + gidy += blockDim.y * gridDim.y; + } + +} + +template +void checkCosine(int d, int k, int *numChosen, thrust::device_vector &dists, + thrust::device_vector ¢roids, + thrust::device_vector ¢roid_dots) { + + dim3 blockDim, gridDim; + + int h_accept[k]; + + thrust::device_vector accept(k); + thrust::fill(accept.begin(), accept.begin() + k, 1); + //printf("after fill accept\n"); + // Divide dists by centroid lengths to get cosine matrix + blockDim.x = WARP_SIZE; + blockDim.y = BLOCK_SIZE / WARP_SIZE; + blockDim.z = 1; + gridDim.x = min((k + WARP_SIZE - 1) / WARP_SIZE, 65535); + gridDim.y = min((k + BSIZE_DIV_WSIZE - 1) / BSIZE_DIV_WSIZE, 65535); + gridDim.z = 1; + + rejectByCosines<<>>(k, + thrust::raw_pointer_cast(accept.data()), + thrust::raw_pointer_cast(dists.data()), + thrust::raw_pointer_cast(centroid_dots.data())); + +#if(CHECK) + gpuErrchk(cudaGetLastError()); +#endif + + *numChosen = thrust::reduce(accept.begin(), accept.begin() + k); + + CUDACHECK( + cudaMemcpy(h_accept, thrust::raw_pointer_cast(accept.data()), + k * sizeof(int), cudaMemcpyDeviceToHost)); + + int skipcopy = 1; + for (int z = 0; z < *numChosen; ++z) { + if (h_accept[z] == 0) + skipcopy = 0; + } + + if (!skipcopy && (*numChosen > 1 && *numChosen < k)) { + int i, j; + int candidate_map[d * (*numChosen)]; + j = 0; + for (i = 0; i < k; ++i) { + if (h_accept[i]) { + for (int m = 0; m < d; ++m) + candidate_map[j * d + m] = i * d + m; + j += 1; + } + } + + thrust::device_vector d_candidate_map(d * (*numChosen)); + CUDACHECK(cudaMemcpy(thrust::raw_pointer_cast(d_candidate_map.data()), candidate_map, d*(*numChosen)*sizeof(int), cudaMemcpyHostToDevice)) + ; + + thrust::device_vector cent_copy(dists); + + thrust::copy_n(centroids.begin(), d * k, cent_copy.begin()); + +#if(CHECK) + gpuErrchk(cudaGetLastError()); +#endif + + // Gather accepted centroid candidates into centroid memory + thrust::gather(d_candidate_map.begin(), + d_candidate_map.begin() + d * (*numChosen), cent_copy.begin(), + centroids.begin()); + } +} + +template +void calculate_distances(int verbose, int q, size_t n, int d, int k, + thrust::device_vector& data, size_t data_offset, + thrust::device_vector& centroids, + thrust::device_vector& data_dots, + thrust::device_vector& centroid_dots, + thrust::device_vector& pairwise_distances); + +template +void batch_calculate_distances(int verbose, int q, size_t n, int d, int k, + thrust::device_vector &data, thrust::device_vector ¢roids, + thrust::device_vector &data_dots, + thrust::device_vector ¢roid_dots, F functor) { + int fudges_size = 4; + double fudges[] = { 1.0, 0.75, 0.5, 0.25 }; + for (const double fudge : fudges) { + try { + // Get info about available memory + // This part of the algo can be very memory consuming + // We might need to batch it + size_t free_byte; + size_t total_byte; + CUDACHECK(cudaMemGetInfo(&free_byte, &total_byte)); + free_byte = free_byte * fudge; + + size_t required_byte = n * k * sizeof(T); + + size_t runs = std::ceil(required_byte / (double) free_byte); + + log_verbose(verbose, + "Batch calculate distance - Rows %ld | K %ld | Data size %d", + n, k, sizeof(T)); + + log_verbose(verbose, + "Batch calculate distance - Free memory %zu | Required memory %zu | Runs %d", + free_byte, required_byte, runs); + + size_t offset = 0; + size_t rows_per_run = n / runs; + thrust::device_vector pairwise_distances(rows_per_run * k); + + for (int run = 0; run < runs; run++) { + if (run + 1 == runs && n % rows_per_run != 0) { + rows_per_run = n % rows_per_run; + } + + thrust::fill_n(pairwise_distances.begin(), + pairwise_distances.size(), (T) 0.0); + + log_verbose(verbose, "Batch calculate distance - Allocated"); + + kmeans::detail::calculate_distances(verbose, 0, rows_per_run, d, + k, data, offset, centroids, data_dots, centroid_dots, + pairwise_distances); + + log_verbose(verbose, + "Batch calculate distance - Distances calculated"); + + functor(rows_per_run, offset, pairwise_distances); + + log_verbose(verbose, "Batch calculate distance - Functor ran"); + + offset += rows_per_run; + } + } catch (const std::bad_alloc& e) { + cudaGetLastError(); + if (fudges[fudges_size - 1] != fudge) { + log_warn(verbose, + "Batch calculate distance - Failed to allocate memory for pairwise distances - retrying."); + continue; + } else { + log_error(verbose, + "Batch calculate distance - Failed to allocate memory for pairwise distances - exiting."); + throw e; + } + } + + return; + } +} + +template +__global__ void make_new_labels(int n, int k, T* pairwise_distances, + int* labels) { + T min_distance = FLT_MAX; //std::numeric_limits::max(); // might be ok TODO FIXME + T min_idx = -1; + int global_id = threadIdx.x + blockIdx.x * blockDim.x; + if (global_id < n) { + for (int c = 0; c < k; c++) { + T distance = pairwise_distances[c * n + global_id]; + if (distance < min_distance) { + min_distance = distance; + min_idx = c; + } + } + labels[global_id] = min_idx; + } +} + +template +void relabel(int n, int k, thrust::device_vector& pairwise_distances, + thrust::device_vector& labels, size_t offset) { + int dev_num; + safe_cuda(cudaGetDevice(&dev_num)); +#define MAX_BLOCK_THREADS2 256 + const int GRID_SIZE = (n - 1) / MAX_BLOCK_THREADS2 + 1; + make_new_labels<<>>( + n, k, thrust::raw_pointer_cast(pairwise_distances.data()), + thrust::raw_pointer_cast(labels.data() + offset)); +#if(CHECK) + gpuErrchk(cudaGetLastError()); +#endif +} + +} +} +namespace mycub { + +extern void *d_key_alt_buf[MAX_NGPUS]; +extern unsigned int key_alt_buf_bytes[MAX_NGPUS]; +extern void *d_value_alt_buf[MAX_NGPUS]; +extern unsigned int value_alt_buf_bytes[MAX_NGPUS]; +extern void *d_temp_storage[MAX_NGPUS]; +extern size_t temp_storage_bytes[MAX_NGPUS]; +extern void *d_temp_storage2[MAX_NGPUS]; +extern size_t temp_storage_bytes2[MAX_NGPUS]; +extern bool cub_initted; + +void sort_by_key_int(thrust::device_vector& keys, + thrust::device_vector& values); + +template +void sort_by_key(thrust::device_vector& keys, + thrust::device_vector& values) { + int dev_num; + safe_cuda(cudaGetDevice(&dev_num)); + cudaStream_t this_stream = cuda_stream[dev_num]; + int SIZE = keys.size(); + if (key_alt_buf_bytes[dev_num] < sizeof(T) * SIZE) { + if (d_key_alt_buf[dev_num]) + safe_cuda(cudaFree(d_key_alt_buf[dev_num])); + safe_cuda(cudaMalloc(&d_key_alt_buf[dev_num], sizeof(T) * SIZE)); + key_alt_buf_bytes[dev_num] = sizeof(T) * SIZE; + std::cout << "Malloc key_alt_buf" << std::endl; + } + if (value_alt_buf_bytes[dev_num] < sizeof(U) * SIZE) { + if (d_value_alt_buf[dev_num]) + safe_cuda(cudaFree(d_value_alt_buf[dev_num])); + safe_cuda(cudaMalloc(&d_value_alt_buf[dev_num], sizeof(U) * SIZE)); + value_alt_buf_bytes[dev_num] = sizeof(U) * SIZE; + std::cout << "Malloc value_alt_buf" << std::endl; + } + cub::DoubleBuffer d_keys(thrust::raw_pointer_cast(keys.data()), + (T*) d_key_alt_buf[dev_num]); + cub::DoubleBuffer d_values(thrust::raw_pointer_cast(values.data()), + (U*) d_value_alt_buf[dev_num]); + cudaError_t err; + + // Determine temporary device storage requirements for sorting operation + //if (temp_storage_bytes[dev_num] == 0) { + void *d_temp; + size_t temp_bytes; + err = cub::DeviceRadixSort::SortPairs(d_temp_storage[dev_num], temp_bytes, + d_keys, d_values, SIZE, 0, sizeof(T) * 8, this_stream); + // Allocate temporary storage for sorting operation + safe_cuda(cudaMalloc(&d_temp, temp_bytes)); + d_temp_storage[dev_num] = d_temp; + temp_storage_bytes[dev_num] = temp_bytes; + std::cout << "Malloc temp_storage. " << temp_storage_bytes[dev_num] + << " bytes" << std::endl; + std::cout << "d_temp_storage[" << dev_num << "] = " + << d_temp_storage[dev_num] << std::endl; + if (err) { + std::cout << "Error " << err << " in SortPairs 1" << std::endl; + std::cout << cudaGetErrorString(err) << std::endl; + } + //} + // Run sorting operation + err = cub::DeviceRadixSort::SortPairs(d_temp, temp_bytes, d_keys, d_values, + SIZE, 0, sizeof(T) * 8, this_stream); + if (err) + std::cout << "Error in SortPairs 2" << std::endl; + //cub::DeviceRadixSort::SortPairs(d_temp_storage[dev_num], temp_storage_bytes[dev_num], d_keys, + // d_values, SIZE, 0, sizeof(T)*8, this_stream); + +} +template +void sum_reduce(thrust::device_vector& values, T* sum) { + int dev_num; + safe_cuda(cudaGetDevice(&dev_num)); + if (!d_temp_storage2[dev_num]) { + cub::DeviceReduce::Sum(d_temp_storage2[dev_num], + temp_storage_bytes2[dev_num], + thrust::raw_pointer_cast(values.data()), sum, values.size(), + cuda_stream[dev_num]); + // Allocate temporary storage for sorting operation + safe_cuda( + cudaMalloc(&d_temp_storage2[dev_num], + temp_storage_bytes2[dev_num])); + } + cub::DeviceReduce::Sum(d_temp_storage2[dev_num], + temp_storage_bytes2[dev_num], + thrust::raw_pointer_cast(values.data()), sum, values.size(), + cuda_stream[dev_num]); +} +void cub_init(); +void cub_close(); + +void cub_init(int dev); +void cub_close(int dev); +} + +namespace kmeans { +namespace detail { + +template +struct absolute_value { + __host__ __device__ + void operator()(T &x) const { + x = (x > 0 ? x : -x); + } +}; + +cublasHandle_t cublas_handle[MAX_NGPUS]; + +void labels_init() { + cublasStatus_t stat; + cudaError_t err; + int dev_num; + safe_cuda(cudaGetDevice(&dev_num)); + stat = cublasCreate(&detail::cublas_handle[dev_num]); + if (stat != CUBLAS_STATUS_SUCCESS) { + std::cout << "CUBLAS initialization failed" << std::endl; + exit(1); + } + err = safe_cuda(cudaStreamCreate(&cuda_stream[dev_num])) + ; + if (err != cudaSuccess) { + std::cout << "Stream creation failed" << std::endl; + + } + cublasSetStream(cublas_handle[dev_num], cuda_stream[dev_num]); + mycub::cub_init(dev_num); +} + +void labels_close() { + int dev_num; + safe_cuda(cudaGetDevice(&dev_num)); + safe_cublas(cublasDestroy(cublas_handle[dev_num])); + safe_cuda(cudaStreamDestroy(cuda_stream[dev_num])); + mycub::cub_close(dev_num); +} + +void streamsync(int dev_num) { + cudaStreamSynchronize(cuda_stream[dev_num]); +} + +/** + * Matrix multiplication: alpha * A^T * B + beta * C + * Optimized for tall and skinny matrices + * + * @tparam float_t + * @param A + * @param B + * @param C + * @param alpha + * @param beta + * @param n + * @param d + * @param k + * @param max_block_rows + * @return + */ +template +__global__ void matmul(const float_t *A, const float_t *B, float_t *C, + const float_t alpha, const float_t beta, int n, int d, int k, + int max_block_rows) { + + extern __shared__ __align__(sizeof(float_t)) unsigned char my_smem[]; + float_t *shared = reinterpret_cast(my_smem); + + float_t *s_A = shared; + float_t *s_B = shared + max_block_rows * d; + + for (int i = threadIdx.x; i < d * k; i += blockDim.x) { + s_B[i] = B[i]; + } + + size_t block_start_row_index = blockIdx.x * max_block_rows; + size_t block_rows = max_block_rows; + + if (blockIdx.x == gridDim.x - 1 && n % max_block_rows != 0) { + block_rows = n % max_block_rows; + } + + for (size_t i = threadIdx.x; i < d * block_rows; i += blockDim.x) { + s_A[i] = alpha * A[d * block_start_row_index + i]; + } + + __syncthreads(); + + float_t elem_c = 0; + + int col_c = threadIdx.x % k; + size_t abs_row_c = block_start_row_index + threadIdx.x / k; + int row_c = threadIdx.x / k; + + // Thread/Block combination either too far for data array + // Or is calculating for index that should be calculated in a different blocks - in some edge cases + // "col_c * n + abs_row_c" can yield same result in different thread/block combinations + if (abs_row_c >= n || threadIdx.x >= block_rows * k) { + return; + } + + for (size_t i = 0; i < d; i++) { + elem_c += s_B[d * col_c + i] * s_A[d * row_c + i]; + } + + C[col_c * n + abs_row_c] = beta * C[col_c * n + abs_row_c] + elem_c; + +} + +template<> +void calculate_distances(int verbose, int q, size_t n, int d, int k, + thrust::device_vector &data, size_t data_offset, + thrust::device_vector ¢roids, + thrust::device_vector &data_dots, + thrust::device_vector ¢roid_dots, + thrust::device_vector &pairwise_distances) { + detail::make_self_dots(k, d, centroids, centroid_dots); + detail::make_all_dots(n, k, data_offset, data_dots, centroid_dots, + pairwise_distances); + + //||x-y||^2 = ||x||^2 + ||y||^2 - 2 x . y + //pairwise_distances has ||x||^2 + ||y||^2, so beta = 1 + //The dgemm calculates x.y for all x and y, so alpha = -2.0 + double alpha = -2.0; + double beta = 1.0; + //If the data were in standard column major order, we'd do a + //centroids * data ^ T + //But the data is in row major order, so we have to permute + //the arguments a little + int dev_num; + safe_cuda(cudaGetDevice(&dev_num)); + + bool do_cublas = true; + if (k <= 16 && d <= 64) { + const int BLOCK_SIZE_MUL = 128; + int block_rows = std::min((size_t) BLOCK_SIZE_MUL / k, n); + int grid_size = std::ceil(static_cast(n) / block_rows); + + int shared_size_B = d * k * sizeof(double); + size_t shared_size_A = block_rows * d * sizeof(double); + if (shared_size_B + shared_size_A < (1 << 15)) { + + matmul<<>>( + thrust::raw_pointer_cast(data.data() + data_offset * d), + thrust::raw_pointer_cast(centroids.data()), + thrust::raw_pointer_cast(pairwise_distances.data()), alpha, + beta, n, d, k, block_rows); + do_cublas = false; + } + } + + if (do_cublas) { + cublasStatus_t stat = + safe_cublas( + cublasDgemm(detail::cublas_handle[dev_num], CUBLAS_OP_T, CUBLAS_OP_N, n, k, d, &alpha, thrust::raw_pointer_cast(data.data() + data_offset * d), d, //Has to be n or d + thrust::raw_pointer_cast(centroids.data()), d,//Has to be k or d + &beta, thrust::raw_pointer_cast(pairwise_distances.data()), n)) + ; //Has to be n or k + + if (stat != CUBLAS_STATUS_SUCCESS) { + std::cout << "Invalid Dgemm" << std::endl; + exit(1); + } + } + + thrust::for_each(pairwise_distances.begin(), pairwise_distances.end(), + absolute_value()); // in-place transformation to ensure all distances are positive indefinite + +#if(CHECK) + gpuErrchk(cudaGetLastError()); +#endif +} + +template<> +void calculate_distances(int verbose, int q, size_t n, int d, int k, + thrust::device_vector &data, size_t data_offset, + thrust::device_vector ¢roids, + thrust::device_vector &data_dots, + thrust::device_vector ¢roid_dots, + thrust::device_vector &pairwise_distances) { + detail::make_self_dots(k, d, centroids, centroid_dots); + detail::make_all_dots(n, k, data_offset, data_dots, centroid_dots, + pairwise_distances); + + //||x-y||^2 = ||x||^2 + ||y||^2 - 2 x . y + //pairwise_distances has ||x||^2 + ||y||^2, so beta = 1 + //The dgemm calculates x.y for all x and y, so alpha = -2.0 + float alpha = -2.0; + float beta = 1.0; + //If the data were in standard column major order, we'd do a + //centroids * data ^ T + //But the data is in row major order, so we have to permute + //the arguments a little + int dev_num; + safe_cuda(cudaGetDevice(&dev_num)); + + if (k <= 16 && d <= 64) { + const int BLOCK_SIZE_MUL = 128; + int block_rows = std::min((size_t) BLOCK_SIZE_MUL / k, n); + int grid_size = std::ceil(static_cast(n) / block_rows); + + int shared_size_B = d * k * sizeof(float); + int shared_size_A = block_rows * d * sizeof(float); + + matmul<<>>( + thrust::raw_pointer_cast(data.data() + data_offset * d), + thrust::raw_pointer_cast(centroids.data()), + thrust::raw_pointer_cast(pairwise_distances.data()), alpha, + beta, n, d, k, block_rows); + } else { + cublasStatus_t stat = + safe_cublas( + cublasSgemm(detail::cublas_handle[dev_num], CUBLAS_OP_T, CUBLAS_OP_N, n, k, d, &alpha, thrust::raw_pointer_cast(data.data() + data_offset * d), d, //Has to be n or d + thrust::raw_pointer_cast(centroids.data()), d,//Has to be k or d + &beta, thrust::raw_pointer_cast(pairwise_distances.data()), n)) + ; //Has to be n or k + + if (stat != CUBLAS_STATUS_SUCCESS) { + std::cout << "Invalid Sgemm" << std::endl; + exit(1); + } + } + + thrust::for_each(pairwise_distances.begin(), pairwise_distances.end(), + absolute_value()); // in-place transformation to ensure all distances are positive indefinite + +#if(CHECK) + gpuErrchk(cudaGetLastError()); +#endif +} + +} +} + +namespace mycub { + +void *d_key_alt_buf[MAX_NGPUS]; +unsigned int key_alt_buf_bytes[MAX_NGPUS]; +void *d_value_alt_buf[MAX_NGPUS]; +unsigned int value_alt_buf_bytes[MAX_NGPUS]; +void *d_temp_storage[MAX_NGPUS]; +size_t temp_storage_bytes[MAX_NGPUS]; +void *d_temp_storage2[MAX_NGPUS]; +size_t temp_storage_bytes2[MAX_NGPUS]; +bool cub_initted; +void cub_init() { + // std::cout <<"CUB init" << std::endl; + for (int q = 0; q < MAX_NGPUS; q++) { + d_key_alt_buf[q] = NULL; + key_alt_buf_bytes[q] = 0; + d_value_alt_buf[q] = NULL; + value_alt_buf_bytes[q] = 0; + d_temp_storage[q] = NULL; + temp_storage_bytes[q] = 0; + d_temp_storage2[q] = NULL; + temp_storage_bytes2[q] = 0; + } + cub_initted = true; +} + +void cub_init(int dev) { + d_key_alt_buf[dev] = NULL; + key_alt_buf_bytes[dev] = 0; + d_value_alt_buf[dev] = NULL; + value_alt_buf_bytes[dev] = 0; + d_temp_storage[dev] = NULL; + temp_storage_bytes[dev] = 0; + d_temp_storage2[dev] = NULL; + temp_storage_bytes2[dev] = 0; +} + +void cub_close() { + for (int q = 0; q < MAX_NGPUS; q++) { + if (d_key_alt_buf[q]) + safe_cuda(cudaFree(d_key_alt_buf[q])); + if (d_value_alt_buf[q]) + safe_cuda(cudaFree(d_value_alt_buf[q])); + if (d_temp_storage[q]) + safe_cuda(cudaFree(d_temp_storage[q])); + if (d_temp_storage2[q]) + safe_cuda(cudaFree(d_temp_storage2[q])); + d_temp_storage[q] = NULL; + d_temp_storage2[q] = NULL; + } + cub_initted = false; +} + +void cub_close(int dev) { + if (d_key_alt_buf[dev]) + safe_cuda(cudaFree(d_key_alt_buf[dev])); + if (d_value_alt_buf[dev]) + safe_cuda(cudaFree(d_value_alt_buf[dev])); + if (d_temp_storage[dev]) + safe_cuda(cudaFree(d_temp_storage[dev])); + if (d_temp_storage2[dev]) + safe_cuda(cudaFree(d_temp_storage2[dev])); + d_temp_storage[dev] = NULL; + d_temp_storage2[dev] = NULL; +} + +void sort_by_key_int(thrust::device_vector &keys, + thrust::device_vector &values) { + int dev_num; + safe_cuda(cudaGetDevice(&dev_num)); + cudaStream_t this_stream = cuda_stream[dev_num]; + int SIZE = keys.size(); + //int *d_key_alt_buf, *d_value_alt_buf; + if (key_alt_buf_bytes[dev_num] < sizeof(int) * SIZE) { + if (d_key_alt_buf[dev_num]) + safe_cuda(cudaFree(d_key_alt_buf[dev_num])); + safe_cuda(cudaMalloc(&d_key_alt_buf[dev_num], sizeof(int) * SIZE)); + key_alt_buf_bytes[dev_num] = sizeof(int) * SIZE; + } + if (value_alt_buf_bytes[dev_num] < sizeof(int) * SIZE) { + if (d_value_alt_buf[dev_num]) + safe_cuda(cudaFree(d_value_alt_buf[dev_num])); + safe_cuda(cudaMalloc(&d_value_alt_buf[dev_num], sizeof(int) * SIZE)); + value_alt_buf_bytes[dev_num] = sizeof(int) * SIZE; + } + cub::DoubleBuffer d_keys(thrust::raw_pointer_cast(keys.data()), + (int *) d_key_alt_buf[dev_num]); + cub::DoubleBuffer d_values(thrust::raw_pointer_cast(values.data()), + (int *) d_value_alt_buf[dev_num]); + + // Determine temporary device storage requirements for sorting operation + if (!d_temp_storage[dev_num]) { + cub::DeviceRadixSort::SortPairs(d_temp_storage[dev_num], + temp_storage_bytes[dev_num], d_keys, d_values, SIZE, 0, + sizeof(int) * 8, this_stream); + // Allocate temporary storage for sorting operation + safe_cuda( + cudaMalloc(&d_temp_storage[dev_num], + temp_storage_bytes[dev_num])); + } + // Run sorting operation + cub::DeviceRadixSort::SortPairs(d_temp_storage[dev_num], + temp_storage_bytes[dev_num], d_keys, d_values, SIZE, 0, + sizeof(int) * 8, this_stream); + // Sorted keys and values are referenced by d_keys.Current() and d_values.Current() + + keys.data() = thrust::device_pointer_cast(d_keys.Current()); + values.data() = thrust::device_pointer_cast(d_values.Current()); +} + +} diff --git a/cuML/src/kmeans/logger.h b/cuML/src/kmeans/logger.h new file mode 100644 index 0000000000..75b6067aa6 --- /dev/null +++ b/cuML/src/kmeans/logger.h @@ -0,0 +1,51 @@ +/*! + * Copyright 2017-2018 H2O.ai, Inc. + * License Apache License Version 2.0 (see LICENSE for details) + */ +#pragma once + +#include +#include +#include +#include +#include + +#define H2O4GPU_LOG_NOTHING 0 // Fatals are errors terminating the program immediately +#define H2O4GPU_LOG_FATAL 100 // Fatals are errors terminating the program immediately +#define H2O4GPU_LOG_ERROR 200 // Errors are when the program may not exit +#define H2O4GPU_LOG_INFO 300 // Info +#define H2O4GPU_LOG_WARN 400 // Warns about unwanted, but not dangerous, state/behaviour +#define H2O4GPU_LOG_DEBUG 500 // Most basic debug information +#define H2O4GPU_LOG_VERBOSE 600 // Everything possible + +#define log_fatal(desired_level, ...) log(desired_level, H2O4GPU_LOG_FATAL, __FILE__, __LINE__, __VA_ARGS__) +#define log_error(desired_level, ...) log(desired_level, H2O4GPU_LOG_ERROR, __FILE__, __LINE__, __VA_ARGS__) +#define log_info(desired_level, ...) log(desired_level, H2O4GPU_LOG_INFO, __FILE__, __LINE__, __VA_ARGS__) +#define log_warn(desired_level, ...) log(desired_level, H2O4GPU_LOG_WARN, __FILE__, __LINE__, __VA_ARGS__) +#define log_debug(desired_level, ...) log(desired_level, H2O4GPU_LOG_DEBUG, __FILE__, __LINE__, __VA_ARGS__) +#define log_verbose(desired_level, ...) log(desired_level, H2O4GPU_LOG_VERBOSE, __FILE__, __LINE__, __VA_ARGS__) + +static const char *levels[] = { "NOTHING", "FATAL", "ERROR", "INFO", "WARN", + "DEBUG", "VERBOSE" }; + +bool should_log(const int desired_lvl, const int verbosity) { + return verbosity > H2O4GPU_LOG_NOTHING && verbosity <= desired_lvl; +} + +void log(int desired_level, int level, const char *file, int line, + const char *fmt, ...) { + if (should_log(desired_level, level)) { + time_t now = time(NULL); + struct tm *local_time = localtime(&now); + + va_list args; + char buf[16]; + buf[strftime(buf, sizeof(buf), "%H:%M:%S", local_time)] = '\0'; + fprintf(stderr, "%s %-5s %s:%d: ", buf, levels[level / 100], file, + line); + va_start(args, fmt); + vfprintf(stderr, fmt, args); + va_end(args); + fprintf(stderr, "\n"); + } +} diff --git a/cuML/src/kmeans/timer.h b/cuML/src/kmeans/timer.h new file mode 100644 index 0000000000..14daabba64 --- /dev/null +++ b/cuML/src/kmeans/timer.h @@ -0,0 +1,18 @@ +/*! + * Modifications Copyright 2017-2018 H2O.ai, Inc. + */ +#ifndef TIMER_H_ +#define TIMER_H_ + +#include +#include + +template +T timer() { + struct timeval tv; + gettimeofday(&tv, NULL); + return static_cast(tv.tv_sec) + + static_cast(tv.tv_usec) * static_cast(1e-6); +} + +#endif // UTIL_H_ diff --git a/cuML/src/kmeans/utils.h b/cuML/src/kmeans/utils.h new file mode 100644 index 0000000000..a3b7b25b3d --- /dev/null +++ b/cuML/src/kmeans/utils.h @@ -0,0 +1,59 @@ +/*! + * Copyright 2017-2018 H2O.ai, Inc. + * License Apache License Version 2.0 (see LICENSE for details) + */ +#pragma once +#include +// #include "cblas.h" +#include "logger.h" + +template +void self_dot(std::vector array_in, int n, int dim, std::vector& dots) { + for (int pt = 0; pt < n; pt++) { + T sum = 0.0; + for (int i = 0; i < dim; i++) { + sum += array_in[pt * dim + i] * array_in[pt * dim + i]; + } + dots[pt] = sum; + } +} + +// void compute_distances(std::vector data_in, +// std::vector centroids_in, +// std::vector &pairwise_distances, +// int n, int dim, int k) { +// std::vector data_dots(n); +// std::vector centroid_dots(k); +// self_dot(data_in, n, dim, data_dots); +// self_dot(centroids_in, k, dim, centroid_dots); +// for (int nn=0; nn data_in, +// std::vector centroids_in, +// std::vector &pairwise_distances, +// int n, int dim, int k) { +// std::vector data_dots(n); +// std::vector centroid_dots(k); +// self_dot(data_in, n, dim, data_dots); +// self_dot(centroids_in, k, dim, centroid_dots); +// for (int nn=0; nn +#include +#include + +namespace ML { + +using namespace MLCommon; + +template +struct KmeansInputs { + int n_clusters; + T tol; + int n_row; + int n_col; +}; + +template +class KmeansTest: public ::testing::TestWithParam > { +protected: + void basicTest() { + params = ::testing::TestWithParam>::GetParam(); + int m = params.n_row; + int n = params.n_col; + int k = params.n_clusters; + + // make testdata on host + T h_srcdata[n * m] = + {1.0,1.0,3.0,4.0, 1.0,2.0,2.0,3.0}; + + // make space for outputs : pred_centroids, pred_labels + // and reference output : labels_ref + allocate(labels_fit, m); + allocate(labels_ref_fit, m); + allocate(pred_centroids, k * n); + allocate(centroids_ref, k * n); + + // make and assign reference output + int h_labels_ref_fit[m] = {1, 1, 0, 0}; + updateDevice(labels_ref_fit, h_labels_ref_fit, m); + + T h_centroids_ref[k * n] = {3.5,2.5, 1.0,1.5}; + updateDevice(centroids_ref, h_centroids_ref, k * n); + + // The actual kmeans api calls + // fit + make_ptr_kmeans(0, verbose, seed, gpu_id, n_gpu, m, n, + ord, k, k, max_iterations, + init_from_data, params.tol, h_srcdata, nullptr, pred_centroids, labels_fit); + } + + void SetUp() override { + basicTest(); + } + + void TearDown() override { + CUDA_CHECK(cudaFree(labels_fit)); + CUDA_CHECK(cudaFree(pred_centroids)); + CUDA_CHECK(cudaFree(labels_ref_fit)); + CUDA_CHECK(cudaFree(centroids_ref)); + + } + +protected: + KmeansInputs params; + T *d_srcdata; + int *labels_fit, *labels_ref_fit; + T *pred_centroids, *centroids_ref; + int verbose = 0; + int seed = 1; + int gpu_id = 0; + int n_gpu = -1; + char ord = 'c'; // here c means col order, NOT C (vs F) order + int max_iterations = 300; + int init_from_data = 0; +}; + +const std::vector > inputsf2 = { + { 2, 0.05f, 4, 2 }}; + +const std::vector > inputsd2 = { + { 2, 0.05, 4, 2 }}; + +typedef KmeansTest KmeansTestF; +TEST_P(KmeansTestF, Fit) { + ASSERT_TRUE( + devArrMatch(labels_ref_fit, labels_fit, params.n_row, + CompareApproxAbs(params.tol))); + ASSERT_TRUE( + devArrMatch(centroids_ref, pred_centroids, params.n_clusters * params.n_col, + CompareApproxAbs(params.tol))); +} + +typedef KmeansTest KmeansTestD; +TEST_P(KmeansTestD, Fit) { + ASSERT_TRUE( + devArrMatch(labels_ref_fit, labels_fit, params.n_row, + CompareApproxAbs(params.tol))); + ASSERT_TRUE( + devArrMatch(centroids_ref, pred_centroids, params.n_clusters * params.n_col, + CompareApproxAbs(params.tol))); +} + +INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestD, ::testing::ValuesIn(inputsd2)); + +} // end namespace ML diff --git a/python/cuML/cuml.pyx b/python/cuML/cuml.pyx index ba7944af3f..0bf5999450 100644 --- a/python/cuML/cuml.pyx +++ b/python/cuML/cuml.pyx @@ -2,3 +2,4 @@ include "pca/pca_wrapper.pyx" include "tsvd/tsvd_wrapper.pyx" include "dbscan/dbscan_wrapper.pyx" include "knn/knn_wrapper.py" +include "kmeans/kmeans_wrapper.pyx" diff --git a/python/cuML/kmeans/c_kmeans.pxd b/python/cuML/kmeans/c_kmeans.pxd new file mode 100644 index 0000000000..533efb9c71 --- /dev/null +++ b/python/cuML/kmeans/c_kmeans.pxd @@ -0,0 +1,59 @@ +import numpy as np +cimport numpy as np +np.import_array + +cdef extern from "kmeans/kmeans_c.h" namespace "ML": + + cdef void make_ptr_kmeans( + int dopredict, + int verbose, + int seed, + int gpu_id, + int n_gpu, + size_t mTrain, + size_t n, + const char ord, + int k, + int k_max, + int max_iterations, + int init_from_data, + float threshold, + const float *srcdata, + const float *centroids, + float *pred_centroids, + int *pred_labels + ) + + cdef void make_ptr_kmeans( + int dopredict, + int verbose, + int seed, + int gpu_id, + int n_gpu, + size_t mTrain, + size_t n, + const char ord, + int k, + int k_max, + int max_iterations, + int init_from_data, + double threshold, + const double *srcdata, + const double *centroids, + double *pred_centroids, + int *pred_labels + ) + + + cdef void kmeans_transform(int verbose, + int gpu_id, int n_gpu, + size_t m, size_t n, const char ord, int k, + const float *src_data, const float *centroids, + float *preds) + + cdef void kmeans_transform(int verbose, + int gpu_id, int n_gpu, + size_t m, size_t n, const char ord, int k, + const double *src_data, const double *centroids, + double *preds) + diff --git a/python/cuML/kmeans/kmeans_test.py b/python/cuML/kmeans/kmeans_test.py new file mode 100644 index 0000000000..fb373b2ebf --- /dev/null +++ b/python/cuML/kmeans/kmeans_test.py @@ -0,0 +1,55 @@ +from cuML import KMeans +import pygdf +import numpy as np +import pandas as pd +print("\n***********TESTING FOR FLOAT DATATYPE***********") + +#gdf_float = pygdf.DataFrame() +#gdf_float['x']=np.asarray([1.0,1.0,3.0,4.0],dtype=np.float32) +#gdf_float['y']=np.asarray([1.0,2.0,2.0,3.0],dtype=np.float32) + +def np2pygdf(df): + # convert numpy array to pygdf dataframe + df = pd.DataFrame({'fea%d'%i:df[:,i] for i in range(df.shape[1])}) + pdf = pygdf.DataFrame() + for c,column in enumerate(df): + pdf[str(c)] = df[column] + return pdf + +y=np.asarray([[1.0,2.0],[1.0,4.0],[1.0,0.0],[4.0,2.0],[4.0,4.0],[4.0,0.0]],dtype=np.float32) +x=np2pygdf(y) +q=np.asarray([[0, 0], [4, 4]],dtype=np.float32) +p=np2pygdf(q) +a=np.asarray([[1.0, 1.0], [1.0, 2.0], [3.0, 2.0], [4.0, 3.0]],dtype=np.float32) +b=np2pygdf(a) +print("input:") +print(b) + +print("\nCalling fit") +kmeans_float = KMeans(n_clusters=2, n_gpu=1) +kmeans_float.fit(b) +print("labels:") +print(kmeans_float.labels_) +print("cluster_centers:") +print(kmeans_float.cluster_centers_) + +''' +print("\nCalling Predict") +print("labels:") +print(kmeans_float.predict(p)) +print("cluster_centers:") +print(kmeans_float.cluster_centers_) +''' + + +print("\nCalling fit_predict") +kmeans_float2 = KMeans(n_clusters=2, n_gpu=1) +print("labels:") +print(kmeans_float2.fit_predict(b)) +print("cluster_centers:") +print(kmeans_float2.cluster_centers_) + + +print("\nCalling transform") +print("\ntransform result:") +print(kmeans_float2.transform(b)) diff --git a/python/cuML/kmeans/kmeans_wrapper.pyx b/python/cuML/kmeans/kmeans_wrapper.pyx new file mode 100644 index 0000000000..724883cf74 --- /dev/null +++ b/python/cuML/kmeans/kmeans_wrapper.pyx @@ -0,0 +1,230 @@ +cimport c_kmeans +import numpy as np +from numba import cuda +import pygdf +from libcpp cimport bool +import ctypes +from libc.stdint cimport uintptr_t +from c_kmeans cimport * + + +class KMeans: + + def __init__(self, n_clusters=8, max_iter=300, tol=1e-4, verbose=0, random_state=1, precompute_distances='auto', init='kmeans++', n_init=1, algorithm='auto', n_gpu=1, gpu_id=0): + self.n_clusters = n_clusters + self.verbose = verbose + self.random_state = random_state + self.precompute_distances = precompute_distances + self.init = init + self.n_init = n_init + self.copy_x = None + self.n_jobs = None + self.algorithm = algorithm + self.max_iter = max_iter + self.tol = tol + self.labels_ = None + self.cluster_centers_ = None + self.n_gpu = n_gpu + self.gpu_id = gpu_id + + def _get_ctype_ptr(self, obj): + # The manner to access the pointers in the gdf's might change, so + # encapsulating access in the following 3 methods. They might also be + # part of future gdf versions. + return obj.device_ctypes_pointer.value + + def _get_column_ptr(self, obj): + return self._get_ctype_ptr(obj._column._data.to_gpu_array()) + + def _get_gdf_as_matrix_ptr(self, gdf): + c = gdf.as_gpu_matrix(order='C').shape + return self._get_ctype_ptr(gdf.as_gpu_matrix(order='C')) + + def fit(self, input_gdf): + x = [] + for col in input_gdf.columns: + x.append(input_gdf[col]._column.dtype) + break + + self.gdf_datatype = np.dtype(x[0]) + self.n_rows = len(input_gdf) + self.n_cols = len(input_gdf._cols) + + cdef np.ndarray[np.float32_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() + + self.labels_ = pygdf.Series(np.zeros(self.n_rows, dtype=np.int32)) + cdef uintptr_t labels_ptr = self._get_column_ptr(self.labels_) + + self.cluster_centers_ = cuda.to_device(np.zeros(self.n_clusters* self.n_cols, dtype=self.gdf_datatype)) + cdef uintptr_t cluster_centers_ptr = self._get_ctype_ptr(self.cluster_centers_) + + if self.gdf_datatype.type == np.float32: + c_kmeans.make_ptr_kmeans( + 0, # dopredict + self.verbose, # verbose + self.random_state, # seed + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + self.n_clusters, # k_max + self.max_iter, # max_iterations + 1, # init_from_data TODO: can use kmeans++ + self.tol, # threshold + host_ary.data, # srcdata + # ptr2, # srcdata + 0, # centroids + cluster_centers_ptr, # pred_centroids + # 0, # pred_centroids + labels_ptr) # pred_labels + else: + c_kmeans.make_ptr_kmeans( + 0, # dopredict + self.verbose, # verbose + self.random_state, # seed + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + self.n_clusters, # k_max + self.max_iter, # max_iterations + 1, # init_from_data TODO: can use kmeans++ + self.tol, # threshold + host_ary.data, # srcdata + 0, # centroids + cluster_centers_ptr, # pred_centroids + labels_ptr) # pred_labels + + cluster_centers_gdf = pygdf.DataFrame() + for i in range(0, self.n_cols): + cluster_centers_gdf[str(i)] = self.cluster_centers_[i:self.n_clusters*self.n_cols:self.n_cols] + self.cluster_centers_ = cluster_centers_gdf + + return self + + def fit_predict(self, input_gdf): + return self.fit(input_gdf).labels_ + + def predict(self, input_gdf): + x = [] + for col in input_gdf.columns: + x.append(input_gdf[col]._column.dtype) + break + + self.gdf_datatype = np.dtype(x[0]) + self.n_rows = len(input_gdf) + self.n_cols = len(input_gdf._cols) + + #cdef uintptr_t input_ptr = self._get_gdf_as_matrix_ptr(input_gdf) + cdef np.ndarray[np.float32_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() + self.labels_ = pygdf.Series(np.zeros(self.n_rows, dtype=np.int32)) + cdef uintptr_t labels_ptr = self._get_column_ptr(self.labels_) + + #pred_centers = pygdf.Series(np.zeros(self.n_clusters* self.n_cols, dtype=self.gdf_datatype)) + cdef uintptr_t cluster_centers_ptr = self._get_gdf_as_matrix_ptr(self.cluster_centers_) + + if self.gdf_datatype.type == np.float32: + c_kmeans.make_ptr_kmeans( + 1, # dopredict + self.verbose, # verbose + self.random_state, # seed + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + self.n_clusters, # k_max + self.max_iter, # max_iterations + 0, # init_from_data TODO: can use kmeans++ + self.tol, # threshold + # input_ptr, # srcdata + host_ary.data, # srcdata + # ptr2, # srcdata + cluster_centers_ptr, # centroids + 0, # pred_centroids + labels_ptr) # pred_labels + else: + c_kmeans.make_ptr_kmeans( + 1, # dopredict + self.verbose, # verbose + self.random_state, # seed + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + self.n_clusters, # k_max + self.max_iter, # max_iterations + 0, # init_from_data TODO: can use kmeans++ + self.tol, # threshold + host_ary.data, # srcdata + cluster_centers_ptr, # centroids + 0, # pred_centroids + labels_ptr) # pred_labels + + return self.labels_ + + + def transform(self, input_gdf): + x = [] + for col in input_gdf.columns: + x.append(input_gdf[col]._column.dtype) + break + + self.gdf_datatype = np.dtype(x[0]) + self.n_rows = len(input_gdf) + self.n_cols = len(input_gdf._cols) + + cdef np.ndarray[np.float32_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() + preds_data = cuda.to_device(np.zeros(self.n_clusters*self.n_rows, + dtype=self.gdf_datatype.type)) + + cdef uintptr_t preds_ptr = self._get_ctype_ptr(preds_data) + + + ary=np.array([1.0,1.5,3.5,2.5],dtype=np.float32) + dary=cuda.to_device(ary) + cdef uintptr_t ptr2 = dary.device_ctypes_pointer.value + cdef uintptr_t cluster_centers_ptr = self._get_gdf_as_matrix_ptr(self.cluster_centers_) + + if self.gdf_datatype.type == np.float32: + c_kmeans.kmeans_transform( + self.verbose, # verbose + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + host_ary.data, # srcdata + cluster_centers_ptr, # centroids + preds_ptr) # preds + + else: + c_kmeans.kmeans_transform( + self.verbose, # verbose + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + host_ary.data, # srcdata + cluster_centers_ptr, # centroids + preds_ptr) # preds + + preds_gdf = pygdf.DataFrame() + for i in range(0, self.n_clusters): + preds_gdf[str(i)] = preds_data[i*self.n_rows:(i+1)*self.n_rows] + + return preds_gdf + + + def fit_transform(self, input_gdf): + return self.fit(input_gdf).transform(input_gdf) diff --git a/setup.py b/setup.py index 73407ea930..2e5d56efa1 100644 --- a/setup.py +++ b/setup.py @@ -119,7 +119,7 @@ def build_extensions(self): ext = Extension('cuML', - sources=['cuML/src/pca/pca.cu', 'cuML/src/tsvd/tsvd.cu', 'cuML/src/dbscan/dbscan.cu', 'python/cuML/cuml.pyx'], + sources=['cuML/src/pca/pca.cu', 'cuML/src/tsvd/tsvd.cu', 'cuML/src/dbscan/dbscan.cu', 'cuML/src//kmeans/kmeans.cu', 'python/cuML/cuml.pyx'], depends=['cuML/src/tsvd/tsvd.cu'], library_dirs=[CUDA['lib64']], libraries=['cudart','cublas','cusolver'], From 4f463290833118376b1a4e38c664792ea4b4c54b Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Fri, 19 Oct 2018 13:41:16 -0400 Subject: [PATCH 02/10] readme update for kmeans and KNN. --- README.md | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 0ed2cd2e05..fbc4178580 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,18 @@ # cuML (v0.1 Alpha) -Machine learning is a fundamental capability of RAPIDS. cuML is a suite of libraries that implements a machine learning algorithms within the RAPIDS data science ecosystem. cuML enables data scientists, researchers, and software engineers to run traditional ML tasks on GPUs without going into the details of CUDA programming. +Machine learning is a fundamental capability of RAPIDS. cuML is a suite of libraries that implements a machine learning algorithms within the RAPIDS data science ecosystem. cuML enables data scientists, researchers, and software engineers to run traditional ML tasks on GPUs without going into the details of CUDA programming. The cuML repository contains: 1. ***python***: Python based GPU Dataframe (GDF) machine learning package that takes [cuDF](https://github.com/rapidsai/cudf-alpha) dataframes as input. cuML connects the data to C++/CUDA based cuML and ml-prims libraries without ever leaving GPU memory. -2. ***cuML***: C++/CUDA machine learning algorithms. This library currently includes the following five algorithms; +2. ***cuML***: C++/CUDA machine learning algorithms. This library currently includes the following six algorithms; a. Single GPU Truncated Singular Value Decomposition (tSVD), b. Single GPU Principal Component Analysis (PCA), c. Single GPU Density-based Spatial Clustering of Applications with Noise (DBSCAN), d. Single GPU Kalman Filtering, - e. Multi-GPU K-Means Clustering. + e. Multi-GPU K-Means Clustering, + f. Multi-GPU K-Nearest Neighbors (Uses [Faiss](https://github.com/facebookresearch/faiss)). 3. ***ml-prims***: Low level machine learning primitives used in cuML. ml-prims is comprised of the following components; a. Linear Algebra, @@ -22,19 +23,21 @@ The cuML repository contains: #### Available Algorithms for version 0.1alpha: -- Truncated Singular Value Decomposition (tSVD) +- Truncated Singular Value Decomposition (tSVD). -- Principal Component Analysis (PCA) +- Principal Component Analysis (PCA). -- Density-based spatial clustering of applications with noise (DBSCAN) +- Density-based spatial clustering of applications with noise (DBSCAN). -Upcoming algorithms for version 0.1: +- K-Means Clustering. + +- K-Nearest Neighbors (Requires [Faiss](https://github.com/facebookresearch/faiss) installation to use). -- K-Means Clustering +Upcoming algorithms for version 0.1: -- Kalman Filter +- Kalman Filter. -More ML algorithms in cuML and more ML primitives in ml-prims are being added currently. Example notebooks are provided in the python folder to test the functionality and performance of this v0.1 alpha version. Goals for future versions include more algorithms and multi-gpu versions of the algorithms and primitives. +More ML algorithms in cuML and more ML primitives in ml-prims are being added currently. Example notebooks are provided in the python folder to test the functionality and performance of this v0.1 alpha version. Goals for future versions include more algorithms and multi-gpu versions of the algorithms and primitives. The installation option provided currently consists on building from source. Upcoming versions will add `pip` and `conda` options, along docker containers. They will be available in the coming weeks. @@ -48,7 +51,7 @@ To use cuML, it must be cloned and built in an environment that already has the List of dependencies: 1. zlib -2. cmake (>= 3.8, version 3.11.4 is recommended and there are issues with version 3.12) +2. cmake for gtests (>= 3.8, version 3.11.4 is recommended and there are issues with version 3.12) 3. CUDA (>= 9.0) 4. Cython (>= 0.28) 5. gcc (>=5.4.0) From 472d2e088ade807066941b1c0b09a1e57aaa1c03 Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Fri, 19 Oct 2018 13:43:21 -0400 Subject: [PATCH 03/10] readme update. --- README.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index fbc4178580..6f2bcbc6cd 100644 --- a/README.md +++ b/README.md @@ -7,12 +7,13 @@ The cuML repository contains: 1. ***python***: Python based GPU Dataframe (GDF) machine learning package that takes [cuDF](https://github.com/rapidsai/cudf-alpha) dataframes as input. cuML connects the data to C++/CUDA based cuML and ml-prims libraries without ever leaving GPU memory. 2. ***cuML***: C++/CUDA machine learning algorithms. This library currently includes the following six algorithms; - a. Single GPU Truncated Singular Value Decomposition (tSVD), - b. Single GPU Principal Component Analysis (PCA), - c. Single GPU Density-based Spatial Clustering of Applications with Noise (DBSCAN), - d. Single GPU Kalman Filtering, - e. Multi-GPU K-Means Clustering, - f. Multi-GPU K-Nearest Neighbors (Uses [Faiss](https://github.com/facebookresearch/faiss)). + +a) Single GPU Truncated Singular Value Decomposition (tSVD), +b) Single GPU Principal Component Analysis (PCA), +c) Single GPU Density-based Spatial Clustering of Applications with Noise (DBSCAN), +d) Single GPU Kalman Filtering, +e) Multi-GPU K-Means Clustering, +f) Multi-GPU K-Nearest Neighbors (Uses [Faiss](https://github.com/facebookresearch/faiss)). 3. ***ml-prims***: Low level machine learning primitives used in cuML. ml-prims is comprised of the following components; a. Linear Algebra, From e3bced2650ea2969bb34668631a2a703d905bea8 Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Fri, 19 Oct 2018 13:44:38 -0400 Subject: [PATCH 04/10] readme update --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 6f2bcbc6cd..c5f5df3b7a 100644 --- a/README.md +++ b/README.md @@ -9,10 +9,15 @@ The cuML repository contains: 2. ***cuML***: C++/CUDA machine learning algorithms. This library currently includes the following six algorithms; a) Single GPU Truncated Singular Value Decomposition (tSVD), + b) Single GPU Principal Component Analysis (PCA), + c) Single GPU Density-based Spatial Clustering of Applications with Noise (DBSCAN), + d) Single GPU Kalman Filtering, + e) Multi-GPU K-Means Clustering, + f) Multi-GPU K-Nearest Neighbors (Uses [Faiss](https://github.com/facebookresearch/faiss)). 3. ***ml-prims***: Low level machine learning primitives used in cuML. ml-prims is comprised of the following components; From fbea425e91003eb969cdcca0ae9aad1459316b4e Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Fri, 19 Oct 2018 13:46:25 -0400 Subject: [PATCH 05/10] readme update --- README.md | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index c5f5df3b7a..fc8fffc72d 100644 --- a/README.md +++ b/README.md @@ -7,25 +7,19 @@ The cuML repository contains: 1. ***python***: Python based GPU Dataframe (GDF) machine learning package that takes [cuDF](https://github.com/rapidsai/cudf-alpha) dataframes as input. cuML connects the data to C++/CUDA based cuML and ml-prims libraries without ever leaving GPU memory. 2. ***cuML***: C++/CUDA machine learning algorithms. This library currently includes the following six algorithms; - -a) Single GPU Truncated Singular Value Decomposition (tSVD), - -b) Single GPU Principal Component Analysis (PCA), - -c) Single GPU Density-based Spatial Clustering of Applications with Noise (DBSCAN), - -d) Single GPU Kalman Filtering, - -e) Multi-GPU K-Means Clustering, - -f) Multi-GPU K-Nearest Neighbors (Uses [Faiss](https://github.com/facebookresearch/faiss)). + a) Single GPU Truncated Singular Value Decomposition (tSVD), + b) Single GPU Principal Component Analysis (PCA), + c) Single GPU Density-based Spatial Clustering of Applications with Noise (DBSCAN), + d) Single GPU Kalman Filtering, + e) Multi-GPU K-Means Clustering, + f) Multi-GPU K-Nearest Neighbors (Uses [Faiss](https://github.com/facebookresearch/faiss)). 3. ***ml-prims***: Low level machine learning primitives used in cuML. ml-prims is comprised of the following components; - a. Linear Algebra, - b. Statistics, - c. Basic Matrix Operations, - d. Distance Functions, - e. Random Number Generation. + a) Linear Algebra, + b) Statistics, + c) Basic Matrix Operations, + d) Distance Functions, + e) Random Number Generation. #### Available Algorithms for version 0.1alpha: From 93f01a41c6db0039c9080bfee8280a13999558ac Mon Sep 17 00:00:00 2001 From: Devavret Makkar Date: Tue, 23 Oct 2018 15:06:21 -0700 Subject: [PATCH 06/10] Multi GPU now works in python. --- cuML/src/kmeans/kmeans.cu | 11 +++++++++++ python/cuML/kmeans/kmeans_test.py | 2 +- setup.py | 4 ++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/cuML/src/kmeans/kmeans.cu b/cuML/src/kmeans/kmeans.cu index 5eac1c2eeb..877ad7d084 100644 --- a/cuML/src/kmeans/kmeans.cu +++ b/cuML/src/kmeans/kmeans.cu @@ -1322,6 +1322,8 @@ void make_ptr_kmeans(int dopredict, int verbose, int seed, int gpu_id, init_from_data, threshold, h_srcdata, h_centroids, &h_pred_centroids, &h_pred_labels); + cudaSetDevice(gpu_id); + if (dopredict == 0) { cudaMemcpy(pred_centroids, h_pred_centroids, k * n * sizeof(float), cudaMemcpyHostToDevice); @@ -1362,6 +1364,11 @@ void make_ptr_kmeans(int dopredict, int verbose, int seed, int gpu_id, max_iterations, init_from_data, threshold, h_srcdata, h_centroids, &h_pred_centroids, &h_pred_labels); + cudaSetDevice(gpu_id); + // int dev = -1; + // cudaGetDevice(&dev); + // printf("device: %d\n", dev); + if (dopredict == 0) { cudaMemcpy(pred_centroids, h_pred_centroids, k * n * sizeof(double), cudaMemcpyHostToDevice); @@ -1395,6 +1402,8 @@ void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, h2o4gpukmeans::kmeans_transform(verbose, gpu_id, actual_n_gpu, m, n, ord, k, h_srcdata, h_centroids, &h_preds); + cudaSetDevice(gpu_id); + cudaMemcpy(preds, h_preds, m * k * sizeof(float), cudaMemcpyHostToDevice); //free(h_srcdata); @@ -1418,6 +1427,8 @@ void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, h2o4gpukmeans::kmeans_transform(verbose, gpu_id, actual_n_gpu, m, n, ord, k, h_srcdata, h_centroids, &h_preds); + cudaSetDevice(gpu_id); + cudaMemcpy(preds, h_preds, m * k * sizeof(double), cudaMemcpyHostToDevice); //free(h_srcdata); diff --git a/python/cuML/kmeans/kmeans_test.py b/python/cuML/kmeans/kmeans_test.py index fb373b2ebf..5384651877 100644 --- a/python/cuML/kmeans/kmeans_test.py +++ b/python/cuML/kmeans/kmeans_test.py @@ -26,7 +26,7 @@ def np2pygdf(df): print(b) print("\nCalling fit") -kmeans_float = KMeans(n_clusters=2, n_gpu=1) +kmeans_float = KMeans(n_clusters=2, n_gpu=-1) kmeans_float.fit(b) print("labels:") print(kmeans_float.labels_) diff --git a/setup.py b/setup.py index 2e5d56efa1..47ee1e4b07 100644 --- a/setup.py +++ b/setup.py @@ -126,10 +126,10 @@ def build_extensions(self): language='c++', runtime_library_dirs=[CUDA['lib64']], # this syntax is specific to this build system - extra_compile_args={'gcc': ['-std=c++11'], + extra_compile_args={'gcc': ['-std=c++11','-fopenmp'], 'nvcc': ['-arch=sm_60', '--ptxas-options=-v', '-c', '--compiler-options', "'-fPIC'",'-std=c++11','--expt-extended-lambda']}, include_dirs = [numpy_include, CUDA['include'], 'cuML/src', 'cuML/external/ml-prims/src','cuML/external/ml-prims/external/cutlass', 'cuML/external/cutlass','cuML/external/ml-prims/external/cub'], - extra_link_args=["-std=c++11"]) + extra_link_args=["-std=c++11",'-fopenmp']) From 045b1b84eafe79a3aeb1eea6a0f8d6233e4405f7 Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Tue, 30 Oct 2018 12:50:17 -0400 Subject: [PATCH 07/10] Added the sphinx friendly docstrings to kmeans_wrapper.pyx --- python/cuML/kmeans/kmeans_wrapper.pyx | 108 ++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/python/cuML/kmeans/kmeans_wrapper.pyx b/python/cuML/kmeans/kmeans_wrapper.pyx index 724883cf74..0ecea88365 100644 --- a/python/cuML/kmeans/kmeans_wrapper.pyx +++ b/python/cuML/kmeans/kmeans_wrapper.pyx @@ -10,6 +10,67 @@ from c_kmeans cimport * class KMeans: + """ + Create a DataFrame, fill it with data, and compute Kmeans: + + .. code-block:: python + + from cuML import KMeans + import pygdf + import numpy as np + import pandas as pd + + def np2pygdf(df): + # convert numpy array to pygdf dataframe + df = pd.DataFrame({'fea%d'%i:df[:,i] for i in range(df.shape[1])}) + pdf = pygdf.DataFrame() + for c,column in enumerate(df): + pdf[str(c)] = df[column] + return pdf + + + a = np.asarray([[1.0, 1.0], [1.0, 2.0], [3.0, 2.0], [4.0, 3.0]],dtype=np.float32) + b = np2pygdf(a) + print("input:") + print(b) + + print("\nCalling fit") + kmeans_float = KMeans(n_clusters=2, n_gpu=-1) + kmeans_float.fit(b) + print("labels:") + print(kmeans_float.labels_) + print("cluster_centers:") + print(kmeans_float.cluster_centers_) + + + Output: + + .. code-block:: python + + input: + 0 1 + 0 1.0 1.0 + 1 1.0 2.0 + 2 3.0 2.0 + 3 4.0 3.0 + + Calling fit + labels: + 0 1 + 1 1 + 2 0 + 3 0 + + cluster_centers: + 0 1 + 0 3.5 2.5 + 1 1.0 1.5 + + + For an additional example see `the PCA notebook `_. For additional docs, see `scikitlearn's Kmeans `_. + + """ + def __init__(self, n_clusters=8, max_iter=300, tol=1e-4, verbose=0, random_state=1, precompute_distances='auto', init='kmeans++', n_init=1, algorithm='auto', n_gpu=1, gpu_id=0): self.n_clusters = n_clusters self.verbose = verbose @@ -40,6 +101,16 @@ class KMeans: c = gdf.as_gpu_matrix(order='C').shape return self._get_ctype_ptr(gdf.as_gpu_matrix(order='C')) + + """ + Compute k-means clustering with input_gdf. + + Parameters + ---------- + input_gdf : PyGDF DataFrame + Dense matrix (floats or doubles) of shape (n_samples, n_features) + + """ def fit(self, input_gdf): x = [] for col in input_gdf.columns: @@ -106,9 +177,28 @@ class KMeans: return self + """ + Compute cluster centers and predict cluster index for each sample. + + Parameters + ---------- + input_gdf : PyGDF DataFrame + Dense matrix (floats or doubles) of shape (n_samples, n_features) + + """ def fit_predict(self, input_gdf): return self.fit(input_gdf).labels_ + + """ + Predict the closest cluster each sample in input_gdf belongs to. + + Parameters + ---------- + input_gdf : PyGDF DataFrame + Dense matrix (floats or doubles) of shape (n_samples, n_features) + + """ def predict(self, input_gdf): x = [] for col in input_gdf.columns: @@ -171,6 +261,15 @@ class KMeans: return self.labels_ + """ + Transform input_gdf to a cluster-distance space. + + Parameters + ---------- + input_gdf : PyGDF DataFrame + Dense matrix (floats or doubles) of shape (n_samples, n_features) + + """ def transform(self, input_gdf): x = [] for col in input_gdf.columns: @@ -226,5 +325,14 @@ class KMeans: return preds_gdf + """ + Compute clustering and transform input_gdf to cluster-distance space. + + Parameters + ---------- + input_gdf : PyGDF DataFrame + Dense matrix (floats or doubles) of shape (n_samples, n_features) + + """ def fit_transform(self, input_gdf): return self.fit(input_gdf).transform(input_gdf) From fce158484c3c235a4edf1b68493e4a82b5a5b411 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Tue, 30 Oct 2018 14:09:46 -0500 Subject: [PATCH 08/10] FIX fixes for cluster_centers core dump, lack of fused types and small docstring corrections --- cuML/src/kmeans/kmeans.cu | 22 +- python/cuML/kmeans/kmeans_test.py | 18 +- python/cuML/kmeans/kmeans_wrapper.pyx | 466 ++++++++++++++++---------- 3 files changed, 303 insertions(+), 203 deletions(-) diff --git a/cuML/src/kmeans/kmeans.cu b/cuML/src/kmeans/kmeans.cu index 877ad7d084..1d1b26fefa 100644 --- a/cuML/src/kmeans/kmeans.cu +++ b/cuML/src/kmeans/kmeans.cu @@ -898,7 +898,7 @@ int kmeans_fit(int verbose, int seed, int gpu_idtry, int n_gputry, size_t rows, break; } } - // Escape from an infinite loop if we come across + // Escape from an infinite loop if we come across if (tmp_left == tmp_right) { residual = tmp_residual; right = tmp_left; @@ -1388,14 +1388,9 @@ void make_ptr_kmeans(int dopredict, int verbose, int seed, int gpu_id, void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, const char ord, int k, const float *src_data, const float *centroids, float *preds) { - //float *h_srcdata = (float*) malloc(m * n * sizeof(float)); - //cudaMemcpy((void*)h_srcdata, (void*)src_data, m*n * sizeof(float), cudaMemcpyDeviceToHost); const float *h_srcdata = src_data; - - float *h_centroids = (float*) malloc(k * n * sizeof(float)); - cudaMemcpy((void*) h_centroids, (void*) centroids, k * n * sizeof(float), - cudaMemcpyDeviceToHost); + const float *h_centroids = centroids; float *h_preds = nullptr; int actual_n_gpu = h2o4gpukmeans::get_n_gpus(n_gpu); @@ -1405,22 +1400,14 @@ void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, cudaSetDevice(gpu_id); cudaMemcpy(preds, h_preds, m * k * sizeof(float), cudaMemcpyHostToDevice); - - //free(h_srcdata); - free(h_centroids); } void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, const char ord, int k, const double *src_data, const double *centroids, double *preds) { - //double *h_srcdata = (double*) malloc(m * n * sizeof(double)); - //cudaMemcpy((void*)h_srcdata, (void*)src_data, m*n * sizeof(double), cudaMemcpyDeviceToHost); const double *h_srcdata = src_data; - - double *h_centroids = (double*) malloc(k * n * sizeof(double)); - cudaMemcpy((void*) h_centroids, (void*) centroids, k * n * sizeof(double), - cudaMemcpyDeviceToHost); + const double *h_centroids = centroids; double *h_preds = nullptr; int actual_n_gpu = h2o4gpukmeans::get_n_gpus(n_gpu); @@ -1430,9 +1417,6 @@ void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, cudaSetDevice(gpu_id); cudaMemcpy(preds, h_preds, m * k * sizeof(double), cudaMemcpyHostToDevice); - - //free(h_srcdata); - free(h_centroids); } } // end namespace ML diff --git a/python/cuML/kmeans/kmeans_test.py b/python/cuML/kmeans/kmeans_test.py index 5384651877..9eb8edf09e 100644 --- a/python/cuML/kmeans/kmeans_test.py +++ b/python/cuML/kmeans/kmeans_test.py @@ -1,27 +1,27 @@ from cuML import KMeans -import pygdf +import cudf import numpy as np import pandas as pd print("\n***********TESTING FOR FLOAT DATATYPE***********") -#gdf_float = pygdf.DataFrame() +#gdf_float = cudf.DataFrame() #gdf_float['x']=np.asarray([1.0,1.0,3.0,4.0],dtype=np.float32) #gdf_float['y']=np.asarray([1.0,2.0,2.0,3.0],dtype=np.float32) -def np2pygdf(df): - # convert numpy array to pygdf dataframe +def np2cudf(df): + # convert numpy array to cudf dataframe df = pd.DataFrame({'fea%d'%i:df[:,i] for i in range(df.shape[1])}) - pdf = pygdf.DataFrame() + pdf = cudf.DataFrame() for c,column in enumerate(df): pdf[str(c)] = df[column] return pdf y=np.asarray([[1.0,2.0],[1.0,4.0],[1.0,0.0],[4.0,2.0],[4.0,4.0],[4.0,0.0]],dtype=np.float32) -x=np2pygdf(y) +x=np2cudf(y) q=np.asarray([[0, 0], [4, 4]],dtype=np.float32) -p=np2pygdf(q) +p=np2cudf(q) a=np.asarray([[1.0, 1.0], [1.0, 2.0], [3.0, 2.0], [4.0, 3.0]],dtype=np.float32) -b=np2pygdf(a) +b=np2cudf(a) print("input:") print(b) @@ -43,7 +43,7 @@ def np2pygdf(df): print("\nCalling fit_predict") -kmeans_float2 = KMeans(n_clusters=2, n_gpu=1) +kmeans_float2 = KMeans(n_clusters=2, n_gpu=-1) print("labels:") print(kmeans_float2.fit_predict(b)) print("cluster_centers:") diff --git a/python/cuML/kmeans/kmeans_wrapper.pyx b/python/cuML/kmeans/kmeans_wrapper.pyx index 0ecea88365..6aef826471 100644 --- a/python/cuML/kmeans/kmeans_wrapper.pyx +++ b/python/cuML/kmeans/kmeans_wrapper.pyx @@ -1,7 +1,7 @@ cimport c_kmeans import numpy as np from numba import cuda -import pygdf +import cudf from libcpp cimport bool import ctypes from libc.stdint cimport uintptr_t @@ -16,21 +16,21 @@ class KMeans: .. code-block:: python from cuML import KMeans - import pygdf + import cudf import numpy as np import pandas as pd - def np2pygdf(df): - # convert numpy array to pygdf dataframe + def np2cudf(df): + # convert numpy array to cuDF dataframe df = pd.DataFrame({'fea%d'%i:df[:,i] for i in range(df.shape[1])}) - pdf = pygdf.DataFrame() + pdf = cudf.DataFrame() for c,column in enumerate(df): pdf[str(c)] = df[column] return pdf - + a = np.asarray([[1.0, 1.0], [1.0, 2.0], [3.0, 2.0], [4.0, 3.0]],dtype=np.float32) - b = np2pygdf(a) + b = np2cudf(a) print("input:") print(b) @@ -41,7 +41,7 @@ class KMeans: print(kmeans_float.labels_) print("cluster_centers:") print(kmeans_float.cluster_centers_) - + Output: @@ -102,16 +102,75 @@ class KMeans: return self._get_ctype_ptr(gdf.as_gpu_matrix(order='C')) - """ - Compute k-means clustering with input_gdf. + def fit(self, input_gdf): + """ + Compute k-means clustering with input_gdf. - Parameters - ---------- - input_gdf : PyGDF DataFrame - Dense matrix (floats or doubles) of shape (n_samples, n_features) + Parameters + ---------- + input_gdf : cuDF DataFrame + Dense matrix (floats or doubles) of shape (n_samples, n_features) - """ - def fit(self, input_gdf): + """ + + # TODO: Replace this wrapper functions for fused types if kmeans code + # isn't changed to a pure GPU memory solution. + + self.gdf_datatype = np.dtype(input_gdf[input_gdf.columns[0]]._column.dtype) + if self.gdf_datatype.type == np.float32: + return self._fit_f32(input_gdf) + elif self.gdf_datatype.type == np.float64: + return self._fit_f64(input_gdf) + + + def _fit_f32(self, input_gdf): + # x = [] + # for col in input_gdf.columns: + # x.append(input_gdf[col]._column.dtype) + # break + + # self.gdf_datatype = np.dtype(x[0]) + self.n_rows = len(input_gdf) + self.n_cols = len(input_gdf._cols) + + cdef np.ndarray[np.float32_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() + + self.labels_ = cudf.Series(np.zeros(self.n_rows, dtype=np.int32)) + cdef uintptr_t labels_ptr = self._get_column_ptr(self.labels_) + + self.cluster_centers_ = cuda.to_device(np.zeros(self.n_clusters* self.n_cols, dtype=self.gdf_datatype)) + cdef uintptr_t cluster_centers_ptr = self._get_ctype_ptr(self.cluster_centers_) + + c_kmeans.make_ptr_kmeans( + 0, # dopredict + self.verbose, # verbose + self.random_state, # seed + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + self.n_clusters, # k_max + self.max_iter, # max_iterations + 1, # init_from_data TODO: can use kmeans++ + self.tol, # threshold + host_ary.data, # srcdata + # ptr2, # srcdata + 0, # centroids + cluster_centers_ptr, # pred_centroids + # 0, # pred_centroids + labels_ptr) # pred_labels + + + cluster_centers_gdf = cudf.DataFrame() + for i in range(0, self.n_cols): + cluster_centers_gdf[str(i)] = self.cluster_centers_[i:self.n_clusters*self.n_cols:self.n_cols] + self.cluster_centers_ = cluster_centers_gdf + + return self + + def _fit_f64(self, input_gdf): x = [] for col in input_gdf.columns: x.append(input_gdf[col]._column.dtype) @@ -120,167 +179,192 @@ class KMeans: self.gdf_datatype = np.dtype(x[0]) self.n_rows = len(input_gdf) self.n_cols = len(input_gdf._cols) - - cdef np.ndarray[np.float32_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() - - self.labels_ = pygdf.Series(np.zeros(self.n_rows, dtype=np.int32)) + + cdef np.ndarray[np.float64_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() + + self.labels_ = cudf.Series(np.zeros(self.n_rows, dtype=np.int32)) cdef uintptr_t labels_ptr = self._get_column_ptr(self.labels_) - + self.cluster_centers_ = cuda.to_device(np.zeros(self.n_clusters* self.n_cols, dtype=self.gdf_datatype)) cdef uintptr_t cluster_centers_ptr = self._get_ctype_ptr(self.cluster_centers_) - - if self.gdf_datatype.type == np.float32: - c_kmeans.make_ptr_kmeans( - 0, # dopredict - self.verbose, # verbose - self.random_state, # seed - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - 'r', # ord - self.n_clusters, # k - self.n_clusters, # k_max - self.max_iter, # max_iterations - 1, # init_from_data TODO: can use kmeans++ - self.tol, # threshold - host_ary.data, # srcdata - # ptr2, # srcdata - 0, # centroids - cluster_centers_ptr, # pred_centroids - # 0, # pred_centroids - labels_ptr) # pred_labels - else: - c_kmeans.make_ptr_kmeans( - 0, # dopredict - self.verbose, # verbose - self.random_state, # seed - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - 'r', # ord - self.n_clusters, # k - self.n_clusters, # k_max - self.max_iter, # max_iterations - 1, # init_from_data TODO: can use kmeans++ - self.tol, # threshold - host_ary.data, # srcdata - 0, # centroids - cluster_centers_ptr, # pred_centroids - labels_ptr) # pred_labels - - cluster_centers_gdf = pygdf.DataFrame() + + + c_kmeans.make_ptr_kmeans( + 0, # dopredict + self.verbose, # verbose + self.random_state, # seed + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + self.n_clusters, # k_max + self.max_iter, # max_iterations + 1, # init_from_data TODO: can use kmeans++ + self.tol, # threshold + host_ary.data, # srcdata + 0, # centroids + cluster_centers_ptr, # pred_centroids + labels_ptr) # pred_labels + + cluster_centers_gdf = cudf.DataFrame() for i in range(0, self.n_cols): cluster_centers_gdf[str(i)] = self.cluster_centers_[i:self.n_clusters*self.n_cols:self.n_cols] self.cluster_centers_ = cluster_centers_gdf return self - """ - Compute cluster centers and predict cluster index for each sample. - - Parameters - ---------- - input_gdf : PyGDF DataFrame - Dense matrix (floats or doubles) of shape (n_samples, n_features) - """ def fit_predict(self, input_gdf): - return self.fit(input_gdf).labels_ + """ + Compute cluster centers and predict cluster index for each sample. + Parameters + ---------- + input_gdf : cuDF DataFrame + Dense matrix (floats or doubles) of shape (n_samples, n_features) + + """ + return self.fit(input_gdf).labels_ - """ - Predict the closest cluster each sample in input_gdf belongs to. - Parameters - ---------- - input_gdf : PyGDF DataFrame - Dense matrix (floats or doubles) of shape (n_samples, n_features) - """ def predict(self, input_gdf): - x = [] - for col in input_gdf.columns: - x.append(input_gdf[col]._column.dtype) - break + """ + Predict the closest cluster each sample in input_gdf belongs to. - self.gdf_datatype = np.dtype(x[0]) + Parameters + ---------- + input_gdf : cuDF DataFrame + Dense matrix (floats or doubles) of shape (n_samples, n_features) + + """ + self.gdf_datatype = np.dtype(input_gdf[input_gdf.columns[0]]._column.dtype) + if self.gdf_datatype.type == np.float32: + return self._predict_f32(input_gdf) + elif self.gdf_datatype.type == np.float64: + return self._predict_f64(input_gdf) + + + def _predict_f32(self, input_gdf): + # x = [] + # for col in input_gdf.columns: + # x.append(input_gdf[col]._column.dtype) + # break + + # self.gdf_datatype = np.dtype(x[0]) self.n_rows = len(input_gdf) self.n_cols = len(input_gdf._cols) #cdef uintptr_t input_ptr = self._get_gdf_as_matrix_ptr(input_gdf) cdef np.ndarray[np.float32_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() - self.labels_ = pygdf.Series(np.zeros(self.n_rows, dtype=np.int32)) + self.labels_ = cudf.Series(np.zeros(self.n_rows, dtype=np.int32)) cdef uintptr_t labels_ptr = self._get_column_ptr(self.labels_) - #pred_centers = pygdf.Series(np.zeros(self.n_clusters* self.n_cols, dtype=self.gdf_datatype)) + #pred_centers = cudf.Series(np.zeros(self.n_clusters* self.n_cols, dtype=self.gdf_datatype)) cdef uintptr_t cluster_centers_ptr = self._get_gdf_as_matrix_ptr(self.cluster_centers_) - if self.gdf_datatype.type == np.float32: - c_kmeans.make_ptr_kmeans( - 1, # dopredict - self.verbose, # verbose - self.random_state, # seed - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - 'r', # ord - self.n_clusters, # k - self.n_clusters, # k_max - self.max_iter, # max_iterations - 0, # init_from_data TODO: can use kmeans++ - self.tol, # threshold - # input_ptr, # srcdata - host_ary.data, # srcdata - # ptr2, # srcdata - cluster_centers_ptr, # centroids - 0, # pred_centroids - labels_ptr) # pred_labels - else: - c_kmeans.make_ptr_kmeans( - 1, # dopredict - self.verbose, # verbose - self.random_state, # seed - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - 'r', # ord - self.n_clusters, # k - self.n_clusters, # k_max - self.max_iter, # max_iterations - 0, # init_from_data TODO: can use kmeans++ - self.tol, # threshold - host_ary.data, # srcdata - cluster_centers_ptr, # centroids - 0, # pred_centroids - labels_ptr) # pred_labels + c_kmeans.make_ptr_kmeans( + 1, # dopredict + self.verbose, # verbose + self.random_state, # seed + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + self.n_clusters, # k_max + self.max_iter, # max_iterations + 0, # init_from_data TODO: can use kmeans++ + self.tol, # threshold + # input_ptr, # srcdata + host_ary.data, # srcdata + # ptr2, # srcdata + cluster_centers_ptr, # centroids + 0, # pred_centroids + labels_ptr) # pred_labels return self.labels_ - """ - Transform input_gdf to a cluster-distance space. + def _predict_f64(self, input_gdf): + # x = [] + # for col in input_gdf.columns: + # x.append(input_gdf[col]._column.dtype) + # break + + # self.gdf_datatype = np.dtype(x[0]) + self.n_rows = len(input_gdf) + self.n_cols = len(input_gdf._cols) + + #cdef uintptr_t input_ptr = self._get_gdf_as_matrix_ptr(input_gdf) + cdef np.ndarray[np.float64_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() + self.labels_ = cudf.Series(np.zeros(self.n_rows, dtype=np.int32)) + cdef uintptr_t labels_ptr = self._get_column_ptr(self.labels_) + + #pred_centers = cudf.Series(np.zeros(self.n_clusters* self.n_cols, dtype=self.gdf_datatype)) + cdef uintptr_t cluster_centers_ptr = self._get_gdf_as_matrix_ptr(self.cluster_centers_) + + c_kmeans.make_ptr_kmeans( + 1, # dopredict + self.verbose, # verbose + self.random_state, # seed + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + self.n_clusters, # k_max + self.max_iter, # max_iterations + 0, # init_from_data TODO: can use kmeans++ + self.tol, # threshold + host_ary.data, # srcdata + cluster_centers_ptr, # centroids + 0, # pred_centroids + labels_ptr) # pred_labels + + return self.labels_ + - Parameters - ---------- - input_gdf : PyGDF DataFrame - Dense matrix (floats or doubles) of shape (n_samples, n_features) - """ def transform(self, input_gdf): - x = [] - for col in input_gdf.columns: - x.append(input_gdf[col]._column.dtype) - break + """ + Transform input_gdf to a cluster-distance space. - self.gdf_datatype = np.dtype(x[0]) + Parameters + ---------- + input_gdf : cuDF DataFrame + Dense matrix (floats or doubles) of shape (n_samples, n_features) + + """ + + + self.gdf_datatype = np.dtype(input_gdf[input_gdf.columns[0]]._column.dtype) + if self.gdf_datatype.type == np.float32: + return self._transform_f32(input_gdf) + elif self.gdf_datatype.type == np.float64: + return self._transform_f64(input_gdf) + + + def _transform_f32(self, input_gdf): + # x = [] + # for col in input_gdf.columns: + # x.append(input_gdf[col]._column.dtype) + # break + + # self.gdf_datatype = np.dtype(x[0]) self.n_rows = len(input_gdf) self.n_cols = len(input_gdf._cols) + cdef np.ndarray[np.float32_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() + + cdef np.ndarray[np.float32_t, ndim=2, mode = 'c', cast=True] cluster_centers_ptr = self.cluster_centers_.as_gpu_matrix(order='C').copy_to_host() + + preds_data = cuda.to_device(np.zeros(self.n_clusters*self.n_rows, dtype=self.gdf_datatype.type)) @@ -290,49 +374,81 @@ class KMeans: ary=np.array([1.0,1.5,3.5,2.5],dtype=np.float32) dary=cuda.to_device(ary) cdef uintptr_t ptr2 = dary.device_ctypes_pointer.value - cdef uintptr_t cluster_centers_ptr = self._get_gdf_as_matrix_ptr(self.cluster_centers_) - if self.gdf_datatype.type == np.float32: - c_kmeans.kmeans_transform( - self.verbose, # verbose - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - 'r', # ord - self.n_clusters, # k - host_ary.data, # srcdata - cluster_centers_ptr, # centroids - preds_ptr) # preds - - else: - c_kmeans.kmeans_transform( - self.verbose, # verbose - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - 'r', # ord - self.n_clusters, # k - host_ary.data, # srcdata - cluster_centers_ptr, # centroids - preds_ptr) # preds - - preds_gdf = pygdf.DataFrame() + c_kmeans.kmeans_transform( + self.verbose, # verbose + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + host_ary.data, # srcdata + cluster_centers_ptr.data, # centroids + preds_ptr) # preds + + + preds_gdf = cudf.DataFrame() for i in range(0, self.n_clusters): preds_gdf[str(i)] = preds_data[i*self.n_rows:(i+1)*self.n_rows] - + return preds_gdf - """ - Compute clustering and transform input_gdf to cluster-distance space. + def _transform_f64(self, input_gdf): + # x = [] + # for col in input_gdf.columns: + # x.append(input_gdf[col]._column.dtype) + # break + + # self.gdf_datatype = np.dtype(x[0]) + self.n_rows = len(input_gdf) + self.n_cols = len(input_gdf._cols) + + + cdef np.ndarray[np.float64_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() + + cdef np.ndarray[np.float64_t, ndim=2, mode = 'c', cast=True] cluster_centers_ptr = self.cluster_centers_.as_gpu_matrix(order='C').copy_to_host() + + + preds_data = cuda.to_device(np.zeros(self.n_clusters*self.n_rows, + dtype=self.gdf_datatype.type)) + + cdef uintptr_t preds_ptr = self._get_ctype_ptr(preds_data) + + + ary=np.array([1.0,1.5,3.5,2.5],dtype=np.float32) + dary=cuda.to_device(ary) + cdef uintptr_t ptr2 = dary.device_ctypes_pointer.value + + c_kmeans.kmeans_transform( + self.verbose, # verbose + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + host_ary.data, # srcdata + cluster_centers_ptr.data, # centroids + preds_ptr) # preds + + preds_gdf = cudf.DataFrame() + for i in range(0, self.n_clusters): + preds_gdf[str(i)] = preds_data[i*self.n_rows:(i+1)*self.n_rows] + + return preds_gdf + - Parameters - ---------- - input_gdf : PyGDF DataFrame - Dense matrix (floats or doubles) of shape (n_samples, n_features) - """ def fit_transform(self, input_gdf): + """ + Compute clustering and transform input_gdf to cluster-distance space. + + Parameters + ---------- + input_gdf : cuDF DataFrame + Dense matrix (floats or doubles) of shape (n_samples, n_features) + + """ return self.fit(input_gdf).transform(input_gdf) From 44ac0c5fa16ee15fab020e8d0a0c979c2521c7d2 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Wed, 31 Oct 2018 12:06:21 -0500 Subject: [PATCH 09/10] FIX fix for variables not being assigned --- python/cuML/dbscan/dbscan_wrapper.pyx | 9 +++--- python/cuML/pca/pca_wrapper.pyx | 42 +++++++++++++++------------ python/cuML/tsvd/tsvd_wrapper.pyx | 10 +++++-- 3 files changed, 35 insertions(+), 26 deletions(-) diff --git a/python/cuML/dbscan/dbscan_wrapper.pyx b/python/cuML/dbscan/dbscan_wrapper.pyx index ba68cdf033..99f97d9f3c 100644 --- a/python/cuML/dbscan/dbscan_wrapper.pyx +++ b/python/cuML/dbscan/dbscan_wrapper.pyx @@ -69,9 +69,6 @@ class DBSCAN: def _get_column_ptr(self, obj): return self._get_ctype_ptr(obj._column._data.to_gpu_array()) - def _get_gdf_as_matrix_ptr(self, gdf): - return self._get_ctype_ptr(gdf.as_gpu_matrix(order='C')) - def fit(self, X): """ Perform DBSCAN clustering from features or distance matrix. @@ -91,7 +88,9 @@ class DBSCAN: self.n_rows = len(X) self.n_cols = len(X._cols) - cdef uintptr_t input_ptr = self._get_gdf_as_matrix_ptr(X) + X_m = X.as_gpu_matrix() + cdef uintptr_t input_ptr = self._get_ctype_ptr(X_m) + self.labels_ = cudf.Series(np.zeros(self.n_rows, dtype=np.int32)) cdef uintptr_t labels_ptr = self._get_column_ptr(self.labels_) @@ -109,7 +108,7 @@ class DBSCAN: self.eps, self.min_samples, labels_ptr) - + del(X_m) def fit_predict(self, X): """ diff --git a/python/cuML/pca/pca_wrapper.pyx b/python/cuML/pca/pca_wrapper.pyx index 478086bb53..253ea1764c 100644 --- a/python/cuML/pca/pca_wrapper.pyx +++ b/python/cuML/pca/pca_wrapper.pyx @@ -1,17 +1,17 @@ - # Copyright (c) 2018, NVIDIA CORPORATION. - # - # Licensed under the Apache License, Version 2.0 (the "License"); - # you may not use this file except in compliance with the License. - # You may obtain a copy of the License at - # - # http://www.apache.org/licenses/LICENSE-2.0 - # - # Unless required by applicable law or agreed to in writing, software - # distributed under the License is distributed on an "AS IS" BASIS, - # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - # See the License for the specific language governing permissions and - # limitations under the License. - # +# Copyright (c) 2018, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# cimport c_pca import numpy as np @@ -185,9 +185,6 @@ class PCA: def _get_column_ptr(self, obj): return self._get_ctype_ptr(obj._column._data.to_gpu_array()) - def _get_gdf_as_matrix_ptr(self, gdf): - return self._get_ctype_ptr(gdf.as_gpu_matrix()) - def fit(self, X, _transform=True): """ Fit the model with X. @@ -219,7 +216,8 @@ class PCA: self._initialize_arrays(X, self.params.n_components, self.params.n_rows, self.params.n_cols) - cdef uintptr_t input_ptr = self._get_gdf_as_matrix_ptr(X) + X_m = X.as_gpu_matrix() + cdef uintptr_t input_ptr = self._get_ctype_ptr(X_m) cdef uintptr_t components_ptr = self._get_ctype_ptr(self.components_) @@ -254,6 +252,7 @@ class PCA: noise_vars_ptr, params) else: + if self.gdf_datatype.type == np.float32: c_pca.pcaFitTransform( input_ptr, trans_input_ptr, @@ -287,6 +286,8 @@ class PCA: self.mean_ptr = mean_ptr self.noise_variance_ptr = noise_vars_ptr + del(X_m) + def fit_transform(self, X): """ Fit the model with X and apply the dimensionality reduction on X. @@ -401,8 +402,10 @@ class PCA: np.zeros(params.n_rows*params.n_components, dtype=gdf_datatype.type)) + X_m = X.as_gpu_matrix() + cdef uintptr_t input_ptr = self._get_ctype_ptr(X_m) + cdef uintptr_t trans_input_ptr = self._get_ctype_ptr(trans_input_data) - cdef uintptr_t input_ptr = self._get_gdf_as_matrix_ptr(X) cdef uintptr_t components_ptr = self.components_ptr cdef uintptr_t singular_vals_ptr = self.singular_values_ptr cdef uintptr_t mean_ptr = self.mean_ptr @@ -426,5 +429,6 @@ class PCA: for i in range(0, params.n_components): X_new[str(i)] = trans_input_data[i*params.n_rows:(i+1)*params.n_rows] + del(X_m) return X_new diff --git a/python/cuML/tsvd/tsvd_wrapper.pyx b/python/cuML/tsvd/tsvd_wrapper.pyx index 163e08e4ab..d9192488ff 100644 --- a/python/cuML/tsvd/tsvd_wrapper.pyx +++ b/python/cuML/tsvd/tsvd_wrapper.pyx @@ -180,7 +180,8 @@ class TruncatedSVD: self._initialize_arrays(X, self.params.n_components, self.params.n_rows, self.params.n_cols) - cdef uintptr_t input_ptr = self._get_gdf_as_matrix_ptr(X) + X_m = X.as_gpu_matrix() + cdef uintptr_t input_ptr = self._get_ctype_ptr(X_m) cdef uintptr_t components_ptr = self._get_ctype_ptr(self.components_) @@ -232,6 +233,8 @@ class TruncatedSVD: self.explained_variance_ratio_ptr = explained_var_ratio_ptr self.singular_values_ptr = singular_vals_ptr + del(X_m) + def fit_transform(self, X): """ @@ -343,8 +346,10 @@ class TruncatedSVD: np.zeros(params.n_rows*params.n_components, dtype=gdf_datatype.type)) + X_m = X.as_gpu_matrix() + cdef uintptr_t input_ptr = self._get_ctype_ptr(X_m) + cdef uintptr_t trans_input_ptr = self._get_ctype_ptr(trans_input_data) - cdef uintptr_t input_ptr = self._get_gdf_as_matrix_ptr(X) cdef uintptr_t components_ptr = self.components_ptr if gdf_datatype.type == np.float32: @@ -362,5 +367,6 @@ class TruncatedSVD: for i in range(0, params.n_components): X_new[str(i)] = trans_input_data[i*params.n_rows:(i+1)*params.n_rows] + del(X_m) return X_new From f9a93e47c91931e4ef8f2325fc6b70a2b9c5e3e0 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Wed, 31 Oct 2018 13:04:22 -0500 Subject: [PATCH 10/10] FIX fix kmeans variable assign bug and change back all arrys to be in GPU in python --- cuML/src/kmeans/kmeans.cu | 43 ++- python/cuML/kmeans/kmeans_wrapper.pyx | 392 ++++++++++---------------- 2 files changed, 172 insertions(+), 263 deletions(-) diff --git a/cuML/src/kmeans/kmeans.cu b/cuML/src/kmeans/kmeans.cu index 1d1b26fefa..dd05259272 100644 --- a/cuML/src/kmeans/kmeans.cu +++ b/cuML/src/kmeans/kmeans.cu @@ -1302,10 +1302,8 @@ void make_ptr_kmeans(int dopredict, int verbose, int seed, int gpu_id, int max_iterations, int init_from_data, float threshold, const float *srcdata, const float *centroids, float *pred_centroids, int *pred_labels) { - //float *h_srcdata = (float*) malloc(mTrain * n * sizeof(float)); - //cudaMemcpy((void*)h_srcdata, (void*)srcdata, mTrain*n * sizeof(float), cudaMemcpyDeviceToHost); - - const float *h_srcdata = srcdata; + float *h_srcdata = (float*) malloc(mTrain * n * sizeof(float)); + cudaMemcpy((void*)h_srcdata, (void*)srcdata, mTrain*n * sizeof(float), cudaMemcpyDeviceToHost); float *h_centroids = nullptr; if (dopredict) { @@ -1332,7 +1330,7 @@ void make_ptr_kmeans(int dopredict, int verbose, int seed, int gpu_id, cudaMemcpy(pred_labels, h_pred_labels, mTrain * sizeof(int), cudaMemcpyHostToDevice); - //free(h_srcdata); + free(h_srcdata); if (dopredict) { free(h_centroids); } @@ -1344,10 +1342,8 @@ void make_ptr_kmeans(int dopredict, int verbose, int seed, int gpu_id, const double *srcdata, const double *centroids, double *pred_centroids, int *pred_labels) { - //double *h_srcdata = (double*) malloc(mTrain * n * sizeof(double)); - //cudaMemcpy((void*)h_srcdata, (void*)srcdata, mTrain*n * sizeof(double), cudaMemcpyDeviceToHost); - - const double *h_srcdata = srcdata; + double *h_srcdata = (double*) malloc(mTrain * n * sizeof(double)); + cudaMemcpy((void*)h_srcdata, (void*)srcdata, mTrain*n * sizeof(double), cudaMemcpyDeviceToHost); double *h_centroids = nullptr; if (dopredict) { @@ -1377,7 +1373,7 @@ void make_ptr_kmeans(int dopredict, int verbose, int seed, int gpu_id, cudaMemcpy(pred_labels, h_pred_labels, mTrain * sizeof(int), cudaMemcpyHostToDevice); - //free(h_srcdata); + free(h_srcdata); if (dopredict) { free(h_centroids); } @@ -1388,9 +1384,15 @@ void make_ptr_kmeans(int dopredict, int verbose, int seed, int gpu_id, void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, const char ord, int k, const float *src_data, const float *centroids, float *preds) { + float *h_srcdata = (float*) malloc(m * n * sizeof(float)); + cudaMemcpy((void*)h_srcdata, (void*)src_data, m*n * sizeof(float), cudaMemcpyDeviceToHost); - const float *h_srcdata = src_data; - const float *h_centroids = centroids; + float *h_centroids = (float*) malloc(k * n * sizeof(float)); + cudaMemcpy((void*) h_centroids, (void*) centroids, k * n * sizeof(float), + cudaMemcpyDeviceToHost); + + // const float *h_srcdata = src_data; + // const float *h_centroids = centroids; float *h_preds = nullptr; int actual_n_gpu = h2o4gpukmeans::get_n_gpus(n_gpu); @@ -1400,14 +1402,24 @@ void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, cudaSetDevice(gpu_id); cudaMemcpy(preds, h_preds, m * k * sizeof(float), cudaMemcpyHostToDevice); + + free(h_srcdata); + free(h_centroids); } void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, const char ord, int k, const double *src_data, const double *centroids, double *preds) { - const double *h_srcdata = src_data; - const double *h_centroids = centroids; + double *h_srcdata = (double*) malloc(m * n * sizeof(double)); + cudaMemcpy((void*)h_srcdata, (void*)src_data, m*n * sizeof(double), cudaMemcpyDeviceToHost); + + double *h_centroids = (double*) malloc(k * n * sizeof(double)); + cudaMemcpy((void*) h_centroids, (void*) centroids, k * n * sizeof(double), + cudaMemcpyDeviceToHost); + + // const double *h_srcdata = src_data; + // const double *h_centroids = centroids; double *h_preds = nullptr; int actual_n_gpu = h2o4gpukmeans::get_n_gpus(n_gpu); @@ -1417,6 +1429,9 @@ void kmeans_transform(int verbose, int gpu_id, int n_gpu, size_t m, size_t n, cudaSetDevice(gpu_id); cudaMemcpy(preds, h_preds, m * k * sizeof(double), cudaMemcpyHostToDevice); + + free(h_srcdata); + free(h_centroids); } } // end namespace ML diff --git a/python/cuML/kmeans/kmeans_wrapper.pyx b/python/cuML/kmeans/kmeans_wrapper.pyx index 6aef826471..600aaa33f1 100644 --- a/python/cuML/kmeans/kmeans_wrapper.pyx +++ b/python/cuML/kmeans/kmeans_wrapper.pyx @@ -102,38 +102,25 @@ class KMeans: return self._get_ctype_ptr(gdf.as_gpu_matrix(order='C')) - def fit(self, input_gdf): + def fit(self, X): """ - Compute k-means clustering with input_gdf. + Compute k-means clustering with X. Parameters ---------- - input_gdf : cuDF DataFrame + X : cuDF DataFrame Dense matrix (floats or doubles) of shape (n_samples, n_features) """ - # TODO: Replace this wrapper functions for fused types if kmeans code - # isn't changed to a pure GPU memory solution. - - self.gdf_datatype = np.dtype(input_gdf[input_gdf.columns[0]]._column.dtype) - if self.gdf_datatype.type == np.float32: - return self._fit_f32(input_gdf) - elif self.gdf_datatype.type == np.float64: - return self._fit_f64(input_gdf) - - - def _fit_f32(self, input_gdf): - # x = [] - # for col in input_gdf.columns: - # x.append(input_gdf[col]._column.dtype) - # break + self.gdf_datatype = np.dtype(X[X.columns[0]]._column.dtype) + self.n_rows = len(X) + self.n_cols = len(X._cols) - # self.gdf_datatype = np.dtype(x[0]) - self.n_rows = len(input_gdf) - self.n_cols = len(input_gdf._cols) + # cdef np.ndarray[np.float32_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() - cdef np.ndarray[np.float32_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() + X_m = X.as_gpu_matrix() + cdef uintptr_t input_ptr = self._get_ctype_ptr(X_m) self.labels_ = cudf.Series(np.zeros(self.n_rows, dtype=np.int32)) cdef uintptr_t labels_ptr = self._get_column_ptr(self.labels_) @@ -141,26 +128,46 @@ class KMeans: self.cluster_centers_ = cuda.to_device(np.zeros(self.n_clusters* self.n_cols, dtype=self.gdf_datatype)) cdef uintptr_t cluster_centers_ptr = self._get_ctype_ptr(self.cluster_centers_) - c_kmeans.make_ptr_kmeans( - 0, # dopredict - self.verbose, # verbose - self.random_state, # seed - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - 'r', # ord - self.n_clusters, # k - self.n_clusters, # k_max - self.max_iter, # max_iterations - 1, # init_from_data TODO: can use kmeans++ - self.tol, # threshold - host_ary.data, # srcdata - # ptr2, # srcdata - 0, # centroids - cluster_centers_ptr, # pred_centroids - # 0, # pred_centroids - labels_ptr) # pred_labels + if self.gdf_datatype.type == np.float32: + c_kmeans.make_ptr_kmeans( + 0, # dopredict + self.verbose, # verbose + self.random_state, # seed + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + self.n_clusters, # k_max + self.max_iter, # max_iterations + 1, # init_from_data TODO: can use kmeans++ + self.tol, # threshold + input_ptr, # srcdata + # ptr2, # srcdata + 0, # centroids + cluster_centers_ptr, # pred_centroids + # 0, # pred_centroids + labels_ptr) # pred_labels + else: + c_kmeans.make_ptr_kmeans( + 0, # dopredict + self.verbose, # verbose + self.random_state, # seed + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + self.n_clusters, # k_max + self.max_iter, # max_iterations + 1, # init_from_data TODO: can use kmeans++ + self.tol, # threshold + input_ptr, # srcdata + 0, # centroids + cluster_centers_ptr, # pred_centroids + labels_ptr) # pred_labels cluster_centers_gdf = cudf.DataFrame() @@ -168,202 +175,120 @@ class KMeans: cluster_centers_gdf[str(i)] = self.cluster_centers_[i:self.n_clusters*self.n_cols:self.n_cols] self.cluster_centers_ = cluster_centers_gdf - return self - - def _fit_f64(self, input_gdf): - x = [] - for col in input_gdf.columns: - x.append(input_gdf[col]._column.dtype) - break - - self.gdf_datatype = np.dtype(x[0]) - self.n_rows = len(input_gdf) - self.n_cols = len(input_gdf._cols) - - cdef np.ndarray[np.float64_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() - - self.labels_ = cudf.Series(np.zeros(self.n_rows, dtype=np.int32)) - cdef uintptr_t labels_ptr = self._get_column_ptr(self.labels_) - - self.cluster_centers_ = cuda.to_device(np.zeros(self.n_clusters* self.n_cols, dtype=self.gdf_datatype)) - cdef uintptr_t cluster_centers_ptr = self._get_ctype_ptr(self.cluster_centers_) - - - c_kmeans.make_ptr_kmeans( - 0, # dopredict - self.verbose, # verbose - self.random_state, # seed - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - 'r', # ord - self.n_clusters, # k - self.n_clusters, # k_max - self.max_iter, # max_iterations - 1, # init_from_data TODO: can use kmeans++ - self.tol, # threshold - host_ary.data, # srcdata - 0, # centroids - cluster_centers_ptr, # pred_centroids - labels_ptr) # pred_labels - - cluster_centers_gdf = cudf.DataFrame() - for i in range(0, self.n_cols): - cluster_centers_gdf[str(i)] = self.cluster_centers_[i:self.n_clusters*self.n_cols:self.n_cols] - self.cluster_centers_ = cluster_centers_gdf + del(X_m) return self - def fit_predict(self, input_gdf): + def fit_predict(self, X): """ Compute cluster centers and predict cluster index for each sample. Parameters ---------- - input_gdf : cuDF DataFrame + X : cuDF DataFrame Dense matrix (floats or doubles) of shape (n_samples, n_features) """ - return self.fit(input_gdf).labels_ + return self.fit(X).labels_ - def predict(self, input_gdf): + def predict(self, X): """ - Predict the closest cluster each sample in input_gdf belongs to. + Predict the closest cluster each sample in X belongs to. Parameters ---------- - input_gdf : cuDF DataFrame + X : cuDF DataFrame Dense matrix (floats or doubles) of shape (n_samples, n_features) """ - self.gdf_datatype = np.dtype(input_gdf[input_gdf.columns[0]]._column.dtype) - if self.gdf_datatype.type == np.float32: - return self._predict_f32(input_gdf) - elif self.gdf_datatype.type == np.float64: - return self._predict_f64(input_gdf) - - - def _predict_f32(self, input_gdf): - # x = [] - # for col in input_gdf.columns: - # x.append(input_gdf[col]._column.dtype) - # break - - # self.gdf_datatype = np.dtype(x[0]) - self.n_rows = len(input_gdf) - self.n_cols = len(input_gdf._cols) - - #cdef uintptr_t input_ptr = self._get_gdf_as_matrix_ptr(input_gdf) - cdef np.ndarray[np.float32_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() - self.labels_ = cudf.Series(np.zeros(self.n_rows, dtype=np.int32)) - cdef uintptr_t labels_ptr = self._get_column_ptr(self.labels_) - - #pred_centers = cudf.Series(np.zeros(self.n_clusters* self.n_cols, dtype=self.gdf_datatype)) - cdef uintptr_t cluster_centers_ptr = self._get_gdf_as_matrix_ptr(self.cluster_centers_) - - c_kmeans.make_ptr_kmeans( - 1, # dopredict - self.verbose, # verbose - self.random_state, # seed - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - 'r', # ord - self.n_clusters, # k - self.n_clusters, # k_max - self.max_iter, # max_iterations - 0, # init_from_data TODO: can use kmeans++ - self.tol, # threshold - # input_ptr, # srcdata - host_ary.data, # srcdata - # ptr2, # srcdata - cluster_centers_ptr, # centroids - 0, # pred_centroids - labels_ptr) # pred_labels - - return self.labels_ - + self.gdf_datatype = np.dtype(X[X.columns[0]]._column.dtype) + self.n_rows = len(X) + self.n_cols = len(X._cols) - def _predict_f64(self, input_gdf): - # x = [] - # for col in input_gdf.columns: - # x.append(input_gdf[col]._column.dtype) - # break + X_m = X.as_gpu_matrix() + cdef uintptr_t input_ptr = self._get_ctype_ptr(X_m) - # self.gdf_datatype = np.dtype(x[0]) - self.n_rows = len(input_gdf) - self.n_cols = len(input_gdf._cols) + clust_mat = self.cluster_centers_.as_gpu_matrix(order='C') + cdef uintptr_t cluster_centers_ptr = self._get_ctype_ptr(clust_mat) - #cdef uintptr_t input_ptr = self._get_gdf_as_matrix_ptr(input_gdf) - cdef np.ndarray[np.float64_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() self.labels_ = cudf.Series(np.zeros(self.n_rows, dtype=np.int32)) cdef uintptr_t labels_ptr = self._get_column_ptr(self.labels_) - #pred_centers = cudf.Series(np.zeros(self.n_clusters* self.n_cols, dtype=self.gdf_datatype)) - cdef uintptr_t cluster_centers_ptr = self._get_gdf_as_matrix_ptr(self.cluster_centers_) - - c_kmeans.make_ptr_kmeans( - 1, # dopredict - self.verbose, # verbose - self.random_state, # seed - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - 'r', # ord - self.n_clusters, # k - self.n_clusters, # k_max - self.max_iter, # max_iterations - 0, # init_from_data TODO: can use kmeans++ - self.tol, # threshold - host_ary.data, # srcdata - cluster_centers_ptr, # centroids - 0, # pred_centroids - labels_ptr) # pred_labels - + if self.gdf_datatype.type == np.float32: + c_kmeans.make_ptr_kmeans( + 1, # dopredict + self.verbose, # verbose + self.random_state, # seed + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + self.n_clusters, # k_max + self.max_iter, # max_iterations + 0, # init_from_data TODO: can use kmeans++ + self.tol, # threshold + # input_ptr, # srcdata + input_ptr, # srcdata + # ptr2, # srcdata + cluster_centers_ptr, # centroids + 0, # pred_centroids + labels_ptr) # pred_labels + else: + c_kmeans.make_ptr_kmeans( + 1, # dopredict + self.verbose, # verbose + self.random_state, # seed + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + self.n_clusters, # k_max + self.max_iter, # max_iterations + 0, # init_from_data TODO: can use kmeans++ + self.tol, # threshold + input_ptr, # srcdata + cluster_centers_ptr, # centroids + 0, # pred_centroids + labels_ptr) # pred_labels + + del(X_m) + del(clust_mat) return self.labels_ - def transform(self, input_gdf): + def transform(self, X): """ - Transform input_gdf to a cluster-distance space. + Transform X to a cluster-distance space. Parameters ---------- - input_gdf : cuDF DataFrame + X : cuDF DataFrame Dense matrix (floats or doubles) of shape (n_samples, n_features) """ - - self.gdf_datatype = np.dtype(input_gdf[input_gdf.columns[0]]._column.dtype) - if self.gdf_datatype.type == np.float32: - return self._transform_f32(input_gdf) - elif self.gdf_datatype.type == np.float64: - return self._transform_f64(input_gdf) + self.n_rows = len(X) + self.n_cols = len(X._cols) - def _transform_f32(self, input_gdf): - # x = [] - # for col in input_gdf.columns: - # x.append(input_gdf[col]._column.dtype) - # break + # cdef np.ndarray[np.float32_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() - # self.gdf_datatype = np.dtype(x[0]) - self.n_rows = len(input_gdf) - self.n_cols = len(input_gdf._cols) + # cdef np.ndarray[np.float32_t, ndim=2, mode = 'c', cast=True] cluster_centers_ptr = self.cluster_centers_.as_gpu_matrix(order='C').copy_to_host() - cdef np.ndarray[np.float32_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() - - cdef np.ndarray[np.float32_t, ndim=2, mode = 'c', cast=True] cluster_centers_ptr = self.cluster_centers_.as_gpu_matrix(order='C').copy_to_host() + X_m = X.as_gpu_matrix() + cdef uintptr_t input_ptr = self._get_ctype_ptr(X_m) + clust_mat = self.cluster_centers_.as_gpu_matrix(order='C') + cdef uintptr_t cluster_centers_ptr = self._get_ctype_ptr(clust_mat) preds_data = cuda.to_device(np.zeros(self.n_clusters*self.n_rows, dtype=self.gdf_datatype.type)) @@ -375,72 +300,41 @@ class KMeans: dary=cuda.to_device(ary) cdef uintptr_t ptr2 = dary.device_ctypes_pointer.value - c_kmeans.kmeans_transform( - self.verbose, # verbose - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - 'r', # ord - self.n_clusters, # k - host_ary.data, # srcdata - cluster_centers_ptr.data, # centroids - preds_ptr) # preds - - - preds_gdf = cudf.DataFrame() - for i in range(0, self.n_clusters): - preds_gdf[str(i)] = preds_data[i*self.n_rows:(i+1)*self.n_rows] - - return preds_gdf - - - def _transform_f64(self, input_gdf): - # x = [] - # for col in input_gdf.columns: - # x.append(input_gdf[col]._column.dtype) - # break - - # self.gdf_datatype = np.dtype(x[0]) - self.n_rows = len(input_gdf) - self.n_cols = len(input_gdf._cols) - - - cdef np.ndarray[np.float64_t, ndim=2, mode = 'c', cast=True] host_ary = input_gdf.as_gpu_matrix(order='C').copy_to_host() - - cdef np.ndarray[np.float64_t, ndim=2, mode = 'c', cast=True] cluster_centers_ptr = self.cluster_centers_.as_gpu_matrix(order='C').copy_to_host() - - - preds_data = cuda.to_device(np.zeros(self.n_clusters*self.n_rows, - dtype=self.gdf_datatype.type)) - - cdef uintptr_t preds_ptr = self._get_ctype_ptr(preds_data) - - - ary=np.array([1.0,1.5,3.5,2.5],dtype=np.float32) - dary=cuda.to_device(ary) - cdef uintptr_t ptr2 = dary.device_ctypes_pointer.value + if self.gdf_datatype.type == np.float32: + c_kmeans.kmeans_transform( + self.verbose, # verbose + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + input_ptr, # srcdata + cluster_centers_ptr, # centroids + preds_ptr) # preds + else: + c_kmeans.kmeans_transform( + self.verbose, # verbose + self.gpu_id, # gpu_id + self.n_gpu, # n_gpu + self.n_rows, # mTrain (rows) + self.n_cols, # n (cols) + 'r', # ord + self.n_clusters, # k + input_ptr, # srcdata + cluster_centers_ptr, # centroids + preds_ptr) # preds - c_kmeans.kmeans_transform( - self.verbose, # verbose - self.gpu_id, # gpu_id - self.n_gpu, # n_gpu - self.n_rows, # mTrain (rows) - self.n_cols, # n (cols) - 'r', # ord - self.n_clusters, # k - host_ary.data, # srcdata - cluster_centers_ptr.data, # centroids - preds_ptr) # preds preds_gdf = cudf.DataFrame() for i in range(0, self.n_clusters): preds_gdf[str(i)] = preds_data[i*self.n_rows:(i+1)*self.n_rows] + del(X_m) + del(clust_mat) return preds_gdf - def fit_transform(self, input_gdf): """ Compute clustering and transform input_gdf to cluster-distance space.