-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathcnn_get_batch.m
150 lines (134 loc) · 4.01 KB
/
cnn_get_batch.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
function imo = cnn_get_batch(images, varargin)
% Modified from CNN_IMAGENET_GET_BATCH
%
% - added `pad` option
% - deals with images of types other than jpeg
opts.imageSize = [227, 227] ;
opts.border = [29, 29] ;
opts.pad = 0; % [TOP BOTTOM LEFT RIGHT]
opts.keepAspect = true ;
opts.numAugments = 1 ;
opts.transformation = 'none' ;
opts.averageImage = [] ;
opts.rgbVariance = zeros(0,3,'single') ;
opts.interpolation = 'bilinear' ;
opts.numThreads = 1 ;
opts.prefetch = false ;
opts = vl_argparse(opts, varargin);
% if only one value is given, apply the same amount of padding to all borders
if numel(opts.pad)==1, opts.pad = repmat(opts.pad,[1 4]); end
if numel(opts.border)==1, opts.border = repmat(opts.border,[1 2]); end
% fetch is true if images is a list of filenames (instead of
% a cell array of images)
fetch = numel(images) >= 1 && ischar(images{1}) ;
% isjpg is true if all images to fetch are of jpeg format
isjpg = fetch && strcmpi(images{1}(end-3:end),'.jpg');
if opts.prefetch
if isjpg, vl_imreadjpeg(images, 'numThreads', opts.numThreads, 'prefetch'); end
imo = [] ;
return ;
end
if fetch
if isjpg,
im = vl_imreadjpeg(images,'numThreads', opts.numThreads) ;
else
im = cell(size(images));
end
else
im = images ;
end
tfs = [] ;
switch opts.transformation
case 'none'
tfs = [
.5 ;
.5 ;
0 ] ;
case 'f5'
tfs = [...
.5 0 0 1 1 .5 0 0 1 1 ;
.5 0 1 0 1 .5 0 1 0 1 ;
0 0 0 0 0 1 1 1 1 1] ;
case 'f25'
[tx,ty] = meshgrid(linspace(0,1,5)) ;
tfs = [tx(:)' ; ty(:)' ; zeros(1,numel(tx))] ;
tfs_ = tfs ;
tfs_(3,:) = 1 ;
tfs = [tfs,tfs_] ;
case 'stretch'
otherwise
error('Uknown transformations %s', opts.transformation) ;
end
[~,transformations] = sort(rand(size(tfs,2), numel(images)), 1) ;
if ~isempty(opts.rgbVariance) && isempty(opts.averageImage)
opts.averageImage = zeros(1,1,3) ;
end
if numel(opts.averageImage) == 3
opts.averageImage = reshape(opts.averageImage, 1,1,3) ;
end
imo = zeros(opts.imageSize(1), opts.imageSize(2), 3, ...
numel(images)*opts.numAugments, 'single') ;
si = 1 ;
for i=1:numel(images)
% acquire image
if isempty(im{i})
imt = imread(images{i}) ;
imt = single(imt) ; % faster than im2single (and multiplies by 255)
else
imt = im{i} ;
end
if size(imt,3) == 1
imt = cat(3, imt, imt, imt) ;
end
% pad
if ~isempty(opts.pad) && any(opts.pad>0),
imtt = imt;
imt = 255*ones(size(imtt,1)+sum(opts.pad(1:2)), ...
size(imtt,2)+sum(opts.pad(3:4)), 3, 'like', imtt);
imt(opts.pad(1)+(1:size(imtt,1)), opts.pad(3)+(1:size(imtt,2)),:) = imtt;
end
% resize
w = size(imt,2) ;
h = size(imt,1) ;
factor = [(opts.imageSize(1)+opts.border(1))/h ...
(opts.imageSize(2)+opts.border(2))/w];
if opts.keepAspect
factor = max(factor) ;
end
if any(abs(factor - 1) > 0.0001)
imt = imresize(imt, ...
'scale', factor, ...
'method', opts.interpolation) ;
end
% crop & flip
w = size(imt,2) ;
h = size(imt,1) ;
for ai = 1:opts.numAugments
switch opts.transformation
case 'stretch'
sz = round(min(opts.imageSize(1:2)' .* (1-0.1+0.2*rand(2,1)), [h;w])) ;
dx = randi(w - sz(2) + 1, 1) ;
dy = randi(h - sz(1) + 1, 1) ;
flip = rand > 0.5 ;
otherwise
tf = tfs(:, transformations(mod(ai-1, numel(transformations)) + 1)) ;
sz = opts.imageSize(1:2) ;
dx = floor((w - sz(2)) * tf(2)) + 1 ;
dy = floor((h - sz(1)) * tf(1)) + 1 ;
flip = tf(3) ;
end
sx = round(linspace(dx, sz(2)+dx-1, opts.imageSize(2))) ;
sy = round(linspace(dy, sz(1)+dy-1, opts.imageSize(1))) ;
if flip, sx = fliplr(sx) ; end
if ~isempty(opts.averageImage)
offset = opts.averageImage ;
if ~isempty(opts.rgbVariance)
offset = bsxfun(@plus, offset, reshape(opts.rgbVariance * randn(3,1), 1,1,3)) ;
end
imo(:,:,:,si) = bsxfun(@minus, imt(sy,sx,:), offset) ;
else
imo(:,:,:,si) = imt(sy,sx,:) ;
end
si = si + 1 ;
end
end