-
Notifications
You must be signed in to change notification settings - Fork 423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deprecate/compute on step #792
Conversation
Codecov Report
@@ Coverage Diff @@
## master #792 +/- ##
=====================================
- Coverage 95% 95% -0%
=====================================
Files 166 166
Lines 6795 6796 +1
=====================================
- Hits 6469 6460 -9
- Misses 326 336 +10 |
…hLightning/metrics into deprecate/compute_on_step
@SkafteNicki mind checking the last two failing tests? |
|
@SkafteNicki why was Thanks |
Agree with @Alec-Stashevsky 's comment. If I am reading the code correctly: def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Automatically calls ``update()``.
Returns the metric value over inputs if ``compute_on_step`` is True.
"""
# add current step
if self._is_synced:
raise TorchMetricsUserError(
"The Metric shouldn't be synced when performing ``update``. "
"HINT: Did you forget to call ``unsync`` ?."
)
# global accumulation
self.update(*args, **kwargs)
if self.compute_on_step:
self._to_sync = self.dist_sync_on_step
# skip restore cache operation from compute as cache is stored below.
self._should_unsync = False
# save context before switch
cache = {attr: getattr(self, attr) for attr in self._defaults}
# call reset, update, compute, on single batch
self._enable_grad = True # allow grads for batch computation
self.reset()
self.update(*args, **kwargs)
self._forward_cache = self.compute()
# restore context
for attr, val in cache.items():
setattr(self, attr, val)
self._is_synced = False
self._should_unsync = True
self._to_sync = True
self._computed = None
self._enable_grad = False
return self._forward_cache This used to be the |
Hi @Alec-Stashevsky metric = Metric(compute_on_step=False)
metric(input) you should instead just do metric = Metric()
metric.update(input) @fierval the change in 0.8 just have the effect that |
What does this PR do?
Fixes #789
Deprecates
compute_on_step
argument.Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃