-
Notifications
You must be signed in to change notification settings - Fork 15
/
engine.py
111 lines (86 loc) · 4.31 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import math
import os
import sys
from typing import Iterable
import numpy as np
import copy
import itertools
import torch
import util.misc as utils
from datasets.hico_eval import HICOEvaluator
from datasets.vcoco_eval import VCOCOEvaluator
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, max_norm: float = 0):
model.train()
criterion.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
if hasattr(criterion, 'loss_labels'):
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
else:
metric_logger.add_meter('obj_class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 10
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
samples = samples.to(device)
targets = [{k: v.to(device) for k, v in t.items() if k != 'filename'} for t in targets]
outputs = model(samples)
#print(targets)
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_unscaled = {f'{k}_unscaled': v
for k, v in loss_dict_reduced.items()}
loss_dict_reduced_scaled = {k: v * weight_dict[k]
for k, v in loss_dict_reduced.items() if k in weight_dict}
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
loss_value = losses_reduced_scaled.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
print(loss_dict_reduced)
sys.exit(1)
optimizer.zero_grad()
losses.backward()
if max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()
metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
if hasattr(criterion, 'loss_labels'):
metric_logger.update(class_error=loss_dict_reduced['class_error'])
else:
metric_logger.update(obj_class_error=loss_dict_reduced['obj_class_error'])
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate_hoi(dataset_file, model, postprocessors, data_loader, subject_category_id, device, args):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
preds = []
gts = []
indices = []
for samples, targets in metric_logger.log_every(data_loader, 10, header):
samples = samples.to(device)
outputs = model(samples)
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors['hoi'](outputs, orig_target_sizes)
preds.extend(list(itertools.chain.from_iterable(utils.all_gather(results))))
gts.extend(list(itertools.chain.from_iterable(utils.all_gather(copy.deepcopy(targets)))))
metric_logger.synchronize_between_processes()
img_ids = [img_gts['id'] for img_gts in gts]
_, indices = np.unique(img_ids, return_index=True)
preds = [img_preds for i, img_preds in enumerate(preds) if i in indices]
gts = [img_gts for i, img_gts in enumerate(gts) if i in indices]
if dataset_file == 'hico':
evaluator = HICOEvaluator(preds, gts, data_loader.dataset.rare_triplets,
data_loader.dataset.non_rare_triplets, data_loader.dataset.correct_mat, args=args)
elif dataset_file == 'vcoco':
evaluator = VCOCOEvaluator(preds, gts, data_loader.dataset.correct_mat, use_nms_filter=args.use_nms_filter)
stats = evaluator.evaluate()
return stats