Skip to content

Commit

Permalink
Dataframe Index as columns in ColumnTransformer (rapidsai#4481)
Browse files Browse the repository at this point in the history
Answers rapidsai#4435

Authors:
  - Victor Lafargue (https://github.com/viclafargue)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#4481
  • Loading branch information
viclafargue authored Jan 18, 2022
1 parent df4c61a commit 71f50a0
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
import cupy as np
import numba

import pandas as pd
import cudf

import cuml
from cuml.internals.global_settings import _global_settings_data
from cuml.common.array_sparse import SparseCumlArray
Expand Down Expand Up @@ -220,6 +223,9 @@ def _safe_indexing(X, indices, *, axis=0):
" column). Got {} instead.".format(axis)
)

if isinstance(indices, (pd.Index, cudf.Index)):
indices = list(indices)

indices_dtype = _determine_key_type(indices)

if axis == 0 and indices_dtype == 'str':
Expand Down
16 changes: 15 additions & 1 deletion python/cuml/test/test_compose.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -268,3 +268,17 @@ def test_make_column_selector():

assert_allclose(t_X, sk_t_X)
assert type(t_X) == type(X)


def test_column_transformer_index(clf_dataset): # noqa: F811
X_np, X = clf_dataset

if not isinstance(X, (pdDataFrame, cuDataFrame)):
pytest.skip()

cu_transformers = [
("scaler", cuStandardScaler(), X.columns)
]

transformer = cuColumnTransformer(cu_transformers)
transformer.fit_transform(X)

0 comments on commit 71f50a0

Please sign in to comment.