-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
18 lines (17 loc) · 875 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import config as c
from train import train
from utils import load_datasets, make_dataloaders, get_loaders
import os
os.environ['TORCH_HOME'] = 'models\\EfficientNet'
# class_name = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal_nut',
# 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper']
class_name = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
for i in class_name:
c.class_name = i
# sensory dataset
# train_set, test_set = load_datasets(c.dataset_path, c.class_name)
# train_loader, test_loader = make_dataloaders(train_set, test_set)
# semantic dataset
train_loader,test_loader = get_loaders(c.dataset,c.class_name,c.batch_size)
model = train(train_loader, test_loader)
# model = eval(train_loader,test_loader)