forked from ouening/OD_dataset_conversion_scripts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
voc_gen_trainval_test.py
103 lines (88 loc) · 2.94 KB
/
voc_gen_trainval_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
'''
Pascal VOC格式数据集生成ImageSets/Main/train.txt,val.txt,trainval.ttx和test.txt
'''
from pathlib import Path
import os
import sys
import xml.etree.ElementTree as ET
import random
import argparse
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import shutil
def mkdir(path):
# 去除首位空格
path = path.strip()
# 去除尾部 \ 符号
path = path.rstrip("\\")
# 判断路径是否存在
# 存在 True
# 不存在 False
isExists = os.path.exists(path)
# 判断结果
if not isExists:
# 如果不存在则创建目录
# 创建目录操作函数
os.makedirs(path)
print(path + ' 创建成功')
return True
else:
# 如果目录存在则不创建,并提示目录已存在
print(path + ' 目录已存在')
return False
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--voc-root', type=str, required=True,
help='VOC格式数据集根目录,该目录下必须包含JPEGImages和Annotations这两个文件夹')
parser.add_argument('--test-ratio',type=float, default=0.2,
help='验证集比例,默认为0.3')
opt = parser.parse_args()
voc_root = opt.voc_root
print('Pascal VOC格式数据集路径:', voc_root)
xml_file = []
img_files = []
ANNO = os.path.join(voc_root, 'Annotations')
JPEG = os.path.join(voc_root, 'JPEGImages')
ImgSets = os.path.join(voc_root, 'ImageSets')
try:
shutil.rmtree(ImgSets)
except FileNotFoundError as e:
a = 1
mkdir(ImgSets)
ImgSetsMain = os.path.join(ImgSets, 'Main')
try:
shutil.rmtree(ImgSetsMain)
except FileNotFoundError as e:
a = 1
mkdir(ImgSetsMain)
p = Path(JPEG)
files = []
for file in p.iterdir():
name,sufix = file.name.split('.')
files.append(name)
# print(name, sufix)
print('数据集长度:',len(files))
files = shuffle(files)
ratio = opt.test_ratio
trainval, test = train_test_split(files, test_size=ratio)
train, val = train_test_split(trainval,test_size=0.2)
print('训练集数量: ',len(train))
print('验证集数量: ',len(val))
print('测试集数量: ',len(test))
def write_txt(txt_path, data):
'''写入txt文件'''
with open(txt_path,'w') as f:
for d in data:
f.write(str(d))
f.write('\n')
# 写入各个txt文件
trainvaltest_txt = os.path.join(ImgSetsMain,'trainvaltest.txt')
write_txt(trainvaltest_txt, files)
trainval_txt = os.path.join(ImgSetsMain,'trainval.txt')
write_txt(trainval_txt, trainval)
train_txt = os.path.join(ImgSetsMain,'train.txt')
write_txt(train_txt, train)
val_txt = os.path.join(ImgSetsMain,'val.txt')
write_txt(val_txt, val)
test_txt = os.path.join(ImgSetsMain,'test.txt')
write_txt(test_txt, test)