Skip to content

Commit

Permalink
Merge pull request rapidsai#5195 from rapidsai/branch-23.02
Browse files Browse the repository at this point in the history
Forward-merge branch-23.02 to branch-23.04
  • Loading branch information
GPUtester authored Feb 3, 2023
2 parents cd84dcc + 3e7f21b commit c1c7b6d
Show file tree
Hide file tree
Showing 11 changed files with 332 additions and 194 deletions.
2 changes: 0 additions & 2 deletions cpp/bench/sg/umap.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ class UmapTransform : public UmapBase {
this->data.X.data(),
this->params.nrows,
this->params.ncols,
nullptr,
nullptr,
this->data.X.data(),
this->params.nrows,
embeddings,
Expand Down
14 changes: 4 additions & 10 deletions cpp/include/cuml/manifold/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,10 @@ struct manifold_sparse_inputs_t : public manifold_inputs_t<T> {
* @tparam value_t
*/
template <typename value_idx, typename value_t>
struct manifold_precomputed_knn_inputs_t : public manifold_dense_inputs_t<value_t> {
manifold_precomputed_knn_inputs_t<value_idx, value_t>(value_idx* knn_indices_,
value_t* knn_dists_,
value_t* X_,
value_t* y_,
int n_,
int d_,
int n_neighbors_)
: manifold_dense_inputs_t<value_t>(X_, y_, n_, d_),
knn_graph(n_, n_neighbors_, knn_indices_, knn_dists_)
struct manifold_precomputed_knn_inputs_t : public manifold_inputs_t<value_t> {
manifold_precomputed_knn_inputs_t<value_idx, value_t>(
value_idx* knn_indices_, value_t* knn_dists_, value_t* y_, int n_, int d_, int n_neighbors_)
: manifold_inputs_t<value_t>(y_, n_, d_), knn_graph(n_, n_neighbors_, knn_indices_, knn_dists_)
{
}

Expand Down
8 changes: 4 additions & 4 deletions cpp/include/cuml/manifold/umap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ void fit(const raft::handle_t& handle,
* @param[in] y: pointer to labels array
* @param[in] n: n_samples of input array
* @param[in] d: n_features of input array
* @param[in] knn_indices: pointer to knn_indices of input (optional)
* @param[in] knn_dists: pointer to knn_dists of input (optional)
* @param[in] params: pointer to ML::UMAPParams object
* @param[out] embeddings: pointer to embedding produced through projection
* @param[out] graph: pointer to fuzzy simplicial set graph
Expand All @@ -131,6 +133,8 @@ void fit_sparse(const raft::handle_t& handle,
float* y,
int n,
int d,
int* knn_indices,
float* knn_dists,
UMAPParams* params,
float* embeddings,
raft::sparse::COO<float, int>* graph);
Expand All @@ -142,8 +146,6 @@ void fit_sparse(const raft::handle_t& handle,
* @param[in] X: pointer to input array to be infered
* @param[in] n: n_samples of input array to be infered
* @param[in] d: n_features of input array to be infered
* @param[in] knn_indices: pointer to knn_indices of input (optional)
* @param[in] knn_dists: pointer to knn_dists of input (optional)
* @param[in] orig_X: pointer to original training array
* @param[in] orig_n: number of rows in original training array
* @param[in] embedding: pointer to embedding created during training
Expand All @@ -155,8 +157,6 @@ void transform(const raft::handle_t& handle,
float* X,
int n,
int d,
int64_t* knn_indices,
float* knn_dists,
float* orig_X,
int orig_n,
float* embedding,
Expand Down
59 changes: 34 additions & 25 deletions cpp/src/umap/umap.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ std::unique_ptr<raft::sparse::COO<float, int>> get_graph(
CUML_LOG_DEBUG("Calling UMAP::get_graph() with precomputed KNN");

manifold_precomputed_knn_inputs_t<knn_indices_dense_t, float> inputs(
knn_indices, knn_dists, X, y, n, d, params->n_neighbors);
knn_indices, knn_dists, y, n, d, params->n_neighbors);
if (y != nullptr) {
UMAPAlgo::_get_graph_supervised<knn_indices_dense_t,
float,
Expand Down Expand Up @@ -106,7 +106,7 @@ void fit(const raft::handle_t& handle,
CUML_LOG_DEBUG("Calling UMAP::fit() with precomputed KNN");

manifold_precomputed_knn_inputs_t<knn_indices_dense_t, float> inputs(
knn_indices, knn_dists, X, y, n, d, params->n_neighbors);
knn_indices, knn_dists, y, n, d, params->n_neighbors);
if (y != nullptr) {
UMAPAlgo::_fit_supervised<knn_indices_dense_t,
float,
Expand Down Expand Up @@ -139,48 +139,57 @@ void fit_sparse(const raft::handle_t& handle,
float* y,
int n,
int d,
int* knn_indices,
float* knn_dists,
UMAPParams* params,
float* embeddings,
raft::sparse::COO<float, int>* graph)
{
manifold_sparse_inputs_t<int, float> inputs(indptr, indices, data, y, nnz, n, d);
if (y != nullptr) {
UMAPAlgo::
_fit_supervised<knn_indices_sparse_t, float, manifold_sparse_inputs_t<int, float>, TPB_X>(
handle, inputs, params, embeddings, graph);
if (knn_indices != nullptr && knn_dists != nullptr) {
manifold_precomputed_knn_inputs_t<knn_indices_sparse_t, float> inputs(
knn_indices, knn_dists, y, n, d, params->n_neighbors);
if (y != nullptr) {
UMAPAlgo::_fit_supervised<knn_indices_sparse_t,
float,
manifold_precomputed_knn_inputs_t<knn_indices_sparse_t, float>,
TPB_X>(handle, inputs, params, embeddings, graph);
} else {
UMAPAlgo::_fit<knn_indices_sparse_t,
float,
manifold_precomputed_knn_inputs_t<knn_indices_sparse_t, float>,
TPB_X>(handle, inputs, params, embeddings, graph);
}
} else {
UMAPAlgo::_fit<knn_indices_sparse_t, float, manifold_sparse_inputs_t<int, float>, TPB_X>(
handle, inputs, params, embeddings, graph);
manifold_sparse_inputs_t<int, float> inputs(indptr, indices, data, y, nnz, n, d);
if (y != nullptr) {
UMAPAlgo::_fit_supervised<knn_indices_sparse_t,
float,
manifold_sparse_inputs_t<knn_indices_sparse_t, float>,
TPB_X>(handle, inputs, params, embeddings, graph);
} else {
UMAPAlgo::_fit<knn_indices_sparse_t,
float,
manifold_sparse_inputs_t<knn_indices_sparse_t, float>,
TPB_X>(handle, inputs, params, embeddings, graph);
}
}
}

void transform(const raft::handle_t& handle,
float* X,
int n,
int d,
knn_indices_dense_t* knn_indices,
float* knn_dists,
float* orig_X,
int orig_n,
float* embedding,
int embedding_n,
UMAPParams* params,
float* transformed)
{
if (knn_indices != nullptr && knn_dists != nullptr) {
manifold_precomputed_knn_inputs_t<knn_indices_dense_t, float> inputs(
knn_indices, knn_dists, X, nullptr, n, d, params->n_neighbors);
UMAPAlgo::_transform<knn_indices_dense_t,
float,
manifold_precomputed_knn_inputs_t<knn_indices_dense_t, float>,
TPB_X>(
handle, inputs, inputs, embedding, embedding_n, params, transformed);
} else {
manifold_dense_inputs_t<float> inputs(X, nullptr, n, d);
manifold_dense_inputs_t<float> orig_inputs(orig_X, nullptr, orig_n, d);
UMAPAlgo::_transform<knn_indices_dense_t, float, manifold_dense_inputs_t<float>, TPB_X>(
handle, inputs, orig_inputs, embedding, embedding_n, params, transformed);
}
manifold_dense_inputs_t<float> inputs(X, nullptr, n, d);
manifold_dense_inputs_t<float> orig_inputs(orig_X, nullptr, orig_n, d);
UMAPAlgo::_transform<knn_indices_dense_t, float, manifold_dense_inputs_t<float>, TPB_X>(
handle, inputs, orig_inputs, embedding, embedding_n, params, transformed);
}

void transform_sparse(const raft::handle_t& handle,
Expand Down
2 changes: 0 additions & 2 deletions cpp/test/sg/umap_parametrizable_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,6 @@ class UMAPParametrizableTest : public ::testing::Test {
X,
n_samples,
umap_params.n_components,
knn_indices,
knn_dists,
X,
n_samples,
model_embedding,
Expand Down
131 changes: 99 additions & 32 deletions python/cuml/common/sparsefuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from cuml.internals.safe_imports import gpu_only_import_from
from cuml.common.kernel_utils import cuda_kernel_factory
import cuml.internals
from cuml.internals.import_utils import has_scipy
from cuml.internals.memory_utils import with_cupy_rmm
from cuml.internals.input_utils import input_to_cuml_array
from cuml.internals.safe_imports import gpu_only_import

import math
import cuml
from cuml.internals.input_utils import input_to_cuml_array, input_to_cupy_array
from cuml.internals.memory_utils import with_cupy_rmm
from cuml.internals.import_utils import has_scipy
from cuml.common.kernel_utils import cuda_kernel_factory
from cuml.internals.safe_imports import cpu_only_import
from cuml.internals.safe_imports import gpu_only_import
from cuml.internals.safe_imports import gpu_only_import_from
np = cpu_only_import('numpy')
cp = gpu_only_import('cupy')
cupyx = gpu_only_import('cupyx')
Expand All @@ -30,6 +31,15 @@
cp_csc_matrix = gpu_only_import_from('cupyx.scipy.sparse', 'csc_matrix')


if has_scipy():
from scipy.sparse import csr_matrix, coo_matrix, csc_matrix
else:
from cuml.common.import_utils import DummyClass
csr_matrix = DummyClass
coo_matrix = DummyClass
csc_matrix = DummyClass


def _map_l1_norm_kernel(dtype):
"""Creates cupy RawKernel for csr_raw_normalize_l1 function."""

Expand Down Expand Up @@ -180,20 +190,12 @@ def _insert_zeros(ary, zero_indices):


@with_cupy_rmm
def extract_knn_graph(knn_graph, convert_dtype=True, sparse=False):
def extract_knn_graph(knn_graph):
"""
Converts KNN graph from CSR, COO and CSC formats into separate
distance and indice arrays. Input can be a cupy sparse graph (device)
or a numpy sparse graph (host).
"""
if has_scipy():
from scipy.sparse import csr_matrix, coo_matrix, csc_matrix
else:
from cuml.internals.import_utils import DummyClass
csr_matrix = DummyClass
coo_matrix = DummyClass
csc_matrix = DummyClass

if isinstance(knn_graph, (csc_matrix, cp_csc_matrix)):
knn_graph = cupyx.scipy.sparse.csr_matrix(knn_graph)
n_samples = knn_graph.shape[0]
Expand All @@ -212,25 +214,90 @@ def extract_knn_graph(knn_graph, convert_dtype=True, sparse=False):
knn_indices = knn_graph.col

if knn_indices is not None:
convert_to_dtype = None
if convert_dtype:
convert_to_dtype = np.int32 if sparse else np.int64

knn_dists = knn_graph.data
return knn_indices, knn_dists
else:
return None


@with_cupy_rmm
def extract_pairwise_dists(pw_dists, n_neighbors):
"""
Extract the nearest neighbors distances and indices
from a pairwise distance matrix.
Parameters
----------
pw_dists: paiwise distances matrix of shape (n_samples, n_samples)
n_neighbors: number of nearest neighbors
(inspired from Scikit-Learn code)
"""
pw_dists, _, _, _ = input_to_cupy_array(pw_dists)

n_rows = pw_dists.shape[0]
sample_range = cp.arange(n_rows)[:, None]
knn_indices = cp.argpartition(pw_dists, n_neighbors - 1, axis=1)
knn_indices = knn_indices[:, :n_neighbors]
argdist = cp.argsort(pw_dists[sample_range, knn_indices])
knn_indices = knn_indices[sample_range, argdist]
knn_dists = pw_dists[sample_range, knn_indices]
return knn_indices, knn_dists


@with_cupy_rmm
def extract_knn_infos(knn_info, n_neighbors):
"""
Extract the nearest neighbors distances and indices
from the knn_info parameter.
Parameters
----------
knn_info : array / sparse array / tuple, optional (device or host)
Either one of :
- Tuple (indices, distances) of arrays of
shape (n_samples, n_neighbors)
- Pairwise distances dense array of shape (n_samples, n_samples)
- KNN graph sparse array (preferably CSR/COO)
n_neighbors: number of nearest neighbors
"""
if knn_info is None:
# no KNN was provided
return None

deepcopy = False
if isinstance(knn_info, tuple):
# dists and indices provided as a tuple
results = knn_info
else:
isaKNNGraph = isinstance(knn_info, (csr_matrix, coo_matrix, csc_matrix,
cp_csr_matrix, cp_coo_matrix,
cp_csc_matrix))
if isaKNNGraph:
# extract dists and indices from a KNN graph
deepcopy = True
results = extract_knn_graph(knn_info)
else:
# extract dists and indices from a pairwise distance matrix
results = extract_pairwise_dists(knn_info, n_neighbors)

if results is not None:
knn_indices, knn_dists = results

knn_indices_m, _, _, _ = \
input_to_cuml_array(knn_indices, order='C',
deepcopy=True,
check_dtype=(np.int64, np.int32),
convert_to_dtype=convert_to_dtype)
input_to_cuml_array(knn_indices.flatten(),
order='C',
deepcopy=deepcopy,
check_dtype=np.int64,
convert_to_dtype=np.int64)

knn_dists_m, _, _, _ = \
input_to_cuml_array(knn_dists, order='C',
deepcopy=True,
input_to_cuml_array(knn_dists.flatten(),
order='C',
deepcopy=deepcopy,
check_dtype=np.float32,
convert_to_dtype=(np.float32
if convert_dtype
else None))
convert_to_dtype=np.float32)

return (knn_indices_m, knn_indices_m.ptr),\
(knn_dists_m, knn_dists_m.ptr)
return (None, None), (None, None)
return knn_indices_m, knn_dists_m
else:
return None
Loading

0 comments on commit c1c7b6d

Please sign in to comment.