-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
44 lines (39 loc) · 1.2 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from pytorch_optimizer import FAdam
from torch.utils.data import TensorDataset, DataLoader
from run_benchmark import run_benchmark
from optimizers.sophia import SophiaG
#from optimizers.ngd import NGD
#from optimizers.adahessian import Adahessian
#from optimizers.lbfgsnew import LBFGSNew
from datasets2.cifar10 import CIFAR10
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)
torch.manual_seed(0)
if len(sys.argv) != 8:
print("python test.py OptimizerClass ModelClass CriterionClass Dataset lr epochs seed")
exit()
optimizer_str = sys.argv[1]
model_str = sys.argv[2]
criterion_str = sys.argv[3]
ds_str = sys.argv[4]
lr_str = sys.argv[5]
epochs = int(sys.argv[6])
seed = int(sys.argv[7])
ModelClass = eval(model_str)
OptimizerClass = eval(optimizer_str)
CriterionClass = eval(criterion_str)
lr = float(lr_str)
DSLoader = eval(ds_str)
ds = DSLoader()
ds.label = ds_str
model = ModelClass()
model.label = model_str
model.to(device)
run_benchmark(model, OptimizerClass, CriterionClass, ds, epochs, seed, lr)