Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] New Feature StratifiedKFold #3109

Merged
merged 24 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
cf87af4
Merge pull request #15 from rapidsai/branch-0.15
daxiongshu Jul 26, 2020
e3b7848
Merge pull request #18 from rapidsai/branch-0.17
daxiongshu Nov 1, 2020
7162d2a
copy basic codes
daxiongshu Nov 1, 2020
e6d8ec3
Merge pull request #19 from rapidsai/branch-0.17
daxiongshu Nov 17, 2020
8b1b7c3
Merge pull request #20 from rapidsai/branch-0.18
daxiongshu Dec 28, 2020
39a38b3
Merge pull request #21 from daxiongshu/branch-0.18
daxiongshu Dec 28, 2020
fdbbe0d
Merge branch 'branch-22.02' of https://github.com/rapidsai/cuml into …
daxiongshu Jan 25, 2022
dcbbf9d
Merge branch 'rapidsai-branch-22.02' into fea_stratified_kfold
daxiongshu Jan 25, 2022
80da002
add docs
daxiongshu Jan 25, 2022
f149b73
first test passed
daxiongshu Jan 27, 2022
8907c84
copy right year
daxiongshu Jan 27, 2022
ffc4df0
Merge branch 'rapidsai:branch-22.02' into fea_stratified_kfold
daxiongshu Jan 28, 2022
20b8e49
Merge branch 'rapidsai:branch-22.04' into fea_stratified_kfold
daxiongshu Feb 14, 2022
46443f9
Merge branch 'rapidsai:branch-22.04' into fea_stratified_kfold
daxiongshu Mar 28, 2022
358a6c3
remove self.tpb
daxiongshu Mar 29, 2022
bb6cdc7
use input_to_cuml_array
daxiongshu Mar 29, 2022
154a0fe
more parameters
daxiongshu Mar 29, 2022
86e612a
test_num_classes_check
daxiongshu Mar 29, 2022
2f73b94
fix style
daxiongshu Mar 29, 2022
666d58b
Merge branch 'rapidsai:branch-22.04' into fea_stratified_kfold
daxiongshu Mar 30, 2022
b94ca51
remove unused func
daxiongshu Mar 30, 2022
f6ffad9
Merge branch 'rapidsai:branch-22.06' into fea_stratified_kfold
daxiongshu Apr 8, 2022
c902d1f
Merge branch 'rapidsai:branch-22.10' into fea_stratified_kfold
daxiongshu Sep 12, 2022
101a008
Merge branch 'rapidsai:branch-22.10' into fea_stratified_kfold
daxiongshu Sep 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/cuml/model_selection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,7 @@
#

from cuml.model_selection._split import train_test_split
from cuml.model_selection._split import StratifiedKFold
from cuml.common.import_utils import has_sklearn

if has_sklearn():
Expand All @@ -27,4 +28,4 @@
with CUDA-based data and cuML estimators, but all of the underlying code
is due to the scikit-learn developers.\n\n""" + GridSearchCV.__doc__

__all__ = ['train_test_split', 'GridSearchCV']
__all__ = ['train_test_split', 'GridSearchCV', 'StratifiedKFold']
103 changes: 103 additions & 0 deletions python/cuml/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np

from cuml.common.memory_utils import _strides_to_order
from cuml.common import input_to_cuml_array
from numba import cuda
from typing import Union

Expand Down Expand Up @@ -471,3 +472,105 @@ def train_test_split(X,
return X_train, X_test, y_train, y_test
else:
return X_train, X_test


class StratifiedKFold:
"""
A cudf based implementation of Stratified K-Folds cross-validator.

Provides train/test indices to split data into stratified K folds.
The percentage of samples for each class are maintained in each
fold.

Parameters
----------
n_splits : int, default=5
Number of folds. Must be at least 2.
shuffle : boolean, default=False
Whether to shuffle each class's samples before splitting.
random_state : int (default=None)
Random seed

Examples
--------
Splitting X,y into stratified K folds
.. code-block:: python
import cupy
X = cupy.random.rand(12,10)
y = cupy.arange(12)%4
kf = StratifiedKFold(n_splits=3)
for tr,te in kf.split(X,y):
print(tr, te)
Output:
.. code-block:: python
[ 4 5 6 7 8 9 10 11] [0 1 2 3]
[ 0 1 2 3 8 9 10 11] [4 5 6 7]
[0 1 2 3 4 5 6 7] [ 8 9 10 11]

"""

def __init__(self, n_splits=5, shuffle=False, random_state=None):
if n_splits < 2 or not isinstance(n_splits, int):
raise ValueError(
f'n_splits {n_splits} is not a integer at least 2')

if random_state is not None and not isinstance(random_state, int):
raise ValueError(f'random_state {random_state} is not an integer')

self.n_splits = n_splits
self.shuffle = shuffle
self.seed = random_state

def get_n_splits(self, X=None, y=None):
return self.n_splits

def split(self, x, y):
if len(x) != len(y):
raise ValueError('Expecting same length of x and y')
y = input_to_cuml_array(y).array.to_output('cupy')
if len(cp.unique(y)) < 2:
raise ValueError(
'number of unique classes cannot be less than 2')
df = cudf.DataFrame()
ids = cp.arange(y.shape[0])

if self.shuffle:
cp.random.seed(self.seed)
cp.random.shuffle(ids)
y = y[ids]

df['y'] = y
df['ids'] = ids
grpby = df.groupby(['y'])

dg = grpby.agg({'y': 'count'})
col = dg.columns[0]
msg = f'n_splits={self.n_splits} cannot be greater ' + \
'than the number of members in each class.'
if self.n_splits > dg[col].min():
raise ValueError(msg)

def get_order_in_group(y, ids, order):
for i in range(cuda.threadIdx.x, len(y), cuda.blockDim.x):
order[i] = i

got = grpby.apply_grouped(get_order_in_group, incols=['y', 'ids'],
outcols={'order': 'int32'},
tpb=64)
got = got.sort_values('ids')

for i in range(self.n_splits):
mask = got['order'] % self.n_splits == i
train = got.loc[~mask, 'ids'].values
test = got.loc[mask, 'ids'].values
if len(test) == 0:
break
yield train, test

def _check_array_shape(self, y):
if y is None:
raise ValueError("Expecting 1D array, got None")
elif hasattr(y, 'shape') and len(y.shape) > 1 and y.shape[1] > 1:
raise ValueError(f"Expecting 1D array, got {y.shape}")
else:
pass
65 changes: 65 additions & 0 deletions python/cuml/test/test_stratified_kfold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import cudf
import cupy as cp
import pytest

from cuml.model_selection import StratifiedKFold


def get_x_y(n_samples, n_classes):
X = cudf.DataFrame({"x": range(n_samples)})
y = cp.arange(n_samples) % n_classes
cp.random.shuffle(y)
y = cudf.Series(y)
return X, y


@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("n_splits", [5, 10])
@pytest.mark.parametrize("n_samples", [10000])
@pytest.mark.parametrize("n_classes", [2, 10])
def test_split_dataframe(n_samples, n_classes, n_splits, shuffle):
X, y = get_x_y(n_samples, n_classes)

kf = StratifiedKFold(n_splits=n_splits, shuffle=shuffle)
for train_index, test_index in kf.split(X, y):
assert len(train_index)+len(test_index) == n_samples
assert len(train_index) == len(test_index)*(n_splits-1)
for i in range(n_classes):
ratio_tr = (y[train_index] == i).sum() / len(train_index)
ratio_te = (y[test_index] == i).sum() / len(test_index)
assert ratio_tr == ratio_te


def test_num_classes_check():
X, y = get_x_y(n_samples=1000, n_classes=1)
kf = StratifiedKFold(n_splits=5)
err_msg = "number of unique classes cannot be less than 2"
with pytest.raises(ValueError, match=err_msg):
for train_index, test_index in kf.split(X, y):
pass


@pytest.mark.parametrize("n_splits", [0, 1])
def test_invalid_folds(n_splits):
X, y = get_x_y(n_samples=1000, n_classes=2)

err_msg = f'n_splits {n_splits} is not a integer at least 2'
with pytest.raises(ValueError, match=err_msg):
kf = StratifiedKFold(n_splits=n_splits)
for train_index, test_index in kf.split(X, y):
break