Skip to content

Commit

Permalink
Merge branch 'main' into rpart-1044
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo authored Jan 23, 2024
2 parents 5866aea + bf3f505 commit df57ac5
Show file tree
Hide file tree
Showing 31 changed files with 76 additions and 15 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,4 @@ Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.0
24 changes: 24 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,29 @@ S3method(.censoring_weights_graf,model_fit)
S3method(augment,model_fit)
S3method(autoplot,glmnet)
S3method(autoplot,model_fit)
S3method(check_args,C5_rules)
S3method(check_args,bag_tree)
S3method(check_args,boost_tree)
S3method(check_args,cubist_rules)
S3method(check_args,decision_tree)
S3method(check_args,default)
S3method(check_args,discrim_flexible)
S3method(check_args,discrim_linear)
S3method(check_args,discrim_regularized)
S3method(check_args,linear_reg)
S3method(check_args,logistic_reg)
S3method(check_args,mars)
S3method(check_args,mlp)
S3method(check_args,multinom_reg)
S3method(check_args,nearest_neighbor)
S3method(check_args,pls)
S3method(check_args,poisson_reg)
S3method(check_args,rand_forest)
S3method(check_args,surv_reg)
S3method(check_args,survival_reg)
S3method(check_args,svm_linear)
S3method(check_args,svm_poly)
S3method(check_args,svm_rbf)
S3method(extract_fit_engine,model_fit)
S3method(extract_parameter_dials,model_spec)
S3method(extract_parameter_set_dials,model_spec)
Expand Down Expand Up @@ -180,6 +203,7 @@ export(bartMachine_interval_calc)
export(boost_tree)
export(case_weights_allowed)
export(cforest_train)
export(check_args)
export(check_empty_ellipse)
export(check_final_param)
export(condense_control)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

* `rpart_train()` has been deprecated in favor of using `decision_tree()` with the `"rpart"` engine or `rpart::rpart()` directly (#1044).

* Fixed bug in fitting some model types with the `"spark"` engine (#1045).

* Fixed issue in `mlp()` metadata where the `stop_iter` engine argument had been mistakenly protected for the `"brulee"` engine. (#1050)

* `.filter_eval_time()` was moved to the survival standalone file.

* Improved errors and documentation related to special terms in formulas. See `?model_formula` to learn more. (#770, #1014)
Expand Down
10 changes: 9 additions & 1 deletion R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,11 @@ min_cols <- function(num_cols, source) {
#' @export
#' @rdname min_cols
min_rows <- function(num_rows, source, offset = 0) {
n <- nrow(source)
if (inherits(source, "tbl_spark")) {
n <- nrow_spark(source)
} else {
n <- nrow(source)
}

if (num_rows > n - offset) {
msg <- paste0(num_rows, " samples were requested but there were ", n,
Expand All @@ -340,3 +344,7 @@ min_rows <- function(num_rows, source, offset = 0) {
as.integer(num_rows)
}

nrow_spark <- function(source) {
rlang::check_installed("sparklyr")
sparklyr::sdf_nrow(source)
}
1 change: 1 addition & 0 deletions R/bag_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ update.bag_tree <-

# ------------------------------------------------------------------------------

#' @export
check_args.bag_tree <- function(object) {
if (object$engine == "C5.0" && object$mode == "regression")
stop("C5.0 is classification only.", call. = FALSE)
Expand Down
1 change: 1 addition & 0 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {

# ------------------------------------------------------------------------------

#' @export
check_args.boost_tree <- function(object) {

args <- lapply(object$args, rlang::eval_tidy)
Expand Down
1 change: 1 addition & 0 deletions R/c5_rules.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ update.C5_rules <-

# make work in different places

#' @export
check_args.C5_rules <- function(object) {

args <- lapply(object$args, rlang::eval_tidy)
Expand Down
1 change: 1 addition & 0 deletions R/cubist_rules.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ update.cubist_rules <-

# make work in different places

#' @export
check_args.cubist_rules <- function(object) {

args <- lapply(object$args, rlang::eval_tidy)
Expand Down
1 change: 1 addition & 0 deletions R/decision_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ translate.decision_tree <- function(x, engine = x$engine, ...) {

# ------------------------------------------------------------------------------

#' @export
check_args.decision_tree <- function(object) {
if (object$engine == "C5.0" && object$mode == "regression")
rlang::abort("C5.0 is classification only.")
Expand Down
1 change: 1 addition & 0 deletions R/discrim_flexible.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ update.discrim_flexible <-

# ------------------------------------------------------------------------------

#' @export
check_args.discrim_flexible <- function(object) {

args <- lapply(object$args, rlang::eval_tidy)
Expand Down
1 change: 1 addition & 0 deletions R/discrim_linear.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ update.discrim_linear <-

# ------------------------------------------------------------------------------

#' @export
check_args.discrim_linear <- function(object) {

args <- lapply(object$args, rlang::eval_tidy)
Expand Down
1 change: 1 addition & 0 deletions R/discrim_regularized.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ update.discrim_regularized <-

# ------------------------------------------------------------------------------

#' @export
check_args.discrim_regularized <- function(object) {

args <- lapply(object$args, rlang::eval_tidy)
Expand Down
1 change: 1 addition & 0 deletions R/linear_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ update.linear_reg <-

# ------------------------------------------------------------------------------

#' @export
check_args.linear_reg <- function(object) {

args <- lapply(object$args, rlang::eval_tidy)
Expand Down
1 change: 1 addition & 0 deletions R/logistic_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ update.logistic_reg <-

# ------------------------------------------------------------------------------

#' @export
check_args.logistic_reg <- function(object) {

args <- lapply(object$args, rlang::eval_tidy)
Expand Down
1 change: 1 addition & 0 deletions R/mars.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ translate.mars <- function(x, engine = x$engine, ...) {

# ------------------------------------------------------------------------------

#' @export
check_args.mars <- function(object) {

args <- lapply(object$args, rlang::eval_tidy)
Expand Down
4 changes: 4 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,14 @@ show_fit <- function(model, eng) {

# Check non-translated core arguments
# Each model has its own definition of this
#' @export
#' @keywords internal
#' @rdname add_on_exports
check_args <- function(object) {
UseMethod("check_args")
}

#' @export
check_args.default <- function(object) {
invisible(object)
}
Expand Down
1 change: 1 addition & 0 deletions R/mlp.R
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ translate.mlp <- function(x, engine = x$engine, ...) {

# ------------------------------------------------------------------------------

#' @export
check_args.mlp <- function(object) {

args <- lapply(object$args, rlang::eval_tidy)
Expand Down
10 changes: 0 additions & 10 deletions R/mlp_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -433,16 +433,6 @@ set_model_arg(
has_submodel = FALSE
)


set_model_arg(
model = "mlp",
eng = "brulee",
parsnip = "stop_iter",
original = "stop_iter",
func = list(pkg = "dials", fun = "stop_iter"),
has_submodel = FALSE
)

set_model_arg(
model = "mlp",
eng = "brulee",
Expand Down
1 change: 1 addition & 0 deletions R/multinom_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ update.multinom_reg <-

# ------------------------------------------------------------------------------

#' @export
check_args.multinom_reg <- function(object) {

args <- lapply(object$args, rlang::eval_tidy)
Expand Down
1 change: 1 addition & 0 deletions R/nearest_neighbor.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ positive_int_scalar <- function(x) {

# ------------------------------------------------------------------------------

#' @export
check_args.nearest_neighbor <- function(object) {

args <- lapply(object$args, rlang::eval_tidy)
Expand Down
1 change: 1 addition & 0 deletions R/pls.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ update.pls <-

# ------------------------------------------------------------------------------

#' @export
check_args.pls <- function(object) {

args <- lapply(object$args, rlang::eval_tidy)
Expand Down
1 change: 1 addition & 0 deletions R/poisson_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ translate.poisson_reg <- function(x, engine = x$engine, ...) {

# ------------------------------------------------------------------------------

#' @export
check_args.poisson_reg <- function(object) {

args <- lapply(object$args, rlang::eval_tidy)
Expand Down
1 change: 1 addition & 0 deletions R/rand_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ translate.rand_forest <- function(x, engine = x$engine, ...) {

# ------------------------------------------------------------------------------

#' @export
check_args.rand_forest <- function(object) {
# move translate checks here?
invisible(object)
Expand Down
1 change: 1 addition & 0 deletions R/surv_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ translate.surv_reg <- function(x, engine = x$engine, ...) {

# ------------------------------------------------------------------------------

#' @export
check_args.surv_reg <- function(object) {

if (object$engine == "flexsurv") {
Expand Down
2 changes: 1 addition & 1 deletion R/survival_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ translate.survival_reg <- function(x, engine = x$engine, ...) {
x
}


#' @export
check_args.survival_reg <- function(object) {

if (object$engine == "flexsurv") {
Expand Down
1 change: 1 addition & 0 deletions R/svm_linear.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ translate.svm_linear <- function(x, engine = x$engine, ...) {

# ------------------------------------------------------------------------------

#' @export
check_args.svm_linear <- function(object) {
invisible(object)
}
Expand Down
1 change: 1 addition & 0 deletions R/svm_poly.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ translate.svm_poly <- function(x, engine = x$engine, ...) {

# ------------------------------------------------------------------------------

#' @export
check_args.svm_poly <- function(object) {
invisible(object)
}
Expand Down
1 change: 1 addition & 0 deletions R/svm_rbf.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ translate.svm_rbf <- function(x, engine = x$engine, ...) {

# ------------------------------------------------------------------------------

#' @export
check_args.svm_rbf <- function(object) {
invisible(object)
}
Expand Down
4 changes: 2 additions & 2 deletions R/tunable.R
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ brulee_linear_engine_args <-
brulee_mlp_engine_args %>%
dplyr::filter(name %in% c("momentum", "batch_size", "stop_iter"))

brulee_logistc_engine_args <-
brulee_logistic_engine_args <-
brulee_mlp_engine_args %>%
dplyr::filter(name %in% c("momentum", "batch_size", "stop_iter", "class_weights"))

Expand Down Expand Up @@ -258,7 +258,7 @@ tunable.logistic_reg <- function(x, ...) {
res$call_info[res$name == "mixture"] <-
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
} else if (x$engine == "brulee") {
res <- add_engine_parameters(res, brulee_logistc_engine_args)
res <- add_engine_parameters(res, brulee_logistic_engine_args)
}
res
}
Expand Down
3 changes: 3 additions & 0 deletions man/add_on_exports.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions tests/testthat/test_boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ test_that('bad input', {
## -----------------------------------------------------------------------------

test_that('argument checks for data dimensions', {
skip_if_not_installed("sparklyr")
library(sparklyr)
skip_if(nrow(spark_installed_versions()) == 0)

spec <-
boost_tree(mtry = 1000, min_n = 1000, trees = 5) %>%
Expand All @@ -36,6 +39,10 @@ test_that('argument checks for data dimensions', {

args <- translate(spec)$method$fit$args
expect_equal(args$min_instances_per_node, expr(min_rows(1000, x)))

sc = spark_connect(master = "local")
cars = copy_to(sc, mtcars, overwrite = TRUE)
expect_equal(min_rows(10, cars), 10)
})

test_that('boost_tree can be fit with 1 predictor if validation is used', {
Expand Down

0 comments on commit df57ac5

Please sign in to comment.