Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] SIGABRT in CUML RF, out of bounds memory usage #4046

Closed
Tracked by #4139
pseudotensor opened this issue Jul 12, 2021 · 19 comments
Closed
Tracked by #4139

[BUG] SIGABRT in CUML RF, out of bounds memory usage #4046

pseudotensor opened this issue Jul 12, 2021 · 19 comments
Labels
? - Needs Triage Need team to review and classify bug Something isn't working inactive-90d

Comments

@pseudotensor
Copy link

pseudotensor commented Jul 12, 2021

Describe the bug
SIGABRT, seemingly from out of bounds

Steps/Code to reproduce bug

Unknown, but paraameters were just Kaggle Paribas with some various Frequency encoding features to get to (91457, 331) size.

parameters

 OrderedDict([('output_type', 'numpy'), ('random_state', 840607124), ('verbose', False), ('n_estimators', 200), ('n_bins', 128), ('split_criterion', 1), ('max_depth', 18), ('max_leaves', 1024), ('max_features', 'auto'), ('min_samples_leaf', 1), ('min_samples_split', 10), ('min_impurity_decrease', 0.0)])

For a binary classification problem.

No messages in console at all, even though ran in debug mode with verbose=4. All I got was SIGABRT and in dmesg this:

[Sun Jul 11 21:15:41 2021] NVRM: GPU at PCI:0000:01:00: GPU-0bb167f8-b3cd-8df7-9644-d5f95716e554
[Sun Jul 11 21:15:41 2021] NVRM: GPU Board Serial Number: 
[Sun Jul 11 21:15:41 2021] NVRM: Xid (PCI:0000:01:00): 13, pid=2041, Graphics SM Warp Exception on (GPC 3, TPC 3, SM 0): Out Of Range Address
[Sun Jul 11 21:15:41 2021] NVRM: Xid (PCI:0000:01:00): 13, pid=2041, Graphics SM Global Exception on (GPC 3, TPC 3, SM 0): Multiple Warp Errors
[Sun Jul 11 21:15:41 2021] NVRM: Xid (PCI:0000:01:00): 13, pid=2041, Graphics Exception: ESR 0x51df30=0xc13000e 0x51df34=0x24 0x51df28=0x4c1eb72 0x51df2c=0x174
[Sun Jul 11 21:15:41 2021] NVRM: Xid (PCI:0000:01:00): 43, pid=6304, Ch 00000088
[Sun Jul 11 21:15:54 2021] NVRM: Xid (PCI:0000:01:00): 13, pid=6304, Graphics SM Warp Exception on (GPC 4, TPC 2, SM 1): Out Of Range Address
[Sun Jul 11 21:15:54 2021] NVRM: Xid (PCI:0000:01:00): 13, pid=6304, Graphics SM Global Exception on (GPC 4, TPC 2, SM 1): Multiple Warp Errors
[Sun Jul 11 21:15:54 2021] NVRM: Xid (PCI:0000:01:00): 13, pid=6304, Graphics Exception: ESR 0x5257b0=0xc12000e 0x5257b4=0x24 0x5257a8=0x4c1eb72 0x5257ac=0x174
[Sun Jul 11 21:15:54 2021] NVRM: Xid (PCI:0000:01:00): 43, pid=8874, Ch 00000088

Expected behavior

Not to crash, be more stable.

Environment details (please complete the following information):

  • Environment location: Bare-metal
  • Linux Distro/Architecture: Ubuntu 18.04LTS
  • GPU Model/Driver: RTX2080 460.80
  • CUDA: 11.2.2
  • Method of cuDF & cuML install: conda nightly 21.08 -- nightly as of 7 days ago.

conda_list.txt.zip

Additional context

If hit again will try to produce repro. But I expect just various testing on NVIDIA's side will reveal. I've only been using CUML RF for a day and already hit this after (maybe) 200 fits on small data.

@pseudotensor pseudotensor added ? - Needs Triage Need team to review and classify bug Something isn't working labels Jul 12, 2021
@pseudotensor
Copy link
Author

pseudotensor commented Jul 12, 2021

Same kind of run, SIGABRT again. This time dmesg says:

[803642.849684] NVRM: Xid (PCI:0000:01:00): 31, pid=17188, Ch 00000060, intr 00000000. MMU Fault: ENGINE GRAPHICS GPCCLIENT_T1_0 faulted @ 0x1459_25a73000. Fault is of type FAULT_PDE ACCESS_TYPE_VIRT_READ

Basically the CUML RF is not stable/usable as-is.

this time the parameters were:

 ('output_type', 'numpy'), ('random_state', 620152258), ('verbose', False), ('n_estimators', 300), ('n_bins', 128), ('split_criterion', 0), ('max_depth', 20), ('max_leaves', 1024), ('max_features', 'auto'), ('min_samples_leaf', 1), ('min_samples_split', 2), ('min_impurity_decrease', 0.0)])

and data was (91457, 492)

@pseudotensor
Copy link
Author

('output_type', 'numpy'), ('random_state', 1037940298), ('verbose', False), ('n_estimators', 300), ('n_bins', 128), ('split_criterion', 1), ('max_depth', 16), ('max_leaves', 1024), ('max_features', 'auto'), ('min_samples_leaf', 1), ('min_samples_split', 10), ('min_impurity_decrease', 0.0)])

(91457, 498)

[805879.057252] NVRM: Xid (PCI:0000:01:00): 31, pid=26264, Ch 00000078, intr 00000000. MMU Fault: ENGINE GRAPHICS GPCCLIENT_T1_6 faulted @ 0x14bf_5f49c000. Fault is of type FAULT_PDE ACCESS_TYPE_VIRT_READ

@dantegd
Copy link
Member

dantegd commented Jul 12, 2021

tagging @vinaydes who's looking into the issue

@vinaydes
Copy link
Contributor

Without a reproducer it is going to be difficult to debug this one. I created following snippet for reproducing the crash. However trying on two different GPUs (Titan V, RTX 3070 Ti) gave me no crash.

import sys
import numpy as np
from sklearn.datasets import make_classification
from cuml.ensemble import RandomForestClassifier as cumlRFClassifier
import time

N_REPS = 50
# (91457, 331)
# OrderedDict([('output_type', 'numpy'), ('random_state', 840607124), ('verbose', False),
#  ('n_estimators', 200), ('n_bins', 128), ('split_criterion', 1), ('max_depth', 18),
#   ('max_leaves', 1024), ('max_features', 'auto'), ('min_samples_leaf', 1), 
#   ('min_samples_split', 10), ('min_impurity_decrease', 0.0)])


rf_params = {
    'n_estimators'             : 200,
    'split_criterion'          : 1,
    'bootstrap'                : True,
    'max_samples'              : 1.0,
    'max_depth'                : 18,
    'max_leaves'               : 1024,
    'max_features'             : 'auto',
    'n_bins'                   : 128,
    'min_samples_leaf'         : 1,
    'min_samples_split'        : 10,
    'min_impurity_decrease'    : 0.0,
    'accuracy_metric'          : 'mse',
    'max_batch_size'           : 128,
    'random_state'             : 840607124,
    'n_streams'                : 4,
    'output_type'              : 'numpy',
    'verbose'                   : False
}

dataset_params = {
    'n_samples'            : 91457,
    'n_features'           : 331,
    'n_informative'        : 60,
    'n_redundant'          : 0,
    'n_repeated'           : 0,
    'n_classes'            : 2,
    'n_clusters_per_class' : 5,
    'weights'              : None,
    'flip_y'               : 0.1,
    'class_sep'            : 1.0,
    'hypercube'            : True,
    'shift'                : 0.0,
    'scale'                : 1.0,
    'shuffle'              : True,
    'random_state'         : None
}

start = time.time()
for i in range(N_REPS):
    [X, y] = make_classification(**dataset_params)
    X = np.float32(X)
    y = np.int32(y)
    cuml_cls = cumlRFClassifier(**rf_params)
    cuml_cls.fit(X, y)
end = time.time()
print('Time to fit = ', end - start)

# (91457, 492)
# OrderedDict([('output_type', 'numpy'), ('random_state', 620152258), ('verbose', False),
# ('n_estimators', 300), ('n_bins', 128), ('split_criterion', 0), ('max_depth', 20),
# ('max_leaves', 1024), ('max_features', 'auto'), ('min_samples_leaf', 1),
# ('min_samples_split', 2), ('min_impurity_decrease', 0.0)])

dataset_params['n_features'] = 492

rf_params['random_state'] = 620152258
rf_params['split_criterion'] = 0
rf_params['max_depth'] = 20
rf_params['min_samples_split'] = 2

start = time.time()
for i in range(N_REPS):
    [X, y] = make_classification(**dataset_params)
    X = np.float32(X)
    y = np.int32(y)
    cuml_cls = cumlRFClassifier(**rf_params)
    cuml_cls.fit(X, y)
end = time.time()
print('Time to fit = ', end - start)

# (91457, 498)
# OrderedDict([('output_type', 'numpy'), ('random_state', 1037940298), ('verbose', False), 
# ('n_estimators', 300), ('n_bins', 128), ('split_criterion', 1), ('max_depth', 16),
# ('max_leaves', 1024), ('max_features', 'auto'), ('min_samples_leaf', 1),
# ('min_samples_split', 10), ('min_impurity_decrease', 0.0)])

dataset_params['n_features'] = 498

rf_params['random_state'] = 1037940298
rf_params['n_estimators'] = 300
rf_params['split_criterion'] = 1
rf_params['min_samples_split'] = 10

start = time.time()
for i in range(N_REPS):
    [X, y] = make_classification(**dataset_params)
    X = np.float32(X)
    y = np.int32(y)
    cuml_cls = cumlRFClassifier(**rf_params)
    cuml_cls.fit(X, y)
end = time.time()
print('Time to fit = ', end - start)

Can you try to run this snippet on the machine where you see the crash? Is any cuML RF code crashing for you? Or is it observed only for certain values of parameters?

@pseudotensor
Copy link
Author

pseudotensor commented Jul 13, 2021

Ok will try to setup repro. Happens quite often. New ones:

[Tue Jul 13 01:28:05 2021] NVRM: GPU at PCI:0000:81:00: GPU-9274a7fb-220c-e792-2b46-32fabfdd7c42
[Tue Jul 13 01:28:05 2021] NVRM: GPU Board Serial Number: 
[Tue Jul 13 01:28:05 2021] NVRM: Xid (PCI:0000:81:00): 31, pid=8202, Ch 00000008, intr 10000000. MMU Fault: ENGINE GRAPHICS GPCCLIENT_T1_5 faulted @ 0x14a5_4d4e8000. Fault is of type FAULT_PDE ACCESS_TYPE_READ
[Tue Jul 13 02:38:45 2021] NVRM: GPU at PCI:0000:02:00: GPU-51ed028d-9df3-4f93-5808-fe891b6658bc
[Tue Jul 13 02:38:45 2021] NVRM: GPU Board Serial Number: 0321317025701
[Tue Jul 13 02:38:45 2021] NVRM: Xid (PCI:0000:02:00): 31, pid=16108, Ch 00000028, intr 10000000. MMU Fault: ENGINE GRAPHICS GPCCLIENT_T1_5 faulted @ 0x151e_9555e000. Fault is of type FAULT_PDE ACCESS_TYPE_READ

@pseudotensor
Copy link
Author

pseudotensor commented Jul 24, 2021

@vinaydes thanks for patience. I've been trying out various CUML things and got back to this because it still happens all the time. This is a repro for me, but it does not fail every time I run it. It might not fail for you at all, since it is some bug that seems to be sometimes accessing wrong memory.

import pickle
import random
import numpy as np

model_class, params, X, y = pickle.load(open("sigabrt.pkl", "rb"))

model = model_class(**params)

random.seed(928529388)
np.random.seed(928529388)
model.fit(X, y)

sigabrt.pkl.zip

gives

/home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/internals/api_decorators.py:794: UserWarning: For reproducible results in Random Forest Classifier or for almost reproducible results in Random Forest Regressor, n_streams==1 is recommended. If n_streams is > 1, results may vary due to stream/thread timing differences, even when random_state is set
  return func(**kwargs)
/home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/internals/api_decorators.py:794: UserWarning: The 'use_experimental_backend' parameter is deprecated and has no effect. It will be removed in 21.10 release.
  return func(**kwargs)
/home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/internals/api_decorators.py:794: UserWarning: The 'split_algo' parameter is deprecated and has no effect. It will be removed in 21.10 release.
  return func(**kwargs)
CUDA call='cudaFreeHost(p)' at file=/workspace/.conda-bld/cuml_1627014716115/work/python/_external_repositories/raft/cpp/include/raft/mr/host/allocator.hpp line=51 failed with an illegal memory access was encountered
terminate called after throwing an instance of 'raft::cuda_error'
  what():  CUDA error encountered at: file=../src/decisiontree/batched-levelalgo/builder_base.cuh line=400: call='cudaStreamSynchronize(s)', Reason=cudaErrorIllegalAddress:an illegal memory access was encountered
Obtained 10 stack frames
#0 in /home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/common/../../../../libcuml++.so(_ZN4raft9exception18collect_call_stackEv+0x3b) [0x7f6a8b20c4db]
#1 in /home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/common/../../../../libcuml++.so(_ZN4raft10cuda_errorC2ERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE+0x5a) [0x7f6a8b20cc6a]
#2 in /home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/common/../../../../libcuml++.so(_ZN2ML2DT7BuilderINS0_21GiniObjectiveFunctionIfiiEEE7doSplitERSt6vectorINS0_4NodeIfiiEESaIS7_EEP11CUstream_st+0xd09) [0x7f6a8b2ea299]
#3 in /home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/common/../../../../libcuml++.so(_ZN2ML2DT9grow_treeINS0_21GiniObjectiveFunctionIfiiEEfiiEEvSt10shared_ptrIN4raft2mr6device9allocatorEES4_INS6_4host9allocatorEEPKT0_T2_mSG_SG_PKT1_SF_PSG_iiRKNS0_18DecisionTreeParamsEP11CUstream_stRSt6vectorI14SparseTreeNodeISD_SH_iESaISS_EERSG_SW_+0x437) [0x7f6a8b2eac47]
#4 in /home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/common/../../../../libcuml++.so(_ZN2ML2DT9grow_treeIfiiEEvSt10shared_ptrIN4raft2mr6device9allocatorEES2_INS4_4host9allocatorEEPKT_T1_mSE_SE_PKT0_SD_PSE_iiRKNS0_18DecisionTreeParamsEP11CUstream_stRSt6vectorI14SparseTreeNodeISB_SF_iESaISQ_EERSE_SU_+0x238) [0x7f6a8b2ef2b8]
#5 in /home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/common/../../../../libcuml++.so(_ZN2ML2DT12DecisionTreeIfiE3fitERKN4raft8handle_tEPKfiiPKiPjiiRPNS0_16TreeMetaDataNodeIfiEENS0_18DecisionTreeParamsEmPf+0x3a8) [0x7f6a8b2ef888]
#6 in /home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/common/../../../../libcuml++.so(+0x664ad1) [0x7f6a8b643ad1]
#7 in /home/jon/minicondadai_py38/lib/python3.8/site-packages/numpy/core/../../../.././libgomp.so.1(+0x146c9) [0x7f6b73a926c9]
#8 in /lib/x86_64-linux-gnu/libpthread.so.0(+0x76db) [0x7f6b752216db]
#9 in /lib/x86_64-linux-gnu/libc.so.6(clone+0x3f) [0x7f6b74f4aa3f]

Aborted (core dumped)

This is not a rare situation. As I mentioned, when changing various data and parameters for the model, this happens within about 100 fits. It's a very fatal crash, so makes it not possible to rely upon the algorithm until fixed.

FYI this is on latest nightly as of 2 nights ago.

(base) jon@pseudotensor:~/h2oai$ conda list | grep 'cu\|nvidia\|rapids'
arrow-cpp                 4.0.1           py38hf0991f3_4_cuda    conda-forge
arrow-cpp-proc            3.0.0                      cuda    conda-forge
cudatoolkit               11.2.72              h2bc3f7f_0    nvidia
cudf                      21.08.00a210723 cuda_11.2_py38_gfc95992e4a_323    rapidsai-nightly
cudf_kafka                21.08.00a210723 py38_gfc95992e4a_323    rapidsai-nightly
cudnn                     8.1.0.77             h90431f1_0    conda-forge
cugraph                   21.08.00a210722 cuda11.2_py38_ge5b35997_89    rapidsai-nightly
cuml                      21.08.00a210723 cuda11.2_py38_g40af8af46_125    rapidsai-nightly
cupy                      9.0.0            py38ha69542f_0    conda-forge
cupy-cuda112              9.0.0                    pypi_0    pypi
cusignal                  21.08.00a210723 py37_gb197d6f_24    rapidsai-nightly
cuspatial                 21.08.00a210722 py38_g2344dcd_24    rapidsai-nightly
custreamz                 21.08.00a210723 py38_gfc95992e4a_323    rapidsai-nightly
cutensor                  1.2.2.5              h96e36e3_3    conda-forge
cuxfilter                 21.08.00a210722 py38_gc51a660_21    rapidsai-nightly
dask-cuda                 21.08.00a210722         py38_37    rapidsai-nightly
dask-cudf                 21.08.00a210723 py38_gfc95992e4a_323    rapidsai-nightly
faiss-proc                1.0.0                      cuda    rapidsai
libcudf                   21.08.00a210723 cuda11.2_gfc95992e4a_323    rapidsai-nightly
libcudf_kafka             21.08.00a210723 gfc95992e4a_323    rapidsai-nightly
libcugraph                21.08.00a210722 cuda11.2_ge5b35997_89    rapidsai-nightly
libcuml                   21.08.00a210723 cuda11.2_g40af8af46_125    rapidsai-nightly
libcumlprims              21.08.00a210715 cuda11.2_g4db0971_5    rapidsai-nightly
libcurl                   7.77.0               h2574ce0_0    conda-forge
libcuspatial              21.08.00a210722 cuda11.2_g2344dcd_24    rapidsai-nightly
libfaiss                  1.7.0           cuda112h5bea7ad_8_cuda    conda-forge
librmm                    21.08.00a210721 cuda11.2_g82fe22f_39    rapidsai-nightly
pyarrow                   4.0.1           py38hb53058b_4_cuda    conda-forge
rapids                    21.08.00a210702 cuda11.2_py38_g2d7ee9d_10    rapidsai-nightly
rapids-blazing            21.08.00a210723 cuda11.2_py38_g8ec899a_47    rapidsai-nightly
rmm                       21.08.00a210722 cuda_11.2_py38_g4d52e5c_40    rapidsai-nightly
torch                     1.9.0+cu111              pypi_0    pypi
torchvision               0.10.0+cu111             pypi_0    pypi
ucx                       1.9.0+gcd9efd3       cuda11.2_0    rapidsai
ucx-proc                  1.0.0                       gpu    rapidsai
ucx-py                    0.21.0a210722   py38_gcd9efd3_34    rapidsai-nightly

Relevant parts from an install from doing

	conda install -c pytorch --override-channels -c nvidia -c rapidsai -c rapidsai-nightly -c numba -c pytorch -c conda-forge cudf=21.08 cuml=21.08 custreamz=21.08 python=3.8 nomkl cryptography==3.4.7 cudatoolkit=11.2 rapids-blazing=21.08 cusignal=21.08 pyarrow cython==0.29.21 matplotlib scikit-learn pynvml pandas==1.2.4 numpy==1.19.5 fastavro fsspec dask-cuda dask distributed cugraph=21.08 networkx scipy pip=20.2.3 -y
	conda install -c pytorch -c rapidsai -c rapidsai-nightly -c nvidia -c conda-forge -c anaconda -c defaults plotly=5.1.0 plotly-orca=3.4.2 -y

@pseudotensor
Copy link
Author

Here's another one, that also doesn't crash every time but most of the time:

import pickle
import random
import numpy as np

model_class, params, X, y = pickle.load(open("sigabrt2.pkl", "rb"))

model = model_class(**params)

random.seed(928529388)
np.random.seed(928529388)
model.fit(X, y)

https://0xdata-public.s3.amazonaws.com/jon/sigabrt2.pkl.zip

/home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/internals/api_decorators.py:794: UserWarning: For reproducible results in Random Forest Classifier or for almost reproducible results in Random Forest Regressor, n_streams==1 is recommended. If n_streams is > 1, results may vary due to stream/thread timing differences, even when random_state is set
  return func(**kwargs)
/home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/internals/api_decorators.py:794: UserWarning: The 'use_experimental_backend' parameter is deprecated and has no effect. It will be removed in 21.10 release.
  return func(**kwargs)
/home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/internals/api_decorators.py:794: UserWarning: The 'split_algo' parameter is deprecated and has no effect. It will be removed in 21.10 release.
  return func(**kwargs)
(base) jon@pseudotensor:~/h2oai$ LD_LIBRARY_PATH=~/minicondadai_py38/lib python cuml_sigabrt.py
/home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/internals/api_decorators.py:794: UserWarning: For reproducible results in Random Forest Classifier or for almost reproducible results in Random Forest Regressor, n_streams==1 is recommended. If n_streams is > 1, results may vary due to stream/thread timing differences, even when random_state is set
  return func(**kwargs)
/home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/internals/api_decorators.py:794: UserWarning: The 'use_experimental_backend' parameter is deprecated and has no effect. It will be removed in 21.10 release.
  return func(**kwargs)
/home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/internals/api_decorators.py:794: UserWarning: The 'split_algo' parameter is deprecated and has no effect. It will be removed in 21.10 release.
  return func(**kwargs)
terminate called after throwing an instance of 'raft::cuda_error'
  what():  CUDA error encountered at: file=_deps/raft-src/cpp/include/raft/mr/buffer_base.hpp line=68: call='cudaStreamSynchronize(stream_)', Reason=cudaErrorIllegalAddress:an illegal memory access was encountered
Obtained 9 stack frames
#0 in /home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/common/../../../../libcuml++.so(_ZN4raft9exception18collect_call_stackEv+0x3b) [0x7f6a1c9124db]
#1 in /home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/common/../../../../libcuml++.so(_ZN4raft10cuda_errorC2ERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE+0x5a) [0x7f6a1c912c6a]
#2 in /home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/common/../../../../libcuml++.so(_ZN2ML2DT9grow_treeINS0_24EntropyObjectiveFunctionIfiiEEfiiEEvSt10shared_ptrIN4raft2mr6device9allocatorEES4_INS6_4host9allocatorEEPKT0_T2_mSG_SG_PKT1_SF_PSG_iiRKNS0_18DecisionTreeParamsEP11CUstream_stRSt6vectorI14SparseTreeNodeISD_SH_iESaISS_EERSG_SW_+0xa91) [0x7f6a1c9f3101]
#3 in /home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/common/../../../../libcuml++.so(_ZN2ML2DT9grow_treeIfiiEEvSt10shared_ptrIN4raft2mr6device9allocatorEES2_INS4_4host9allocatorEEPKT_T1_mSE_SE_PKT0_SD_PSE_iiRKNS0_18DecisionTreeParamsEP11CUstream_stRSt6vectorI14SparseTreeNodeISB_SF_iESaISQ_EERSE_SU_+0x3a4) [0x7f6a1c9f5424]
#4 in /home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/common/../../../../libcuml++.so(_ZN2ML2DT12DecisionTreeIfiE3fitERKN4raft8handle_tEPKfiiPKiPjiiRPNS0_16TreeMetaDataNodeIfiEENS0_18DecisionTreeParamsEmPf+0x3a8) [0x7f6a1c9f5888]
#5 in /home/jon/minicondadai_py38/lib/python3.8/site-packages/cuml/common/../../../../libcuml++.so(+0x664ad1) [0x7f6a1cd49ad1]
#6 in /home/jon/minicondadai_py38/lib/python3.8/site-packages/numpy/core/../../../.././libgomp.so.1(+0x146c9) [0x7f6b0d8196c9]
#7 in /lib/x86_64-linux-gnu/libpthread.so.0(+0x76db) [0x7f6b0efa86db]
#8 in /lib/x86_64-linux-gnu/libc.so.6(clone+0x3f) [0x7f6b0ecd1a3f]

Aborted (core dumped)

@pseudotensor
Copy link
Author

pseudotensor commented Jul 25, 2021

I tried changing varoius parameters, back to defaults, but none mattered until I reached

params['max_leaves'] = -1

So

import pickle
import random
import sys

import numpy as np

model_class, params, X, y = pickle.load(open("sigabrt2.pkl", "rb"))

print(params)
for i in range(0, 100):
    params['max_leaves'] = -1
    model = model_class(**params)

    random.seed(928529388)
    np.random.seed(928529388)
    model.fit(X, y)
    print("done %s" % i)
    sys.stdout.flush()

does not crash. But this just means unlimited leaves. However, the max_depth of 18 is up to 262144 leaves, so limiting the leaves to 1024 shouldn't violate some condition that happen to be unprotected.

And going back to defaults for params['max_batch_size'] = 128 doesn't crash

import pickle
import random
import sys

import numpy as np

model_class, params, X, y = pickle.load(open("sigabrt2.pkl", "rb"))

print(params)
for i in range(0, 100):
    params['max_batch_size'] = 128
    model = model_class(**params)

    random.seed(928529388)
    np.random.seed(928529388)
    model.fit(X, y)
    print("done %s" % i)
    sys.stdout.flush()

doesn't crash.

But this still crashes, suggesting experimental backend doesn't matter/cause it.

import pickle
import random
import sys

import numpy as np

model_class, params, X, y = pickle.load(open("sigabrt2.pkl", "rb"))

print(params)
for i in range(0, 100):
    params['use_experimental_backend'] = False
    model = model_class(**params)

    random.seed(928529388)
    np.random.seed(928529388)
    model.fit(X, y)
    print("done %s" % i)
    sys.stdout.flush()

However, this does not crash:

import pickle
import random
import sys

import numpy as np

model_class, params, X, y = pickle.load(open("sigabrt2.pkl", "rb"))

print(params)
for i in range(0, 100):
    params['use_experimental_backend'] = False
    params['max_batch_size'] = 128
    model = model_class(**params)

    random.seed(928529388)
    np.random.seed(928529388)
    model.fit(X, y)
    print("done %s" % i)
    sys.stdout.flush()

So it seems max_batch_size matters even when use_experimental_backend = False, which makes no sense given the documentation that says:

This is used only when ‘use_experimental_backend’ is true. 

And this still crashes:

import pickle
import random
import sys

import numpy as np

model_class, params, X, y = pickle.load(open("sigabrt2.pkl", "rb"))

print(params)
for i in range(0, 100):
    params['n_streams'] = 1
    print(params)
    sys.stdout.flush()
    model = model_class(**params)

    random.seed(928529388)
    np.random.seed(928529388)
    model.fit(X, y)
    print("done %s" % i)
    sys.stdout.flush()

so seems unrelated to streams.

@pseudotensor
Copy link
Author

pseudotensor commented Jul 25, 2021

Other odd thing is I got those parameters by doing get_params() from the model that crashed on me. The max_batch_size is always 4096, but document says default is 128. So something wrong there in docs or code.

sigabrt.pkl is same with max_batch_size set automatically to 4096 according to get_params(), even though default is supposed to be 128.

@vinaydes
Copy link
Contributor

Thanks @pseudotensor. I'll run your samples and see if I can reproduce the issue.

sigabrt.pkl is same with max_batch_size set automatically to 4096 according to get_params(), even though default is supposed to be 128.

Default batch size is recently changed from 128 to 4096. You are using nightly cuML but probably looking at stable release documentation. Nightly documentation reflects the correct default value of 4096 https://docs.rapids.ai/api/cuml/nightly/api.html#random-forest.

@pseudotensor
Copy link
Author

@vinaydes Ok, but with 4096 I get all the above SIGABRT's. Only when using 128 do I happen to not hit them. So it seems the new default exposes major issues.

I gave many reproducible examples.

@vinaydes
Copy link
Contributor

Got it, I am investigating the issue.

@vinaydes
Copy link
Contributor

Thanks to @venkywonka we seem to have reached to the cause of this issue. For now, to unblock you, I would suggest leaving the max_leaves parameter to the default value of -1. Specifying non-default value seems to be causing the issue. I'll update this bug when we have a proper fix.

@RAMitchell This issue is related how we are updating n_leaves. Basically value of n_leaves could be different for different threads from the same threadblock, in the nodeSplitKernel. I have couple of ways to fix the issue, however since you are planning to update the node queue anyway I was wondering if this issue would also be taken care by it?

@vinaydes
Copy link
Contributor

@pseudotensor Thanks for your patience. The PR #4126 should fix this issue. Please test in your setup once the PR is merged.

@pseudotensor
Copy link
Author

pseudotensor commented Jul 29, 2021

Great thanks! Will this be part of 21.08 release? Or is it too late for that? Seems like a critical bug that needs to be in 21.08

@vinaydes
Copy link
Contributor

Yes, it should be part of 21.08.

rapids-bot bot pushed a commit that referenced this issue Jul 29, 2021
Fixes issue #4046. 
In the `nodeSplitKernel` each thread calls `leafBasedOnParams()` which reads global variable `n_leaves`. Different threads from same threadblock read `n_leaves` at different times. Between two threads reading `n_leaves`, value of it could be changed by some other threadblock. Thus one or few threads might concluded that `max_leaves` is reached, and rest of the threads might conclude otherwise. This caused crash in partitioning the samples.

In the solution provided here, instead of every thread reading `n_leaves`, only one thread from a threadblock reads the value and broadcasts it to every other thread via shared memory. This ensures complete agreement on `max_leaves` criterion among threads from threadblock.

Performance results to be posted shortly.

Authors:
  - Vinay Deshpande (https://github.com/vinaydes)

Approvers:
  - Venkat (https://github.com/venkywonka)
  - Rory Mitchell (https://github.com/RAMitchell)
  - Thejaswi. N. S (https://github.com/teju85)

URL: #4126
@dantegd
Copy link
Member

dantegd commented Jul 29, 2021

Confirming that the fix was just in time for 21.08 and was just merged. Thanks!!!

@github-actions
Copy link

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.

@vinaydes
Copy link
Contributor

vinaydes commented Feb 8, 2023

This was closed long time ago #4046 (comment)

@vinaydes vinaydes closed this as completed Feb 8, 2023
vimarsh6739 pushed a commit to vimarsh6739/cuml that referenced this issue Oct 9, 2023
…i#4126)

Fixes issue rapidsai#4046. 
In the `nodeSplitKernel` each thread calls `leafBasedOnParams()` which reads global variable `n_leaves`. Different threads from same threadblock read `n_leaves` at different times. Between two threads reading `n_leaves`, value of it could be changed by some other threadblock. Thus one or few threads might concluded that `max_leaves` is reached, and rest of the threads might conclude otherwise. This caused crash in partitioning the samples.

In the solution provided here, instead of every thread reading `n_leaves`, only one thread from a threadblock reads the value and broadcasts it to every other thread via shared memory. This ensures complete agreement on `max_leaves` criterion among threads from threadblock.

Performance results to be posted shortly.

Authors:
  - Vinay Deshpande (https://github.com/vinaydes)

Approvers:
  - Venkat (https://github.com/venkywonka)
  - Rory Mitchell (https://github.com/RAMitchell)
  - Thejaswi. N. S (https://github.com/teju85)

URL: rapidsai#4126
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage Need team to review and classify bug Something isn't working inactive-90d
Projects
None yet
Development

No branches or pull requests

3 participants