-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lightning Migration #837
Lightning Migration #837
Conversation
I have a really quick comment. In order to be able to visualize the model architecture we need to do the following change in the training_step method: instead of
we need to add:
Also, in the tutorial network_architecture_visualization.ipynb instead of:
we need:
|
Model Benchmark
Model TrainingPeytonManningYosemiteTempsAirPassengers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Great work!!
All points that we discussed can be addressed in later PRs.
v0 Todos
minimal
mode, eg._train_minimal(...)
_train(...)
v1 Todos
Support theself.metrics.add_specific_target
function whenhighlight_forecast_step_n
is definedv2 Todos (separate PRs)
Changes
Guiding idea: Migrate from using plain PyTorch to the PyTorch Lightning framework.
Consequent changes:
Before
The training logic is contained in
forecaster.py
, manually iter through epochs and batches. Thetrain_epoch()
function directly callsforward()
on the TimeNet model. Optimization happens manually intrain_epoch()
.After
The training loop is abstracted using the Lightning training logic. Init a Lightning
Trainer
object that can run the training loop automatically. Calling thefit()
method on the Lightning trainer executes thetraining_step()
function of the model using the correct epoch and batch. Optimization and parameter updates happen automatically after thetraining_step()
function. The whole training logic is abstracted away. Lightning provides useful tools such as a progress bar, a learning rate finder, early stopping, GPU support etc.