forked from ming71/UCAS-AOD-benchmark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_prepare.py
92 lines (81 loc) · 3.4 KB
/
data_prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import glob
import random
import shutil
from tqdm import tqdm
random.seed(666)
def copyfiles(src_files, dst_folder, is_plane = False):
pbar = tqdm(src_files)
for file in pbar:
pbar.set_description("Creating {}:".format(dst_folder))
if not is_plane:
filename = os.path.split(file)[1]
else:
_filename = os.path.split(file)[1]
name, ext = os.path.splitext(_filename)
filename = 'P' + str(int(name.strip('P')) + 510).zfill(4) + ext
dstfile = os.path.join(dst_folder, filename)
# print(dstfile)
shutil.copyfile(file, dstfile)
def rewrite_label(annos, dst_folder, is_plane = False):
pbar = tqdm(annos)
for file in pbar:
pbar.set_description("Rewriting to {}:".format(dst_folder))
if not is_plane:
filename = os.path.split(file)[1]
else:
_filename = os.path.split(file)[1]
name, ext = os.path.splitext(_filename)
filename = 'P' + str(int(name.strip('P')) + 510).zfill(4) + ext
dstfile = os.path.join(dst_folder, filename)
# print(dstfile)
with open(dstfile, 'w') as fw:
with open(file, 'r') as f:
_lines = f.readlines()
if is_plane:
lines = ['airplane\t' + x for x in _lines]
else:
lines = ['car\t' + x for x in _lines]
content = ''.join(lines)
fw.write(content)
def creat_tree(root_dir):
if not os.path.exists(root_dir):
raise RuntimeError('invalid dataset path!')
os.mkdir(os.path.join(root_dir, 'AllImages'))
os.mkdir(os.path.join(root_dir, 'Annotations'))
car_imgs = glob.glob(os.path.join(root_dir, 'CAR/*.png'))
car_annos = glob.glob(os.path.join(root_dir, 'CAR/P*.txt'))
airplane_imgs = glob.glob(os.path.join(root_dir, 'PLANE/*.png'))
airplane_annos = glob.glob(os.path.join(root_dir, 'PLANE/P*.txt'))
copyfiles(car_imgs, os.path.join(root_dir, 'AllImages') )
copyfiles(airplane_imgs, os.path.join(root_dir, 'AllImages'), True)
rewrite_label(car_annos, os.path.join(root_dir, 'Annotations'))
rewrite_label(airplane_annos, os.path.join(root_dir, 'Annotations'), True)
def generate_files(root_dir, type):
if type not in ["train", "test", "val"]:
assert False, "wrong type"
setfile = os.path.join(root_dir, 'ImageSets/{}.txt'.format(type))
img_dir = os.path.join(root_dir, 'AllImages')
label_dir = os.path.join(root_dir, 'Annotations')
test_dir = os.path.join(root_dir, type)
os.makedirs(test_dir)
if not os.path.exists(setfile):
raise RuntimeError('{} is not founded!'.format(setfile))
with open(setfile, 'r') as f:
lines = f.readlines()
pbar = tqdm(lines)
for line in pbar:
pbar.set_description("Copying to {} dir...".format(type))
filename = line.strip()
src = os.path.join(img_dir, filename + '.png')
dst = os.path.join(test_dir, filename + '.png')
shutil.copyfile(src, dst)
src = os.path.join(label_dir, filename + '.txt')
dst = os.path.join(test_dir, filename + '.txt')
shutil.copyfile(src, dst)
if __name__ == "__main__":
root_dir = 'UCAS_AOD/'
creat_tree(root_dir)
generate_files(root_dir, "train")
generate_files(root_dir, "test")
generate_files(root_dir, "val")