-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[R] Add evaluation set and early stopping for xgboost()
#11065
Changes from 6 commits
99125d8
67992e2
ce669a8
d53c5ea
594b76f
5339e99
903db11
be47a11
d80f0f9
485c91b
a4e44e2
b21b414
5176c30
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,10 @@ prescreen.parameters <- function(params) { | |
|
||
prescreen.objective <- function(objective) { | ||
if (!is.null(objective)) { | ||
if (!is.character(objective) || length(objective) != 1L || is.na(objective)) { | ||
stop("'objective' must be a single character/string variable.") | ||
} | ||
|
||
if (objective %in% .OBJECTIVES_NON_DEFAULT_MODE()) { | ||
stop( | ||
"Objectives with non-default prediction mode (", | ||
|
@@ -30,8 +34,8 @@ prescreen.objective <- function(objective) { | |
) | ||
} | ||
|
||
if (!is.character(objective) || length(objective) != 1L || is.na(objective)) { | ||
stop("'objective' must be a single character/string variable.") | ||
if (objective %in% .RANKING_OBJECTIVES()) { | ||
stop("Ranking objectives are not supported in 'xgboost()'. Try 'xgb.train()'.") | ||
} | ||
} | ||
} | ||
|
@@ -501,7 +505,7 @@ check.nthreads <- function(nthreads) { | |
return(as.integer(nthreads)) | ||
} | ||
|
||
check.can.use.qdm <- function(x, params) { | ||
check.can.use.qdm <- function(x, params, eval_set) { | ||
if ("booster" %in% names(params)) { | ||
if (params$booster == "gblinear") { | ||
return(FALSE) | ||
|
@@ -512,6 +516,9 @@ check.can.use.qdm <- function(x, params) { | |
return(FALSE) | ||
} | ||
} | ||
if (NROW(eval_set)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does this imply? If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, because it then slices the DMatrix that gets created. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the slicing need to happen after the DMatrix is created? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, because otherwise there'd be issues with things like needing to make sure categorical 'y' and features have the same encodings between the two sets, objects from package There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, so from your perspective the DMatrix is more suitable for slicing than built-in classes... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are planning to work on CV with shared quantile cuts for improved computation performance. (Sharing the quantiles between QDM folds). It's a minor information leak but can significantly increase performance, especially with external memory. As a result, I have to consider how this can be implemented. If we double down on the DMatrix slicing, it will prevent us from the optimization. It's very unlikely that we can slice an external memory DMatrix. Also, the slice method in XGBoost is quite slow and memory inefficient. I can merge this PR as it's, but I think we might have more troubles when applying the optimization for CV. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, the alternative would be to:
But it'd end up being inefficient either way. The moreso considering that on the R side, the slicing would happen with a vector of random indices on one of the following:
It could in theory be more efficient to do the slicing in R for base There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any suggestion for the future implementation of CV optimization previously mentioned? It's designed for the Qdm and the external memory version of Qdm. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good and quite helpful for the CV function indeed. But I don't think it's very relevant here, since unlike Hence, it doesn't need to consider special cases like external memory or distributed mode, and there isn't too much room for improvement in terms of speed savings. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. We will have a second cv function in the future for high-level inputs (like data.table and iterator etc). |
||
return(FALSE) | ||
} | ||
return(TRUE) | ||
} | ||
|
||
|
@@ -717,6 +724,129 @@ process.x.and.col.args <- function( | |
return(lst_args) | ||
} | ||
|
||
process.eval.set <- function(eval_set, lst_args) { | ||
if (!NROW(eval_set)) { | ||
return(NULL) | ||
} | ||
nrows <- nrow(lst_args$dmatrix_args$data) | ||
is_classif <- hasName(lst_args$metadata, "y_levels") | ||
processed_y <- lst_args$dmatrix_args$label | ||
eval_set <- as.vector(eval_set) | ||
if (length(eval_set) == 1L) { | ||
|
||
eval_set <- as.numeric(eval_set) | ||
if (is.na(eval_set) || eval_set < 0 || eval_set >= 1) { | ||
stop("'eval_set' as a fraction must be a number between zero and one (non-inclusive).") | ||
} | ||
if (eval_set == 0) { | ||
return(NULL) | ||
} | ||
nrow_eval <- as.integer(round(nrows * eval_set, 0)) | ||
if (nrow_eval < 1) { | ||
warning( | ||
"Desired 'eval_set' fraction amounts to zero observations.", | ||
" Will not create evaluation set." | ||
) | ||
return(NULL) | ||
} | ||
nrow_train <- nrows - nrow_eval | ||
if (nrow_train < 2L) { | ||
stop("Desired 'eval_set' fraction would leave less than 2 observations for training data.") | ||
} | ||
if (is_classif && nrow_train < length(lst_args$metadata$y_levels)) { | ||
stop("Desired 'eval_set' fraction would not leave enough samples for each class of 'y'.") | ||
} | ||
|
||
seed <- lst_args$params$seed | ||
if (!is.null(seed)) { | ||
set.seed(seed) | ||
} | ||
|
||
idx_shuffled <- sample(nrows, nrows, replace = FALSE) | ||
idx_eval <- idx_shuffled[seq(1L, nrow_eval)] | ||
idx_train <- idx_shuffled[seq(nrow_eval + 1L, nrows)] | ||
# Here we want the training set to include all of the classes of 'y' for classification | ||
# objectives. If that condition doesn't hold with the random sample, then it forcibly | ||
# makes a new random selection in such a way that the condition would always hold, by | ||
# first sampling one random example of 'y' for training and then choosing the evaluation | ||
# set from the remaining rows. The procedure here is quite inefficient, but there aren't | ||
# enough random-related functions in base R to be able to construct an efficient version. | ||
if (is_classif && length(unique(processed_y[idx_train])) < length(lst_args$metadata$y_levels)) { | ||
# These are defined in order to avoid NOTEs from CRAN checks | ||
# when using non-standard data.table evaluation with column names. | ||
idx <- NULL | ||
y <- NULL | ||
ranked_idx <- NULL | ||
chosen <- NULL | ||
|
||
dt <- data.table::data.table(y = processed_y, idx = seq(1L, nrows))[ | ||
, .( | ||
ranked_idx = seq(1L, .N), | ||
chosen = rep(sample(.N, 1L), .N), | ||
idx | ||
) | ||
, by = y | ||
] | ||
min_idx_train <- dt[ranked_idx == chosen, idx] | ||
rem_idx <- dt[ranked_idx != chosen, idx] | ||
if (length(rem_idx) == nrow_eval) { | ||
idx_train <- min_idx_train | ||
idx_eval <- rem_idx | ||
} else { | ||
rem_idx <- rem_idx[sample(length(rem_idx), length(rem_idx), replace = FALSE)] | ||
idx_eval <- rem_idx[seq(1L, nrow_eval)] | ||
idx_train <- c(min_idx_train, rem_idx[seq(nrow_eval + 1L, length(rem_idx))]) | ||
} | ||
} | ||
|
||
} else { | ||
|
||
if (any(eval_set != floor(eval_set))) { | ||
stop("'eval_set' as indices must contain only integers.") | ||
} | ||
eval_set <- as.integer(eval_set) | ||
idx_min <- min(eval_set) | ||
if (is.na(idx_min) || idx_min < 1L) { | ||
stop("'eval_set' contains invalid indices.") | ||
} | ||
idx_max <- max(eval_set) | ||
if (is.na(idx_max) || idx_max > nrows) { | ||
stop("'eval_set' contains row indices beyond the size of the input data.") | ||
} | ||
idx_train <- seq(1L, nrows)[-eval_set] | ||
if (is_classif && length(unique(processed_y[idx_train])) < length(lst_args$metadata$y_levels)) { | ||
warning("'eval_set' indices will leave some classes of 'y' outside of the training data.") | ||
} | ||
idx_eval <- eval_set | ||
|
||
} | ||
|
||
# Note: slicing is done in the constructed DMatrix object instead of in the | ||
# original input, because objects from 'Matrix' might change class after | ||
# being sliced (e.g. 'dgRMatrix' turns into 'dgCMatrix'). | ||
return(list(idx_train = idx_train, idx_eval = idx_eval)) | ||
} | ||
|
||
check.early.stopping.rounds <- function(early_stopping_rounds, eval_set) { | ||
if (is.null(early_stopping_rounds)) { | ||
return(NULL) | ||
} | ||
if (is.null(eval_set)) { | ||
stop("'early_stopping_rounds' requires passing 'eval_set'.") | ||
} | ||
if (NROW(early_stopping_rounds) != 1L) { | ||
stop("'early_stopping_rounds' must be NULL or an integer greater than zero.") | ||
} | ||
early_stopping_rounds <- as.integer(early_stopping_rounds) | ||
if (is.na(early_stopping_rounds) || early_stopping_rounds <= 0L) { | ||
stop( | ||
"'early_stopping_rounds' must be NULL or an integer greater than zero. Got: ", | ||
early_stopping_rounds | ||
) | ||
} | ||
return(early_stopping_rounds) | ||
} | ||
|
||
#' Fit XGBoost Model | ||
#' | ||
#' @export | ||
|
@@ -808,6 +938,34 @@ process.x.and.col.args <- function( | |
#' 2 (info), and 3 (debug). | ||
#' @param monitor_training Whether to monitor objective optimization progress on the input data. | ||
#' Note that same 'x' and 'y' data are used for both model fitting and evaluation. | ||
#' @param eval_set Subset of the data to use as evaluation set. Can be passed as: | ||
#' - A vector of row indices (base-1 numeration) indicating the observations that are to be designed | ||
#' as evaluation data. | ||
#' - A number between zero and one indicating a random fraction of the input data to use as | ||
#' evaluation data. Note that the selection will be done uniformly at random, regardless of | ||
#' argument `weights`. | ||
#' | ||
#' If passed, this subset of the data will be excluded from the training procedure, and a default | ||
#' metric for the selected objective will be calculated on this dataset after each boosting | ||
#' iteration (pass `verbosity>0` to have these metrics printed during training). | ||
#' | ||
#' If passing a fraction, in classification problems, the evaluation set will be chosen in such a | ||
#' way that at least one observation of each class will be kept in the training data. | ||
#' | ||
#' For more elaborate evaluation variants (e.g. custom metrics, multiple evaluation sets, etc.), | ||
#' one might want to use [xgb.train()] instead. | ||
#' @param early_stopping_rounds Number of boosting rounds after which training will be stopped | ||
#' if there is no improvement in performance (as measured by the last metric passed under | ||
#' `eval_metric`, or by the default metric for the objective if `eval_metric` is not passed) on the | ||
#' evaluation data from `eval_set`. Must pass `eval_set` in order to use this functionality. | ||
#' | ||
#' If `NULL`, early stopping will not be used. | ||
#' @param print_every_n When passing `verbosity>0` and either `monitor_training=TRUE` or `eval_set`, | ||
#' evaluation logs (metrics calculated on the training and/or evaluation data) will be printed every | ||
#' nth iteration according to the value passed here. The first and last iteration are always | ||
#' included regardless of this 'n'. | ||
#' | ||
#' Only has an effect when passing `verbosity>0`. | ||
#' @param nthreads Number of parallel threads to use. If passing zero, will use all CPU threads. | ||
#' @param seed Seed to use for random number generation. If passing `NULL`, will draw a random | ||
#' number using R's PRNG system to use as seed. | ||
|
@@ -895,6 +1053,9 @@ xgboost <- function( | |
weights = NULL, | ||
verbosity = 0L, | ||
monitor_training = verbosity > 0, | ||
eval_set = NULL, | ||
early_stopping_rounds = NULL, | ||
print_every_n = 1L, | ||
nthreads = parallel::detectCores(), | ||
seed = 0L, | ||
monotone_constraints = NULL, | ||
|
@@ -907,7 +1068,7 @@ xgboost <- function( | |
params <- list(...) | ||
params <- prescreen.parameters(params) | ||
prescreen.objective(objective) | ||
use_qdm <- check.can.use.qdm(x, params) | ||
use_qdm <- check.can.use.qdm(x, params, eval_set) | ||
lst_args <- process.y.margin.and.objective(y, base_margin, objective, params) | ||
lst_args <- process.row.weights(weights, lst_args) | ||
lst_args <- process.x.and.col.args( | ||
|
@@ -918,8 +1079,9 @@ xgboost <- function( | |
lst_args, | ||
use_qdm | ||
) | ||
eval_set <- process.eval.set(eval_set, lst_args) | ||
|
||
if (use_qdm && "max_bin" %in% names(params)) { | ||
if (use_qdm && hasName(params, "max_bin")) { | ||
lst_args$dmatrix_args$max_bin <- params$max_bin | ||
} | ||
|
||
|
@@ -932,15 +1094,23 @@ xgboost <- function( | |
|
||
fn_dm <- if (use_qdm) xgb.QuantileDMatrix else xgb.DMatrix | ||
dm <- do.call(fn_dm, lst_args$dmatrix_args) | ||
if (!is.null(eval_set)) { | ||
dm_eval <- xgb.slice.DMatrix(dm, eval_set$idx_eval) | ||
dm <- xgb.slice.DMatrix(dm, eval_set$idx_train) | ||
} | ||
evals <- list() | ||
if (monitor_training) { | ||
evals <- list(train = dm) | ||
} | ||
if (!is.null(eval_set)) { | ||
evals <- c(evals, list(eval = dm_eval)) | ||
} | ||
model <- xgb.train( | ||
params = params, | ||
data = dm, | ||
nrounds = nrounds, | ||
verbose = verbosity, | ||
print_every_n = print_every_n, | ||
evals = evals | ||
) | ||
attributes(model)$metadata <- lst_args$metadata | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a defined behavior for using multiple metrics or multiple evals in R? In Python, the last metric and the last validation dataset is used for early stopping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the same in R. Updated the docs.