Skip to content

Commit

Permalink
[REVIEW] Improved Array Conversion with CumlArrayDescriptor and Decor…
Browse files Browse the repository at this point in the history
…ators [skip-ci] (#3040)

* Adding additional checking for incorrect use cases. Added CumlArrayDescriptor

* Cleaning up more use cases

* Initial commit of CumlArrayDescriptor in PCA

* Incrementally updating CumlArray uses

* Adding some improvements to decorators to auto detect certain scenarios where a function returns CumlArray

* Adding internals.func_utils to test wrapping all functions and checking output types

* Commit before merging upstream

* Updating native_bayes

* Partial working state

* Updating KMeans

* Partial pass over all Base subclasses

* Mostly complete pass of removing to_output

* Completed cleanup of Base method removal

* Cleaning up more to_output uses. Fixing test errors

* Adding tartet_arg property and fixing tests that can use it

* More cleanup and test fixing

* Updating types derived from Base to properly use get_param_names and allow setting Base values in constructor

* Fixing import order. Adding support for sparse arrays

* Attempting to fix nearest neighbors

* Removing commented code

* Fixing failing tests

* Fixing more tests

* Adding PR to CHANGELOG and style fixes

* Fixing missing import

* Removing protocol interface for python 3.7

* Fixing ARIMA. Required including changes from PR#2956

* Fixing labelbinarizer and KNN failing tests

* Removing "invalid syntax" so flake8 can run

* Adding more wrappers to ARIMA so tests pass.

* Committing CI change to allow tests to run.

* Moving memory check to plugin

* Adding ability to load SPD environment variables to the logger

* Changing pytest import-mode to better support development

* Changing relative imports to absolute

* Adding first iteration of dev guide to see how it looks

* Improving the quick_run plugin

* Removing skip_* from cuml decorators

* Fixing cuml_decorators test.

* Removing the logger environment addition

* Updating non-Base methods to use decorators

* Large cleanup of remaining to_output, with_cupy_rmm and input_to_dev_ptr

* Style cleanup

* Apply John's suggestions from code review on Dev Guide

Co-authored-by: John Zedlewski <[email protected]>

* Large update to Estimator Guide incorporating feedback from JohnZ

* Removing array tracking and putting in plugin

* Removing PR Description file

* Removing ArrayOutputable

* Removing test plugins

* Cleaning up code to remove unnecessary diffs

* Style cleanup

* Defaulting to cp array instead of np, per feedback

* Adding additional tests

* Separating func_tools into separate files

* Removing extra changes to conftest.py which should not have been committed.

* Renaming base.py back to base.pyx

* Apply suggestions from code review

Co-authored-by: Dante Gama Dessavre <[email protected]>

* Incorporating feedback from Dante's code review

* Removing straggling TODO

* Applying Dante's Revisions to ESTIMATOR_GUIDE

Co-authored-by: Dante Gama Dessavre <[email protected]>

* Updateing ESTIMATOR_GUIDE from feedback from Dante

* Cleaning up straggling to_output

* Another iteration on code review feedback

* Style cleanup

* More small items from code review

* One final change to ESTIMATOR_GUIDE

* Updaing all *_mg.pyx files to use the new naming conventions and CumlArrayDescriptor

Co-authored-by: John Zedlewski <[email protected]>
Co-authored-by: Dante Gama Dessavre <[email protected]>
  • Loading branch information
3 people authored Nov 13, 2020
1 parent f1cca8d commit 77da916
Show file tree
Hide file tree
Showing 103 changed files with 4,737 additions and 1,857 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
- PR #3112: Speed test_array
- PR #3111: Adding Cython to Code Coverage
- PR #3129: Update notebooks README
- PR #3040: Improved Array Conversion with CumlArrayDescriptor and Decorators
- PR #3134: Improving the Deprecation Message Formatting in Documentation

## Bug Fixes
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@

# Output type configuration

global_output_type = 'input'
global_output_type = None

from cuml.common.memory_utils import set_global_output_type, using_output_type

Expand Down
3 changes: 2 additions & 1 deletion python/cuml/benchmark/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def _convert_to_gpuarray(data, order='F'):
gs = cudf.Series.from_pandas(data)
return cuda.as_cuda_array(gs)
else:
return input_utils.input_to_dev_array(data, order=order)[0]
return input_utils.input_to_cuml_array(
data, order=order)[0].to_output("numba")


def _convert_to_gpuarray_c(data):
Expand Down
54 changes: 27 additions & 27 deletions python/cuml/cluster/dbscan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ from cuml.common.base import Base
from cuml.common.doc_utils import generate_docstring
from cuml.raft.common.handle cimport handle_t
from cuml.common import input_to_cuml_array
from cuml.common import using_output_type
from cuml.common.array_descriptor import CumlArrayDescriptor

from collections import defaultdict

Expand Down Expand Up @@ -186,6 +188,9 @@ class DBSCAN(Base):
<http://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html>`_.
"""

labels_ = CumlArrayDescriptor()
core_sample_indices_ = CumlArrayDescriptor()

def __init__(self, eps=0.5, handle=None, min_samples=5,
verbose=False, max_mbytes_per_batch=None,
output_type=None, calc_core_sample_indices=True):
Expand All @@ -196,18 +201,17 @@ class DBSCAN(Base):
self.calc_core_sample_indices = calc_core_sample_indices

# internal array attributes
self._labels_ = None # accessed via estimator.labels_
self.labels_ = None

# accessed via estimator._core_sample_indices_ when
# self.calc_core_sample_indices == True
self._core_sample_indices_ = None
# One used when `self.calc_core_sample_indices == True`
self.core_sample_indices_ = None

# C++ API expects this to be numeric.
if self.max_mbytes_per_batch is None:
self.max_mbytes_per_batch = 0

@generate_docstring(skip_parameters_heading=True)
def fit(self, X, out_dtype="int32"):
def fit(self, X, out_dtype="int32") -> "DBSCAN":
"""
Perform DBSCAN clustering from features.

Expand All @@ -218,11 +222,6 @@ class DBSCAN(Base):
"int64", np.int64}.

"""
self._set_base_attributes(output_type=X, n_features=X)

if self._labels_ is not None:
del self._labels_

if out_dtype not in ["int32", np.int32, "int64", np.int64]:
raise ValueError("Invalid value for out_dtype. "
"Valid values are {'int32', 'int64', "
Expand All @@ -236,16 +235,16 @@ class DBSCAN(Base):

cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()

self._labels_ = CumlArray.empty(n_rows, dtype=out_dtype)
cdef uintptr_t labels_ptr = self._labels_.ptr
self.labels_ = CumlArray.empty(n_rows, dtype=out_dtype)
cdef uintptr_t labels_ptr = self.labels_.ptr

cdef uintptr_t core_sample_indices_ptr = <uintptr_t> NULL

# Create the output core_sample_indices only if needed
if self.calc_core_sample_indices:
self._core_sample_indices_ = \
self.core_sample_indices_ = \
CumlArray.empty(n_rows, dtype=out_dtype)
core_sample_indices_ptr = self._core_sample_indices_.ptr
core_sample_indices_ptr = self.core_sample_indices_.ptr

if self.dtype == np.float32:
if out_dtype is "int32" or out_dtype is np.int32:
Expand Down Expand Up @@ -303,20 +302,21 @@ class DBSCAN(Base):
# Finally, resize the core_sample_indices array if necessary
if self.calc_core_sample_indices:

# Temp convert to cupy array only once
core_samples_cupy = self._core_sample_indices_.to_output("cupy")
# Temp convert to cupy array (better than using `cupy.asarray`)
with using_output_type("cupy"):

# First get the min index. These have to monotonically increasing,
# so the min index should be the first returned -1
min_index = cp.argmin(core_samples_cupy).item()
# First get the min index. These have to monotonically
# increasing, so the min index should be the first returned -1
min_index = cp.argmin(self.core_sample_indices_).item()

# Check for the case where there are no -1's
if (min_index == 0 and core_samples_cupy[min_index].item() != -1):
# Nothing to delete. The array has no -1's
pass
else:
self._core_sample_indices_ = \
self._core_sample_indices_[:min_index]
# Check for the case where there are no -1's
if ((min_index == 0 and
self.core_sample_indices_[min_index].item() != -1)):
# Nothing to delete. The array has no -1's
pass
else:
self.core_sample_indices_ = \
self.core_sample_indices_[:min_index]

return self

Expand All @@ -325,7 +325,7 @@ class DBSCAN(Base):
'type': 'dense',
'description': 'Cluster labels',
'shape': '(n_samples, 1)'})
def fit_predict(self, X, out_dtype="int32"):
def fit_predict(self, X, out_dtype="int32") -> CumlArray:
"""
Performs clustering on X and returns cluster labels.

Expand Down
50 changes: 25 additions & 25 deletions python/cuml/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ import cudf
import numpy as np
import rmm
import warnings
import typing

from libcpp cimport bool
from libc.stdint cimport uintptr_t, int64_t
from libc.stdlib cimport calloc, malloc, free

from cuml.common.array import CumlArray
from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.common.base import Base
from cuml.common.doc_utils import generate_docstring
from cuml.raft.common.handle cimport handle_t
Expand Down Expand Up @@ -259,6 +261,9 @@ class KMeans(Base):
<http://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html>`_.
"""

labels_ = CumlArrayDescriptor()
cluster_centers_ = CumlArrayDescriptor()

def __init__(self, handle=None, n_clusters=8, max_iter=300, tol=1e-4,
verbose=False, random_state=1,
init='scalable-k-means++', n_init=1, oversampling_factor=2.0,
Expand All @@ -275,8 +280,8 @@ class KMeans(Base):
self.max_samples_per_batch=int(max_samples_per_batch)

# internal array attributes
self._labels_ = None # accessed via estimator.labels_
self._cluster_centers_ = None # accessed via estimator.cluster_centers_ # noqa
self.labels_ = None
self.cluster_centers_ = None

cdef KMeansParams params
params.n_clusters = <int>self.n_clusters
Expand All @@ -301,7 +306,7 @@ class KMeans(Base):
else:
self.init = 'preset'
params.init = Array
self._cluster_centers_, n_rows, self.n_cols, self.dtype = \
self.cluster_centers_, n_rows, self.n_cols, self.dtype = \
input_to_cuml_array(init, order='C',
check_dtype=[np.float32, np.float64])

Expand All @@ -316,13 +321,11 @@ class KMeans(Base):
self._params = params

@generate_docstring()
def fit(self, X, sample_weight=None):
def fit(self, X, sample_weight=None) -> "KMeans":
"""
Compute k-means clustering with X.

"""
self._set_base_attributes(output_type=X, n_features=X)

if self.init == 'preset':
check_cols = self.n_cols
check_dtype = self.dtype
Expand All @@ -349,15 +352,15 @@ class KMeans(Base):

cdef uintptr_t sample_weight_ptr = sample_weight_m.ptr

self._labels_ = CumlArray.zeros(shape=n_rows, dtype=np.int32)
cdef uintptr_t labels_ptr = self._labels_.ptr
self.labels_ = CumlArray.zeros(shape=n_rows, dtype=np.int32)
cdef uintptr_t labels_ptr = self.labels_.ptr

if (self.init in ['scalable-k-means++', 'k-means||', 'random']):
self._cluster_centers_ = \
self.cluster_centers_ = \
CumlArray.zeros(shape=(self.n_clusters, self.n_cols),
dtype=self.dtype, order='C')

cdef uintptr_t cluster_centers_ptr = self._cluster_centers_.ptr
cdef uintptr_t cluster_centers_ptr = self.cluster_centers_.ptr

cdef float inertiaf = 0
cdef double inertiad = 0
Expand Down Expand Up @@ -409,15 +412,16 @@ class KMeans(Base):
'type': 'dense',
'description': 'Cluster indexes',
'shape': '(n_samples, 1)'})
def fit_predict(self, X, sample_weight=None):
def fit_predict(self, X, sample_weight=None) -> CumlArray:
"""
Compute cluster centers and predict cluster index for each sample.

"""
return self.fit(X, sample_weight=sample_weight).labels_

def _predict_labels_inertia(self, X, convert_dtype=False,
sample_weight=None):
sample_weight=None) -> typing.Tuple[CumlArray,
float]:
"""
Predict the closest cluster each sample in X belongs to.

Expand Down Expand Up @@ -446,8 +450,6 @@ class KMeans(Base):
Sum of squared distances of samples to their closest cluster center.
"""

out_type = self._get_output_type(X)

X_m, n_rows, n_cols, dtype = \
input_to_cuml_array(X, order='C', check_dtype=self.dtype,
convert_to_dtype=(self.dtype if convert_dtype
Expand All @@ -468,10 +470,10 @@ class KMeans(Base):

cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()

cdef uintptr_t cluster_centers_ptr = self._cluster_centers_.ptr
cdef uintptr_t cluster_centers_ptr = self.cluster_centers_.ptr

self._labels_ = CumlArray.zeros(shape=n_rows, dtype=np.int32)
cdef uintptr_t labels_ptr = self._labels_.ptr
self.labels_ = CumlArray.zeros(shape=n_rows, dtype=np.int32)
cdef uintptr_t labels_ptr = self.labels_.ptr

# Sum of squared distances of samples to their closest cluster center.
cdef float inertiaf = 0
Expand Down Expand Up @@ -511,13 +513,13 @@ class KMeans(Base):
self.handle.sync()
del(X_m)
del(sample_weight_m)
return self._labels_.to_output(out_type), inertia
return self.labels_, inertia

@generate_docstring(return_values={'name': 'preds',
'type': 'dense',
'description': 'Cluster indexes',
'shape': '(n_samples, 1)'})
def predict(self, X, convert_dtype=False, sample_weight=None):
def predict(self, X, convert_dtype=False, sample_weight=None) -> CumlArray:
"""
Predict the closest cluster each sample in X belongs to.

Expand All @@ -532,14 +534,12 @@ class KMeans(Base):
'type': 'dense',
'description': 'Transformed data',
'shape': '(n_samples, n_clusters)'})
def transform(self, X, convert_dtype=False):
def transform(self, X, convert_dtype=False) -> CumlArray:
"""
Transform X to a cluster-distance space.

"""

out_type = self._get_output_type(X)

X_m, n_rows, n_cols, dtype = \
input_to_cuml_array(X, order='C', check_dtype=self.dtype,
convert_to_dtype=(self.dtype if convert_dtype
Expand All @@ -550,7 +550,7 @@ class KMeans(Base):

cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()

cdef uintptr_t cluster_centers_ptr = self._cluster_centers_.ptr
cdef uintptr_t cluster_centers_ptr = self.cluster_centers_.ptr

preds = CumlArray.zeros(shape=(n_rows, self.n_clusters),
dtype=self.dtype,
Expand Down Expand Up @@ -589,7 +589,7 @@ class KMeans(Base):
self.handle.sync()

del(X_m)
return preds.to_output(out_type)
return preds

@generate_docstring(return_values={'name': 'score',
'type': 'float',
Expand All @@ -610,7 +610,7 @@ class KMeans(Base):
'type': 'dense',
'description': 'Transformed data',
'shape': '(n_samples, n_clusters)'})
def fit_transform(self, X, convert_dtype=False):
def fit_transform(self, X, convert_dtype=False) -> CumlArray:
"""
Compute clustering and transform X to cluster-distance space.

Expand Down
13 changes: 6 additions & 7 deletions python/cuml/cluster/kmeans_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class KMeansMG(KMeans):
def __init__(self, **kwargs):
super(KMeansMG, self).__init__(**kwargs)

def fit(self, X):
def fit(self, X) -> "KMeansMG":
"""
Compute k-means clustering with X in a multi-node multi-GPU setting.

Expand All @@ -84,7 +84,6 @@ class KMeansMG(KMeans):
ndarray, cuda array interface compliant array like CuPy

"""
self._set_base_attributes(n_features=X)

X_m, self.n_rows, self.n_cols, self.dtype = \
input_to_cuml_array(X, order='C')
Expand All @@ -94,12 +93,12 @@ class KMeansMG(KMeans):
cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()

if (self.init in ['scalable-k-means++', 'k-means||', 'random']):
self._cluster_centers_ = CumlArray.zeros(shape=(self.n_clusters,
self.n_cols),
dtype=self.dtype,
order='C')
self.cluster_centers_ = CumlArray.zeros(shape=(self.n_clusters,
self.n_cols),
dtype=self.dtype,
order='C')

cdef uintptr_t cluster_centers_ptr = self._cluster_centers_.ptr
cdef uintptr_t cluster_centers_ptr = self.cluster_centers_.ptr

cdef size_t n_rows = self.n_rows
cdef size_t n_cols = self.n_cols
Expand Down
4 changes: 0 additions & 4 deletions python/cuml/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,5 @@

## legacy to be removed after complete CumlAray migration

from cuml.common.numba_utils import zeros
from cuml.common.input_utils import get_cudf_column_ptr
from cuml.common.input_utils import get_dev_array_ptr
from cuml.common.input_utils import input_to_dev_array
from cuml.common.input_utils import sparse_scipy_to_cp
from cuml.common.timing_utils import timed
Loading

0 comments on commit 77da916

Please sign in to comment.