Skip to content

Commit

Permalink
[REVIEW] Sparse KNN + UMAP Sparse Inputs (#2836)
Browse files Browse the repository at this point in the history
* Adding brute force knn shell to sparse

* Stubbing out algorithm flow

* Adding initial headers to wrapper

* Performing idx batching

* Starting to full in cusparse calls

* Checking in

* Beginning to add selection kernel

* Finished header

* Updates. Need to finish populating merge buffer

* Using block select for selecting k and using 3-partition merge buffer

* Logic is just about done.

* Checking in changes. Need to swap out cuda 11 cusparse calls for cuda 10.2 version

* Everything is building. Need end-to-end test

* Running clang format

* Updating changelog

* Using raft's cusparse_wrappers.h instead of cuml

* Removing cuda11-required GEMM calls (commenting them out for now, will swap them out shortly)

* Fixing clang style

* Separating distance computation from selection from general brute force algorithm to make pieces more reusable

* Updating clang style

* Adding batcher to help ease batch state management

* Fixing clang style

* MOre clang fixes

* IP distance is computed using search * index.T.

* Making type template for value_t all the way through knn_merge_parts

* Adding simple googletest for sparse pairwise dists. The transpose conversion seems super expensive, but maybe it's necessary.

* Completing test for basic inner product distances

* Removing prints from test

* Cleaning up batching for knn. Ready to gtest

* KNN w/ max inner product is working

* Adding guts of expanded l2 computation.

* Cleaning up some debug prints

* Fixing clang format

* More cleanup and clang style fix

* Fixing style for sparse knn prim test

* Hoping i've captured all the clang updates

* Updating per include_checker

* I feel like I"m bouncing back and forth between clang and include checker

* Refactoring sparse pairwise dists to return dense outputs

* Beginning python layer

* iAdding python layer for sparse inputs to nearest neighbors

* End to end sparse knn works. Need to finish norms for expanded euclidean and expose it.

* Removing unused file

* Adding gtest for expanded l2.

* Sparse l2 matches sklearn

* Fixing clang format style

* Fixing dstyle in gtests

* Lots of changes and cleanup. Still need to flip the batching

* Progress on tiling. Still a failure when tile sizes don't match up.

* Tiling w/ uneven batch sizes works! Now just need to figure out what to do when the leftover values are <k

* Some further optinmizations are necessary, but this works for now.

* Ready for cleanup

* Parametrizing sparse knn tests

* More cleanup.

* Fixing clang format

* Fixing clang format style

* Fixing flake8 for sparse nn tests

* Fixing googletests

* More cleanup of sparse knn

* Adding sparse support to UMAP by abstracting the inputs

* Everything's building. Have one template issue to fix in the sparse knn

* Updates to API

* Usig a struct to manage the knn graph output state

* C++ side is largely done. Still need to figure out what to do w/ the separate int64_t type in the sparse knn

* Removing examples/comms, which seems to have gotten re-checked in by mistake

* Fixing c++ style

* Fixing include checks

* This darn style checker is going to kill me.....

* Adding template type params for output

* UMAP is officially accepting sparse inputs

* More cleanup

* Cleaning up gtests and making them easier to write

* Fixing up and parametrizing tests

* Fixing style

* Fixing python style

* More clang format style fixes

* Pulled umap inputs classes to more shared location so tsne can use them.

Added kselection gtest

* Updating clang format

* Fixing bad ide refactor

* Updating changelog

* Fixing more clang format

* Fixing flake8 style. Not sure why these didn't show up locally

* Decomposing sparse knn into a class.

* Review feedback

* Better umap sparse test

* More testing updates

* Adding docs to some of the remaining prims in csr.cuh

* Adding gtests for transpose and row slice. Need to add one for todense

* GTest for csr to dense

* Fixing style

* Removing debug logging from new gtests

* Fixing flake8  style

* Getting build to pass

* Running clang-tidy

* Fixing format for sparse gtests

* Adding 'algo_params' to get_param_names()

* Removing cumlarray output in kneighbors

* Finishing review feedback

* Fixing style

* Fixing format

* clang-format

* Style changes

* More review updates

* Style updates

* Running clang format on distance.cuh

* Runing clang format on tests

* Fixing cython style

* Updating RAFT commit

* Updating neighbors from bad merge
  • Loading branch information
cjnolet authored Nov 19, 2020
1 parent d30edd9 commit b205e8f
Show file tree
Hide file tree
Showing 46 changed files with 12,828 additions and 814 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# cuML 0.17.0 (Date TBD)

## New Features
- PR #2659: Add initial max inner product sparse knn
- PR #2836: Refactor UMAP to accept sparse inputs

## Improvements
- PR #3077: Improve runtime for test_kmeans
Expand Down
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ if(BUILD_CUML_CPP_LIBRARY)
src/holtwinters/holtwinters.cu
src/kmeans/kmeans.cu
src/knn/knn.cu
src/knn/knn_sparse.cu
src/metrics/accuracy_score.cu
src/metrics/adjusted_rand_index.cu
src/metrics/completeness_score.cu
Expand Down
20 changes: 10 additions & 10 deletions cpp/bench/sg/umap.cu
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ class UmapSupervised : public UmapBase {

protected:
void coreBenchmarkMethod() {
fit(*this->handle, this->data.X, yFloat, this->params.nrows,
this->params.ncols, nullptr, nullptr, &uParams, embeddings);
UMAP::fit(*this->handle, this->data.X, yFloat, this->params.nrows,
this->params.ncols, nullptr, nullptr, &uParams, embeddings);
}
};
ML_BENCH_REGISTER(Params, UmapSupervised, "blobs", getInputs());
Expand All @@ -124,8 +124,8 @@ class UmapUnsupervised : public UmapBase {

protected:
void coreBenchmarkMethod() {
fit(*this->handle, this->data.X, this->params.nrows, this->params.ncols,
nullptr, nullptr, &uParams, embeddings);
UMAP::fit(*this->handle, this->data.X, nullptr, this->params.nrows,
this->params.ncols, nullptr, nullptr, &uParams, embeddings);
}
};
ML_BENCH_REGISTER(Params, UmapUnsupervised, "blobs", getInputs());
Expand All @@ -136,17 +136,17 @@ class UmapTransform : public UmapBase {

protected:
void coreBenchmarkMethod() {
transform(*this->handle, this->data.X, this->params.nrows,
this->params.ncols, nullptr, nullptr, this->data.X,
this->params.nrows, embeddings, this->params.nrows, &uParams,
transformed);
UMAP::transform(*this->handle, this->data.X, this->params.nrows,
this->params.ncols, nullptr, nullptr, this->data.X,
this->params.nrows, embeddings, this->params.nrows,
&uParams, transformed);
}
void allocateBuffers(const ::benchmark::State& state) {
UmapBase::allocateBuffers(state);
auto& handle = *this->handle;
alloc(transformed, this->params.nrows * uParams.n_components);
fit(handle, this->data.X, yFloat, this->params.nrows, this->params.ncols,
nullptr, nullptr, &uParams, embeddings);
UMAP::fit(handle, this->data.X, yFloat, this->params.nrows,
this->params.ncols, nullptr, nullptr, &uParams, embeddings);
}
void deallocateBuffers(const ::benchmark::State& state) {
dealloc(transformed, this->params.nrows * uParams.n_components);
Expand Down
2 changes: 1 addition & 1 deletion cpp/cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ else(DEFINED ENV{RAFT_PATH})

ExternalProject_Add(raft
GIT_REPOSITORY https://github.com/rapidsai/raft.git
GIT_TAG 9b3afe67895fbea397fb2c72375157aadfc132d8
GIT_TAG eebd0e306624b419168b2cd5cd7aa44ebaec51f1
PREFIX ${RAFT_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
Expand Down
25 changes: 25 additions & 0 deletions cpp/include/cuml/distance/distance_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

namespace ML {
namespace Distance {

/** enum to tell how to compute euclidean distance */
enum DistanceType : unsigned short {
/** evaluate as dist_ij = sum(x_ik^2) + sum(y_ij)^2 - 2*sum(x_ik * y_jk) */
EucExpandedL2 = 0,
/** same as above, but inside the epilogue, perform square root operation */
EucExpandedL2Sqrt = 1,
/** cosine distance */
EucExpandedCosine = 2,
/** L1 distance */
EucUnexpandedL1 = 3,
/** evaluate as dist_ij += (x_ik - y-jk)^2 */
EucUnexpandedL2 = 4,
/** same as above, but inside the epilogue, perform square root operation */
EucUnexpandedL2Sqrt = 5,
/** simple inner product */
InnerProduct = 6
};

}; // end namespace Distance
}; // end namespace ML
123 changes: 123 additions & 0 deletions cpp/include/cuml/manifold/common.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

namespace ML {

// Dense input uses int64_t until FAISS is updated
typedef int64_t knn_indices_dense_t;

typedef int knn_indices_sparse_t;

/**
* Simple container for KNN graph properties
* @tparam value_idx
* @tparam value_t
*/
template <typename value_idx, typename value_t>
struct knn_graph {
knn_graph(value_idx n_rows_, int n_neighbors_)
: n_rows(n_rows_), n_neighbors(n_neighbors_) {}

knn_graph(value_idx n_rows_, int n_neighbors_, value_idx *knn_indices_,
value_t *knn_dists_)
: n_rows(n_rows_),
n_neighbors(n_neighbors_),
knn_indices(knn_indices_),
knn_dists(knn_dists_) {}

value_idx *knn_indices;
value_t *knn_dists;

value_idx n_rows;
int n_neighbors;
};

/**
* Base struct for representing inputs to manifold learning
* algorithms.
* @tparam T
*/
template <typename T>
struct manifold_inputs_t {
T *y;
int n;
int d;

manifold_inputs_t(T *y_, int n_, int d_) : y(y_), n(n_), d(d_) {}

virtual bool alloc_knn_graph() const;
};

/**
* Dense input to manifold learning algorithms
* @tparam T
*/
template <typename T>
struct manifold_dense_inputs_t : public manifold_inputs_t<T> {
T *X;

manifold_dense_inputs_t(T *x_, T *y_, int n_, int d_)
: manifold_inputs_t<T>(y_, n_, d_), X(x_) {}

bool alloc_knn_graph() const { return true; }
};

/**
* Sparse CSR input to manifold learning algorithms
* @tparam value_idx
* @tparam T
*/
template <typename value_idx, typename T>
struct manifold_sparse_inputs_t : public manifold_inputs_t<T> {
value_idx *indptr;
value_idx *indices;
T *data;

size_t nnz;

manifold_sparse_inputs_t(value_idx *indptr_, value_idx *indices_, T *data_,
T *y_, size_t nnz_, int n_, int d_)
: manifold_inputs_t<T>(y_, n_, d_),
indptr(indptr_),
indices(indices_),
data(data_),
nnz(nnz_) {}

bool alloc_knn_graph() const { return true; }
};

/**
* Precomputed KNN graph input to manifold learning algorithms
* @tparam value_idx
* @tparam value_t
*/
template <typename value_idx, typename value_t>
struct manifold_precomputed_knn_inputs_t
: public manifold_dense_inputs_t<value_t> {
manifold_precomputed_knn_inputs_t<value_idx, value_t>(
value_idx *knn_indices_, value_t *knn_dists_, value_t *X_, value_t *y_,
int n_, int d_, int n_neighbors_)
: manifold_dense_inputs_t<value_t>(X_, y_, n_, d_),
knn_graph(n_, n_neighbors_, knn_indices_, knn_dists_) {}

knn_graph<value_idx, value_t> knn_graph;

bool alloc_knn_graph() const { return false; }
};

}; // end namespace ML
96 changes: 14 additions & 82 deletions cpp/include/cuml/manifold/umap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,19 @@
#include "umapparams.h"

namespace ML {
namespace UMAP {

void transform(const raft::handle_t &handle, float *X, int n, int d,
int64_t *knn_indices, float *knn_dists, float *orig_X,
int orig_n, float *embedding, int embedding_n,
UMAPParams *params, float *transformed);

void transform_sparse(const raft::handle_t &handle, int *indptr, int *indices,
float *data, size_t nnz, int n, int d, int *orig_x_indptr,
int *orig_x_indices, float *orig_x_data, size_t orig_nnz,
int orig_n, float *embedding, int embedding_n,
UMAPParams *params, float *transformed);

void find_ab(const raft::handle_t &handle, UMAPParams *params);

void fit(const raft::handle_t &handle,
Expand All @@ -34,86 +41,11 @@ void fit(const raft::handle_t &handle,
int n, int d, int64_t *knn_indices, float *knn_dists,
UMAPParams *params, float *embeddings);

void fit(const raft::handle_t &handle,
float *X, // input matrix
int n, // rows
int d, // cols
int64_t *knn_indices, float *knn_dists, UMAPParams *params,
float *embeddings);

class UMAP_API {
float *orig_X;
int orig_n;
raft::handle_t *handle;
UMAPParams *params;

public:
UMAP_API(const raft::handle_t &handle, UMAPParams *params);
~UMAP_API();

/**
* Fits an unsupervised UMAP model
* @param X
* pointer to an array in row-major format (note: this will be col-major soon)
* @param n
* n_samples in X
* @param d
* d_features in X
* @param knn_indices
* an array containing the n_neighbors nearest neighors indices for each sample
* @param knn_dists
* an array containing the n_neighbors nearest neighors distances for each sample
* @param embeddings
* an array to return the output embeddings of size (n_samples, n_components)
*/
void fit(float *X, int n, int d, int64_t *knn_indices, float *knn_dists,
float *embeddings);

/**
* Fits a supervised UMAP model
* @param X
* pointer to an array in row-major format (note: this will be col-major soon)
* @param y
* pointer to an array of labels, shape=n_samples
* @param n
* n_samples in X
* @param d
* d_features in X
* @param knn_indices
* an array containing the n_neighbors nearest neighors indices for each sample
* @param knn_dists
* an array containing the n_neighbors nearest neighors distances for each sample
* @param embeddings
* an array to return the output embeddings of size (n_samples, n_components)
*/
void fit(float *X, float *y, int n, int d, int64_t *knn_indices,
float *knn_dists, float *embeddings);

/**
* Project a set of X vectors into the embedding space.
* @param X
* pointer to an array in row-major format (note: this will be col-major soon)
* @param n
* n_samples in X
* @param d
* d_features in X
* @param knn_indices
* an array containing the n_neighbors nearest neighors indices for each sample
* @param knn_dists
* an array containing the n_neighbors nearest neighors distances for each sample
* @param embedding
* pointer to embedding array of size (embedding_n, n_components) that has been created with fit()
* @param embedding_n
* n_samples in embedding array
* @param out
* pointer to array for storing output embeddings (n, n_components)
*/
void transform(float *X, int n, int d, int64_t *knn_indices, float *knn_dists,
float *embedding, int embedding_n, float *out);

/**
* Get the UMAPParams instance
*/
UMAPParams *get_params();
};
void fit_sparse(const raft::handle_t &handle,
int *indptr, // input matrix
int *indices, float *data, size_t nnz, float *y,
int n, // rows
int d, // cols
UMAPParams *params, float *embeddings);
} // namespace UMAP
} // namespace ML
42 changes: 42 additions & 0 deletions cpp/include/cuml/neighbors/knn_sparse.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cuml/cuml.hpp>

#include <cusparse_v2.h>

#include <cuml/neighbors/knn.hpp>

namespace ML {
namespace Sparse {

constexpr int DEFAULT_BATCH_SIZE = 2 << 16;

void brute_force_knn(raft::handle_t &handle, const int *idx_indptr,
const int *idx_indices, const float *idx_data,
size_t idx_nnz, int n_idx_rows, int n_idx_cols,
const int *query_indptr, const int *query_indices,
const float *query_data, size_t query_nnz,
int n_query_rows, int n_query_cols, int *output_indices,
float *output_dists, int k,
size_t batch_size_index = DEFAULT_BATCH_SIZE,
size_t batch_size_query = DEFAULT_BATCH_SIZE,
ML::MetricType metric = ML::MetricType::METRIC_L2,
float metricArg = 0, bool expanded_form = false);
}; // end namespace Sparse
}; // end namespace ML
Loading

0 comments on commit b205e8f

Please sign in to comment.