Skip to content

Commit

Permalink
Support CPU object for train_test_split (#5873)
Browse files Browse the repository at this point in the history
This PR adds support to CPU objects for `train_test_split`, leveraging the
input conversion tools defined in `input_utils.py`. This PR also adds
`output_to_df_obj_like` API that converts CumlArray back to a series/dataframe,
matching metadata from input.

In the meantime, this PR reimplements majority of `train_test_split` by
centralizing indices compute and gather. This reduces the number of kernels
launched, especially in the cases where stratify keys are provided.

Closes #5619

Authors:
  - Michael Wang (https://github.com/isVoid)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #5873
  • Loading branch information
isVoid authored Apr 30, 2024
1 parent 3c1a72c commit 1609fcd
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 254 deletions.
50 changes: 50 additions & 0 deletions python/cuml/internals/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#

from collections import namedtuple
from typing import Literal

from cuml.internals.array import CumlArray
from cuml.internals.array_sparse import SparseCumlArray
Expand Down Expand Up @@ -46,6 +47,7 @@
cp_ndarray = gpu_only_import_from("cupy", "ndarray")
CudfSeries = gpu_only_import_from("cudf", "Series")
CudfDataFrame = gpu_only_import_from("cudf", "DataFrame")
CudfIndex = gpu_only_import_from("cudf", "Index")
DaskCudfSeries = gpu_only_import_from("dask_cudf", "Series")
DaskCudfDataFrame = gpu_only_import_from("dask_cudf", "DataFrame")
np_ndarray = cpu_only_import_from("numpy", "ndarray")
Expand All @@ -64,6 +66,7 @@
nvtx_annotate = gpu_only_import_from("nvtx", "annotate", alt=null_decorator)
PandasSeries = cpu_only_import_from("pandas", "Series")
PandasDataFrame = cpu_only_import_from("pandas", "DataFrame")
PandasIndex = cpu_only_import_from("pandas", "Index")

cuml_array = namedtuple("cuml_array", "array n_rows n_cols dtype")

Expand All @@ -73,13 +76,15 @@
np_ndarray: "numpy",
PandasSeries: "pandas",
PandasDataFrame: "pandas",
PandasIndex: "pandas",
}


try:
_input_type_to_str[cp_ndarray] = "cupy"
_input_type_to_str[CudfSeries] = "cudf"
_input_type_to_str[CudfDataFrame] = "cudf"
_input_type_to_str[CudfIndex] = "cudf"
_input_type_to_str[NumbaDeviceNDArrayBase] = "numba"
except UnavailableError:
pass
Expand Down Expand Up @@ -160,9 +165,15 @@ def get_supported_input_type(X):
if isinstance(X, PandasSeries):
return PandasSeries

if isinstance(X, PandasIndex):
return PandasIndex

if isinstance(X, CudfDataFrame):
return CudfDataFrame

if isinstance(X, CudfIndex):
return CudfIndex

try:
if numba_cuda.devicearray.is_cuda_ndarray(X):
return numba_cuda.devicearray.DeviceNDArrayBase
Expand Down Expand Up @@ -205,6 +216,21 @@ def determine_array_type(X):
return _input_type_to_str.get(gen_type, None)


def determine_df_obj_type(X):
if X is None:
return None

# Get the generic type
gen_type = get_supported_input_type(X)

if gen_type in (CudfDataFrame, PandasDataFrame):
return "dataframe"
elif gen_type in (CudfSeries, PandasSeries):
return "series"

return None


def determine_array_dtype(X):

if X is None:
Expand Down Expand Up @@ -575,3 +601,27 @@ def sparse_scipy_to_cp(sp, dtype):
v = cp.asarray(values, dtype=dtype)

return cupyx.scipy.sparse.coo_matrix((v, (r, c)), sp.shape)


def output_to_df_obj_like(
X_out: CumlArray, X_in, output_type: Literal["series", "dataframe"]
):
"""Cast CumlArray `X_out` to the dataframe / series type as `X_in`
`CumlArray` abstracts away the dataframe / series metadata, when API
methods needs to return a dataframe / series matching original input
metadata, this function can copy input metadata to output.
"""

if output_type not in ["series", "dataframe"]:
raise ValueError(
f'output_type must be either "series" or "dataframe" : {output_type}'
)

out = None
if output_type == "series":
out = X_out.to_output("series")
out.name = X_in.name
elif output_type == "dataframe":
out = X_out.to_output("dataframe")
out.columns = X_in.columns
return out
Loading

0 comments on commit 1609fcd

Please sign in to comment.