-
Notifications
You must be signed in to change notification settings - Fork 223
/
create_mnistm.py
84 lines (65 loc) · 2.17 KB
/
create_mnistm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tarfile
import os
import pickle as pkl
import numpy as np
import skimage
import skimage.io
import skimage.transform
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')
BST_PATH = 'BSR_bsds500.tgz'
rand = np.random.RandomState(42)
f = tarfile.open(BST_PATH)
train_files = []
for name in f.getnames():
if name.startswith('BSR/BSDS500/data/images/train/'):
train_files.append(name)
print('Loading BSR training images')
background_data = []
for name in train_files:
try:
fp = f.extractfile(name)
bg_img = skimage.io.imread(fp)
background_data.append(bg_img)
except:
continue
def compose_image(digit, background):
"""Difference-blend a digit and a random patch from a background image."""
w, h, _ = background.shape
dw, dh, _ = digit.shape
x = np.random.randint(0, w - dw)
y = np.random.randint(0, h - dh)
bg = background[x:x+dw, y:y+dh]
return np.abs(bg - digit).astype(np.uint8)
def mnist_to_img(x):
"""Binarize MNIST digit and convert to RGB."""
x = (x > 0).astype(np.float32)
d = x.reshape([28, 28, 1]) * 255
return np.concatenate([d, d, d], 2)
def create_mnistm(X):
"""
Give an array of MNIST digits, blend random background patches to
build the MNIST-M dataset as described in
http://jmlr.org/papers/volume17/15-239/15-239.pdf
"""
X_ = np.zeros([X.shape[0], 28, 28, 3], np.uint8)
for i in range(X.shape[0]):
if i % 1000 == 0:
print('Processing example', i)
bg_img = rand.choice(background_data)
d = mnist_to_img(X[i])
d = compose_image(d, bg_img)
X_[i] = d
return X_
print('Building train set...')
train = create_mnistm(mnist.train.images)
print('Building test set...')
test = create_mnistm(mnist.test.images)
print('Building validation set...')
valid = create_mnistm(mnist.validation.images)
# Save dataset as pickle
with open('mnistm_data.pkl', 'wb') as f:
pkl.dump({ 'train': train, 'test': test, 'valid': valid }, f, pkl.HIGHEST_PROTOCOL)