Skip to content

Commit

Permalink
Learning rate scheduler can be checkpointed (#218, Closes #217)
Browse files Browse the repository at this point in the history
HasStateDict type changed to include torch.optim.lr_scheduler._LRScheduler which was missing before, causing the checkpointer to not save/load the state_dict of LRSchedulers.
  • Loading branch information
georgeyiasemis committed Jul 31, 2022
1 parent acaf797 commit dc8bb1c
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion direct/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
Number = Union[float, int]
PathOrString = Union[pathlib.Path, str]
FileOrUrl = NewType("FileOrUrl", PathOrString)
HasStateDict = Union[nn.Module, torch.optim.Optimizer, GradScaler]
HasStateDict = Union[nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, GradScaler]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"numpy>=1.21.2",
"h5py==3.3.0",
"omegaconf==2.1.1",
"torch>=1.10.2",
"torch==1.11.0",
"torchvision",
"scikit-image>=0.19.0",
"scikit-learn>=1.0.1",
Expand Down

0 comments on commit dc8bb1c

Please sign in to comment.