Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TuneWrapper with clssif.glmnet does not seem to apply tuned hyperparameter in outer cross-validation #2472

Closed
MartinKlm opened this issue Nov 6, 2018 · 8 comments

Comments

@MartinKlm
Copy link

I came across this bug when comparing different predictors with "benchmark" and found Lasso to perform surprisingly poor.

So, what I think happens, is that the tuned "s" hyperparameters of "classif.glmnet" are not applied within the outer cross-validation when used in the TuneWrapper setting. The results marelely depend on the initial value of "s"; the extracted tuning values for each cross validation loop appear correct.

To illustrate this, I added three examples on the standard iris data set:

  1. Not initializing "s" when creating the learner and subsequently tuning "s"
  2. Initializing "s" with the a reasonable value (0.01; might be the default according to my simulation, but 0.1 is stated in the appendix) when creating the learner and subsequently tuning "s"
  3. Initializing "s" with the a unreasonable value (100) when creating the learner and subsequently tuning "s"

In general, all these exaples should produceto identical results, as the initial value should not matter, but only the tuned value. This was the case when executing the examples with the ksvm classifier, but does not hold for glmnet).

Here is the code:

##
##
### optimize hyper parameter without initializing it

set.seed(123)
ps = makeParamSet(
  makeDiscreteParam("s", values = 2^(-10:10))
)
ctrl = makeTuneControlGrid()
inner = makeResampleDesc("Subsample", iters = 2)
lasso_lrn <- makeLearner("classif.glmnet", alpha=1)
lrn = makeTuneWrapper(lasso_lrn, resampling = inner, par.set = ps, control = ctrl, show.info = FALSE, measure=list(mlr::acc))

# Outer resampling loop
outer = makeResampleDesc("CV", iters = 3)
r1 = resample(lrn, iris.task, resampling = outer, extract = getTuneResult, show.info = FALSE, measure=list(mlr::acc))


###
###
### optimize hyper parameter and initializing it with default(?)

set.seed(123)
ps = makeParamSet(
  makeDiscreteParam("s", values = 2^(-10:10))
)
ctrl = makeTuneControlGrid()
inner = makeResampleDesc("Subsample", iters = 2)
lasso_lrn <- makeLearner("classif.glmnet", alpha=1, s=0.01)
lrn = makeTuneWrapper(lasso_lrn, resampling = inner, par.set = ps, control = ctrl, show.info = FALSE, measure=list(mlr::acc))

# Outer resampling loop
outer = makeResampleDesc("CV", iters = 3)
r2 = resample(lrn, iris.task, resampling = outer, extract = getTuneResult, show.info = FALSE, measure=list(mlr::acc))


###
###
### optimize hyper parameter and initializing it with bad value

set.seed(123)
ps = makeParamSet(
  makeDiscreteParam("s", values = 2^(-10:10))
)
ctrl = makeTuneControlGrid()
inner = makeResampleDesc("Subsample", iters = 2)
lasso_lrn <- makeLearner("classif.glmnet", alpha=1, s=100)
lrn = makeTuneWrapper(lasso_lrn, resampling = inner, par.set = ps, control = ctrl, show.info = FALSE, measure=list(mlr::acc))

# Outer resampling loop
outer = makeResampleDesc("CV", iters = 3)
r3 = resample(lrn, iris.task, resampling = outer, extract = getTuneResult, show.info = FALSE, measure=list(mlr::acc))



###
###
###
### print results
r1$extract
r2$extract
r3$extract

r1$measures.test
r2$measures.test
r3$measures.test


###
###
### output
#> r1$extract
#[[1]]
#Tune result:
#Op. pars: s=0.00390625
#acc.test.mean=0.9705882
#
#[[2]]
#Tune result:
#Op. pars: s=0.001953125
#acc.test.mean=0.9558824
#
#[[3]]
#Tune result:
#Op. pars: s=0.015625
#acc.test.mean=0.9411765
#
#> r2$extract
#[[1]]
#Tune result:
#Op. pars: s=0.00390625
#acc.test.mean=0.9705882
#
#[[2]]
#Tune result:
#Op. pars: s=0.001953125
#acc.test.mean=0.9558824
#
#[[3]]
#Tune result:
#Op. pars: s=0.015625
#acc.test.mean=0.9411765
#
#> r3$extract
#[[1]]
#Tune result:
#Op. pars: s=0.00390625
#acc.test.mean=0.9705882
#
#[[2]]
#Tune result:
#Op. pars: s=0.001953125
#acc.test.mean=0.9558824
#
#[[3]]
#Tune result:
#Op. pars: s=0.015625
#acc.test.mean=0.9411765
#
#> 
#> r1$measures.test
#  iter  acc
#1    1 0.96
#2    2 0.94
#3    3 0.94
#> r2$measures.test
#  iter  acc
#1    1 0.96
#2    2 0.94
#3    3 0.94
#> r3$measures.test
#  iter  acc
#1    1 0.32
#2    2 0.26
#3    3 0.28

You can see from the output, that the tuning results are identical for all three examples, but the cross-validation porediction results are not. While the example without initializing "s" and the initialization with 0.01 lead to the same results (r1 and r2, respectively), the initialization with 100 leads to completely different predictions (r3).

This brought me to the conclusion, that the tuned values are not applied in the outer cross validation.

Thanks for looking into this,
Martin

sessionInfo()

#R version 3.4.2 (2017-09-28)
#Platform: x86_64-w64-mingw32/x64 (64-bit)
#Running under: Windows >= 8 x64 (build 9200)
#
#Matrix products: default
#
#locale:
#[1] LC_COLLATE=German_Germany.1252  LC_CTYPE=German_Germany.1252    #LC_MONETARY=German_Germany.1252 LC_NUMERIC=C                    LC_TIME=German_Germany.1252    
#
#attached base packages:
#[1] stats     graphics  grDevices utils     datasets  methods   base     
#
#other attached packages:
#[1] mlr_2.13          ParamHelpers_1.11
#
#loaded via a namespace (and not attached):
#[1] parallelMap_1.3   Rcpp_0.12.18      pillar_1.3.0      compiler_3.4.2    plyr_1.8.4        bindr_0.1.1       #iterators_1.0.10  tools_3.4.2       tibble_1.4.2      gtable_0.2.0     
#[11] checkmate_1.8.5   lattice_0.20-35   pkgconfig_2.0.2   rlang_0.2.2       foreach_1.4.4     Matrix_1.2-14     #fastmatch_1.1-0   rstudioapi_0.7    parallel_3.4.2    bindrcpp_0.2.2   
#[21] dplyr_0.7.6       grid_3.4.2        glmnet_2.0-16     tidyselect_0.2.4  glue_1.3.0        data.table_1.11.4 #R6_2.2.2          XML_3.98-1.16     survival_2.42-6   ggplot2_3.0.0    
#[31] purrr_0.2.5       magrittr_1.5      codetools_0.2-15  backports_1.1.2   scales_1.0.0      BBmisc_1.11       #splines_3.4.2     assertthat_0.2.0  colorspace_1.3-2  stringi_1.1.7    
#[41] lazyeval_0.2.1    munsell_0.5.0     crayon_1.3.4    
@mb706
Copy link
Contributor

mb706 commented Nov 6, 2018

Hi Martin, thanks for the report. This is indeed a bug in the TuneWrapper predictLearner function. The problem here is that the ... argument contains the "outer" value of s; even though the lrn object is modified the bad value is passed on inside the .... This only happens with when = "predict" (or "both") parameters.

P.S. I don't know whether you are aware of the classif.cvglmnet learner, which does internal (and more computationally efficient) tuning of the s parameter. Be sure to set the type.measure parameter to something appropriate ("class" in your case) when you use it.

@berndbischl
Copy link
Member

didnt we have EXACTLY that bug before? and I fixed that? this is due to partial matching in R?

@mb706
Copy link
Contributor

mb706 commented Nov 6, 2018

I don't know if this was here before, but the problem here is unrelated to partial matching. The bug happens because predictLearner.<LRN> in most (maybe all) cases doesn't use the .learner object to determine parameters, but uses the ... parameters passed. See e.g. glmnet learner. TuneParams, should call predictLearner(lrn, <xxx> ) with changed arguments.

@MartinKlm
Copy link
Author

Thanks for the quick reply and the hint with cvglmnet!
Keep up the great work!
Martin

@berndbischl
Copy link
Member

I fixed it yesterday on the train. Need to upload and test though. Thx a lot for this report. This was a somewhat weird bug

@jakob-r
Copy link
Member

jakob-r commented Nov 9, 2018

@berndbischl Do you have a PR?

@berndbischl
Copy link
Member

yes here: PR #2479

@pat-s
Copy link
Member

pat-s commented Apr 15, 2019

Merged in 7ea4a57.

@pat-s pat-s closed this as completed Apr 15, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants