-
Notifications
You must be signed in to change notification settings - Fork 322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gradient verification callback #465
Merged
Merged
Changes from all commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
b0e76b8
initial commit
awaelchli 44f8ade
docs cleanup
awaelchli 52a5134
isort
awaelchli 64c389b
black
awaelchli 8b8548e
top level imports
awaelchli 1b77ce5
rst docs
awaelchli 21cb8a1
update chlog
awaelchli b97a65a
isort again
awaelchli 9f2adcd
format
Borda 15a3240
Merge branch 'master' into feature/gradient-verification
awaelchli c7b610c
Apply suggestions from code review
awaelchli 028be19
fix import
awaelchli 188bf09
Merge remote-tracking branch 'origin/feature/gradient-verification' i…
awaelchli e0dc1fb
increase coverage
awaelchli 1b54855
don't skip tests that can partially run on cpu
awaelchli 04eb2d7
black format
awaelchli 6fa47a0
make bots happy
awaelchli 79d36dc
cleanup
awaelchli 1f2d250
more tests for full coverage
awaelchli 4365e9a
isort, black
awaelchli 7cb9f43
mypy complaining
awaelchli b8b2af2
remove unused import
awaelchli 4ff9630
stop complain
awaelchli beba0b9
try type ignore
awaelchli af7db67
try ignore
awaelchli 15e8991
try ignore
awaelchli 6f52175
try ignore
awaelchli b8ff39e
try ignore
awaelchli 590b722
stupid mypy
awaelchli df8c746
stupid mypy
awaelchli f4d9866
stupid mypy
awaelchli 791f4fa
stupid mypi
awaelchli b6402e7
stupid mypy
awaelchli 2638e3e
ugly yapf
awaelchli 4a1defc
Merge branch 'master' into feature/gradient-verification
awaelchli 3101e06
yapf :(
awaelchli 84d98bb
yapffffffff
awaelchli 224de45
chlog
Borda 8439bd9
Apply suggestions from code review
Borda ae63dad
yapf
Borda File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# type: ignore | ||
from abc import abstractmethod | ||
from copy import deepcopy | ||
from typing import Any, Optional | ||
|
||
import torch.nn as nn | ||
from pytorch_lightning import Callback | ||
from pytorch_lightning.core.lightning import LightningModule | ||
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn | ||
|
||
|
||
class VerificationBase: | ||
""" | ||
Base class for model verification. | ||
All verifications should run with any :class:`torch.nn.Module` unless otherwise stated. | ||
""" | ||
|
||
def __init__(self, model: nn.Module): | ||
""" | ||
Arguments: | ||
model: The model to run verification for. | ||
""" | ||
super().__init__() | ||
self.model = model | ||
|
||
@abstractmethod | ||
def check(self, *args: Any, **kwargs: Any) -> bool: | ||
""" Runs the actual test on the model. All verification classes must implement this. | ||
|
||
Arguments: | ||
*args: Any positional arguments that are needed to run the test | ||
*kwargs: Keyword arguments that are needed to run the test | ||
|
||
Returns: | ||
`True` if the test passes, and `False` otherwise. Some verifications can only be performed | ||
with a heuristic accuracy, thus the return value may not always reflect the true state of | ||
the system in these cases. | ||
""" | ||
|
||
def _get_input_array_copy(self, input_array: Optional[Any] = None) -> Any: | ||
""" | ||
Returns a deep copy of the example input array in cases where it is expected that the | ||
input changes during the verification process. | ||
|
||
Arguments: | ||
input_array: The input to clone. | ||
""" | ||
if input_array is None and isinstance(self.model, LightningModule): | ||
input_array = self.model.example_input_array | ||
input_array = deepcopy(input_array) | ||
|
||
if isinstance(self.model, LightningModule): | ||
input_array = self.model.transfer_batch_to_device(input_array, self.model.device) | ||
else: | ||
input_array = move_data_to_device(input_array, device=next(self.model.parameters()).device) | ||
|
||
return input_array | ||
|
||
def _model_forward(self, input_array: Any) -> Any: | ||
""" | ||
Feeds the input array to the model via the ``__call__`` method. | ||
|
||
Arguments: | ||
input_array: The input that goes into the model. If it is a tuple, it gets | ||
interpreted as the sequence of positional arguments and is passed in by tuple unpacking. | ||
If it is a dict, the contents get passed in as named parameters by unpacking the dict. | ||
Otherwise, the input array gets passed in as a single argument. | ||
|
||
Returns: | ||
The output of the model. | ||
""" | ||
if isinstance(input_array, tuple): | ||
return self.model(*input_array) | ||
if isinstance(input_array, dict): | ||
return self.model(**input_array) | ||
return self.model(input_array) | ||
|
||
|
||
class VerificationCallbackBase(Callback): | ||
""" | ||
Base class for model verification in form of a callback. | ||
This type of verification is expected to only work with | ||
:class:`~pytorch_lightning.core.lightning.LightningModule` and will take the input array | ||
from :attr:`~pytorch_lightning.core.lightning.LightningModule.example_input_array` if needed. | ||
""" | ||
|
||
def __init__(self, warn: bool = True, error: bool = False) -> None: | ||
""" | ||
Arguments: | ||
warn: If ``True``, prints a warning message when verification fails. Default: ``True``. | ||
error: If ``True``, prints an error message when verification fails. Default: ``False``. | ||
""" | ||
self._raise_warning = warn | ||
self._raise_error = error | ||
|
||
def message(self, *args: Any, **kwargs: Any) -> str: | ||
""" | ||
The message to be printed when the model does not pass the verification. | ||
If the message for warning and error differ, override the | ||
:meth:`warning_message` and :meth:`error_message` | ||
methods directly. | ||
|
||
Arguments: | ||
*args: Any positional arguments that are needed to construct the message. | ||
**kwargs: Any keyword arguments that are needed to construct the message. | ||
|
||
Returns: | ||
The message as a string. | ||
""" | ||
|
||
def warning_message(self, *args: Any, **kwargs: Any) -> str: | ||
""" The warning message printed when the model does not pass the verification. """ | ||
return self.message(*args, **kwargs) | ||
|
||
def error_message(self, *args: Any, **kwargs: Any) -> str: | ||
""" The error message printed when the model does not pass the verification. """ | ||
return self.message(*args, **kwargs) | ||
|
||
def _raise(self, *args: Any, **kwargs: Any) -> None: | ||
if self._raise_error: | ||
raise RuntimeError(self.error_message(*args, **kwargs)) | ||
if self._raise_warning: | ||
rank_zero_warn(self.warning_message(*args, **kwargs)) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@akihironitta this mypy tool is not very smart :( It is not liking that the subclass has different args than here. I want this abstract method to be as general as possible and not specify concrete arguments and types. It should only act as an interface. Any suggestions how to proceed? I believe I have to add
# type: ignore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found a related issue in mypy repo, and it basically says that an easy workaround would be to add
# type: ignore[override]
, so shall we just ignore it then?python/mypy#1237 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no chance, it won't work. I tried to put it everywhere: at the top of method, on the same line as signature, below it, on top of the class, on top of the file, both in subclass and superclass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to spam
#type: ignore
everywhere to make it work.This mypy tool, I don't understand it. I spent hours now studying the docs of this tool trying to figure out what the error messages mean. I tried everything, but the
type: ignore
are unavoidable, yet they pollute the code unnecessarily. It's unbelievably frustrating, and I can no longer work on this, sorry.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I completely understand your frustration. Let's ignore them all for now.