-
Notifications
You must be signed in to change notification settings - Fork 7
/
dataloader.lua
116 lines (103 loc) · 3.31 KB
/
dataloader.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
-- ************************************************************
-- Author : Bumsoo Kim, 2016
-- Github : https://github.com/meliketoy/wide-residual-network
--
-- Korea University, Data-Mining Lab
-- wide-residual-networks Torch implementation
--
-- Description : dataloader.lua
-- Multi-threaded data loader feeding batches during training.
-- ***********************************************************
local datasets = require 'datasets/init'
local Threads = require 'threads'
Threads.serialization('threads.sharedserialize')
local M = {}
local DataLoader = torch.class('resnet.DataLoader', M)
function DataLoader.create(opt)
-- The train and val loader
local loaders = {}
for i, split in ipairs{'train', 'val'} do
local dataset = datasets.create(opt, split)
loaders[i] = M.DataLoader(dataset, opt, split)
end
return table.unpack(loaders)
end
function DataLoader:__init(dataset, opt, split)
local manualSeed = opt.manualSeed
local function init()
require('datasets/' .. opt.dataset)
end
local function main(idx)
if manualSeed ~= 0 then
torch.manualSeed(manualSeed + idx)
end
torch.setnumthreads(1)
_G.dataset = dataset
_G.preprocess = dataset:preprocess()
return dataset:size()
end
local threads, sizes = Threads(opt.nThreads, init, main)
self.nCrops = (split == 'val' and opt.tenCrop) and 10 or 1
self.threads = threads
self.__size = sizes[1][1]
self.batchSize = math.floor(opt.batchSize / self.nCrops)
end
function DataLoader:size()
return math.ceil(self.__size / self.batchSize)
end
function DataLoader:run()
local threads = self.threads
local size, batchSize = self.__size, self.batchSize
local perm = torch.randperm(size)
local idx, sample = 1, nil
local function enqueue()
while idx <= size and threads:acceptsjob() do
local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1))
threads:addjob(
function(indices, nCrops)
local sz = indices:size(1)
local batch, imageSize
local target = torch.IntTensor(sz)
for i, idx in ipairs(indices:totable()) do
local sample = _G.dataset:get(idx)
local input = _G.preprocess(sample.input)
if not batch then
imageSize = input:size():totable()
if nCrops > 1 then table.remove(imageSize, 1) end
batch = torch.FloatTensor(sz, nCrops, table.unpack(imageSize))
end
batch[i]:copy(input)
target[i] = sample.target
end
collectgarbage()
return {
input = batch:view(sz * nCrops, table.unpack(imageSize)),
target = target,
}
end,
function(_sample_)
sample = _sample_
end,
indices,
self.nCrops
)
idx = idx + batchSize
end
end
local n = 0
local function loop()
enqueue()
if not threads:hasjob() then
return nil
end
threads:dojob()
if threads:haserror() then
threads:synchronize()
end
enqueue()
n = n + 1
return n, sample
end
return loop
end
return M.DataLoader