-
Notifications
You must be signed in to change notification settings - Fork 92
/
convert_masks.py
83 lines (67 loc) · 1.97 KB
/
convert_masks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/python
# -*- coding: utf-8 -*-
'''
Converts the .mat segmentation labels in the Augmented Pascal VOC
dataset to color-coded .png images.
Download the dataset from:
http://home.bharathh.info/pubs/codes/SBD/download.html
'''
import argparse
import glob
import os
from os import path
import scipy.io
import PIL
import numpy as np
PASCAL_PALETTE = {
0: (0, 0, 0),
1: (128, 0, 0),
2: (0, 128, 0),
3: (128, 128, 0),
4: (0, 0, 128),
5: (128, 0, 128),
6: (0, 128, 128),
7: (128, 128, 128),
8: (64, 0, 0),
9: (192, 0, 0),
10: (64, 128, 0),
11: (192, 128, 0),
12: (64, 0, 128),
13: (192, 0, 128),
14: (64, 128, 128),
15: (192, 128, 128),
16: (0, 64, 0),
17: (128, 64, 0),
18: (0, 192, 0),
19: (128, 192, 0),
20: (0, 64, 128),
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--in-dir', type=str, help='Input folder',
required=True)
parser.add_argument('--out-dir', type=str, help='Output folder',
required=True)
args = parser.parse_args()
files = sorted(glob.glob(path.join(args.in_dir, '*.mat')))
assert len(files), 'no matlab files found in the input folder'
try:
os.makedirs(args.out_dir)
except OSError:
pass
# BOUNDARIES_IDX = 0
SEGMENTATION_IDX = 1
# CATEGORIES_PRESENT_IDX = 2
for f_cnt, fname in enumerate(files):
mat = scipy.io.loadmat(fname, mat_dtype=True)
seg_data = mat['GTcls'][0][0][SEGMENTATION_IDX]
img_data = np.zeros(seg_data.shape, dtype=np.uint8)
for i in range(img_data.shape[0]):
for j in range(img_data.shape[1]):
img_data[i, j] = seg_data[i, j]
img = PIL.Image.fromarray(img_data)
img_name = str.replace(path.basename(fname), '.mat', '.png')
img.save(path.join(args.out_dir, img_name), 'png')
print(f'{f_cnt:05}/{len(files):05}')
if __name__ == '__main__':
main()