forked from rllab-snu/Deep-Elastic-Network
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CIFAR_input.py
97 lines (62 loc) · 2.77 KB
/
CIFAR_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
full_data_dir = '/home/ShiMengnan/Dataset/cifar-100-python/train'
vali_dir = '/home/ShiMengnan/Dataset/cifar-100-python/test'
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
IMG_WIDTH = 32
IMG_HEIGHT = 32
IMG_DEPTH = 3
NUM_CLASS = 100
NUM_TRAIN_BATCH = 1
EPOCH_SIZE = 50000 * NUM_TRAIN_BATCH
def _read_one_batch(path):
dicts = np.load(path, allow_pickle=True, encoding='latin1')
data = dicts['data']
label = np.array(dicts['fine_labels'])
return data, label
def read_in_all_images(address_list, shuffle=True):
data = np.array([]).reshape([0, IMG_WIDTH * IMG_HEIGHT * IMG_DEPTH])
label = np.array([])
for address in address_list:
batch_data, batch_label = _read_one_batch(address)
data = np.concatenate((data, batch_data))
label = np.concatenate((label, batch_label))
num_data = len(label)
data = data.reshape((num_data, IMG_HEIGHT * IMG_WIDTH, IMG_DEPTH), order='F')
data = data.reshape((num_data, IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH))
if shuffle is True:
order = np.random.permutation(num_data)
data = data[order, ...]
label = label[order]
data = data.astype(np.float32)
return data, label
def horizontal_flip(image, axis):
flip_prop = np.random.randint(low=0, high=2)
if flip_prop == 0:
image = np.flip(image, axis)
return image
def whitening_image(image_np):
for i in range(len(image_np)):
mean = np.mean(image_np[i, ...])
std = np.max([np.std(image_np[i, ...]), 1.0 / np.sqrt(IMG_HEIGHT * IMG_WIDTH * IMG_DEPTH)])
image_np[i, ...] = (image_np[i, ...] - mean) / std
return image_np
def random_crop_and_flip(batch_data, padding_size):
pad_width = ((0, 0), (padding_size, padding_size), (padding_size, padding_size), (0, 0))
batch_data = np.pad(batch_data, pad_width=pad_width, mode='constant', constant_values=0)
cropped_batch = np.zeros(len(batch_data) * IMG_HEIGHT * IMG_WIDTH * IMG_DEPTH).reshape(
len(batch_data), IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH)
for i in range(len(batch_data)):
x_offset = np.random.randint(low=0, high=2 * padding_size, size=1)[0]
y_offset = np.random.randint(low=0, high=2 * padding_size, size=1)[0]
cropped_batch[i, ...] = batch_data[i, ...][x_offset:x_offset + IMG_HEIGHT,
y_offset:y_offset + IMG_WIDTH, :]
cropped_batch[i, ...] = horizontal_flip(image=cropped_batch[i, ...], axis=1)
return cropped_batch
def read_train_data():
path_list = []
path_list.append(full_data_dir)
data, label = read_in_all_images(path_list)
return data, label
def read_vali_data():
validation_array, validation_labels = read_in_all_images([vali_dir])
return validation_array, validation_labels