-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_dataset.lua
152 lines (121 loc) · 4.25 KB
/
generate_dataset.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
-- This script loads the dataset: training data and test data
-- For each set, preprocessing, such as normalization, is adopted
--
-- Inspired by https://github.com/torch/demos/blob/master/person-detector/data.lua
------------------------------------------------------------
-- load or generate new dataset:
------------------------------------------------------------
print(sys.COLORS.red .. "<data> creating a new dataset from files:")
-- load files in directory
require "paths"
currentPath = paths.cwd() -- Current Working Directory
--print("Current path: " .. currentPath)
-- Note that lua does not work well with relative path
local trainImgDir = "./data/train/image/"
local trainLabelDir = "./data/train/label/"
print("Start finding files in " .. trainLabelDir)
-- load label file names
fileLabel = {}
for file in paths.files(trainLabelDir) do
if file:find("txt" .. "$") then
local filename = trainLabelDir .. file
--print(filename)
table.insert(fileLabel, filename)
end
end
local sizeTrain = #fileLabel
-- sort file names
table.sort(fileLabel, function(a, b) return a < b end)
--print("Found files:")
--print(fileLabel)
print("Found " .. sizeTrain .. " files.")
------------------------------------------------------------
-----------------------------------------------------------------
-- load label files
-----------------------------------------------------------------
Label = {}
cntErr = 0
cntOk = 0
for i = 1, sizeTrain do
local f = io.open(fileLabel[i], "r")
local labelTmp = {}
local cntElm = 1
if f then
while true do
local f_val = f:read("*n")
if not f_val then
break
end
if cntElm >= 9 and cntElm <= 16 then
labelTmp[#labelTmp + 1] = f_val
end
cntElm = cntElm + 1
end
io.close(f)
cntOk = cntOk + 1
else
cntErr = cntErr + 1
end
Label[#Label + 1] = labelTmp
end
print(cntOk .. " files are valid.")
print(cntErr .. " files are invalid.")
--torch.save("./data/train_label.t7", Label)
--print(Label[10])
-----------------------------------------------------------------
-----------------------------------------------------------------
-- construct a table containing file names
-----------------------------------------------------------------
fileImage = {}
for i = 1, sizeTrain do
table.insert(fileImage, {src = paths.concat(trainImgDir, fileLabel[i]:sub(-13, -5) .. "_src.jpg"),
tar = paths.concat(trainImgDir, fileLabel[i]:sub(-13, -5) .. "_tar.jpg")})
end
--print("Found images:")
--print(fileImage)
-----------------------------------------------------------------
-- It seems that the size of the dataset is not affordable at the current OS system.
-- So, for a temporal solution, I decided to save only filenames so that they are loaded during constructing mini batches.
--print(sys.COLORS.red .. "Save the image list and label list")
--torch.save("./data/img_list.t7", fileImage)
--torch.save("./data/label_list.t7", fileLabel)
-----------------------------------------------------------------
-- Construct the training dataset structure
-----------------------------------------------------------------
sizeImWidth = 300
sizeImHeight = 300
sizeChannel = 3 -- since it is a siamese network
sizeLabel = 2 * 4 -- difference between two corner points
numIm = 2
trainData = {
imSrc = torch.ByteTensor(sizeTrain, sizeChannel, sizeImWidth, sizeImHeight),
imTar = torch.ByteTensor(sizeTrain, sizeChannel, sizeImWidth, sizeImHeight),
labels = torch.Tensor(sizeTrain, sizeLabel),
diffNorm = 20,
size = function() return sizeTrain end
}
-- shuffle dataset: get shuffled indices in this variable:
--local idxTrainShuffle = torch.randperm(sizeTrain)
-- load train image data
require "image"
for i = 1, sizeTrain do
local img_tmp = image.load(fileImage[i].src)
img_tmp:mul(255):byte()
trainData.imSrc[i] = img_tmp:clone()
local img_tmp = image.load(fileImage[i].tar)
img_tmp:mul(255):byte()
trainData.imTar[i] = img_tmp:clone()
trainData.labels[i] = torch.Tensor(Label[i])
if i % 1000 == 0 then
print(i .. " files processed...")
end
end
-- save created dataset
torch.save('./data/train.t7', trainData)
-- Displaying the dataset architecture
print(sys.COLORS.red .. "Training data: ")
print(trainData)
print()
-- preprocessing
trainData.size = function() return sizeTrain end
------------------------------------------------------------