Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable setting stop_iter with brulee mlp() #1051

Merged
merged 1 commit into from
Jan 16, 2024
Merged

enable setting stop_iter with brulee mlp() #1051

merged 1 commit into from
Jan 16, 2024

Conversation

simonpcouch
Copy link
Contributor

Closes #1050.🏄

stop_iter had been mistakenly registered as a model argument for mlp().

library(tidymodels)

set.seed(1)
mlp(hidden_units = 5, epochs = 100) %>% 
  set_engine("brulee", stop_iter = 5)  %>% 
  set_mode("regression") %>%
  fit(price ~ latitude + longitude, data = Sacramento)
#> parsnip model object
#> 
#> Multilayer perceptron
#> 
#> relu activation
#> 5 hidden units,  21 model parameters
#> 932 samples, 2 features, numeric outcome 
#> weight decay: 0.001 
#> dropout proportion: 0 
#> batch size: 839 
#> learn rate: 0.01 
#> scaled validation loss after 100 epochs: 1.29


# tunable information is still preserved:
set.seed(1)
res <- 
  tune_grid(
    mlp(hidden_units = 3, epochs = 50) %>% 
      set_engine("brulee", stop_iter = tune())  %>% 
      set_mode("regression"),
    price ~ latitude + longitude,
    vfold_cv(Sacramento, v = 5),
    grid = 3
  )
#> → A | warning: A correlation computation is required, but `estimate` is constant and has 0
#>                standard deviation, resulting in a divide by 0 error. `NA` will be returned.
#> There were issues with some computations   A: x1
#> There were issues with some computations   A: x4
#> There were issues with some computations   A: x5
#> 

collect_metrics(res)
#> # A tibble: 6 × 7
#>   stop_iter .metric .estimator    mean     n  std_err .config             
#>       <int> <chr>   <chr>        <dbl> <int>    <dbl> <chr>               
#> 1        19 rmse    standard   1.46e+5     5  1.17e+4 Preprocessor1_Model1
#> 2        19 rsq     standard   2.61e-2     1 NA       Preprocessor1_Model1
#> 3        13 rmse    standard   2.74e+9     5  2.74e+9 Preprocessor1_Model2
#> 4        13 rsq     standard   8.45e-2     1 NA       Preprocessor1_Model2
#> 5         5 rmse    standard   2.99e+9     5  1.88e+9 Preprocessor1_Model3
#> 6         5 rsq     standard   1.14e-1     2  2.74e-2 Preprocessor1_Model3

Created on 2024-01-16 with reprex v2.0.2

Probably worth writing up a quick check to see if we've done this anywhere else. :)

Copy link
Member

@topepo topepo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@topepo topepo merged commit 2752453 into main Jan 16, 2024
7 checks passed
@topepo topepo deleted the iter-1050 branch January 16, 2024 19:54
Copy link

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Jan 31, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

setting stop_iter with brulee model
2 participants