-
Notifications
You must be signed in to change notification settings - Fork 540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] .predict_proba on fitted Pipeline object with a ColumnTransformer step raises exception #4368
Comments
Thanks for including a reproducible example. From an initial triage, there are at least a couple of issues intermingling here, which may partially cause this behavior. It looks like a cuML Pipeline with a cuML OneHotEncoder fails: import pandas as pd
from sklearn.pipeline import Pipeline as sk_Pipeline
from sklearn.preprocessing import OneHotEncoder as sk_OneHotEncoder
from cuml.experimental.preprocessing import ColumnTransformer as cu_ColumnTransformer
from cuml.preprocessing import OneHotEncoder as cu_OneHotEncoder
from cuml.pipeline import Pipeline as cu_Pipeline
X_train = pd.DataFrame(
[{"id": 1, "cat": "a", "num": 1.0, "extra": 5},
{"id": 2, "cat": "a", "num": 2.0, "extra": -1},
{"id": 3, "cat": "b", "num": 3.0, "extra": 100}]
)
# skl Pipeline, skl OHE
categorical_vars = ["cat"]
categorical_transformer = sk_Pipeline(
[
("ordinal", sk_OneHotEncoder(sparse=False)),
]
)
print(categorical_transformer.fit(X_train[categorical_vars]))
# cuml Pipeline, skl OHE
categorical_transformer = cu_Pipeline(
[
("ordinal", sk_OneHotEncoder(sparse=False)),
]
)
print(categorical_transformer.fit(X_train[categorical_vars]))
# cuml Pipeline, cuml OHE
categorical_transformer = cu_Pipeline(
[
("ordinal", cu_OneHotEncoder(sparse=False)),
]
)
print(categorical_transformer.fit(X_train[categorical_vars]))
Pipeline(steps=[('ordinal', OneHotEncoder(sparse=False))])
Pipeline(steps=[('ordinal', OneHotEncoder(sparse=False))])
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_65047/1591776796.py in <module>
39 )
40
---> 41 print(categorical_transformer.fit(X_train[categorical_vars]))
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/sklearn/pipeline.py in fit(self, X, y, **fit_params)
392 if self._final_estimator != "passthrough":
393 fit_params_last_step = fit_params_steps[self.steps[-1][0]]
--> 394 self._final_estimator.fit(Xt, y, **fit_params_last_step)
395
396 return self
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/cuml/internals/api_decorators.py in inner(*args, **kwargs)
358 def inner(*args, **kwargs):
359 with self._recreate_cm(func, args):
--> 360 return func(*args, **kwargs)
361
362 return inner
TypeError: fit() takes 2 positional arguments but 3 were given It also looks like a cuML ColumnTransformer can fail when a Pipeline has multiple steps: import pandas as pd
from sklearn.pipeline import Pipeline as sk_Pipeline
from sklearn.impute import SimpleImputer as sk_SimpleImputer
from sklearn.preprocessing import StandardScaler as sk_StandardScaler
from sklearn.compose import ColumnTransformer as sk_ColumnTransformer
from cuml.experimental.preprocessing import ColumnTransformer as cu_ColumnTransformer
from cuml.preprocessing import SimpleImputer as cu_SimpleImputer
from cuml.preprocessing import StandardScaler as cu_StandardScaler
from cuml.pipeline import Pipeline as cu_Pipeline
X_train = pd.DataFrame(
[{"id": 1, "cat": "a", "num": 1.0, "extra": 5},
{"id": 2, "cat": "a", "num": 2.0, "extra": -1},
{"id": 3, "cat": "b", "num": 3.0, "extra": 100}]
)
## all cuml except ColumnTransformer
numeric_vars = ["num"]
numeric_transformer = cu_Pipeline(
steps=[
("imputer", cu_SimpleImputer(strategy="mean")),
("scaler", cu_StandardScaler()),
]
)
preprocessor = sk_ColumnTransformer(
transformers=[
("numeric", numeric_transformer, numeric_vars),
],
)
preprocessor.fit_transform(X_train) # works
preprocessor.fit(X_train); preprocessor.transform(X_train) # works
array([[-1.22474487],
[ 0. ],
[ 1.22474487]]) ## cuml ColumnTransformer with single step pipeline
numeric_vars = ["num"]
numeric_transformer = cu_Pipeline(
steps=[
# ("imputer", cu_SimpleImputer(strategy="mean")),
("scaler2", cu_StandardScaler()),
]
)
preprocessor = cu_ColumnTransformer(
transformers=[
("numeric", numeric_transformer, numeric_vars),
],
)
preprocessor.fit_transform(X_train) # works
preprocessor.fit(X_train); preprocessor.transform(X_train) # works
array([[-1.22474487],
[ 0. ],
[ 1.22474487]]) ## cuml ColumnTransformer with two step pipeline (also happens with SimpleImputer)
numeric_vars = ["num"]
numeric_transformer = cu_Pipeline(
steps=[
("scaler1", cu_StandardScaler()),
("scaler2", cu_StandardScaler()),
]
)
preprocessor = cu_ColumnTransformer(
transformers=[
("numeric", numeric_transformer, numeric_vars),
],
)
preprocessor.fit_transform(X_train) # works
preprocessor.fit(X_train); preprocessor.transform(X_train) # works
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
/tmp/ipykernel_65651/4058867824.py in <module>
15
16 preprocessor.fit_transform(X_train) # works
---> 17 preprocessor.fit(X_train); preprocessor.transform(X_train) # works
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/cuml/internals/api_decorators.py in inner_get(*args, **kwargs)
584
585 # Call the function
--> 586 ret_val = func(*args, **kwargs)
587
588 return cm.process_return(ret_val)
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/cuml/_thirdparty/sklearn/preprocessing/_column_transformer.py in transform(self, X)
932 "data given during fit."
933 )
--> 934 Xs = self._fit_transform(X, None, _transform_one, fitted=True)
935 self._validate_output(Xs)
936
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/cuml/_thirdparty/sklearn/preprocessing/_column_transformer.py in _fit_transform(self, X, y, func, fitted)
804 self._iter(fitted=fitted, replace_strings=True))
805 try:
--> 806 return Parallel(n_jobs=self.n_jobs)(
807 delayed(func)(
808 transformer=clone(trans) if not fitted else trans,
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/joblib/parallel.py in __call__(self, iterable)
1041 # remaining jobs.
1042 self._iterating = False
-> 1043 if self.dispatch_one_batch(iterator):
1044 self._iterating = self._original_iterator is not None
1045
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/joblib/parallel.py in dispatch_one_batch(self, iterator)
859 return False
860 else:
--> 861 self._dispatch(tasks)
862 return True
863
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/joblib/parallel.py in _dispatch(self, batch)
777 with self._lock:
778 job_idx = len(self._jobs)
--> 779 job = self._backend.apply_async(batch, callback=cb)
780 # A job can complete so quickly than its callback is
781 # called before we get here, causing self._jobs to
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/joblib/_parallel_backends.py in apply_async(self, func, callback)
206 def apply_async(self, func, callback=None):
207 """Schedule a func to be run"""
--> 208 result = ImmediateResult(func)
209 if callback:
210 callback(result)
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/joblib/_parallel_backends.py in __init__(self, batch)
570 # Don't delay the application, to avoid keeping the input
571 # arguments in memory
--> 572 self.results = batch()
573
574 def get(self):
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/joblib/parallel.py in __call__(self)
260 # change the default number of processes to -1
261 with parallel_backend(self._backend, n_jobs=self._n_jobs):
--> 262 return [func(*args, **kwargs)
263 for func, args, kwargs in self.items]
264
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/joblib/parallel.py in <listcomp>(.0)
260 # change the default number of processes to -1
261 with parallel_backend(self._backend, n_jobs=self._n_jobs):
--> 262 return [func(*args, **kwargs)
263 for func, args, kwargs in self.items]
264
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/cuml/_thirdparty/sklearn/preprocessing/_column_transformer.py in __call__(self, *args, **kwargs)
359 def __call__(self, *args, **kwargs):
360 _global_settings_data.shared_state = self.config
--> 361 return self.function(*args, **kwargs)
362
363
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/cuml/_thirdparty/sklearn/preprocessing/_column_transformer.py in _transform_one(transformer, X, y, weight, **fit_params)
285
286 def _transform_one(transformer, X, y, weight, **fit_params):
--> 287 res = transformer.transform(X).to_output('cupy')
288 # if we have a weight for this transformer, multiply output
289 if weight is None:
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/sklearn/utils/metaestimators.py in <lambda>(*args, **kwargs)
111
112 # lambda, but not partial, allows help() to work with update_wrapper
--> 113 out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs) # noqa
114 else:
115
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/sklearn/pipeline.py in transform(self, X)
645 Xt = X
646 for _, _, transform in self._iter():
--> 647 Xt = transform.transform(Xt)
648 return Xt
649
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/cuml/internals/api_decorators.py in inner_get(*args, **kwargs)
584
585 # Call the function
--> 586 ret_val = func(*args, **kwargs)
587
588 return cm.process_return(ret_val)
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/cuml/_thirdparty/sklearn/preprocessing/_data.py in transform(self, X, copy)
793 copy = copy if copy is not None else self.copy
794
--> 795 X = self._validate_data(X, reset=False,
796 accept_sparse=['csr', 'csc'], copy=copy,
797 estimator=self, dtype=FLOAT_DTYPES,
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/cuml/_thirdparty/sklearn/utils/skl_dependencies.py in _validate_data(self, X, y, reset, validate_separately, **check_params)
109 f"requires y to be passed, but the target y is None."
110 )
--> 111 X = check_array(X, **check_params)
112 out = X
113 else:
~/conda/envs/rapids-21.12/lib/python3.8/site-packages/cuml/thirdparty_adapters/adapters.py in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, warn_on_dtype, estimator)
248 raise ValueError("Not enough samples")
249
--> 250 if ensure_min_features > 0 and hasshape and array.ndim == 2:
251 n_features = array.shape[1]
252 if n_features < ensure_min_features:
AttributeError: 'CumlArray' object has no attribute 'ndim' |
This issue has been labeled |
This issue has been labeled |
…th a ColumnTransformer step (#4774) This PR fixes a subtle bug in check_array of cuml.thirdparty_adapters.adapters which is the primary cause for the bug. Fix #4368. Authors: - https://github.com/VamsiTallam95 - Ray Douglass (https://github.com/raydouglass) Approvers: - Dante Gama Dessavre (https://github.com/dantegd) URL: #4774
…th a ColumnTransformer step (rapidsai#4774) This PR fixes a subtle bug in check_array of cuml.thirdparty_adapters.adapters which is the primary cause for the bug. Fix rapidsai#4368. Authors: - https://github.com/VamsiTallam95 - Ray Douglass (https://github.com/raydouglass) Approvers: - Dante Gama Dessavre (https://github.com/dantegd) URL: rapidsai#4774
Describe the bug
When using the
ColumnTransformer
fromcuml.experimental.preprocessing
in an already fit Pipeline, the methodspredict
/predict_proba
raise exceptions stating that X has a mismatched number of features, even though the data is of the same shape as the DataFrame passed tofit
.Steps/Code to reproduce bug
Here's a minimal example, preserving the types and shape of my real data and the structure of the pipeline (same encoders, imputers, and classifier used):
The stack trace output is immediately below. Using
sklearn
equivalents,predict_proba
executes without exception.Expected behavior
Calling
predict_proba
on a fitted pipeline should return an array of predictions.If there is a larger error with the input to
predict_proba
orpredict
, a more descriptive error would be very much appreciated as well.Environment details (please complete the following information):
rapidsai/rapidsai-core:21.08-cuda11.0-base-ubuntu18.04-py3.8
The text was updated successfully, but these errors were encountered: