-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update docs #18
Merged
Merged
update docs #18
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import time | ||
from typing import Union, Sequence, Callable, Optional | ||
|
||
import brainstate as bst | ||
|
@@ -74,6 +75,7 @@ def compile( | |
self, | ||
optimizer: bst.optim.Optimizer, | ||
metrics: Union[str, Sequence[str]] = None, | ||
measture_train_step_compile_time: bool = False, | ||
): | ||
""" | ||
Configures the trainer for training. | ||
|
@@ -138,6 +140,12 @@ def _loss_fun(): | |
self.fn_outputs_losses_test = bst.compile.jit(fn_outputs_losses_test) | ||
self.fn_train_step = bst.compile.jit(fn_train_step) | ||
|
||
if measture_train_step_compile_time: | ||
t0 = time.time() | ||
self._compile_training_step(self.batch_size) | ||
t1 = time.time() | ||
return self, t1 - t0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): Inconsistent return types could cause bugs - method returns tuple when measuring time but self otherwise Consider using a consistent return type and providing the timing information through a different mechanism, such as a class attribute or logging. |
||
|
||
return self | ||
|
||
@utils.timing | ||
|
@@ -150,6 +158,7 @@ def train( | |
callbacks: Union[Callback, Sequence[Callback]] = None, | ||
model_restore_path: str = None, | ||
model_save_path: str = None, | ||
measture_train_step_time: bool = False, | ||
): | ||
""" | ||
Trains the trainer. | ||
|
@@ -177,6 +186,9 @@ def train( | |
model_save_path (String): Prefix of filenames created for the checkpoint. | ||
""" | ||
|
||
if measture_train_step_time: | ||
t0 = time.time() | ||
|
||
if self.metrics is None: | ||
raise ValueError("Compile the trainer before training.") | ||
|
||
|
@@ -210,8 +222,21 @@ def train( | |
training_display.summary(self.train_state) | ||
if model_save_path is not None: | ||
self.save(model_save_path, verbose=1) | ||
|
||
if measture_train_step_time: | ||
t1 = time.time() | ||
return self, t1 - t0 | ||
return self | ||
|
||
def _compile_training_step(self, batch_size=None): | ||
# get data | ||
self.train_state.set_data_train(*self.problem.train_next_batch(batch_size)) | ||
|
||
# train one batch | ||
self.fn_train_step.compile(self.train_state.X_train, | ||
self.train_state.y_train, | ||
**self.train_state.Aux_train) | ||
|
||
def _train(self, iterations, display_every, batch_size, callbacks): | ||
for i in range(iterations): | ||
callbacks.on_epoch_begin() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (typo): Parameter name contains a typo: 'measture' should be 'measure'
Suggested implementation:
You may need to: