Skip to content

Commit

Permalink
Changes to precomputed knn API in UMAP
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Oct 3, 2022
1 parent fcbf1a5 commit ccc1b50
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 81 deletions.
32 changes: 26 additions & 6 deletions python/cuml/common/sparsefuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from cuml.common.kernel_utils import cuda_kernel_factory
from cupyx.scipy.sparse import csr_matrix as cp_csr_matrix,\
coo_matrix as cp_coo_matrix, csc_matrix as cp_csc_matrix
from sklearn.neighbors import NearestNeighbors as skNearestNeighbors


def _map_l1_norm_kernel(dtype):
Expand Down Expand Up @@ -176,7 +177,8 @@ def _insert_zeros(ary, zero_indices):


@with_cupy_rmm
def extract_knn_graph(knn_graph, convert_dtype=True, sparse=False):
def extract_knn_graph(knn_graph, n_neighbors, convert_dtype=True,
sparse=False):
"""
Converts KNN graph from CSR, COO and CSC formats into separate
distance and indice arrays. Input can be a cupy sparse graph (device)
Expand All @@ -190,6 +192,11 @@ def extract_knn_graph(knn_graph, convert_dtype=True, sparse=False):
coo_matrix = DummyClass
csc_matrix = DummyClass

convert_to_dtype = None
if convert_dtype:
convert_to_dtype = np.int32 if sparse else np.int64

# CSC matrices preprocessing
if isinstance(knn_graph, (csc_matrix, cp_csc_matrix)):
knn_graph = cupyx.scipy.sparse.csr_matrix(knn_graph)
n_samples = knn_graph.shape[0]
Expand All @@ -202,17 +209,30 @@ def extract_knn_graph(knn_graph, convert_dtype=True, sparse=False):
knn_graph.data = knn_graph.data[reordering]

knn_indices = None
# CSR matrices
if isinstance(knn_graph, (csr_matrix, cp_csr_matrix)):
knn_indices = knn_graph.indices
knn_dists = knn_graph.data
# COO matrices
elif isinstance(knn_graph, (coo_matrix, cp_coo_matrix)):
knn_indices = knn_graph.col
knn_dists = knn_graph.data
# Dense distance matrix
else:
distance_matrix = knn_graph
n_samples = distance_matrix.shape[0]
nn = skNearestNeighbors(n_neighbors=n_neighbors, metric='precomputed')
nn.fit(distance_matrix)
knn_dists, knn_indices = nn.kneighbors(n_neighbors=n_neighbors,
return_distance=True)
knn_dists = np.column_stack((np.zeros(n_samples),
knn_dists))[:, :-1].flatten()
knn_indices = np.column_stack((np.arange(n_samples),
knn_indices))[:, :-1].flatten()
knn_dists = knn_dists.astype(np.float64)
knn_indices = knn_indices.astype(np.int32)

if knn_indices is not None:
convert_to_dtype = None
if convert_dtype:
convert_to_dtype = np.int32 if sparse else np.int64

knn_dists = knn_graph.data
knn_indices_m, _, _, _ = \
input_to_cuml_array(knn_indices, order='C',
deepcopy=True,
Expand Down
139 changes: 76 additions & 63 deletions python/cuml/manifold/umap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,14 @@ class UMAP(Base,
consistency of trained embeddings, allowing for reproducible results
to 3 digits of precision, but will do so at the expense of potentially
slower training and increased memory usage.
precomputed_knn : sparse array-like (device or host)
shape=(n_samples, n_samples)
A sparse array containing the k-nearest neighbors of X,
or a pairwise distance matrix. This allows the use of
a custom distance function. This function should match
the metric used to train the UMAP embeedings.
Acceptable formats: sparse SciPy ndarray, CuPy device ndarray,
CSR/COO preferred other formats will go through conversion to CSR
callback: An instance of GraphBasedDimRedCallback class
Used to intercept the internal state of embeddings while they are being
trained. Example of callback usage:
Expand Down Expand Up @@ -332,6 +340,7 @@ class UMAP(Base,
handle=None,
hash_input=False,
random_state=None,
precomputed_knn=None,
callback=None,
output_type=None):

Expand Down Expand Up @@ -393,6 +402,7 @@ class UMAP(Base,
else:
raise Exception("Invalid target metric: {}" % target_metric)

self.precomputed_knn = precomputed_knn
self.callback = callback # prevent callback destruction
self.X_m = None
self.embedding_ = None
Expand Down Expand Up @@ -486,55 +496,55 @@ class UMAP(Base,
@generate_docstring(convert_dtype_cast='np.float32',
X='dense_sparse',
skip_parameters_heading=True)
def fit(self, X, y=None, precomputed=False,
convert_dtype=True) -> "UMAP":
def fit(self, X, y=None, convert_dtype=True,
knn_graph=None) -> "UMAP":
"""
Fit X into an embedded space.

Parameters
----------
precomputed : boolean, False by default
When set to True, X should be provided in the form of
a sparse array containing the k-nearest neighbors of the input.
This allows the use of a custom metrics whereas UMAP
would otherwise use the euclidean distance.
Acceptable formats for the KNN graph:
Sparse SciPy or CuPy ndarray of shape (n_samples1, n_samples2),
CSR/COO preferred, other formats will go through conversion to CSR
knn_graph : (deprecated) sparse array-like (device or host)
shape=(n_samples, n_samples)
A sparse array containing the k-nearest neighbors of X,
or a pairwise distance matrix. This allows the use of
a custom distance function. This function should match
the metric used to train the UMAP embeedings.
Acceptable formats: sparse SciPy ndarray, CuPy device ndarray,
CSR/COO preferred other formats will go through conversion to CSR
"""
if len(X.shape) != 2:
raise ValueError("data should be two dimensional")

if y is not None and precomputed is True\
and self.target_metric != "categorical":
if y is not None and self.target_metric != "categorical":
raise ValueError("Cannot provide a KNN graph when in \
semi-supervised mode with categorical target_metric for now.")

self.sparse_fit = is_sparse(X) and not precomputed

cdef uintptr_t X_ptr = 0

# Precomputed distance matrix in the form of a KNN graph
if precomputed:
knn_indices_m, knn_dists_m =\
extract_knn_graph(X, convert_dtype)
self.n_rows, self.n_dims = X.shape
index = None
if knn_graph is not None:
self.precomputed_knn = knn_graph

# Precomputed knn
if self.precomputed_knn is not None:
self.knn_indices_m, self.knn_dists_m =\
extract_knn_graph(self.precomputed_knn, self.n_neighbors)
self.n_rows, self.n_dims = self.precomputed_knn.shape
self.index = None
# Dense input
elif not is_sparse(X):
self.X_m, self.n_rows, self.n_dims, dtype = \
input_to_cuml_array(X, order='C', check_dtype=np.float32,
convert_to_dtype=(np.float32
if convert_dtype
else None))
if convert_dtype
else None))
X_ptr = self.X_m.ptr
index = self.X_m.index
self.index = self.X_m.index
# Sparse input
else:
self.X_m = SparseCumlArray(X, convert_to_dtype=cupy.float32,
convert_format=False)
convert_format=False)
self.n_rows, self.n_dims = self.X_m.shape
index = self.X_m.index
self.index = self.X_m.index

if self.n_rows <= 1:
raise ValueError("There needs to be more than 1 sample to "
Expand All @@ -545,7 +555,7 @@ class UMAP(Base,
self.embedding_ = CumlArray.zeros((self.n_rows,
self.n_components),
order="C", dtype=np.float32,
index=index)
index=self.index)

if self.hash_input:
with using_output_type("numpy"):
Expand All @@ -569,39 +579,43 @@ class UMAP(Base,
y_raw = y_m.ptr

fss_graph = GraphHolder.new_graph(handle_.get_stream())
if precomputed:
self.sparse_fit = False

# Precomputed knn
if self.precomputed_knn is not None:
fit_preprocessed(handle_[0],
<float*> y_raw,
<int> self.n_rows,
<int> self.n_dims,
<int64_t*><uintptr_t> knn_indices_m.ptr,
<float*><uintptr_t> knn_dists_m.ptr,
<int64_t*><uintptr_t> self.knn_indices_m.ptr,
<float*><uintptr_t> self.knn_dists_m.ptr,
<UMAPParams*>umap_params,
<float*>embed_raw,
<COO*> fss_graph.get())
# Dense input
elif not is_sparse(X):
fit(handle_[0],
<float*> X_ptr,
<float*> y_raw,
<int> self.n_rows,
<int> self.n_dims,
<UMAPParams*>umap_params,
<float*>embed_raw,
<COO*> fss_graph.get())
# Sparse input
else:
if self.sparse_fit:
fit_sparse(handle_[0],
<int*><uintptr_t> self.X_m.indptr.ptr,
<int*><uintptr_t> self.X_m.indices.ptr,
<float*><uintptr_t> self.X_m.data.ptr,
<size_t> self.X_m.nnz,
<float*> y_raw,
<int> self.n_rows,
<int> self.n_dims,
<UMAPParams*> umap_params,
<float*> embed_raw,
<COO*> fss_graph.get())

else:
fit(handle_[0],
<float*> X_ptr,
<float*> y_raw,
<int> self.n_rows,
<int> self.n_dims,
<UMAPParams*>umap_params,
<float*>embed_raw,
<COO*> fss_graph.get())
self.sparse_fit = True
fit_sparse(handle_[0],
<int*><uintptr_t> self.X_m.indptr.ptr,
<int*><uintptr_t> self.X_m.indices.ptr,
<float*><uintptr_t> self.X_m.data.ptr,
<size_t> self.X_m.nnz,
<float*> y_raw,
<int> self.n_rows,
<int> self.n_dims,
<UMAPParams*> umap_params,
<float*> embed_raw,
<COO*> fss_graph.get())

self.graph_ = fss_graph.get_cupy_coo()

Expand All @@ -620,8 +634,8 @@ class UMAP(Base,
low-dimensional space.',
'shape': '(n_samples, n_components)'})
@cuml.internals.api_base_fit_transform()
def fit_transform(self, X, y=None, precomputed=False,
convert_dtype=True) -> CumlArray:
def fit_transform(self, X, y=None, convert_dtype=True,
knn_graph=None) -> CumlArray:
"""
Fit X into an embedded space and return that transformed
output.
Expand All @@ -634,17 +648,16 @@ class UMAP(Base,

Parameters
----------
precomputed : boolean, False by default
When set to True, X should be provided in the form of
a sparse array containing the k-nearest neighbors of the input.
This allows the use of a custom metrics whereas UMAP
would otherwise use the euclidean distance.
Acceptable formats for the KNN graph:
Sparse SciPy or CuPy ndarray of shape (n_samples1, n_samples2),
CSR/COO preferred, other formats will go through conversion to CSR

knn_graph : (deprecated) sparse array-like (device or host)
shape=(n_samples, n_samples)
A sparse array containing the k-nearest neighbors of X,
or a pairwise distance matrix. This allows the use of
a custom distance function. This function should match
the metric used to train the UMAP embeedings.
Acceptable formats: sparse SciPy ndarray, CuPy device ndarray,
CSR/COO preferred other formats will go through conversion to CSR
"""
self.fit(X, y, precomputed=precomputed, convert_dtype=convert_dtype)
self.fit(X, y, convert_dtype=convert_dtype, knn_graph=knn_graph)

return self.embedding_

Expand Down
27 changes: 15 additions & 12 deletions python/cuml/tests/test_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from sklearn.datasets import make_blobs
from sklearn.manifold import trustworthiness
from sklearn.metrics import adjusted_rand_score
from cuml.metrics.pairwise_distances import pairwise_distances


dataset_names = ['iris', 'digits', 'wine', 'blobs']

Expand Down Expand Up @@ -483,14 +485,12 @@ def test_umap_knn_parameters(n_neighbors):
n_samples=2000, n_features=10, centers=5, random_state=0)
data = data.astype(np.float32)

def fit_transform_embed(knn_graph):
def fit_transform_embed(precomputed_knn):
model = cuUMAP(random_state=42,
init='random',
n_neighbors=n_neighbors)
print('type(knn_graph):', type(knn_graph))
return model.fit_transform(knn_graph,
precomputed=True,
convert_dtype=True)
n_neighbors=n_neighbors,
precomputed_knn=precomputed_knn)
return model.fit_transform(data)

def test_trustworthiness(embedding):
trust = trustworthiness(data, embedding, n_neighbors=n_neighbors)
Expand All @@ -501,18 +501,21 @@ def test_equality(e1, e2):
print("mean diff: %s" % mean_diff)
assert mean_diff < 1.0

precomputed_dists = pairwise_distances(data)
embedding1 = fit_transform_embed(precomputed_dists)

neigh = NearestNeighbors(n_neighbors=n_neighbors)
neigh.fit(data)
knn_graph = neigh.kneighbors_graph(data, mode="distance")

embedding1 = fit_transform_embed(knn_graph.tocsr())
embedding2 = fit_transform_embed(knn_graph.tocoo())
embedding3 = fit_transform_embed(knn_graph.tocsc())
embedding2 = fit_transform_embed(knn_graph.tocsr())
embedding3 = fit_transform_embed(knn_graph.tocoo())
embedding4 = fit_transform_embed(knn_graph.tocsc())

test_trustworthiness(embedding1)

test_equality(embedding1, embedding2)
test_trustworthiness(embedding2)
# test_equality(embedding1, embedding2)
test_equality(embedding2, embedding3)
test_equality(embedding3, embedding4)


def correctness_sparse(a, b, atol=0.1, rtol=0.2, threshold=0.95):
Expand Down

0 comments on commit ccc1b50

Please sign in to comment.