-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
78 lines (60 loc) · 2.6 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import argparse
import logging
import os
import pprint
import torch
from torch import nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader
import yaml
from mmseg.core.evaluation.metrics import mean_iou
from dataset.semi import SemiDataset
from model.semseg.segmentor import Segmentor
from supervised import evaluate
from util.ohem import ProbOhemCrossEntropy2d
from util.utils import count_params, init_log
from util.dist_helper import setup_distributed
from util.classes import CLASSES
parser = argparse.ArgumentParser(description='Semi-Supervised Semantic Segmentation')
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--ckpt', type=str, required=True)
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--port', default=None, type=int)
def main():
args = parser.parse_args()
cfg = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
logger = init_log('global', logging.INFO)
logger.propagate = 0
rank, word_size = setup_distributed(port=args.port)
if rank == 0:
print('{}\n'.format(pprint.pformat(cfg)))
cudnn.enabled = True
cudnn.benchmark = True
model = Segmentor(cfg)
model.load_state_dict(torch.load(args.ckpt, map_location='cpu'))
# print(model)
if rank == 0:
print('Total params: {:.1f}M\n'.format(count_params(model)))
local_rank = int(os.environ["LOCAL_RANK"])
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
output_device=local_rank, find_unused_parameters=False)
valset = SemiDataset(cfg['dataset'], cfg['data_root'], 'val')
valsampler = torch.utils.data.distributed.DistributedSampler(valset)
valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=2,
drop_last=False, sampler=valsampler)
if cfg['dataset'] == 'cityscapes':
eval_mode = 'sliding_window'
else:
eval_mode = 'original'
mIOU, iou_class = evaluate(model, valloader, eval_mode, cfg, local_rank)
if rank == 0:
print('***** Evaluation {} ***** >>>> meanIOU: {:.3f}\n'.format(eval_mode, mIOU))
iou_class = [(cls_idx, iou) for cls_idx, iou in enumerate(iou_class)]
iou_class.sort(key=lambda x:x[1])
for (cls_idx, iou) in iou_class:
print('***** Evaluation ***** >>>> Class [{:} {:}] IoU: {:.2f}'.format(cls_idx, CLASSES[cfg['dataset']][cls_idx], iou))
if __name__ == '__main__':
main()