Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model can't be reproduced when data is big with tree_method = "gpu_hist" ? #3921

Closed
joegaotao opened this issue Nov 19, 2018 · 15 comments
Closed

Comments

@joegaotao
Copy link

joegaotao commented Nov 19, 2018

I did some tests in R(xgboost-0.81.0.1), when data N is big, I found the models trained with the same parameter are not the same on GPU (tree_method = "gpu_hist"), when data N is relatively small, the models are the same. But when I use tree_method = "hist" to train the model repeatedly on cpu, all the models result are the same. I don't know what happened on GPU training, due to the precision?

GPU test code, big data:

library(xgboost)
# Simulate N x p random matrix with some binomial response dependent on pp columns
set.seed(111)
N <- 800000
p <- 100
X <- matrix(runif(N * p), ncol = p)
beta <- runif(p)
y <- X %*% beta + rnorm(N, mean = 0, sd  = 0.1)

tr <- sample.int(N, N * 0.75)

param <- list(nrounds = 10, num_parallel_tree = 1, nthread = 1L, eta = 0.3, max_depth = 30,
  seed = 2018, colsample_bytree = 0.4, subsample = 0.6,  min_child_weight = 1000,
  tree_method = 'gpu_hist', grow_policy = "lossguide", max_leaves = 1e4,  max_bin = 256,
  n_gpus = 1, gpu_id = 3, verbose = FALSE)
param$data <- X[tr,]
param$label <- y[tr]

set.seed(2019)
bst_gpu1 <- do.call(xgboost::xgboost, param)
test_pred1 <- predict(bst_gpu1, newdata = X)

set.seed(2019)
bst_gpu2 <- do.call(xgboost::xgboost, param)
test_pred2 <- predict(bst_gpu2, newdata = X)

set.seed(2019)
bst_gpu3 <- do.call(xgboost::xgboost, param)
test_pred3 <- predict(bst_gpu3, newdata = X)

set.seed(2019)
bst_gpu4 <- do.call(xgboost::xgboost, param)
test_pred4 <- predict(bst_gpu4, newdata = X)

set.seed(2019)
bst_gpu5 <- do.call(xgboost::xgboost, param)
test_pred5 <- predict(bst_gpu5, newdata = X)

all_pred <- cbind(test_pred1, test_pred2, test_pred3, test_pred4, test_pred5)
head(all_pred)
#       test_pred1 test_pred2 test_pred3 test_pred4 test_pred5
# [1,]   22.43434   22.65794   22.46917   22.60526   22.43433
# [2,]   24.28225   24.42978   24.34619   24.60111   24.28225
# [3,]   23.11788   23.15692   23.07406   23.22111   23.11788
# [4,]   23.74367   23.92602   24.26277   24.11207   23.74367
# [5,]   22.97502   23.24378   23.25752   22.92594   22.97502
# [6,]   23.34638   23.52209   23.47491   23.71274   23.34638

summary(test_pred1 - test_pred2)
#      Min.    1st Qu.     Median       Mean    3rd Qu.       Max. 
# -1.2855778 -0.1688867 -0.0002308 -0.0002195  0.1685147  1.3110085
summary(test_pred1 - test_pred3)
#      Min.    1st Qu.     Median       Mean    3rd Qu.       Max. 
# -1.3292294 -0.1703205  0.0000973 -0.0001312  0.1701469  1.3229237 

the difference is big, but change N to 80000 or replace tree_method = "gpu_hist" to tree_method = "hist", the results are the same.

@trivialfis
Copy link
Member

@hcho3 I'm not sure whether the seed has effect on ColumnSampler, simple grepping on the c++ source code doesn't have many mentions on seed or random_state.

@hcho3
Copy link
Collaborator

hcho3 commented Nov 19, 2018

@trivialfis This line sets the seed globally:

common::GlobalRandom().seed(tparam_.seed);

@trivialfis
Copy link
Member

@hcho3 Thanks!

@joegaotao
Copy link
Author

joegaotao commented Nov 20, 2018

@trivialfis I think it might be not related to the random seed because the sample number has effect on the model result. Testing with cuda-9.2 and cuda-9.1 with the latest master branch, problem remains on cuda-9.2, while results are the same on cuda-9.1. I test another real data, although the difference all still exists, cuda-9.1 is smaller and seems more robust. Maybe related to cuda version? Could you reproduce my problem?

@trivialfis
Copy link
Member

@joegaotao I'm running CUDA 9.2 and I can reproduce your issue in R, but not in Python, which makes thing even weirder. I will try to instrument CUDA code when I have the time.

@hcho3
Copy link
Collaborator

hcho3 commented Nov 28, 2018

@trivialfis I wonder if it has to do with the fact that XGBoost-R uses its own random generator. See #3781.

@joegaotao
Copy link
Author

@hcho3 @trivialfis Yes, parameter seed is invalid in XGBoost-R and I have to use set.seed() globally before each run

@trivialfis
Copy link
Member

Reopening. We will look into this eventually.

@trivialfis trivialfis reopened this Dec 24, 2018
@thanish
Copy link

thanish commented Jan 10, 2019

I have been running into the same issue of non-reproducible result with python xgboost when I run the 'tree_method' : 'gpu_hist'. I have set the np.random.seed() just before I run the xgb.cv and passed the 'seed' in the parameter too in the xgb.cv. But can't get the same result while re-running the model with the same data and same parameters

@trivialfis
Copy link
Member

@thanish Interesting, could you post related section of your script and describe the data shape and sparsity?

@thanish
Copy link

thanish commented Jan 10, 2019

My data is of shape 1600000, 26

Below is my code

dtrain_prod = xgb.DMatrix(data = train_prod_df[indep], label = train_prod_df[dep])
dtest_prod = xgb.DMatrix(data = test_prod_df[indep])
num_rounds = 20000

params = {'objective' : 'reg:linear',
          'max_depth' : 8,
          'eta' : 0.05,
          'subsample': 1,
          'colsample_bytree': 1
          ,'tree_method' : 'gpu_hist'
          }

folds = 3
np.random.seed(100)
xgb_model_cv = xgb.cv(params,
                      dtrain_prod, 
                      num_rounds,
                      nfold = folds , 
                      verbose_eval = True,
                      early_stopping_rounds = 20,
                      seed = 100)

print("")
best_eval = xgb_model_cv['test-rmse-mean'].min()
best_round = xgb_model_cv.loc[xgb_model_cv['test-rmse-mean'] == xgb_model_cv['test-rmse-mean'].min(),].index
best_round = list(best_round)[-1]
print("The best round", best_round)
print("The best eval", best_eval)

Output of the first run

[10658]	train-rmse:1.47646e+07+22322.8	test-rmse:3.00938e+07+522885
[10659]	train-rmse:1.47643e+07+22461.4	test-rmse:3.00938e+07+522889
[10660]	train-rmse:1.47639e+07+22637.1	test-rmse:3.00937e+07+522870
[10661]	train-rmse:1.47635e+07+22674.8	test-rmse:3.00937e+07+522928
[10662]	train-rmse:1.47632e+07+22582.7	test-rmse:3.00937e+07+522883
[10663]	train-rmse:1.47628e+07+22617.4	test-rmse:3.00937e+07+522892
[10664]	train-rmse:1.47625e+07+22571.1	test-rmse:3.00937e+07+522881

The best round 10645
The best eval 30093620.0

output of the 2nd run without any changes

[10487]	train-rmse:1.48827e+07+20892.9	test-rmse:3.00699e+07+527809
[10488]	train-rmse:1.48822e+07+20810.9	test-rmse:3.00699e+07+527743
[10489]	train-rmse:1.48814e+07+20773.5	test-rmse:3.00699e+07+527754
[10490]	train-rmse:1.48806e+07+20589.1	test-rmse:3.00699e+07+527657
[10491]	train-rmse:1.48798e+07+20357.8	test-rmse:3.00699e+07+527707
[10492]	train-rmse:1.48793e+07+20262.4	test-rmse:3.00698e+07+527589
[10493]	train-rmse:1.48789e+07+20238.8	test-rmse:3.00699e+07+527621
[10494]	train-rmse:1.48784e+07+20330.2	test-rmse:3.00699e+07+527594
[10495]	train-rmse:1.4878e+07+20380.1	test-rmse:3.00699e+07+527537
[10496]	train-rmse:1.48778e+07+20342.7	test-rmse:3.00699e+07+527482

The best round 10477
The best eval 30069673.333333332

@trivialfis
Copy link
Member

@RAMitchell I tried to print the values of inputs and output of EvaluateSplitKernel and use diff to compare those printed values, seems at least some errors came from this kernel. But I can't compare what's inside this kernel with this mass of data.

@mirekphd
Copy link

mirekphd commented Jul 2, 2019

@trivialfis Maybe related to cuda version? Could you reproduce my problem?

It's not related to CUDA version. We reproduced your problem* in yet another, even older CUDA version - 9.0.176.

The positive connection of the bug occurence to data size seems to be there, but we haven't tested it thoroughly.


*of XGBoost GPU (histogram) predictions instability between program runs (despite having fixed seeds and getting unchanged predictions in the CPU)

@NamLQ
Copy link

NamLQ commented Sep 11, 2019

Using single_precision_histogram = F will give you the reproducible results.

@trivialfis
Copy link
Member

Closing in favour of #5023

@lock lock bot locked as resolved and limited conversation to collaborators May 5, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants