Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements of UMAP/TSNE precomputed KNN feature #4865

Merged
merged 20 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions cpp/bench/sg/umap.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ class UmapTransform : public UmapBase {
this->data.X.data(),
this->params.nrows,
this->params.ncols,
nullptr,
nullptr,
this->data.X.data(),
this->params.nrows,
embeddings,
Expand Down
14 changes: 4 additions & 10 deletions cpp/include/cuml/manifold/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,10 @@ struct manifold_sparse_inputs_t : public manifold_inputs_t<T> {
* @tparam value_t
*/
template <typename value_idx, typename value_t>
struct manifold_precomputed_knn_inputs_t : public manifold_dense_inputs_t<value_t> {
manifold_precomputed_knn_inputs_t<value_idx, value_t>(value_idx* knn_indices_,
value_t* knn_dists_,
value_t* X_,
value_t* y_,
int n_,
int d_,
int n_neighbors_)
: manifold_dense_inputs_t<value_t>(X_, y_, n_, d_),
knn_graph(n_, n_neighbors_, knn_indices_, knn_dists_)
struct manifold_precomputed_knn_inputs_t : public manifold_inputs_t<value_t> {
manifold_precomputed_knn_inputs_t<value_idx, value_t>(
value_idx* knn_indices_, value_t* knn_dists_, value_t* y_, int n_, int d_, int n_neighbors_)
: manifold_inputs_t<value_t>(y_, n_, d_), knn_graph(n_, n_neighbors_, knn_indices_, knn_dists_)
{
}

Expand Down
8 changes: 4 additions & 4 deletions cpp/include/cuml/manifold/umap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ void fit(const raft::handle_t& handle,
* @param[in] y: pointer to labels array
* @param[in] n: n_samples of input array
* @param[in] d: n_features of input array
* @param[in] knn_indices: pointer to knn_indices of input (optional)
* @param[in] knn_dists: pointer to knn_dists of input (optional)
* @param[in] params: pointer to ML::UMAPParams object
* @param[out] embeddings: pointer to embedding produced through projection
* @param[out] graph: pointer to fuzzy simplicial set graph
Expand All @@ -131,6 +133,8 @@ void fit_sparse(const raft::handle_t& handle,
float* y,
int n,
int d,
int* knn_indices,
float* knn_dists,
UMAPParams* params,
float* embeddings,
raft::sparse::COO<float, int>* graph);
Expand All @@ -142,8 +146,6 @@ void fit_sparse(const raft::handle_t& handle,
* @param[in] X: pointer to input array to be infered
* @param[in] n: n_samples of input array to be infered
* @param[in] d: n_features of input array to be infered
* @param[in] knn_indices: pointer to knn_indices of input (optional)
* @param[in] knn_dists: pointer to knn_dists of input (optional)
* @param[in] orig_X: pointer to original training array
* @param[in] orig_n: number of rows in original training array
* @param[in] embedding: pointer to embedding created during training
Expand All @@ -155,8 +157,6 @@ void transform(const raft::handle_t& handle,
float* X,
int n,
int d,
int64_t* knn_indices,
float* knn_dists,
float* orig_X,
int orig_n,
float* embedding,
Expand Down
59 changes: 34 additions & 25 deletions cpp/src/umap/umap.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ std::unique_ptr<raft::sparse::COO<float, int>> get_graph(
CUML_LOG_DEBUG("Calling UMAP::get_graph() with precomputed KNN");

manifold_precomputed_knn_inputs_t<knn_indices_dense_t, float> inputs(
knn_indices, knn_dists, X, y, n, d, params->n_neighbors);
knn_indices, knn_dists, y, n, d, params->n_neighbors);
if (y != nullptr) {
UMAPAlgo::_get_graph_supervised<knn_indices_dense_t,
float,
Expand Down Expand Up @@ -106,7 +106,7 @@ void fit(const raft::handle_t& handle,
CUML_LOG_DEBUG("Calling UMAP::fit() with precomputed KNN");

manifold_precomputed_knn_inputs_t<knn_indices_dense_t, float> inputs(
knn_indices, knn_dists, X, y, n, d, params->n_neighbors);
knn_indices, knn_dists, y, n, d, params->n_neighbors);
if (y != nullptr) {
UMAPAlgo::_fit_supervised<knn_indices_dense_t,
float,
Expand Down Expand Up @@ -139,48 +139,57 @@ void fit_sparse(const raft::handle_t& handle,
float* y,
int n,
int d,
int* knn_indices,
float* knn_dists,
UMAPParams* params,
float* embeddings,
raft::sparse::COO<float, int>* graph)
{
manifold_sparse_inputs_t<int, float> inputs(indptr, indices, data, y, nnz, n, d);
if (y != nullptr) {
UMAPAlgo::
_fit_supervised<knn_indices_sparse_t, float, manifold_sparse_inputs_t<int, float>, TPB_X>(
handle, inputs, params, embeddings, graph);
if (knn_indices != nullptr && knn_dists != nullptr) {
manifold_precomputed_knn_inputs_t<knn_indices_sparse_t, float> inputs(
knn_indices, knn_dists, y, n, d, params->n_neighbors);
if (y != nullptr) {
UMAPAlgo::_fit_supervised<knn_indices_sparse_t,
float,
manifold_precomputed_knn_inputs_t<knn_indices_sparse_t, float>,
TPB_X>(handle, inputs, params, embeddings, graph);
} else {
UMAPAlgo::_fit<knn_indices_sparse_t,
float,
manifold_precomputed_knn_inputs_t<knn_indices_sparse_t, float>,
TPB_X>(handle, inputs, params, embeddings, graph);
}
} else {
UMAPAlgo::_fit<knn_indices_sparse_t, float, manifold_sparse_inputs_t<int, float>, TPB_X>(
handle, inputs, params, embeddings, graph);
manifold_sparse_inputs_t<int, float> inputs(indptr, indices, data, y, nnz, n, d);
if (y != nullptr) {
UMAPAlgo::_fit_supervised<knn_indices_sparse_t,
float,
manifold_sparse_inputs_t<knn_indices_sparse_t, float>,
TPB_X>(handle, inputs, params, embeddings, graph);
} else {
UMAPAlgo::_fit<knn_indices_sparse_t,
float,
manifold_sparse_inputs_t<knn_indices_sparse_t, float>,
TPB_X>(handle, inputs, params, embeddings, graph);
}
}
}

void transform(const raft::handle_t& handle,
float* X,
int n,
int d,
knn_indices_dense_t* knn_indices,
float* knn_dists,
float* orig_X,
int orig_n,
float* embedding,
int embedding_n,
UMAPParams* params,
float* transformed)
{
if (knn_indices != nullptr && knn_dists != nullptr) {
manifold_precomputed_knn_inputs_t<knn_indices_dense_t, float> inputs(
knn_indices, knn_dists, X, nullptr, n, d, params->n_neighbors);
UMAPAlgo::_transform<knn_indices_dense_t,
float,
manifold_precomputed_knn_inputs_t<knn_indices_dense_t, float>,
TPB_X>(
handle, inputs, inputs, embedding, embedding_n, params, transformed);
} else {
manifold_dense_inputs_t<float> inputs(X, nullptr, n, d);
manifold_dense_inputs_t<float> orig_inputs(orig_X, nullptr, orig_n, d);
UMAPAlgo::_transform<knn_indices_dense_t, float, manifold_dense_inputs_t<float>, TPB_X>(
handle, inputs, orig_inputs, embedding, embedding_n, params, transformed);
}
manifold_dense_inputs_t<float> inputs(X, nullptr, n, d);
manifold_dense_inputs_t<float> orig_inputs(orig_X, nullptr, orig_n, d);
UMAPAlgo::_transform<knn_indices_dense_t, float, manifold_dense_inputs_t<float>, TPB_X>(
handle, inputs, orig_inputs, embedding, embedding_n, params, transformed);
}

void transform_sparse(const raft::handle_t& handle,
Expand Down
2 changes: 0 additions & 2 deletions cpp/test/sg/umap_parametrizable_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,6 @@ class UMAPParametrizableTest : public ::testing::Test {
X,
n_samples,
umap_params.n_components,
knn_indices,
knn_dists,
X,
n_samples,
model_embedding,
Expand Down
119 changes: 93 additions & 26 deletions python/cuml/common/sparsefuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import math
import numpy as np
import cupy as cp
import cupyx
from cuml.internals.input_utils import input_to_cuml_array
from cuml.internals.input_utils import input_to_cuml_array, input_to_cupy_array
from cuml.internals.memory_utils import with_cupy_rmm
from cuml.internals.import_utils import has_scipy
import cuml.internals
Expand All @@ -26,6 +27,15 @@
coo_matrix as cp_coo_matrix, csc_matrix as cp_csc_matrix


if has_scipy():
from scipy.sparse import csr_matrix, coo_matrix, csc_matrix
else:
from cuml.common.import_utils import DummyClass
csr_matrix = DummyClass
coo_matrix = DummyClass
csc_matrix = DummyClass


def _map_l1_norm_kernel(dtype):
"""Creates cupy RawKernel for csr_raw_normalize_l1 function."""

Expand Down Expand Up @@ -176,20 +186,12 @@ def _insert_zeros(ary, zero_indices):


@with_cupy_rmm
def extract_knn_graph(knn_graph, convert_dtype=True, sparse=False):
def extract_knn_graph(knn_graph):
"""
Converts KNN graph from CSR, COO and CSC formats into separate
distance and indice arrays. Input can be a cupy sparse graph (device)
or a numpy sparse graph (host).
"""
if has_scipy():
from scipy.sparse import csr_matrix, coo_matrix, csc_matrix
else:
from cuml.internals.import_utils import DummyClass
csr_matrix = DummyClass
coo_matrix = DummyClass
csc_matrix = DummyClass

if isinstance(knn_graph, (csc_matrix, cp_csc_matrix)):
knn_graph = cupyx.scipy.sparse.csr_matrix(knn_graph)
n_samples = knn_graph.shape[0]
Expand All @@ -208,25 +210,90 @@ def extract_knn_graph(knn_graph, convert_dtype=True, sparse=False):
knn_indices = knn_graph.col

if knn_indices is not None:
convert_to_dtype = None
if convert_dtype:
convert_to_dtype = np.int32 if sparse else np.int64

knn_dists = knn_graph.data
return knn_indices, knn_dists
else:
return None


@with_cupy_rmm
def extract_pairwise_dists(pw_dists, n_neighbors):
"""
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
Extract the nearest neighbors distances and indices
from a pairwise distance matrix.

Parameters
----------
pw_dists: paiwise distances matrix of shape (n_samples, n_samples)
n_neighbors: number of nearest neighbors

(inspired from Scikit-Learn code)
"""
pw_dists, _, _, _ = input_to_cupy_array(pw_dists)

n_rows = pw_dists.shape[0]
sample_range = cp.arange(n_rows)[:, None]
knn_indices = cp.argpartition(pw_dists, n_neighbors - 1, axis=1)
knn_indices = knn_indices[:, :n_neighbors]
argdist = cp.argsort(pw_dists[sample_range, knn_indices])
knn_indices = knn_indices[sample_range, argdist]
knn_dists = pw_dists[sample_range, knn_indices]
return knn_indices, knn_dists


@with_cupy_rmm
def extract_knn_infos(knn_info, n_neighbors):
"""
Extract the nearest neighbors distances and indices
from the knn_info parameter.

Parameters
----------
knn_info : array / sparse array / tuple, optional (device or host)
Either one of :
- Tuple (indices, distances) of arrays of
shape (n_samples, n_neighbors)
- Pairwise distances dense array of shape (n_samples, n_samples)
- KNN graph sparse array (preferably CSR/COO)
n_neighbors: number of nearest neighbors
"""
if knn_info is None:
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
# no KNN was provided
return None

deepcopy = False
if isinstance(knn_info, tuple):
# dists and indices provided as a tuple
results = knn_info
else:
isaKNNGraph = isinstance(knn_info, (csr_matrix, coo_matrix, csc_matrix,
cp_csr_matrix, cp_coo_matrix,
cp_csc_matrix))
if isaKNNGraph:
# extract dists and indices from a KNN graph
deepcopy = True
results = extract_knn_graph(knn_info)
else:
# extract dists and indices from a pairwise distance matrix
results = extract_pairwise_dists(knn_info, n_neighbors)

if results is not None:
knn_indices, knn_dists = results

knn_indices_m, _, _, _ = \
input_to_cuml_array(knn_indices, order='C',
deepcopy=True,
check_dtype=(np.int64, np.int32),
convert_to_dtype=convert_to_dtype)
input_to_cuml_array(knn_indices.flatten(),
order='C',
deepcopy=deepcopy,
check_dtype=np.int64,
convert_to_dtype=np.int64)

knn_dists_m, _, _, _ = \
input_to_cuml_array(knn_dists, order='C',
deepcopy=True,
input_to_cuml_array(knn_dists.flatten(),
order='C',
deepcopy=deepcopy,
check_dtype=np.float32,
convert_to_dtype=(np.float32
if convert_dtype
else None))
convert_to_dtype=np.float32)

return (knn_indices_m, knn_indices_m.ptr),\
(knn_dists_m, knn_dists_m.ptr)
return (None, None), (None, None)
return knn_indices_m, knn_dists_m
else:
return None
Loading