You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
In our latest nightlies, CUDA11.8, Python 3.10, loading the xgboost model in the Forest Inference Demo notebook says that the output_type should be set to output_type=self.output_type and is currently set to mirror. The error is not helpful and may be currently incorrect, and the docs don't provide that possible input/may require an update.
Steps/Code to reproduce bug
generate model file with this code:
import xgboost as xgb
import cupy
import os
from cuml.testing.utils import array_equal
from cuml.internals.import_utils import has_xgboost
from cuml.model_selection import train_test_split
from cuml.datasets import make_classification
# synthetic data size
n_rows = 10000
n_columns = 100
n_categories = 2
random_state = cupy.random.RandomState(43210)
# fraction of data used for model training
train_size = 0.8
# trained model output filename
model_path = 'xgb.model'
# num of iterations for which xgboost is trained
num_rounds = 100
# maximum tree depth in each training round
max_depth = 20
# create the dataset
X, y = make_classification(
n_samples=n_rows,
n_features=n_columns,
n_informative=int(n_columns/5),
n_classes=n_categories,
random_state=42
)
# convert the dataset to float32
X = X.astype('float32')
y = y.astype('float32')
# split the dataset into training and validation splits
X_train, X_validation, y_train, y_validation = train_test_split(X, y, train_size=0.8)
# set the xgboost model parameters
params = {
'verbosity': 0,
'eval_metric':'error',
'objective':'binary:logistic',
'max_depth': max_depth,
'tree_method': 'gpu_hist'
}
# convert training data into DMatrix
dtrain = xgb.DMatrix(X_train, label=y_train)
# train the xgboost model
trained_model = xgb.train(params, dtrain, num_rounds)
# save the trained xgboost model
trained_model.save_model(model_path)
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
/rapids/notebooks/cuml/Untitled.ipynb in <module>
----> 1 fil_model = ForestInference.load(
2 filename="xgb.model",
3 algo='BATCH_TREE_REORG',
4 output_class=True,
5 threshold=0.50,
/opt/conda/envs/rapids/lib/python3.10/site-packages/cuml/internals/api_decorators.py in wrapper(*args, **kwargs)
190 ret = func(*args, **kwargs)
191 else:
--> 192 return func(*args, **kwargs)
193
194 return cm.process_return(ret)
fil.pyx in cuml.fil.fil.ForestInference.load()
/opt/conda/envs/rapids/lib/python3.10/site-packages/cuml/internals/api_decorators.py in inner_f(*args, **kwargs)
340 kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
341
--> 342 return func(**kwargs)
343
344 # Set this flag to prevent auto adding this decorator twice
fil.pyx in cuml.fil.fil.ForestInference.__init__()
/opt/conda/envs/rapids/lib/python3.10/site-packages/cuml/internals/api_decorators.py in inner_f(*args, **kwargs)
340 kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
341
--> 342 return func(**kwargs)
343
344 # Set this flag to prevent auto adding this decorator twice
base.pyx in cuml.internals.base.Base.__init__()
base.pyx in cuml.internals.base._check_output_type_str()
AssertionError: Cannot pass output_type='mirror' in Base.__init__(). Did you forget to pass `output_type=self.output_type` to a child estimator? Currently `cuml.global_settings.output_type==`mirror`
and nothing in the docs help me figure out how to fix it properly
Expected behavior
The model should load and/or the error message and corresponding docs should help me fix the issue so that I can successfully load the model
Environment details (please complete the following information):
With Python 3.10, there appears to be an issue with the interaction between the staticmethod decorator and Cython. This workaround temporarily switches all staticmethods in FIL to classmethods until the underlying issue can be sorted.
Resolve#5200.
Authors:
- William Hicks (https://github.com/wphicks)
Approvers:
- Dante Gama Dessavre (https://github.com/dantegd)
URL: #5202
With Python 3.10, there appears to be an issue with the interaction between the staticmethod decorator and Cython. This workaround temporarily switches all staticmethods in FIL to classmethods until the underlying issue can be sorted.
Resolverapidsai#5200.
Authors:
- William Hicks (https://github.com/wphicks)
Approvers:
- Dante Gama Dessavre (https://github.com/dantegd)
URL: rapidsai#5202
Describe the bug
In our latest nightlies, CUDA11.8, Python 3.10, loading the xgboost model in the Forest Inference Demo notebook says that the
output_type
should be set tooutput_type=self.output_type
and is currently set tomirror
. The error is not helpful and may be currently incorrect, and the docs don't provide that possible input/may require an update.Steps/Code to reproduce bug
The output will be this error:
and nothing in the docs help me figure out how to fix it properly
Expected behavior
The model should load and/or the error message and corresponding docs should help me fix the issue so that I can successfully load the model
Environment details (please complete the following information):
docker run --gpus all --rm -it --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 -p 8888:8888 -p 8787:8787 -p 8786:8786 rapidsai/rapidsai-core-nightly-arm64:23.02-cuda11.8-runtime-ubuntu22.04-py3.10
Additional context
Add any other context about the problem here.
The text was updated successfully, but these errors were encountered: