Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow type = "pred_int" for logistic regression? #1036

Closed
millermc38 opened this issue Dec 11, 2023 · 3 comments
Closed

Allow type = "pred_int" for logistic regression? #1036

millermc38 opened this issue Dec 11, 2023 · 3 comments

Comments

@millermc38
Copy link

millermc38 commented Dec 11, 2023

As far as I can tell based on the code below, the capability to produce prediction intervals for logistic regression has not been implemented. Apologies if I'm mistaken and just not managed to figure out how to do it. I could do the song and dance of bootstrapping, but that is not my preferred method (unless you recommend it for technical reasons over a theoretical solution). Here is a small reproducible example taken mostly from documentation:

library(tidymodels)
library(tidyverse)

hotels <- 
  read_csv("https://tidymodels.org/start/case-study/hotels.csv") %>%
  mutate(across(where(is.character), as.factor))

set.seed(123)
splits      <- initial_split(hotels, strata = children)

hotel_other <- training(splits)
hotel_test  <- testing(splits)



set.seed(234)
val_set <- validation_split(hotel_other, 
                            strata = children, 
                            prop = 0.80)




lr_mod <- 
  logistic_reg() %>% 
  set_engine("glm")



holidays <- c("AllSouls", "AshWednesday", "ChristmasEve", "Easter", 
              "ChristmasDay", "GoodFriday", "NewYearsDay", "PalmSunday")

lr_recipe <- 
  recipe(children ~ ., data = hotel_other) %>% 
  step_date(arrival_date) %>% 
  step_holiday(arrival_date, holidays = holidays) %>% 
  step_rm(arrival_date) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())


lr_workflow <- 
  workflow() %>% 
  add_model(lr_mod) %>% 
  add_recipe(lr_recipe)

lr_workflow_trained<-lr_workflow%>%fit(data = hotel_other)



ci_conf_int<-predict.model_fit(object = lr_workflow_trained%>%extract_fit_parsnip(),
                                        new_data = bake(prep(lr_recipe),hotel_test),
                                        type = "conf_int")

#Does not work :(
ci_pred_int<-predict.model_fit(object = lr_workflow_trained%>%extract_fit_parsnip(),
                                 new_data = bake(prep(lr_recipe),hotel_test),
                                 type = "pred_int")

Session Info:

R version 4.2.2 (2022-10-31 ucrt)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19045)

Matrix products: default

locale:
[1] LC_COLLATE=English_United States.utf8  LC_CTYPE=English_United States.utf8   
[3] LC_MONETARY=English_United States.utf8 LC_NUMERIC=C                          
[5] LC_TIME=English_United States.utf8    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] lubridate_1.9.2    forcats_1.0.0      stringr_1.5.0      readr_2.1.4       
 [5] tidyverse_2.0.0    yardstick_1.2.0    workflowsets_1.0.1 workflows_1.1.3   
 [9] tune_1.1.2         tidyr_1.3.0        tibble_3.2.1       rsample_1.2.0     
[13] recipes_1.0.8      purrr_1.0.2        parsnip_1.1.1      modeldata_1.2.0   
[17] infer_1.0.4        ggplot2_3.4.3      dplyr_1.1.2        dials_1.2.0       
[21] scales_1.2.1       broom_1.0.5        tidymodels_1.1.1  

loaded via a namespace (and not attached):
 [1] fs_1.6.2            bit64_4.0.5         DiceDesign_1.9      fBasics_4032.96    
 [5] tools_4.2.2         backports_1.4.1     utf8_1.2.2          R6_2.5.1           
 [9] rpart_4.1.19        colorspace_2.0-3    nnet_7.3-18         withr_2.5.0        
[13] clock_0.6.1         tidyselect_1.2.0    timeSeries_4031.107 curl_4.3.3         
[17] bit_4.0.5           compiler_4.2.2      cli_3.6.1           pacman_0.5.1       
[21] mvtnorm_1.1-3       spatial_7.3-15      digest_0.6.31       pkgconfig_2.0.3    
[25] parallelly_1.35.0   lhs_1.1.6           stabledist_0.7-1    rlang_1.1.1        
[29] rstudioapi_0.15.0   circular_0.4-95     generics_0.1.3      vroom_1.6.3        
[33] zip_2.2.2           magrittr_2.0.3      Matrix_1.6-0        Rcpp_1.0.10        
[37] munsell_0.5.0       fansi_1.0.3         GPfit_1.0-8         lifecycle_1.0.3    
[41] furrr_0.3.1         stringi_1.7.8       snakecase_0.11.0    MASS_7.3-58.1      
[45] stable_1.1.6        grid_4.2.2          parallel_4.2.2      listenv_0.9.0      
[49] crayon_1.5.2        lattice_0.20-45     splines_4.2.2       hms_1.1.3          
[53] pillar_1.9.0        statip_0.2.3        boot_1.3-28         rjson_0.2.21       
[57] future.apply_1.10.0 codetools_0.2-18    rmutil_1.1.10       glue_1.6.2         
[61] data.table_1.14.6   vctrs_0.6.3         tzdb_0.3.0          foreach_1.5.2      
[65] gtable_0.3.3        clue_0.3-65         future_1.32.0       gower_1.0.1        
[69] janitor_2.2.0       prodlim_2023.03.31  class_7.3-20        survival_3.4-0     
[73] modeest_2.4.0       timeDate_4022.108   this.path_1.1.0     iterators_1.0.14   
[77] hardhat_1.3.0       writexl_1.4.2       cluster_2.1.4       lava_1.7.2.1       
[81] timechange_0.1.1    globals_0.16.2      ellipsis_0.3.2      ipred_0.9-14  
@hfrick
Copy link
Member

hfrick commented Dec 15, 2023

For parsnip, we typically only wrap predictions / prediction types available from the engine and that's not the case for glm(). So from that point of view, I think it's not that likely that we'll do exactly that.

However, you can get prediction intervals via the probably package. It runs under "conformal inference", which does not make it very obvious. Max presented on this at posit::conf this year and used the alternative title "how to make prediction intervals with no parametric assumptions"... (video here). There is also a tidymodels.org article on the topic. Is that helpful?

@millermc38
Copy link
Author

Oh interesting, I didn't even realize that getting prediction intervals with glm() is not straightforward! In that case your line of thinking makes sense to me. Thank you for your helpful reply. I will close this now.

Copy link

github-actions bot commented Jan 2, 2024

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Jan 2, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants