-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for 'reduce_on_plateau' learning rate scheduler #120
Add support for 'reduce_on_plateau' learning rate scheduler #120
Conversation
Hello @SvenVw Thanks for this work ! And very happy to see the package being used ! Currently the 4 tests impacted in It seems to me that the condition in https://github.com/SvenVw/tabnet/blob/f5cefa41cb7886b314d5555fa2f004663001205f/R/model.R#L545-L551 could be rework to first manage the case Then there is a typo in your 3rd test (see the comment) Finally, would you add your name as contributor in the author list of "DESCRIPTION" file ? and add a bullet mentioning the new feature in the NEWS.md file ? Thanks a lot |
Hi Christophe, Thanks for your feedback! I have modified the if statement so that it checks if the step function has the argument |
Codecov Report
@@ Coverage Diff @@
## main #120 +/- ##
==========================================
- Coverage 88.02% 87.95% -0.08%
==========================================
Files 10 10
Lines 1127 1137 +10
==========================================
+ Hits 992 1000 +8
- Misses 135 137 +2
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
Thanks a lot @SvenVw for this ! |
Hi there! I've created this pull request to enhance the functionality of the package by adding support for the
reduce_on_plateau
learning rate scheduler. Through my experience with other networks, I've observed that this particular scheduler often leads to improved model performance.Currently, the documentation states that for the
lr_scheduler
parameter intabnet_config
, any othertorch::lr_scheduler
function can be passed. However, when I attempted to do so, I encountered an error. The issue arises because during (pre)training, onlyscheduler$step()
is supported, whilereduce_on_plateau
requires the loss as input and thus needsscheduler$step(loss)
.By submitting this pull request, I aim to introduce this missing functionality to this nice and usefull package 😃. Please let me know if there's anything else I can provide or do to support this pull request.