Skip to content

Commit

Permalink
Fix type hint in split function. (rapidsai#5625)
Browse files Browse the repository at this point in the history
Authors:
  - Jiaming Yuan (https://github.com/trivialfis)

Approvers:
  - Simon Adorf (https://github.com/csadorf)

URL: rapidsai#5625
  • Loading branch information
trivialfis authored Oct 25, 2023
1 parent 3146025 commit 2b5aa3e
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions python/cuml/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
# limitations under the License.
#

from typing import Union
from cuml.internals.safe_imports import gpu_only_import_from
from typing import Optional, Union

from cuml.common import input_to_cuml_array
from cuml.internals.array import array_to_memory_order
from cuml.internals.safe_imports import cpu_only_import
from cuml.internals.safe_imports import gpu_only_import
from cuml.internals.safe_imports import (
cpu_only_import,
gpu_only_import,
gpu_only_import_from,
)

cudf = gpu_only_import("cudf")
cp = gpu_only_import("cupy")
Expand Down Expand Up @@ -138,7 +141,6 @@ def _stratify_split(
if hasattr(X, "__cuda_array_interface__") or isinstance(
X, cupyx.scipy.sparse.csr_matrix
):

X_train_i = cp.array(
X[perm_indices_class_i[: n_i[i]]], order=x_order
)
Expand Down Expand Up @@ -244,11 +246,11 @@ def _approximate_mode(class_counts, n_draws, rng):
def train_test_split(
X,
y=None,
test_size: Union[float, int] = None,
train_size: Union[float, int] = None,
test_size: Optional[Union[float, int]] = None,
train_size: Optional[Union[float, int]] = None,
shuffle: bool = True,
random_state: Union[
int, cp.random.RandomState, np.random.RandomState
random_state: Optional[
Union[int, cp.random.RandomState, np.random.RandomState]
] = None,
stratify=None,
):
Expand Down

0 comments on commit 2b5aa3e

Please sign in to comment.