-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactored fit() interface #1039
Conversation
Hey @ourownstory @noxan, this PR would bring some interface changes (remove minimal, add metrics and checkpointing). Should we provide the metrics and checkpointing arg in NeuralProphet or in the fit() function? |
@karl-richter I like your approach, also agree with moving the metrics and checkpoints to where they actually matter (in the fit function) - would also make it very clear if people disabled metrics that they cannot expect that as return value. |
Interface Change # 0.4.2
def fit(
self,
df,
freq="auto",
validation_df=None,
progress="bar",
minimal=False
)
# 0.5.0
def fit(
self,
df: pd.DataFrame,
freq: str = "auto",
validation_df: pd.DataFrame = None,
# Control the training process (required once we support continuous training, since we might just want to train few additional epochs)
epochs: int = None,
batch_size: int = None,
learning_rate: float = None,
early_stopping: bool = False, # new with Lightning
# Control the verbosity of the training, `minimal=True` deactivates all of the below
minimal: bool = False,
metrics: bool = None, # moved from __init__
progress="bar",
checkpointing: bool = False, # new with Lightning
# Activate continue training
continue_training: bool = False, # new with Lightning
num_workers=0, # distributed training
) |
Codecov Report
@@ Coverage Diff @@
## main #1039 +/- ##
==========================================
- Coverage 90.20% 89.98% -0.22%
==========================================
Files 21 21
Lines 4757 4783 +26
==========================================
+ Hits 4291 4304 +13
- Misses 466 479 +13
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
Model Benchmark
|
Added preparation for distributed training on request by @judussoari. This includes the |
@karl-richter Many thanks for the great work. I'm still working on the review. For the next pull requests it would be great to split this into several smaller ones, as it is hard to review everything at once and especially to keep track which changes relate to what...
Happy to have a chat with you to explain this in detail, I'll try to conclude the review soon :) |
@karl-richter Overall a great set of changes. Might be best to have a short call on this. We should consider extracting changes into separate PRs and then review and merge them individually. Especially with the release of |
Co-authored-by: Richard Stromer <[email protected]>
Co-authored-by: Richard Stromer <[email protected]>
@noxan I adressed most of our comments:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
🔬 Background
minimal
used a separate function to_train()
called_train_minimal()
that did not contain all of the above mentioned features and was "minimal"._train()
function and addedminimal
as a parameter to it.progress
,metrics
andcheckpointing
parameter. Theminimal
mode further persists, but only remains a shortcut to set the other parameters to false.🔮 Key changes
collect_metrics
parameter in__init__
since this information can only be provided in the fit method in the futuremetrics
andcheckpointing
parameter to thefit()
and subsequent functions📋 Review Checklist
Please make sure to follow our best practices in the Contributing guidelines.