Skip to content

Commit

Permalink
[ENH] Let cuDF handle input types for label encoder. (rapidsai#5783)
Browse files Browse the repository at this point in the history
cuDF handles more types than the label encoder currently does (like torch tensor). This PR delegates the type checking to cuDF.

- Let cuDF handle input types for label encoder.
- Small cleanups.

Authors:
  - Jiaming Yuan (https://github.com/trivialfis)
  - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#5783
  • Loading branch information
trivialfis authored Mar 4, 2024
1 parent b0ba340 commit e3b898f
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 73 deletions.
2 changes: 2 additions & 0 deletions python/cuml/_thirdparty/sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all):
if not isinstance(attributes, (list, tuple)):
attributes = [attributes]
attrs = all_or_any([hasattr(estimator, attr) for attr in attributes])
elif hasattr(estimator, "__sklearn_is_fitted__"):
attrs = estimator.__sklearn_is_fitted__()
else:
attrs = [v for v in vars(estimator)
if v.endswith("_") and not v.startswith("__")]
Expand Down
80 changes: 31 additions & 49 deletions python/cuml/preprocessing/LabelEncoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,16 +14,27 @@
# limitations under the License.
#

from cuml.common.exceptions import NotFittedError
from cuml.internals.safe_imports import cpu_only_import_from
from cuml import Base
from cuml.internals.safe_imports import cpu_only_import
from cuml.internals.safe_imports import gpu_only_import
from typing import TYPE_CHECKING

cudf = gpu_only_import("cudf")
cp = gpu_only_import("cupy")
np = cpu_only_import("numpy")
pdSeries = cpu_only_import_from("pandas", "Series")
from cuml import Base
from cuml._thirdparty.sklearn.utils.validation import check_is_fitted
from cuml.common.exceptions import NotFittedError
from cuml.internals.safe_imports import (
cpu_only_import,
cpu_only_import_from,
gpu_only_import,
)

if TYPE_CHECKING:
import cudf
import cupy as cp
import numpy as np
from pandas import Series as pdSeries
else:
cudf = gpu_only_import("cudf")
cp = gpu_only_import("cupy")
np = cpu_only_import("numpy")
pdSeries = cpu_only_import_from("pandas", "Series")


class LabelEncoder(Base):
Expand Down Expand Up @@ -125,7 +136,7 @@ def __init__(
handle=None,
verbose=False,
output_type=None,
):
) -> None:

super().__init__(
handle=handle, verbose=verbose, output_type=output_type
Expand All @@ -136,13 +147,8 @@ def __init__(
self._fitted: bool = False
self.handle_unknown = handle_unknown

def _check_is_fitted(self):
if not self._fitted:
msg = (
"This LabelEncoder instance is not fitted yet. Call 'fit' "
"with appropriate arguments before using this estimator."
)
raise NotFittedError(msg)
def __sklearn_is_fitted__(self) -> bool:
return self.classes_ is not None

def _validate_keywords(self):
if self.handle_unknown not in ("error", "ignore"):
Expand Down Expand Up @@ -174,17 +180,13 @@ def fit(self, y, _classes=None):
self._validate_keywords()

if _classes is None:
y = (
self._to_cudf_series(y)
.drop_duplicates()
.sort_values(ignore_index=True)
) # dedupe and sort
# dedupe and sort
y = cudf.Series(y).drop_duplicates().sort_values(ignore_index=True)
self.classes_ = y
else:
self.classes_ = _classes

self.dtype = y.dtype if y.dtype != cp.dtype("O") else str
self._fitted = True
return self

def transform(self, y) -> cudf.Series:
Expand All @@ -211,11 +213,9 @@ def transform(self, y) -> cudf.Series:
KeyError
if a category appears that was not seen in `fit`
"""
y = self._to_cudf_series(y)
check_is_fitted(self)

self._check_is_fitted()

y = y.astype("category")
y = cudf.Series(y, dtype="category")

encoded = y.cat.set_categories(self.classes_)._column.codes
encoded = cudf.Series(encoded, index=y.index)
Expand All @@ -233,13 +233,12 @@ def fit_transform(self, y, z=None) -> cudf.Series:
`LabelEncoder().fit(y).transform(y)`
"""

y = self._to_cudf_series(y)
y = cudf.Series(y)
self.dtype = y.dtype if y.dtype != cp.dtype("O") else str

y = y.astype("category")
self.classes_ = y._column.categories

self._fitted = True
return cudf.Series(y._column.codes, index=y.index)

def inverse_transform(self, y: cudf.Series) -> cudf.Series:
Expand All @@ -258,9 +257,9 @@ def inverse_transform(self, y: cudf.Series) -> cudf.Series:
Reverted labels
"""
# check LabelEncoder is fitted
self._check_is_fitted()
check_is_fitted(self)
# check input type is cudf.Series
y = self._to_cudf_series(y)
y = cudf.Series(y)

# check if ord_label out of bound
ord_label = y.unique()
Expand All @@ -285,20 +284,3 @@ def get_param_names(self):
return super().get_param_names() + [
"handle_unknown",
]

def _to_cudf_series(self, y):
if isinstance(y, pdSeries):
y = cudf.from_pandas(y)
elif isinstance(y, cp.ndarray):
y = cudf.Series(y)
elif isinstance(y, np.ndarray):
y = cudf.Series(y)
elif not isinstance(y, cudf.Series):
msg = (
"input should be either 'cupy.ndarray'"
" or 'numpy.ndarray' or 'pandas.Series',"
" or 'cudf.Series'"
"got {0}.".format(type(y))
)
raise TypeError(msg)
return y
17 changes: 9 additions & 8 deletions python/cuml/tests/dask/test_dask_label_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from cuml.common.exceptions import NotFittedError
import pytest
from cuml.internals.safe_imports import cpu_only_import

import cuml
from cuml._thirdparty.sklearn.utils.validation import check_is_fitted
from cuml.common.exceptions import NotFittedError
from cuml.dask.preprocessing.LabelEncoder import LabelEncoder
from cuml.internals.safe_imports import gpu_only_import
from cuml.internals.safe_imports import cpu_only_import, gpu_only_import

cudf = gpu_only_import("cudf")
np = cpu_only_import("numpy")
Expand Down Expand Up @@ -51,7 +52,7 @@ def test_labelencoder_transform(length, cardinality, client):
tmp = cudf.Series(np.random.choice(cardinality, (length,)))
df = dask_cudf.from_cudf(tmp, npartitions=len(client.has_what()))
le = LabelEncoder().fit(df)
assert le._fitted
check_is_fitted(le)

encoded = le.transform(df)

Expand All @@ -69,7 +70,7 @@ def test_labelencoder_unseen(client):
npartitions=len(client.has_what()),
)
le = LabelEncoder().fit(df)
assert le._fitted
check_is_fitted(le)

with pytest.raises(KeyError):
tmp = dask_cudf.from_cudf(
Expand Down Expand Up @@ -141,7 +142,7 @@ def test_inverse_transform(
le.fit_transform(orig_label)
else:
le.fit(orig_label)
assert le._fitted is True
check_is_fitted(le)

# test if inverse_transform is correct
reverted = le.inverse_transform(ord_label)
Expand Down Expand Up @@ -175,7 +176,7 @@ def test_empty_input(empty, ord_label, client):
ord_label = dask_cudf.from_cudf(ord_label, npartitions=n_workers)
le = LabelEncoder()
le.fit(empty)
assert le._fitted is True
check_is_fitted(le)

# test if correctly raies ValueError
with pytest.raises(ValueError, match="y contains previously unseen label"):
Expand All @@ -184,7 +185,7 @@ def test_empty_input(empty, ord_label, client):
# check fit_transform()
le = LabelEncoder()
transformed = le.fit_transform(empty).compute()
assert le._fitted is True
check_is_fitted(le)
assert len(transformed) == 0


Expand Down
45 changes: 29 additions & 16 deletions python/cuml/tests/test_label_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from cuml.common.exceptions import NotFittedError
import pytest
from cuml.internals.safe_imports import cpu_only_import

from cuml._thirdparty.sklearn.utils.validation import check_is_fitted
from cuml.common.exceptions import NotFittedError
from cuml.internals.safe_imports import cpu_only_import, gpu_only_import
from cuml.preprocessing.LabelEncoder import LabelEncoder
from cuml.internals.safe_imports import gpu_only_import

pd = cpu_only_import("pandas")
cudf = gpu_only_import("cudf")
np = cpu_only_import("numpy")
cp = gpu_only_import("cupy")
Expand Down Expand Up @@ -46,7 +48,7 @@ def test_labelencoder_transform(length, cardinality):
"""Try fitting and then encoding a small subset of the df"""
df = cudf.Series(np.random.choice(cardinality, (length,)))
le = LabelEncoder().fit(df)
assert le._fitted
check_is_fitted(le)

subset = df.iloc[0 : df.shape[0] // 2]
encoded = le.transform(subset)
Expand All @@ -62,7 +64,7 @@ def test_labelencoder_unseen():
"""Try encoding a value that was not present during fitting"""
df = cudf.Series(np.random.choice(10, (10,)))
le = LabelEncoder().fit(df)
assert le._fitted
check_is_fitted(le)

with pytest.raises(KeyError):
le.transform(cudf.Series([-1]))
Expand All @@ -72,7 +74,7 @@ def test_labelencoder_unfitted():
"""Try calling `.transform()` without fitting first"""
df = cudf.Series(np.random.choice(10, (10,)))
le = LabelEncoder()
assert not le._fitted
assert not le.__sklearn_is_fitted__()

with pytest.raises(NotFittedError):
le.transform(df)
Expand Down Expand Up @@ -117,7 +119,7 @@ def test_inverse_transform(
le.fit_transform(orig_label)
else:
le.fit(orig_label)
assert le._fitted is True
check_is_fitted(le)

# test if inverse_transform is correct
reverted = le.inverse_transform(ord_label)
Expand All @@ -132,7 +134,7 @@ def test_unfitted_inverse_transform():
"""Try calling `.inverse_transform()` without fitting first"""
df = cudf.Series(np.random.choice(10, (10,)))
le = LabelEncoder()
assert not le._fitted
assert not le.__sklearn_is_fitted__()

with pytest.raises(NotFittedError):
le.transform(df)
Expand All @@ -145,7 +147,7 @@ def test_empty_input(empty, ord_label):
# prepare LabelEncoder
le = LabelEncoder()
le.fit(empty)
assert le._fitted is True
check_is_fitted(le)

# test if correctly raies ValueError
with pytest.raises(ValueError, match="y contains previously unseen label"):
Expand All @@ -154,7 +156,7 @@ def test_empty_input(empty, ord_label):
# check fit_transform()
le = LabelEncoder()
transformed = le.fit_transform(empty)
assert le._fitted is True
check_is_fitted(le)
assert len(transformed) == 0


Expand Down Expand Up @@ -187,18 +189,29 @@ def _array_to_similarity_mat(x):

@pytest.mark.parametrize("length", [10, 1000])
@pytest.mark.parametrize("cardinality", [5, 10, 50])
@pytest.mark.parametrize("dtype", ["cupy", "numpy"])
def test_labelencoder_fit_transform_cupy_numpy(length, cardinality, dtype):
"""Try encoding the cupy array"""
@pytest.mark.parametrize("dtype", ["cupy", "numpy", "pd"])
def test_labelencoder_fit_transform_cupy_numpy_pd(length, cardinality, dtype):
"""Try encoding with various types"""
x = cp.random.choice(cardinality, (length,))
# to series
if dtype == "numpy":
x = x.get()
elif dtype == "pd":
x = pd.Series(x.get())
encoded = LabelEncoder().fit_transform(x)

x_arr = _array_to_similarity_mat(x)
if dtype == "pd":
x_arr = _df_to_similarity_mat(x)
else:
x_arr = _array_to_similarity_mat(x)

encoded_arr = _array_to_similarity_mat(encoded.values)
if dtype == "numpy":

# to array
if dtype == "numpy" or dtype == "pd":
encoded_arr = encoded_arr.get()
if dtype == "pd":
x = x.to_numpy()
assert ((encoded_arr == encoded_arr.T) == (x == x_arr.T)).all()


Expand Down Expand Up @@ -229,7 +242,7 @@ def test_inverse_transform_cupy_numpy(
le.fit_transform(orig_label)
else:
le.fit(orig_label)
assert le._fitted is True
check_is_fitted(le)

# test if inverse_transform is correct
reverted = le.inverse_transform(ord_label)
Expand Down

0 comments on commit e3b898f

Please sign in to comment.