Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW][PROPOSAL] Add tags and prefered memory order tags to estimators #3113

Merged
merged 27 commits into from
Nov 20, 2020
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
39e94a4
FEA Add preferred_order class parameter to linear models
dantegd Nov 4, 2020
bc99ec6
Merge branch 'branch-0.17' of https://github.com/rapidsai/cuml into 0…
dantegd Nov 4, 2020
7e84f8c
ENH adopt tags from scikit-learn API to support preferred order attri…
dantegd Nov 8, 2020
564ae51
DOC remove attribute docstrings
dantegd Nov 8, 2020
86830dc
FIX Change straggling classes
dantegd Nov 8, 2020
d6d8a51
FIX Change straggling classes
dantegd Nov 8, 2020
c85acea
FIX Add missing self
dantegd Nov 8, 2020
2f42eaa
FIX straggling attribute
dantegd Nov 9, 2020
bb8e2f7
ENH Add device data tag for proposal
dantegd Nov 11, 2020
9c5b63f
FEA Add all scikit-learn API tags to base and improve gpu input types…
dantegd Nov 11, 2020
340a5e2
FEA Add preferred_order tag to cluster models
dantegd Nov 11, 2020
c66ad00
FEA Add preferred_order tag to most models
dantegd Nov 11, 2020
0a892d1
Merge branch-0.17
dantegd Nov 14, 2020
fe6efb7
ENH Improvements and PR review feedback
dantegd Nov 14, 2020
4e5bb3f
DOC add tag documentation to estimator guide
dantegd Nov 14, 2020
e0cc4f9
DOC add scikit link
dantegd Nov 14, 2020
fe117db
Update wiki/python/ESTIMATOR_GUIDE.md
dantegd Nov 18, 2020
520a4c4
Update wiki/python/ESTIMATOR_GUIDE.md
dantegd Nov 18, 2020
c3a9b41
Update wiki/python/ESTIMATOR_GUIDE.md
dantegd Nov 18, 2020
637dde3
Update wiki/python/ESTIMATOR_GUIDE.md
dantegd Nov 18, 2020
a9ba498
Update wiki/python/ESTIMATOR_GUIDE.md
dantegd Nov 18, 2020
44e698a
ENH Rename test_fit to test_api and add tags tests
dantegd Nov 19, 2020
5689ef8
FIX fixes from PR review
dantegd Nov 19, 2020
d713c89
DOC Added entry to changelog
dantegd Nov 19, 2020
1b02eac
Merge branch 'branch-0.17' into 017-fea-pref-order
cjnolet Nov 19, 2020
5fb7ef1
FIX PEP8 fixes
dantegd Nov 19, 2020
17d0ff6
Merge branch '017-fea-pref-order' of github.com:dantegd/cuml into 017…
dantegd Nov 19, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
- PR #3135: Add QuasiNewton tests
- PR #3040: Improved Array Conversion with CumlArrayDescriptor and Decorators
- PR #3134: Improving the Deprecation Message Formatting in Documentation
- PR #3113: Add tags and prefered memory order tags to estimators
- PR #3137: Reorganize Pytest Config and Add Quick Run Option
- PR #3144: Adding Ability to Set Arbitrary Cmake Flags in ./build.sh
- PR #3155: Eliminate unnecessary warnings from random projection test
Expand All @@ -58,7 +59,7 @@
- PR #3086: Reverting FIL Notebook Testing
- PR #3114: Fixed a typo in SVC's predict_proba AttributeError
- PR #3117: Fix two crashes in experimental RF backend
- PR #3119: Fix memset args for benchmark
- PR #3119: Fix memset args for benchmark
- PR #3130: Return Python string from `dump_as_json()` of RF
- PR #3136: Fix stochastic gradient descent example
- PR #3156: Force local conda artifact install
Expand Down
5 changes: 5 additions & 0 deletions python/cuml/cluster/dbscan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,8 @@ class DBSCAN(Base):
"max_mbytes_per_batch",
"calc_core_sample_indices",
]

def _more_tags(self):
return {
'preferred_input_order': 'C'
}
5 changes: 5 additions & 0 deletions python/cuml/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -622,3 +622,8 @@ class KMeans(Base):
['n_init', 'oversampling_factor', 'max_samples_per_batch',
'init', 'max_iter', 'n_clusters', 'random_state',
'tol']

def _more_tags(self):
return {
'preferred_input_order': 'C'
}
48 changes: 48 additions & 0 deletions python/cuml/common/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,34 @@ from cuml.common.doc_utils import generate_docstring
import cuml.common.input_utils


# tag system based on experimental tag system from Scikit-learn >=0.21
dantegd marked this conversation as resolved.
Show resolved Hide resolved
# https://scikit-learn.org/stable/developers/develop.html#estimator-tags
_default_tags = {
# cuML specific tags
'preferred_input_order': None,
'X_types_gpu': ['2darray'],

# Scikit-learn API standard tags
'non_deterministic': False,
'requires_positive_X': False,
'requires_positive_y': False,
'X_types': ['2darray'],
'poor_score': False,
'no_validation': False,
'multioutput': False,
'allow_nan': False,
'stateless': False,
'multilabel': False,
'_skip_test': False,
'_xfail_checks': False,
'multioutput_only': False,
'binary_only': False,
'requires_fit': True,
'requires_y': False,
'pairwise': False,
}


class Base(metaclass=cuml.internals.BaseMetaClass):
"""
Base class for all the ML algos. It handles some of the common operations
Expand Down Expand Up @@ -348,6 +376,16 @@ class Base(metaclass=cuml.internals.BaseMetaClass):
else:
self.n_features_in_ = X.shape[1]

def _get_tags(self):
# method and code based on scikit-learn 0.21 _get_tags functionality:
# https://scikit-learn.org/stable/developers/develop.html#estimator-tags
collected_tags = _default_tags
for cl in reversed(inspect.getmro(self.__class__)):
if hasattr(cl, '_more_tags') and cl != Base:
more_tags = cl._more_tags(self)
collected_tags.update(more_tags)
return collected_tags


class RegressorMixin:
"""Mixin class for regression estimators in cuML"""
Expand Down Expand Up @@ -379,6 +417,11 @@ class RegressorMixin:
preds = self.predict(X, **kwargs)
return r2_score(y, preds, handle=handle)

def _more_tags(self):
return {
'requires_y': True
}


class ClassifierMixin:
"""Mixin class for classifier estimators in cuML"""
Expand Down Expand Up @@ -410,6 +453,11 @@ class ClassifierMixin:
preds = self.predict(X, **kwargs)
return accuracy_score(y, preds, handle=handle)

def _more_tags(self):
return {
'requires_y': True
}


# Internal, non class owned helper functions
def _check_output_type_str(output_str):
Expand Down
7 changes: 7 additions & 0 deletions python/cuml/decomposition/pca.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -728,3 +728,10 @@ class PCA(Base):
def __setstate__(self, state):
self.__dict__.update(state)
self.handle = Handle()

def _more_tags(self):
return {
'preferred_input_order': 'F',
'X_types_gpu': ['2darray', 'sparse'],
'X_types': ['2darray', 'sparse']
}
5 changes: 5 additions & 0 deletions python/cuml/decomposition/tsvd.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,8 @@ class TruncatedSVD(Base):
def get_param_names(self):
return super().get_param_names() + \
["algorithm", "n_components", "n_iter", "random_state", "tol"]

def _more_tags(self):
return {
'preferred_input_order': 'F'
}
7 changes: 7 additions & 0 deletions python/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -950,3 +950,10 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin):
if self.dtype == np.float64:
return dump_rf_as_json(rf_forest64).decode('utf-8')
return dump_rf_as_json(rf_forest).decode('utf-8')

def _more_tags(self):
return {
# fit and predict require conflicting memory layouts
'preferred_input_order': None
}

6 changes: 6 additions & 0 deletions python/cuml/ensemble/randomforestregressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -746,3 +746,9 @@ class RandomForestRegressor(BaseRandomForestModel, RegressorMixin):
if self.dtype == np.float64:
return dump_rf_as_json(rf_forest64).decode('utf-8')
return dump_rf_as_json(rf_forest).decode('utf-8')

def _more_tags(self):
return {
# fit and predict require conflicting memory layouts
'preferred_input_order': None
}
5 changes: 5 additions & 0 deletions python/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -728,3 +728,8 @@ class ForestInference(Base):

# DO NOT RETURN self._impl here!!
return self

def _more_tags(self):
return {
'preferred_input_order': 'C'
}
5 changes: 5 additions & 0 deletions python/cuml/linear_model/elastic_net.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,8 @@ class ElasticNet(Base, RegressorMixin):
"tol",
"selection",
]

def _more_tags(self):
return {
'preferred_input_order': 'F'
dantegd marked this conversation as resolved.
Show resolved Hide resolved
}
5 changes: 5 additions & 0 deletions python/cuml/linear_model/lasso.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,8 @@ class Lasso(Base, RegressorMixin):
"tol",
"selection",
]

def _more_tags(self):
return {
'preferred_input_order': 'F'
}
5 changes: 5 additions & 0 deletions python/cuml/linear_model/linear_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,8 @@ class LinearRegression(Base, RegressorMixin):
def get_param_names(self):
return super().get_param_names() + \
['algorithm', 'fit_intercept', 'normalize']

def _more_tags(self):
return {
'preferred_input_order': 'F'
}
5 changes: 5 additions & 0 deletions python/cuml/linear_model/logistic_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -424,3 +424,8 @@ class LogisticRegression(Base, ClassifierMixin):
super(LogisticRegression, self).__init__(handle=None,
verbose=state["verbose"])
self.__dict__.update(state)

def _more_tags(self):
return {
'preferred_input_order': 'F'
}
5 changes: 5 additions & 0 deletions python/cuml/linear_model/mbsgd_classifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,8 @@ class MBSGDClassifier(Base, ClassifierMixin):
"batch_size",
"n_iter_no_change",
]

def _more_tags(self):
return {
'preferred_input_order': 'F'
}
5 changes: 5 additions & 0 deletions python/cuml/linear_model/mbsgd_regressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,8 @@ class MBSGDRegressor(Base, RegressorMixin):
"batch_size",
"n_iter_no_change",
]

def _more_tags(self):
return {
'preferred_input_order': 'F'
}
6 changes: 5 additions & 1 deletion python/cuml/linear_model/ridge.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ class Ridge(Base, RegressorMixin):
def __init__(self, alpha=1.0, solver='eig', fit_intercept=True,
normalize=False, handle=None, output_type=None,
verbose=False):

"""
Initializes the linear ridge regression class.

Expand Down Expand Up @@ -394,3 +393,8 @@ class Ridge(Base, RegressorMixin):
def get_param_names(self):
return super().get_param_names() + \
['solver', 'fit_intercept', 'normalize', 'alpha']

def _more_tags(self):
return {
'preferred_input_order': 'F'
}
5 changes: 5 additions & 0 deletions python/cuml/manifold/t_sne.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -485,3 +485,8 @@ class TSNE(Base):
"pre_momentum",
"post_momentum",
]

def _more_tags(self):
return {
'preferred_input_order': 'C'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just leaving a small note (more for myself) that t_sne & UMAP both could probably accept 'F' now that the underlying KNN prim can accept it.

}
5 changes: 5 additions & 0 deletions python/cuml/manifold/umap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -877,3 +877,8 @@ class UMAP(Base):
"optim_batch_size",
"callback",
]

def _more_tags(self):
return {
'preferred_input_order': 'C'
}
6 changes: 6 additions & 0 deletions python/cuml/neighbors/kneighbors_classifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,9 @@ class KNeighborsClassifier(NearestNeighbors, ClassifierMixin):

def get_param_names(self):
return super().get_param_names() + ["weights"]

def _more_tags(self):
return {
# fit and predict require conflicting memory layouts
'preferred_input_order': 'F'
}
6 changes: 6 additions & 0 deletions python/cuml/neighbors/kneighbors_regressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,9 @@ class KNeighborsRegressor(NearestNeighbors, RegressorMixin):

def get_param_names(self):
return super().get_param_names() + ["weights"]

def _more_tags(self):
return {
# fit and predict require conflicting memory layouts
'preferred_input_order': 'F'
JohnZed marked this conversation as resolved.
Show resolved Hide resolved
}
5 changes: 5 additions & 0 deletions python/cuml/neighbors/nearest_neighbors.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -746,3 +746,8 @@ def kneighbors_graph(X=None, n_neighbors=5, mode='connectivity', verbose=False,
query = X.X_m

return X.kneighbors_graph(X=query, n_neighbors=n_neighbors, mode=mode)

def _more_tags(self):
return {
'preferred_input_order': 'F'
}
5 changes: 5 additions & 0 deletions python/cuml/random_projection/random_projection.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -589,3 +589,8 @@ class SparseRandomProjection(Base, BaseRandomProjection):
"dense_output",
"random_state"
]

def _more_tags(self):
return {
'preferred_input_order': 'F'
}
5 changes: 5 additions & 0 deletions python/cuml/solvers/cd.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,8 @@ class CD(Base):
"tol",
"shuffle",
]

def _more_tags(self):
return {
'preferred_input_order': 'F'
}
5 changes: 5 additions & 0 deletions python/cuml/solvers/qn.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -537,3 +537,8 @@ class QN(Base):
return super().get_param_names() + \
['loss', 'fit_intercept', 'l1_strength', 'l2_strength',
'max_iter', 'tol', 'linesearch_max_iter', 'lbfgs_memory']

def _more_tags(self):
return {
'preferred_input_order': 'F'
}
5 changes: 5 additions & 0 deletions python/cuml/solvers/sgd.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -507,3 +507,8 @@ class SGD(Base):
"batch_size",
"n_iter_no_change",
]

def _more_tags(self):
return {
'preferred_input_order': 'F'
}
5 changes: 5 additions & 0 deletions python/cuml/svm/svc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -519,3 +519,8 @@ class SVC(SVMBase, ClassifierMixin):
params.remove("epsilon")

return params

def _more_tags(self):
return {
'preferred_input_order': 'F'
}
5 changes: 5 additions & 0 deletions python/cuml/svm/svm_base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -574,3 +574,8 @@ class SVMBase(Base):
self.__dict__.update(state)
self._model = self._get_svm_model()
self._freeSvmBuffers = False

def _more_tags(self):
return {
'preferred_input_order': 'F'
}
Loading