forked from JackonYang/captcha-tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
input_data.py
107 lines (80 loc) · 2.84 KB
/
input_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# -*- coding:utf-8 -*-
import os
from PIL import Image
import numpy as np
import json
def load_data(data_dir, flatten=False):
train_dir = os.path.join(data_dir, 'train')
test_dir = os.path.join(data_dir, 'test')
meta_info = os.path.join(data_dir, 'meta.json')
with open(meta_info, 'r') as f:
meta = json.load(f)
return (
meta,
DataSet(
*_read_images_and_labels(train_dir, flatten=flatten, **meta)),
DataSet(
*_read_images_and_labels(test_dir, flatten=flatten, **meta)),
)
class DataSet:
"""提供 next_batch 方法"""
def __init__(self, images, labels):
self._images = images
self._labels = labels
self._num_examples = images.shape[0]
self.ptr = 0
@property
def images(self):
return self._images
@property
def labels(self):
return self._labels
def next_batch(self, size=100, shuffle=True):
if self.ptr + size > self._num_examples:
self.ptr = 0
if self.ptr == 0:
if shuffle:
perm = np.arange(self._num_examples)
np.random.shuffle(perm)
self._images = self._images[perm]
self._labels = self._labels[perm]
self.ptr += size
return (
self._images[self.ptr - size: self.ptr],
self._labels[self.ptr - size: self.ptr],
)
def _read_images_and_labels(dir_name, flatten, ext='.png', **meta):
images = []
labels = []
for fn in os.listdir(dir_name):
if fn.endswith(ext):
fd = os.path.join(dir_name, fn)
images.append(_read_image(fd, flatten=flatten, **meta))
labels.append(_read_lable(fd, **meta))
return np.array(images), np.array(labels)
def _read_image(filename, flatten, width, height, **extra_meta):
im = Image.open(filename).convert('L')
data = np.asarray(im)
if flatten:
return data.reshape(width * height)
return data
def _read_lable(filename, label_choices, **extra_meta):
basename = os.path.basename(filename)
idx = label_choices.index(basename.split('_')[0])
data = np.zeros(len(label_choices))
data[idx] = 1
return data
def display_info(meta, train_data, test_data):
print '=' * 20
for k, v in meta.items():
print '%s: %s' % (k, v)
print '=' * 20
print 'train images: %s, labels: %s' % (train_data.images.shape, train_data.labels.shape)
print 'test images: %s, labels: %s' % (test_data.images.shape, test_data.labels.shape)
batch_xs, batch_ys = train_data.next_batch(100)
print 'batch images: %s, labels: %s' % (batch_xs.shape, batch_ys.shape)
if __name__ == '__main__':
ret1 = load_data('images/char-1-groups-1000/')
display_info(*ret1)
ret2 = load_data('images/char-1-groups-1000/', flatten=True)
display_info(*ret2)