Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for parsnip #1162 #220

Merged
merged 7 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/testthat/_snaps/glmnet-linear.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@
linear_reg(penalty = 0.01) %>% set_engine("glmnet") %>% fit(mpg ~ ., data = mtcars) %>%
multi_predict(mtcars, type = "class")
Condition
Error in `check_pred_type()`:
Error in `multi_predict()`:
! For class predictions, the object should be a classification model.

2 changes: 1 addition & 1 deletion tests/testthat/_snaps/glmnet-logistic.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@
funded_amnt) + int_rate + term, data = lending_club) %>% multi_predict(
lending_club, type = "time")
Condition
Error in `check_pred_type()`:
Error in `multi_predict()`:
! For event time predictions, the object should be a censored regression.

2 changes: 1 addition & 1 deletion tests/testthat/_snaps/glmnet-multinom.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@
multinom_reg(penalty = 0.01) %>% set_engine("glmnet") %>% fit(class ~ ., data = hpc_data) %>%
multi_predict(hpc_data, type = "numeric")
Condition
Error in `check_pred_type()`:
Error in `multi_predict()`:
! For numeric predictions, the object should be a regression model.

46 changes: 46 additions & 0 deletions tests/testthat/_snaps/parsnip-case-weights.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# boost_tree - xgboost case weights

Code
print(wt_fit$fit$call)
Output
xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0,
colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1,
subsample = 1), data = x$data, nrounds = 15, watchlist = x$watchlist,
verbose = 0, nthread = 1, objective = "binary:logistic")

# decision_tree - rpart case weights

Code
print(wt_fit$fit$call)
Output
rpart::rpart(formula = Class ~ ., data = data, weights = weights)

# logistic_reg - stan case weights

Code
print(wt_fit$fit$call)
Output
rstanarm::stan_glm(formula = Class ~ ., family = stats::binomial,
data = data, weights = weights, seed = ~1, refresh = 0)

# mars - earth case weights

Code
print(wt_fit$fit$call)
Output
earth(formula = Class ~ ., data = data, weights = weights, keepxy = TRUE,
glm = ~list(family = stats::binomial))

# mlp - nnet case weights

Case weights are not enabled by the underlying model implementation.

# rand_forest - ranger case weights

Code
print(wt_fit$fit$call)
Output
ranger::ranger(x = maybe_data_frame(x), y = y, num.threads = 1,
verbose = FALSE, seed = sample.int(10^5, 1), probability = TRUE,
case.weights = weights)

2 changes: 1 addition & 1 deletion tests/testthat/_snaps/parsnip-survival-censoring-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@
predict(alt_obj, time = 100)
Condition
Error in `predict()`:
! Don't know how to predict with a censoring model of type: reverse_km
! Don't know how to predict with a censoring model of type reverse_km.

14 changes: 7 additions & 7 deletions tests/testthat/_snaps/parsnip-survival-censoring-weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
.censoring_weights_graf("nothing useful")
Condition
Error in `.censoring_weights_graf()`:
! There is no `.censoring_weights_graf()` method for objects with class(es): 'character'
! There is no `.censoring_weights_graf()` method for objects with class <character>.

---

Code
.censoring_weights_graf(cox_model, lung)
Condition
Error:
! There should be a single column of class `Surv`
! There should be a single column of class <Surv>.

---

Expand All @@ -22,23 +22,23 @@
.censoring_weights_graf(cox_model, lung_left)
Condition
Error in `.censoring_weights_graf()`:
! For this usage, the allowed censoring type is: 'right'
! For this usage, the allowed censoring type is "right".

---

Code
.censoring_weights_graf(cox_model, lung2)
Condition
Error:
! The input should have a list column called `.pred`.
! The input should have a list column called ".pred".

---

Code
.censoring_weights_graf(cox_model, preds, cens_predictors = "shouldn't be using this anyway!")
Condition
Warning:
The 'cens_predictors' argument to the survival weighting function is not currently used.
`cens_predictors` is not currently used.
Output
# A tibble: 3 x 2
.pred surv
Expand All @@ -60,6 +60,6 @@
Code
.censoring_weights_graf(wrong_model, mtcars)
Condition
Error in `.check_censor_model()`:
! The model needs to be for mode 'censored regression', not for mode 'regression'.
Error in `.censoring_weights_graf()`:
! The model needs to be for mode "censored regression", not for mode 'regression'.

6 changes: 3 additions & 3 deletions tests/testthat/_snaps/parsnip-survival-standalone.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
parsnip:::.is_surv(1)
Condition
Error:
! The object does not have class `Surv`.
! The object does not have class <Surv>.

# .check_cens_type()

Code
parsnip:::.check_cens_type(left_c, type = "right", fail = TRUE)
Condition
Error:
! For this usage, the allowed censoring type is: 'right'
! For this usage, the allowed censoring type is "right".

---

Code
parsnip:::.check_cens_type(left_c, type = c("right", "interval"), fail = TRUE)
Condition
Error:
! For this usage, the allowed censoring types are: 'right' and 'interval'
! For this usage, the allowed censoring types are "right" or "interval".

12 changes: 8 additions & 4 deletions tests/testthat/_snaps/randomForest.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@
f_fit <- spec %>% fit(body_mass_g ~ ., data = penguins)
Condition
Warning:
1000 columns were requested but there were 6 predictors in the data. 6 will be used.
! 1000 columns were requested but there were 6 predictors in the data.
i 6 predictors will be used.
Warning:
1000 samples were requested but there were 333 rows in the data. 333 will be used.
! 1000 samples were requested but there were 333 rows in the data.
i 333 samples will be used.

---

Code
xy_fit <- spec %>% fit_xy(x = penguins[, -6], y = penguins$body_mass_g)
Condition
Warning:
1000 columns were requested but there were 6 predictors in the data. 6 will be used.
! 1000 columns were requested but there were 6 predictors in the data.
i 6 predictors will be used.
Warning:
1000 samples were requested but there were 333 rows in the data. 333 will be used.
! 1000 samples were requested but there were 333 rows in the data.
i 333 samples will be used.

13 changes: 13 additions & 0 deletions tests/testthat/_snaps/recipes-varying.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# recipe steps with non-varying args error if specified as varying()

Code
varying_args(rec_bad_varying)
Condition
Error in `map()`:
i In index: 1.
Caused by error in `map2()`:
i In index: 5.
i With name: skip.
Caused by error in `.f()`:
! The argument skip for a recipe step of type "step_type" is not allowed to vary.

36 changes: 36 additions & 0 deletions tests/testthat/_snaps/recipes1.1.0/recipes-nnmf_sparse.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Correct values

Code
print(rec)
Message

-- Recipe ----------------------------------------------------------------------

-- Inputs
Number of variables by role
outcome: 1
predictor: 4

-- Operations
* Non-negative matrix factorization for: all_predictors()

# No NNF

Code
print(rec)
Message

-- Recipe ----------------------------------------------------------------------

-- Inputs
Number of variables by role
outcome: 1
predictor: 4

-- Training information
Training data contained 150 data points and no incomplete rows.

-- Operations
* No non-negative matrix factorization was extracted from: Sepal.Length,
Sepal.Width, Petal.Length, Petal.Width | Trained

6 changes: 0 additions & 6 deletions tests/testthat/_snaps/survival-tune-bayes.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@

Code
expect_snapshot_plot(print(autoplot(bayes_dynamic_res)), "dyn-bayes")
Condition
Warning in `filter_plot_eval_time()`:
No evaluation time was set; a value of 5 was used.

# Bayesian tuning survival models with mixture of metric types

Expand All @@ -32,9 +29,6 @@

Code
expect_snapshot_plot(print(autoplot(bayes_mixed_res)), "mix-bayes-0-times")
Condition
Warning in `filter_plot_eval_time()`:
No evaluation time was set; a value of 5 was used.

---

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/testthat/_snaps/survival-tune-bayes/dyn-bayes.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/testthat/_snaps/survival-tune-bayes/mix-bayes-0-times.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion tests/testthat/test-glmnet-linear.R
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ test_that('multi_predict() with default or single penalty value', {

test_that('error traps', {
skip_if_not_installed("glmnet")
skip_if_not_installed("parsnip", minimum_version = "1.2.1.9001")
skip_if_not_installed("parsnip", minimum_version = "1.2.1.9002")

expect_snapshot(error = TRUE, {
linear_reg(penalty = 0.01) %>%
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-glmnet-logistic.R
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ test_that("class predictions are factors with all levels", {

test_that('error traps', {
skip_if_not_installed("glmnet")
skip_if_not_installed("parsnip", minimum_version = "1.2.1.9001")
skip_if_not_installed("parsnip", minimum_version = "1.2.1.9002")

data("lending_club", package = "modeldata", envir = rlang::current_env())

Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-glmnet-multinom.R
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ test_that("class predictions are factors with all levels", {

test_that('error traps', {
skip_if_not_installed("glmnet")
skip_if_not_installed("parsnip", minimum_version = "1.2.1.9001")
skip_if_not_installed("parsnip", minimum_version = "1.2.1.9002")

data("hpc_data", package = "modeldata", envir = rlang::current_env())

Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-parsnip-survival-censoring-model.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ test_that("predict() avoids zero probabilities", {
})

test_that("Handle unknown censoring model", {
skip_if_not_installed("parsnip", minimum_version = "1.1.0.9002")
skip_if_not_installed("parsnip", minimum_version = "1.2.1.9002")
mod_fit <-
survival_reg() %>%
fit(Surv(time, status) ~ age + sex, data = lung)
Expand Down
4 changes: 3 additions & 1 deletion tests/testthat/test-parsnip-survival-censoring-weights.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ test_that('compute Graf weights', {
})

test_that("error messages in context of .censoring_weights_graf()", {
skip_if_not_installed("parsnip", minimum_version = "1.2.1.9002")

lung2 <- lung %>%
dplyr::mutate(surv = Surv(time, status), .keep = "unused")

Expand Down Expand Up @@ -228,7 +230,7 @@ test_that("error for .censoring_weights_graf.workflow()", {

test_that("error for .censoring_weights_graf() from .check_censor_model()", {
# temporarily its own test, see above
skip_if_not_installed("parsnip", minimum_version = "1.1.0.9003")
skip_if_not_installed("parsnip", minimum_version = "1.2.1.9002")
wrong_model <- fit(linear_reg(), mpg ~ ., data = mtcars)
expect_snapshot(error = TRUE, .censoring_weights_graf(wrong_model, mtcars))
})
Expand Down
2 changes: 2 additions & 0 deletions tests/testthat/test-parsnip-survival-standalone.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
skip_if_not_installed("survival")

test_that(".is_surv()", {
skip_if_not_installed("parsnip", minimum_version = "1.2.1.9002")
times <- seq(1, 100, length.out = 5)
events <- c(1, 0, 1, 0, 1)
right_c <- survival::Surv(times, events)
Expand All @@ -11,6 +12,7 @@ test_that(".is_surv()", {
})

test_that(".check_cens_type()", {
skip_if_not_installed("parsnip", minimum_version = "1.2.1.9002")
times <- seq(1, 100, length.out = 5)
events <- c(1, 0, 1, 0, 1)
left_c <- survival::Surv(times, events, type = "left")
Expand Down
1 change: 1 addition & 0 deletions tests/testthat/test-randomForest.R
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ test_that('randomForest regression prediction', {
test_that('argument checks for data dimensions', {

skip_if_not_installed("randomForest")
skip_if_not_installed("parsnip", minimum_version = "1.2.1.9002")

data(penguins, package = "modeldata")
penguins <- na.omit(penguins)
Expand Down
5 changes: 3 additions & 2 deletions tests/testthat/test-recipes-varying.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@ test_that('recipe parameters', {

test_that("recipe steps with non-varying args error if specified as varying()", {
withr::local_options(lifecycle_verbosity = "quiet")
skip("not applicable")

rec_bad_varying <- rec_1
rec_bad_varying$steps[[1]]$skip <- varying()

expect_error(
expect_snapshot(
varying_args(rec_bad_varying),
"The following argument for a recipe step of type 'step_center' is not allowed to vary: 'skip'."
error = TRUE
)
})

Expand Down
Loading