-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimgloader.py
202 lines (188 loc) · 8.79 KB
/
imgloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""Utilities for real-time data augmentation on image data.
https://raw.githubusercontent.com/keras-team/keras-preprocessing/master/keras_preprocessing/image/numpy_array_iterator.py
modified with no numpy array copying
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import warnings
import numpy as np
import random
import tensorflow as tf
from keras.preprocessing.image import Iterator
from keras.preprocessing.image import array_to_img
class NumpyArrayIterator(Iterator):
"""Iterator yielding data from a Numpy array.
# Arguments
x: Numpy array of input data or tuple.
If tuple, the second elements is either
another numpy array or a list of numpy arrays,
each of which gets passed
through as an output without any modifications.
y: Numpy array of targets data.
image_data_generator: Instance of `ImageDataGenerator`
to use for random transformations and normalization.
batch_size: Integer, size of a batch.
shuffle: Boolean, whether to shuffle the data between epochs.
sample_weight: Numpy array of sample weights.
seed: Random seed for data shuffling.
data_format: String, one of `channels_first`, `channels_last`.
save_to_dir: Optional directory where to save the pictures
being yielded, in a viewable format. This is useful
for visualizing the random transformations being
applied, for debugging purposes.
save_prefix: String prefix to use for saving sample
images (if `save_to_dir` is set).
save_format: Format to use for saving sample images
(if `save_to_dir` is set).
subset: Subset of data (`"training"` or `"validation"`) if
validation_split is set in ImageDataGenerator.
ignore_class_split: Boolean (default: False), ignore difference
in number of classes in labels across train and validation
split (useful for non-classification tasks)
dtype: Dtype to use for the generated arrays.
"""
def __new__(cls, *args, **kwargs):
try:
from tensorflow.keras.utils import Sequence as TFSequence
if TFSequence not in cls.__bases__:
cls.__bases__ = cls.__bases__ + (TFSequence,)
except ImportError:
pass
return super(NumpyArrayIterator, cls).__new__(cls)
def __init__(self,
x,
y,
image_data_generator,
batch_size=32,
shuffle=False,
sample_weight=None,
seed=None,
data_format='channels_last',
save_to_dir=None,
save_prefix='',
save_format='png',
subset=None,
ignore_class_split=False,
dtype='float32',
transform_img=lambda x: x,
network_heads=1):
self.dtype = dtype
if (type(x) is tuple) or (type(x) is list):
if type(x[1]) is not list:
x_misc = [np.asarray(x[1])]
else:
x_misc = [np.asarray(xx) for xx in x[1]]
x = x[0]
for xx in x_misc:
if len(x) != len(xx):
raise ValueError(
'All of the arrays in `x` '
'should have the same length. '
'Found a pair with: len(x[0]) = %s, len(x[?]) = %s' %
(len(x), len(xx)))
else:
x_misc = []
if y is not None and len(x) != len(y):
raise ValueError('`x` (images tensor) and `y` (labels) '
'should have the same length. '
'Found: x.shape = %s, y.shape = %s' %
(np.asarray(x).shape, np.asarray(y).shape))
if sample_weight is not None and len(x) != len(sample_weight):
raise ValueError('`x` (images tensor) and `sample_weight` '
'should have the same length. '
'Found: x.shape = %s, sample_weight.shape = %s' %
(np.asarray(x).shape, np.asarray(sample_weight).shape))
if subset is not None:
if subset not in {'training', 'validation'}:
raise ValueError('Invalid subset name:', subset,
'; expected "training" or "validation".')
split_idx = int(len(x) * image_data_generator._validation_split)
if (y is not None and not ignore_class_split and not
np.array_equal(np.unique(y[:split_idx]),
np.unique(y[split_idx:]))):
raise ValueError('Training and validation subsets '
'have different number of classes after '
'the split. If your numpy arrays are '
'sorted by the label, you might want '
'to shuffle them.')
if subset == 'validation':
x = x[:split_idx]
x_misc = [np.asarray(xx[:split_idx]) for xx in x_misc]
if y is not None:
y = y[:split_idx]
else:
x = x[split_idx:]
x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc]
if y is not None:
y = y[split_idx:]
# self.x = np.asarray(x, dtype=self.dtype)
self.x = x
self.x_misc = x_misc
if self.x.ndim != 4:
raise ValueError('Input data in `NumpyArrayIterator` '
'should have rank 4. You passed an array '
'with shape', self.x.shape)
channels_axis = 3 if data_format == 'channels_last' else 1
if self.x.shape[channels_axis] not in {1, 3, 4}:
warnings.warn('NumpyArrayIterator is set to use the '
'data format convention "' + data_format + '" '
'(channels on axis ' + str(channels_axis) +
'), i.e. expected either 1, 3, or 4 '
'channels on axis ' + str(channels_axis) + '. '
'However, it was passed an array with shape ' +
str(self.x.shape) + ' (' +
str(self.x.shape[channels_axis]) + ' channels).')
if y is not None:
# self.y = np.asarray(y)
self.y = y
else:
self.y = None
if sample_weight is not None:
self.sample_weight = np.asarray(sample_weight)
else:
self.sample_weight = None
self.image_data_generator = image_data_generator
self.data_format = data_format
self.save_to_dir = save_to_dir
self.save_prefix = save_prefix
self.save_format = save_format
self.transform_img = transform_img
self.network_heads = network_heads
super(NumpyArrayIterator, self).__init__(x.shape[0],
batch_size,
shuffle,
seed)
def _get_batches_of_transformed_samples(self, index_array):
batch_x = np.zeros(tuple([len(index_array)] + list(self.x.shape)[1:]),
dtype=self.dtype)
batch_y = np.zeros(tuple([len(index_array)] + list(self.y.shape)[1:]),
dtype=self.dtype)
for i, j in enumerate(index_array):
rseed = random.randint(0, 1000000)
x = self.x[j]
y = self.y[j]
params = self.image_data_generator.get_random_transform(x.shape, seed = rseed)
params_y = self.image_data_generator.get_random_transform(y.shape, seed = rseed)
x = self.image_data_generator.apply_transform(
x.astype(self.dtype), params)
# x = self.image_data_generator.standardize(x)
x = self.transform_img(x)
y = self.image_data_generator.apply_transform(
y.astype(self.dtype), params_y)
# y = self.image_data_generator.standardize(y)
y = self.transform_img(y)
batch_x[i] = x
batch_y[i] = y
batch_x_miscs = [xx[index_array] for xx in self.x_misc]
output = (batch_x if batch_x_miscs == []
else [batch_x] + batch_x_miscs,)
if self.y is None:
return output[0]
if self.network_heads > 1:
batch_y = self.network_heads * [batch_y]
output += (batch_y,)
if self.sample_weight is not None:
output += (self.sample_weight[index_array],)
return output