Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for 'reduce_on_plateau' learning rate scheduler #120

Merged
merged 9 commits into from
Jun 10, 2023

Conversation

SvenVw
Copy link
Contributor

@SvenVw SvenVw commented Jun 2, 2023

Hi there! I've created this pull request to enhance the functionality of the package by adding support for the reduce_on_plateau learning rate scheduler. Through my experience with other networks, I've observed that this particular scheduler often leads to improved model performance.

Currently, the documentation states that for the lr_scheduler parameter in tabnet_config, any other torch::lr_scheduler function can be passed. However, when I attempted to do so, I encountered an error. The issue arises because during (pre)training, only scheduler$step() is supported, while reduce_on_plateau requires the loss as input and thus needs scheduler$step(loss).

By submitting this pull request, I aim to introduce this missing functionality to this nice and usefull package 😃. Please let me know if there's anything else I can provide or do to support this pull request.

@SvenVw SvenVw changed the title Add scheduler reduce on plateau Add support for 'reduce_on_plateau' learning rate scheduler Jun 2, 2023
@dfalbel dfalbel requested a review from cregouby June 3, 2023 07:50
@cregouby
Copy link
Collaborator

cregouby commented Jun 3, 2023

Hello @SvenVw

Thanks for this work ! And very happy to see the package being used !

Currently the 4 tests impacted in test_pretraining.R and test-hardhat_parameters.R:97:3 are failing but I don't think it is a big deal to fix them :

It seems to me that the condition in https://github.com/SvenVw/tabnet/blob/f5cefa41cb7886b314d5555fa2f004663001205f/R/model.R#L545-L551 could be rework to first manage the case if (config$lr_scheduler == "reduce_on_plateau") { and use the else { for the other lr_shedulers.

Then there is a typo in your 3rd test (see the comment)

Finally, would you add your name as contributor in the author list of "DESCRIPTION" file ? and add a bullet mentioning the new feature in the NEWS.md file ?

Thanks a lot
Christophe

tests/testthat/test-pretraining.R Outdated Show resolved Hide resolved
tests/testthat/test-pretraining.R Show resolved Hide resolved
@cregouby cregouby added the enhancement New feature or request label Jun 5, 2023
@SvenVw
Copy link
Contributor Author

SvenVw commented Jun 6, 2023

Hi Christophe,

Thanks for your feedback! I have modified the if statement so that it checks if the step function has the argument metrics present or not. In that case it still supports if the user provides NULL or a torch::lr_scheduler function. Otherwise if you check if config$lr_scheduler == 'reduce_on_plateau you get an error if lr_scheduler is NULL. This implementation seems a little bit cumbersome to me, so feel free to suggest an improvement.

@codecov
Copy link

codecov bot commented Jun 6, 2023

Codecov Report

Merging #120 (7bd9c4f) into main (7beb0fe) will decrease coverage by 0.08%.
The diff coverage is 83.33%.

❗ Current head 7bd9c4f differs from pull request most recent head 705be28. Consider uploading reports for the commit 705be28 to get more accurate results

@@            Coverage Diff             @@
##             main     #120      +/-   ##
==========================================
- Coverage   88.02%   87.95%   -0.08%     
==========================================
  Files          10       10              
  Lines        1127     1137      +10     
==========================================
+ Hits          992     1000       +8     
- Misses        135      137       +2     
Impacted Files Coverage Δ
R/model.R 94.54% <83.33%> (-0.22%) ⬇️
R/pretraining.R 94.96% <83.33%> (-0.49%) ⬇️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@cregouby cregouby merged commit 685fb62 into mlverse:main Jun 10, 2023
@cregouby
Copy link
Collaborator

Thanks a lot @SvenVw for this !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants