Skip to content

Commit

Permalink
deprecate rpart_train() (#1048)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch authored Jan 25, 2024
1 parent 7c40966 commit 8ccf1be
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# parsnip (development version)

* `rpart_train()` has been deprecated in favor of using `decision_tree()` with the `"rpart"` engine or `rpart::rpart()` directly (#1044).

* Fixed bug in fitting some model types with the `"spark"` engine (#1045).

* Fixed issues in metadata for the `"brulee"` engine where several arguments were mistakenly protected. (#1050, #1054)
Expand Down
11 changes: 10 additions & 1 deletion R/decision_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,12 @@ check_args.decision_tree <- function(object) {

#' Decision trees via rpart
#'
#' `rpart_train` is a wrapper for `rpart()` tree-based models
#' @description
#' `rpart_train()` is a wrapper for `rpart()` tree-based models
#' where all of the model arguments are in the main function.
#'
#' The function is now deprecated, as parsnip uses `rpart::rpart()` directly.
#'
#' @param formula A model formula.
#' @param data A data frame.
#' @param cp A non-negative number for complexity parameter. Any split
Expand All @@ -166,6 +169,12 @@ check_args.decision_tree <- function(object) {
#' @export
rpart_train <-
function(formula, data, weights = NULL, cp = 0.01, minsplit = 20, maxdepth = 30, ...) {
lifecycle::deprecate_warn(
"1.2.0",
"rpart_train()",
details = 'Instead, use `decision_tree(engine = "rpart")` or `rpart::rpart()` directly.'
)

bitness <- 8 * .Machine$sizeof.pointer
if (bitness == 32 & maxdepth > 30)
maxdepth <- 30
Expand Down
10 changes: 10 additions & 0 deletions tests/testthat/test_decision_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ test_that('bad input', {
expect_snapshot_error(translate(decision_tree(formula = y ~ x)))
})

test_that('rpart_train is stop-deprecated when it ought to be (#1044)', {
skip_on_cran()

# once this test fails, transition `rpart_train()` to `deprecate_stop()`
# and transition this test to fail if `rpart_train()` still exists after a year.
if (Sys.Date() > "2025-01-01") {
expect_error(rpart_train(mpg ~ ., mtcars))
}
})

# ------------------------------------------------------------------------------

test_that('argument checks for data dimensions', {
Expand Down

0 comments on commit 8ccf1be

Please sign in to comment.