-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixup all checkpointing examples #323
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@sgugger 50/50 on whether to consider these "slow" tests or not. They add ~1.5 min to the CI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Just one question on the tests added (nice addition btw :-) )
Thanks a lot for fixing those!
tests/test_examples.py
Outdated
with mock.patch.object(sys, "argv", testargs): | ||
checkpointing.main() | ||
self.assertTrue(os.path.exists(os.path.join(tmpdir, "epoch_0"))) | ||
with self.assertRaises(AssertionError): | ||
mocked_print.assert_any_call("epoch 0:", {"accuracy": 0.5, "f1": 0.0}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are we sure we will get those values exactly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This stems from the scheduler we use in the examples, it makes it impossible for the model to train quickly so we always get an accuracy of .5 and an f1 of 0 for all of our epochs. Hence why none of these example tests check for if we get "good" accuracy, it's for independent behavior.
But I dug deeper and found mock.ANY
. In this case we only care about matching the epoch *
text, not the results. So instead we have something like this:
dummy_results = {"accuracy":mock.ANY, "f1":mock.ANY}
with self.assertRaises(AssertionError):
mocked_print.assert_any_call("epoch 0:", dummy_results)
Which helps me sleep much better at night
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You never know if your model might have learned a tiny something, so yes, much better :-)
Fix logic in all checkpointing examples
What does this add?
This PR fixes a number of bugs currently present in the save/load examples
Who is it for?
Closes #322
Why is it needed?
As I was exploring solving 322 and writing tests, I was noticing that some behaviors weren't quite behaving how I would have expected them to.
It also didn't make logical sense to me that if we resume at epoch 1, the numbering starts at epoch 0 again (and thus, our checkpoint saves do as well!). So, that behavior had to change slightly.