Skip to content

Commit

Permalink
Updates to consistency of MNMG PCA/TSVD solvers (docs + code consolid…
Browse files Browse the repository at this point in the history
…ation) (#4556)

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Micka (https://github.com/lowener)

URL: #4556
  • Loading branch information
cjnolet authored Mar 17, 2022
1 parent da89b78 commit 48120ce
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 78 deletions.
8 changes: 6 additions & 2 deletions cpp/include/cuml/decomposition/params.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -77,7 +77,11 @@ class paramsPCATemplate : public paramsTSVDTemplate<enum_solver> {
};

typedef paramsTSVDTemplate<> paramsTSVD;

typedef paramsPCATemplate<> paramsPCA;

enum class mg_solver { COV_EIG_DQ, COV_EIG_JACOBI, QR };

typedef paramsPCATemplate<mg_solver> paramsPCAMG;
typedef paramsTSVDTemplate<mg_solver> paramsTSVDMG;

}; // end namespace ML
7 changes: 0 additions & 7 deletions cpp/include/cuml/decomposition/pca_mg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,6 @@
#include "pca.hpp"

namespace ML {

enum class mg_solver { COV_EIG_DQ, COV_EIG_JACOBI, QR };

typedef paramsTSVDTemplate<mg_solver> paramsTSVDMG;

typedef paramsPCATemplate<mg_solver> paramsPCAMG;

namespace PCA {
namespace opg {

Expand Down
16 changes: 8 additions & 8 deletions cpp/include/cuml/decomposition/tsvd_mg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void fit(raft::handle_t& handle,
MLCommon::Matrix::floatData_t** input,
float* components,
float* singular_vals,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose = false);

void fit(raft::handle_t& handle,
Expand All @@ -51,7 +51,7 @@ void fit(raft::handle_t& handle,
MLCommon::Matrix::doubleData_t** input,
double* components,
double* singular_vals,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose = false);

/**
Expand All @@ -77,7 +77,7 @@ void fit_transform(raft::handle_t& handle,
float* explained_var,
float* explained_var_ratio,
float* singular_vals,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose);

void fit_transform(raft::handle_t& handle,
Expand All @@ -89,7 +89,7 @@ void fit_transform(raft::handle_t& handle,
double* explained_var,
double* explained_var_ratio,
double* singular_vals,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose);

/**
Expand All @@ -109,7 +109,7 @@ void transform(raft::handle_t& handle,
MLCommon::Matrix::Data<float>** input,
float* components,
MLCommon::Matrix::Data<float>** trans_input,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose);

void transform(raft::handle_t& handle,
Expand All @@ -118,7 +118,7 @@ void transform(raft::handle_t& handle,
MLCommon::Matrix::Data<double>** input,
double* components,
MLCommon::Matrix::Data<double>** trans_input,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose);

/**
Expand All @@ -138,7 +138,7 @@ void inverse_transform(raft::handle_t& handle,
MLCommon::Matrix::Data<float>** trans_input,
float* components,
MLCommon::Matrix::Data<float>** input,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose);

void inverse_transform(raft::handle_t& handle,
Expand All @@ -147,7 +147,7 @@ void inverse_transform(raft::handle_t& handle,
MLCommon::Matrix::Data<double>** trans_input,
double* components,
MLCommon::Matrix::Data<double>** input,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose);

}; // end namespace opg
Expand Down
30 changes: 15 additions & 15 deletions cpp/src/tsvd/tsvd_mg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void fit_impl(raft::handle_t& handle,
Matrix::PartDescriptor& input_desc,
T* components,
T* singular_vals,
paramsTSVD prms,
paramsTSVDMG& prms,
cudaStream_t* streams,
std::uint32_t n_streams,
bool verbose)
Expand Down Expand Up @@ -93,7 +93,7 @@ void fit_impl(raft::handle_t& handle,
Matrix::Data<T>** input,
T* components,
T* singular_vals,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose)
{
int rank = handle.get_comms().get_rank();
Expand Down Expand Up @@ -128,7 +128,7 @@ void transform_impl(raft::handle_t& handle,
Matrix::PartDescriptor input_desc,
T* components,
std::vector<Matrix::Data<T>*>& trans_input,
paramsTSVD prms,
paramsTSVDMG& prms,
cudaStream_t* streams,
std::uint32_t n_streams,
bool verbose)
Expand Down Expand Up @@ -180,7 +180,7 @@ void transform_impl(raft::handle_t& handle,
Matrix::Data<T>** input,
T* components,
Matrix::Data<T>** trans_input,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose)
{
int rank = handle.get_comms().get_rank();
Expand Down Expand Up @@ -215,7 +215,7 @@ void inverse_transform_impl(raft::handle_t& handle,
Matrix::PartDescriptor trans_input_desc,
T* components,
std::vector<Matrix::Data<T>*>& input,
paramsTSVD prms,
paramsTSVDMG& prms,
cudaStream_t* streams,
std::uint32_t n_streams,
bool verbose)
Expand Down Expand Up @@ -265,7 +265,7 @@ void inverse_transform_impl(raft::handle_t& handle,
Matrix::Data<T>** trans_input,
T* components,
Matrix::Data<T>** input,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose)
{
int rank = handle.get_comms().get_rank();
Expand Down Expand Up @@ -318,7 +318,7 @@ void fit_transform_impl(raft::handle_t& handle,
T* explained_var,
T* explained_var_ratio,
T* singular_vals,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose)
{
int rank = handle.get_comms().get_rank();
Expand Down Expand Up @@ -386,7 +386,7 @@ void fit(raft::handle_t& handle,
Matrix::floatData_t** input,
float* components,
float* singular_vals,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose)
{
fit_impl(handle, rank_sizes, n_parts, input, components, singular_vals, prms, verbose);
Expand All @@ -398,7 +398,7 @@ void fit(raft::handle_t& handle,
Matrix::doubleData_t** input,
double* components,
double* singular_vals,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose)
{
fit_impl(handle, rank_sizes, n_parts, input, components, singular_vals, prms, verbose);
Expand All @@ -413,7 +413,7 @@ void fit_transform(raft::handle_t& handle,
float* explained_var,
float* explained_var_ratio,
float* singular_vals,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose)
{
fit_transform_impl(handle,
Expand All @@ -438,7 +438,7 @@ void fit_transform(raft::handle_t& handle,
double* explained_var,
double* explained_var_ratio,
double* singular_vals,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose)
{
fit_transform_impl(handle,
Expand All @@ -460,7 +460,7 @@ void transform(raft::handle_t& handle,
Matrix::Data<float>** input,
float* components,
Matrix::Data<float>** trans_input,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose)
{
transform_impl(handle, rank_sizes, n_parts, input, components, trans_input, prms, verbose);
Expand All @@ -472,7 +472,7 @@ void transform(raft::handle_t& handle,
Matrix::Data<double>** input,
double* components,
Matrix::Data<double>** trans_input,
paramsTSVD prms,
paramsTSVDMG& prms,
bool verbose)
{
transform_impl(handle, rank_sizes, n_parts, input, components, trans_input, prms, verbose);
Expand All @@ -484,7 +484,7 @@ void inverse_transform(raft::handle_t& handle,
Matrix::Data<float>** trans_input,
float* components,
Matrix::Data<float>** input,
paramsTSVD prms,
paramsTSVDMG prms,
bool verbose)
{
inverse_transform_impl(
Expand All @@ -497,7 +497,7 @@ void inverse_transform(raft::handle_t& handle,
Matrix::Data<double>** trans_input,
double* components,
Matrix::Data<double>** input,
paramsTSVD prms,
paramsTSVDMG prms,
bool verbose)
{
inverse_transform_impl(
Expand Down
17 changes: 10 additions & 7 deletions python/cuml/dask/decomposition/pca.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,9 +32,11 @@ class PCA(BaseDecomposition,
the data. N_components is usually small, say at 3, where it can be used for
data visualization, data compression and exploratory analysis.
cuML's multi-node multi-gpu (MNMG) PCA expects a dask cuDF input, and
provides a "Full" algorithm. It uses a full eigendecomposition
then selects the top K eigenvectors.
cuML's multi-node multi-gpu (MNMG) PCA expects a dask-cuDF object as input
and provides 2 algorithms, Full and Jacobi. Full (default) uses a full
eigendecomposition then selects the top K eigenvectors. The Jacobi
algorithm can be much faster as it iteratively tries to correct the top K
eigenvectors, but might be less accurate.
Examples
--------
Expand Down Expand Up @@ -107,9 +109,10 @@ class PCA(BaseDecomposition,
n_components : int (default = 1)
The number of top K singular vectors / values you want.
Must be <= number(columns).
svd_solver : 'full', 'jacobi', or 'tsqr'
'full': run exact full SVD and select the components by postprocessing
'jacobi': iteratively compute SVD of the covariance matrix
svd_solver : 'full', 'jacobi', 'auto'
'full': Run exact full SVD and select the components by postprocessing
'jacobi': Iteratively compute SVD of the covariance matrix
'auto': For compatiblity with Scikit-learn. Alias for 'jacobi'.
verbose : int or boolean, default=False
Sets logging level. It must be one of `cuml.common.logger.level_*`.
See :ref:`verbosity-levels` for more info.
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/dask/decomposition/tsvd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -97,7 +97,7 @@ class TruncatedSVD(BaseDecomposition,
n_components : int (default = 1)
The number of top K singular vectors / values you want.
Must be <= number(columns).
svd_solver : 'full'
svd_solver : 'full', 'jacobi'
Only Full algorithm is supported since it's significantly faster on GPU
then the other solvers including randomized SVD.
verbose : int or boolean, default=False
Expand Down
8 changes: 8 additions & 0 deletions python/cuml/decomposition/base_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ from cuml.decomposition.utils cimport *
from cuml.common import input_to_cuml_array
from cuml.common.opg_data_utils_mg cimport *

from enum import IntEnum


class MGSolver(IntEnum):
COV_EIG_DQ = <underlying_type_t_solver> mg_solver.COV_EIG_DQ
COV_EIG_JACOBI = <underlying_type_t_solver> mg_solver.COV_EIG_JACOBI
QR = <underlying_type_t_solver> mg_solver.QR


class BaseDecompositionMG(object):

Expand Down
32 changes: 4 additions & 28 deletions python/cuml/decomposition/pca_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -35,30 +35,12 @@ import cuml.common.opg_data_utils_mg as opg
import cuml.internals
from cuml.common.base import Base
from raft.common.handle cimport handle_t
from cuml.decomposition.utils cimport paramsSolver
from cuml.decomposition.utils cimport *
from cuml.common.opg_data_utils_mg cimport *

from cuml.decomposition import PCA
from cuml.decomposition.base_mg import BaseDecompositionMG


ctypedef int underlying_type_t_solver


cdef extern from "cuml/decomposition/pca_mg.hpp" namespace "ML":

ctypedef enum mg_solver "ML::mg_solver":
COV_EIG_DQ "ML::mg_solver::COV_EIG_DQ"
COV_EIG_JACOBI "ML::mg_solver::COV_EIG_JACOBI"
QR "ML::mg_solver::QR"

cdef cppclass paramsTSVDMG(paramsSolver):
size_t n_components
mg_solver algorithm # = solver::COV_EIG_DQ

cdef cppclass paramsPCAMG(paramsTSVDMG):
bool copy
bool whiten
from cuml.decomposition import PCA
from cuml.decomposition.base_mg import BaseDecompositionMG, MGSolver


cdef extern from "cuml/decomposition/pca_mg.hpp" namespace "ML::PCA::opg":
Expand Down Expand Up @@ -88,12 +70,6 @@ cdef extern from "cuml/decomposition/pca_mg.hpp" namespace "ML::PCA::opg":
bool verbose) except +


class MGSolver(IntEnum):
COV_EIG_DQ = <underlying_type_t_solver> mg_solver.COV_EIG_DQ
COV_EIG_JACOBI = <underlying_type_t_solver> mg_solver.COV_EIG_JACOBI
QR = <underlying_type_t_solver> mg_solver.QR


class PCAMG(BaseDecompositionMG, PCA):

def __init__(self, **kwargs):
Expand All @@ -103,6 +79,7 @@ class PCAMG(BaseDecompositionMG, PCA):
algo_map = {
'full': MGSolver.COV_EIG_DQ,
'auto': MGSolver.COV_EIG_JACOBI,
'jacobi': MGSolver.COV_EIG_JACOBI,
# 'arpack': NOT_SUPPORTED,
# 'randomized': NOT_SUPPORTED,
}
Expand All @@ -118,7 +95,6 @@ class PCAMG(BaseDecompositionMG, PCA):
params.n_rows = n_rows
params.n_cols = n_cols
params.whiten = self.whiten
params.n_iterations = self.iterated_power
params.tol = self.tol
params.algorithm = <mg_solver> (<underlying_type_t_solver> (
self.c_algorithm))
Expand Down
Loading

0 comments on commit 48120ce

Please sign in to comment.