From 71f50a063fd8b37f2f3f9493d56a63351aebdc46 Mon Sep 17 00:00:00 2001 From: Victor Lafargue Date: Tue, 18 Jan 2022 01:07:38 +0100 Subject: [PATCH] Dataframe Index as columns in ColumnTransformer (#4481) Answers #4435 Authors: - Victor Lafargue (https://github.com/viclafargue) Approvers: - Dante Gama Dessavre (https://github.com/dantegd) URL: https://github.com/rapidsai/cuml/pull/4481 --- .../sklearn/preprocessing/_column_transformer.py | 6 ++++++ python/cuml/test/test_compose.py | 16 +++++++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/python/cuml/_thirdparty/sklearn/preprocessing/_column_transformer.py b/python/cuml/_thirdparty/sklearn/preprocessing/_column_transformer.py index e248656a83..f3d9177d5f 100644 --- a/python/cuml/_thirdparty/sklearn/preprocessing/_column_transformer.py +++ b/python/cuml/_thirdparty/sklearn/preprocessing/_column_transformer.py @@ -30,6 +30,9 @@ import cupy as np import numba +import pandas as pd +import cudf + import cuml from cuml.internals.global_settings import _global_settings_data from cuml.common.array_sparse import SparseCumlArray @@ -220,6 +223,9 @@ def _safe_indexing(X, indices, *, axis=0): " column). Got {} instead.".format(axis) ) + if isinstance(indices, (pd.Index, cudf.Index)): + indices = list(indices) + indices_dtype = _determine_key_type(indices) if axis == 0 and indices_dtype == 'str': diff --git a/python/cuml/test/test_compose.py b/python/cuml/test/test_compose.py index f80c390662..34375c2ecf 100644 --- a/python/cuml/test/test_compose.py +++ b/python/cuml/test/test_compose.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -268,3 +268,17 @@ def test_make_column_selector(): assert_allclose(t_X, sk_t_X) assert type(t_X) == type(X) + + +def test_column_transformer_index(clf_dataset): # noqa: F811 + X_np, X = clf_dataset + + if not isinstance(X, (pdDataFrame, cuDataFrame)): + pytest.skip() + + cu_transformers = [ + ("scaler", cuStandardScaler(), X.columns) + ] + + transformer = cuColumnTransformer(cu_transformers) + transformer.fit_transform(X)