From f947fd7b6ae1eb4fe67b8cc4b16f7ab7c9734009 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 15 Mar 2023 10:46:48 +0000 Subject: [PATCH] add tests on usage of data descriptors --- tests/testthat/test-glmnet-linear.R | 35 ++++++++++++++++++++++++ tests/testthat/test-glmnet-logistic.R | 38 +++++++++++++++++++++++++++ tests/testthat/test-glmnet-multinom.R | 37 ++++++++++++++++++++++++++ tests/testthat/test-glmnet-poisson.R | 37 ++++++++++++++++++++++++++ 4 files changed, 147 insertions(+) diff --git a/tests/testthat/test-glmnet-linear.R b/tests/testthat/test-glmnet-linear.R index 9493e8fe..aef8ca21 100644 --- a/tests/testthat/test-glmnet-linear.R +++ b/tests/testthat/test-glmnet-linear.R @@ -348,3 +348,38 @@ test_that("base-R families: type NULL", { mpred_numeric <- multi_predict(f_fit, hpc[1:5,], type = "numeric") expect_identical(mpred, mpred_numeric) }) + +test_that("data descriptors and quosures work", { + skip_if_not_installed("glmnet") + + my_penalty <- 1 + my_mixture <- 0.3 + my_penalties <- c(0.05, 1) + + # use data descriptor .cols() + # mtcars has 11 columns so 10 predictor columns, thus penalty is 1 + f_fit <- linear_reg(penalty = .cols() - 9, mixture = my_mixture) %>% + set_engine("glmnet", nlambda = 15) %>% + fit(mpg ~ ., data = mtcars) + + expect_identical( + predict(f_fit, mtcars[1:3,]), + predict(f_fit, mtcars[1:3,], penalty = 1) + ) + + expect_identical( + predict(f_fit, mtcars[1:3,], penalty = my_penalty), + predict(f_fit, mtcars[1:3,], penalty = 1) + ) + + expect_identical( + multi_predict(f_fit, mtcars[1:3, ], penalty = my_penalties) %>% + tidyr::unnest(cols = .pred) %>% + dplyr::arrange(penalty) %>% + dplyr::pull(.pred), + c( + predict(f_fit, mtcars[1:3,], penalty = 0.05) %>% pull(.pred), + predict(f_fit, mtcars[1:3,], penalty = 1) %>% pull(.pred) + ) + ) +}) diff --git a/tests/testthat/test-glmnet-logistic.R b/tests/testthat/test-glmnet-logistic.R index 0f97f81c..03c842d1 100644 --- a/tests/testthat/test-glmnet-logistic.R +++ b/tests/testthat/test-glmnet-logistic.R @@ -511,3 +511,41 @@ test_that("base-R families: type NULL", { mpred_class <- multi_predict(f_fit, lending_club[1:5,], type = "class") expect_identical(mpred, mpred_class) }) + +test_that("data descriptors and quosures work", { + skip_if_not_installed("glmnet") + + data("lending_club", package = "modeldata", envir = rlang::current_env()) + lending_club <- lending_club[1:200, ] + + my_penalty <- 1 + my_mixture <- 0.3 + my_penalties <- c(0.05, 1) + + # use data descriptor .cols() + # formula has 3 predictor columns, thus penalty is 1 + f_fit <- logistic_reg(penalty = .cols() - 2, mixture = my_mixture) %>% + set_engine("glmnet", nlambda = 15) %>% + fit(Class ~ log(funded_amnt) + int_rate + term, data = lending_club) + + expect_identical( + predict(f_fit, lending_club[1:3,]), + predict(f_fit, lending_club[1:3,], penalty = 1) + ) + + expect_identical( + predict(f_fit, lending_club[1:3,], penalty = my_penalty), + predict(f_fit, lending_club[1:3,], penalty = 1) + ) + + expect_identical( + multi_predict(f_fit, lending_club[1:3, ], penalty = my_penalties) %>% + tidyr::unnest(cols = .pred) %>% + dplyr::arrange(penalty) %>% + dplyr::pull(.pred_class), + c( + predict(f_fit, lending_club[1:3,], penalty = 0.05) %>% pull(.pred_class), + predict(f_fit, lending_club[1:3,], penalty = 1) %>% pull(.pred_class) + ) + ) +}) diff --git a/tests/testthat/test-glmnet-multinom.R b/tests/testthat/test-glmnet-multinom.R index fcc55a67..cb5b6d73 100644 --- a/tests/testthat/test-glmnet-multinom.R +++ b/tests/testthat/test-glmnet-multinom.R @@ -421,3 +421,40 @@ test_that('error traps', { multi_predict(hpc_data, type = "numeric") }) }) + +test_that("data descriptors and quosures work", { + skip_if_not_installed("glmnet") + + data("hpc_data", package = "modeldata", envir = rlang::current_env()) + + my_penalty <- 1 + my_mixture <- 0.3 + my_penalties <- c(0.05, 1) + + # use data descriptor .cols() + # formula has 3 predictor columns, thus penalty is 1 + f_fit <- multinom_reg(penalty = .cols() - 2, mixture = my_mixture) %>% + set_engine("glmnet", nlambda = 15) %>% + fit(class ~ protocol + log(compounds) + input_fields, data = hpc_data) + + expect_identical( + predict(f_fit, hpc_data[1:3,]), + predict(f_fit, hpc_data[1:3,], penalty = 1) + ) + + expect_identical( + predict(f_fit, hpc_data[1:3,], penalty = my_penalty), + predict(f_fit, hpc_data[1:3,], penalty = 1) + ) + + expect_identical( + multi_predict(f_fit, hpc_data[1:3, ], penalty = my_penalties) %>% + tidyr::unnest(cols = .pred) %>% + dplyr::arrange(penalty) %>% + dplyr::pull(.pred_class), + c( + predict(f_fit, hpc_data[1:3,], penalty = 0.05) %>% pull(.pred_class), + predict(f_fit, hpc_data[1:3,], penalty = 1) %>% pull(.pred_class) + ) + ) +}) diff --git a/tests/testthat/test-glmnet-poisson.R b/tests/testthat/test-glmnet-poisson.R index aca9658f..9be4277c 100644 --- a/tests/testthat/test-glmnet-poisson.R +++ b/tests/testthat/test-glmnet-poisson.R @@ -137,3 +137,40 @@ test_that('error traps', { fit(mpg ~ ., data = mtcars[-(1:4), ]) }) }) + +test_that("data descriptors and quosures work", { + skip_if_not_installed("glmnet") + + data(seniors, package = "poissonreg", envir = rlang::current_env()) + + my_penalty <- 1 + my_mixture <- 0.3 + my_penalties <- c(0.05, 1) + + # use data descriptor .cols() + # formula has 3 predictor columns, thus penalty is 1 + f_fit <- linear_reg(penalty = .cols() - 2, mixture = my_mixture) %>% + set_engine("glmnet", nlambda = 15) %>% + fit(count ~ ., data = seniors) + + expect_identical( + predict(f_fit, seniors[1:3,]), + predict(f_fit, seniors[1:3,], penalty = 1) + ) + + expect_identical( + predict(f_fit, seniors[1:3,], penalty = my_penalty), + predict(f_fit, seniors[1:3,], penalty = 1) + ) + + expect_identical( + multi_predict(f_fit, seniors[1:3, ], penalty = my_penalties) %>% + tidyr::unnest(cols = .pred) %>% + dplyr::arrange(penalty) %>% + dplyr::pull(.pred), + c( + predict(f_fit, seniors[1:3,], penalty = 0.05) %>% pull(.pred), + predict(f_fit, seniors[1:3,], penalty = 1) %>% pull(.pred) + ) + ) +})