diff --git a/direct/nn/openvino/openvino_model.py b/direct/nn/openvino/openvino_model.py index 793cd0ac..1fbf50fb 100644 --- a/direct/nn/openvino/openvino_model.py +++ b/direct/nn/openvino/openvino_model.py @@ -1,9 +1,9 @@ import io import torch -from torch import nn from openvino.inference_engine import IECore from openvino_extensions import get_extensions_path +from torch import nn class InstanceNorm2dFunc(torch.autograd.Function): diff --git a/direct/types.py b/direct/types.py index 2858810e..4faa1504 100644 --- a/direct/types.py +++ b/direct/types.py @@ -10,4 +10,4 @@ Number = Union[float, int] PathOrString = Union[pathlib.Path, str] FileOrUrl = NewType("FileOrUrl", PathOrString) -HasStateDict = Union[nn.Module, torch.optim.Optimizer, GradScaler] +HasStateDict = Union[nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, GradScaler] diff --git a/setup.py b/setup.py index 18b1275a..22183760 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ "numpy>=1.21.2", "h5py==3.3.0", "omegaconf==2.1.1", - "torch>=1.10.2", + "torch==1.11.0", "torchvision", "scikit-image>=0.19.0", "scikit-learn>=1.0.1",