Skip to content

Commit

Permalink
Learning rate scheduler can be checkpointed (#218, Closes #217)
Browse files Browse the repository at this point in the history
HasStateDict type changed to include torch.optim.lr_scheduler._LRScheduler which was missing before, causing the checkpointer to not save/load the state_dict of LRSchedulers.
  • Loading branch information
georgeyiasemis authored Jul 13, 2022
1 parent 702b8f4 commit c058f4c
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion direct/nn/openvino/openvino_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import io

import torch
from torch import nn
from openvino.inference_engine import IECore
from openvino_extensions import get_extensions_path
from torch import nn


class InstanceNorm2dFunc(torch.autograd.Function):
Expand Down
2 changes: 1 addition & 1 deletion direct/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
Number = Union[float, int]
PathOrString = Union[pathlib.Path, str]
FileOrUrl = NewType("FileOrUrl", PathOrString)
HasStateDict = Union[nn.Module, torch.optim.Optimizer, GradScaler]
HasStateDict = Union[nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, GradScaler]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"numpy>=1.21.2",
"h5py==3.3.0",
"omegaconf==2.1.1",
"torch>=1.10.2",
"torch==1.11.0",
"torchvision",
"scikit-image>=0.19.0",
"scikit-learn>=1.0.1",
Expand Down

0 comments on commit c058f4c

Please sign in to comment.