-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_NAS_utils.py
65 lines (56 loc) · 2.02 KB
/
train_NAS_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.metrics import DetectionMetrics_050
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback
from super_gradients.training import models
from super_gradients.training import Trainer
from super_gradients.training import dataloaders
from super_gradients.training.dataloaders.dataloaders import coco_detection_yolo_format_train, coco_detection_yolo_format_val
CHECKPOINT_DIR = 'checkpoints'
trainer = Trainer(experiment_name='yolonas_AI',
ckpt_root_dir=CHECKPOINT_DIR)
dataset_params = {
'data_dir': 'data',
'train_images_dir': 'train/train_img',
'train_labels_dir': 'train/trainyolo',
'val_images_dir': 'val/val_img',
'val_labels_dir': 'val/valyolo',
'test_images_dir': 'test/test_img',
'test_labels_dir': 'test/testyolo',
'classes': ['Pedestrian', 'Cyclist', 'Car', 'Truck', 'Tram', 'Tricycle'],
}
train_data = coco_detection_yolo_format_train(
dataset_params={
'data_dir': dataset_params['data_dir'],
'images_dir': dataset_params['train_images_dir'],
'labels_dir': dataset_params['train_labels_dir'],
'classes': dataset_params['classes']
},
dataloader_params={
'batch_size': 7,
'num_workers': 2
}
)
val_data = coco_detection_yolo_format_val(
dataset_params={
'data_dir': dataset_params['data_dir'],
'images_dir': dataset_params['val_images_dir'],
'labels_dir': dataset_params['val_labels_dir'],
'classes': dataset_params['classes']
},
dataloader_params={
'batch_size': 7,
'num_workers': 2
}
)
test_data = coco_detection_yolo_format_val(
dataset_params={
'data_dir': dataset_params['data_dir'],
'images_dir': dataset_params['test_images_dir'],
'labels_dir': dataset_params['test_labels_dir'],
'classes': dataset_params['classes']
},
dataloader_params={
'batch_size': 7,
'num_workers': 2
}
)