Skip to content

Commit

Permalink
Sklearn meta-estimators into namespace (#3493)
Browse files Browse the repository at this point in the history
Closes #3484

Imports sklearn's Pipeline and GridSearch meta-estimators into cuML namespace for ease-of-use.

Authors:
  - Victor Lafargue (@viclafargue)

Approvers:
  - William Hicks (@wphicks)
  - John Zedlewski (@JohnZed)

URL: #3493
  • Loading branch information
viclafargue authored Mar 16, 2021
1 parent 96eaf62 commit 28c3e39
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 1 deletion.
26 changes: 25 additions & 1 deletion python/cuml/model_selection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,28 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from cuml.model_selection._split import train_test_split
from sklearn.model_selection import GridSearchCV


GridSearchCV.__doc__ = """
This code is developed and maintained by scikit-learn and imported
by cuML to maintain the familiar sklearn namespace structure.
cuML includes tests to ensure full compatibility of these wrappers
with CUDA-based data and cuML estimators, but all of the underlying code
is due to the scikit-learn developers.\n\n""" + GridSearchCV.__doc__

__all__ = ['train_test_split']
__all__ = ['train_test_split', 'GridSearchCV']
27 changes: 27 additions & 0 deletions python/cuml/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from sklearn.pipeline import Pipeline


Pipeline.__doc__ = """
This code is developed and maintained by scikit-learn and imported
by cuML to maintain the familiar sklearn namespace structure.
cuML includes tests to ensure full compatibility of these wrappers
with CUDA-based data and cuML estimators, but all of the underlying code
is due to the scikit-learn developers.\n\n""" + Pipeline.__doc__

__all__ = ['Pipeline']
102 changes: 102 additions & 0 deletions python/cuml/test/test_meta_estimators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import pytest
import cuml
import cupy

from cuml.pipeline import Pipeline
from cuml.model_selection import GridSearchCV

from cuml.test.utils import ClassEnumerator

from cuml.datasets import make_regression, make_classification
from cuml.model_selection import train_test_split
from sklearn.datasets import load_iris

from cuml.experimental.preprocessing import StandardScaler
from cuml.svm import SVC


def test_pipeline():
X, y = make_classification(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
pipe = Pipeline(steps=[('scaler', StandardScaler()), ('svc', SVC())])
pipe.fit(X_train, y_train)
score = pipe.score(X_test, y_test)
assert score > 0.8


def test_gridsearchCV():
iris = load_iris()
parameters = {'kernel': ('linear', 'rbf'), 'C': [1, 10]}
clf = GridSearchCV(SVC(), parameters)
clf.fit(iris.data, iris.target)
assert clf.best_params_['kernel'] == 'rbf'
assert clf.best_params_['C'] == 10


@pytest.fixture(scope="session")
def regression_dataset(request):
X, y = make_regression(n_samples=10, n_features=5, random_state=0)
return train_test_split(X, y, random_state=0)


@pytest.fixture(scope="session")
def classification_dataset(request):
X, y = make_classification(n_samples=10, n_features=5, random_state=0)
return train_test_split(X, y, random_state=0)


models_config = ClassEnumerator(module=cuml)
models = models_config.get_models()


@pytest.mark.parametrize('model_key', ['ElasticNet',
'Lasso',
'Ridge',
'LinearRegression',
'LogisticRegression',
'MBSGDRegressor',
'RandomForestRegressor',
'KNeighborsRegressor'])
def test_pipeline_with_regression(regression_dataset, model_key):
X_train, X_test, y_train, y_test = regression_dataset
model_const = models[model_key]
if model_key == 'RandomForestRegressor':
model = model_const(n_bins=2)
else:
model = model_const()
pipe = Pipeline(steps=[('scaler', StandardScaler()), ('model', model)])
pipe.fit(X_train, y_train)
prediction = pipe.predict(X_test)
assert isinstance(prediction, cupy.ndarray)


@pytest.mark.parametrize('model_key', ['MBSGDClassifier',
'RandomForestClassifier',
'KNeighborsClassifier'])
def test_pipeline_with_classification(classification_dataset, model_key):
X_train, X_test, y_train, y_test = classification_dataset
model_const = models[model_key]
if model_key == 'RandomForestClassifier':
model = model_const(n_bins=2)
else:
model = model_const()
pipe = Pipeline(steps=[('scaler', StandardScaler()), ('model', model)])
pipe.fit(X_train, y_train)
prediction = pipe.predict(X_test)
assert isinstance(prediction, cupy.ndarray)

0 comments on commit 28c3e39

Please sign in to comment.