-
Notifications
You must be signed in to change notification settings - Fork 10
/
dataset-mnist.lua
84 lines (69 loc) · 2.42 KB
/
dataset-mnist.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
require 'torch'
require 'paths'
mnist = {}
mnist.path_remote = 'https://s3.amazonaws.com/torch7/data/mnist.t7.tgz'
mnist.path_dataset = 'mnist.t7'
mnist.path_trainset = paths.concat(mnist.path_dataset, 'train_32x32.t7')
mnist.path_testset = paths.concat(mnist.path_dataset, 'test_32x32.t7')
function mnist.download()
if not paths.filep(mnist.path_trainset) or not paths.filep(mnist.path_testset) then
local remote = mnist.path_remote
local tar = paths.basename(remote)
os.execute('wget ' .. remote .. '; ' .. 'tar xvf ' .. tar .. '; rm ' .. tar)
end
end
function mnist.loadTrainSet(maxLoad, geometry)
return mnist.loadDataset(mnist.path_trainset, maxLoad, geometry)
end
function mnist.loadTestSet(maxLoad, geometry)
return mnist.loadDataset(mnist.path_testset, maxLoad, geometry)
end
-- raw image: [0,255]
function mnist.loadDataset(fileName, maxLoad)
mnist.download()
local f = torch.load(fileName, 'ascii')
local data = f.data:type(torch.getdefaulttensortype())
local labels = f.labels
local nExample = f.data:size(1)
if maxLoad and maxLoad > 0 and maxLoad < nExample then
nExample = maxLoad
print('<mnist> loading only ' .. nExample .. ' examples')
end
data = data[{{1,nExample},{},{},{}}]
labels = labels[{{1,nExample}}]
print('<mnist> done')
local dataset = {}
dataset.data = data
dataset.labels = labels
function dataset:normalize(mean_, std_)
local mean = mean_ or data:view(data:size(1), -1):mean(1)
local std = std_ or data:view(data:size(1), -1):std(1, true)
for i=1,data:size(1) do
data[i]:add(-mean[1][i])
if std[1][i] > 0 then
tensor:select(2, i):mul(1/std[1][i])
end
end
return mean, std
end
function dataset:normalizeGlobal(mean_, std_)
local std = std_ or data:std()
local mean = mean_ or data:mean()
data:add(-mean)
data:mul(1/std)
return mean, std
end
function dataset:size()
return nExample
end
local labelvector = torch.zeros(10)
setmetatable(dataset, {__index = function(self, index)
local input = self.data[index]
local class = self.labels[index]
local label = labelvector:zero()
label[class] = 1
local example = {input, label}
return example
end})
return dataset
end