-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathcreate_imagenet_benchmark_datasets.py
49 lines (46 loc) · 1.84 KB
/
create_imagenet_benchmark_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
Run the following commands in ~ before running this file
wget http://image-net.org/small/train_64x64.tar
wget http://image-net.org/small/valid_64x64.tar
tar -xvf train_64x64.tar
tar -xvf valid_64x64.tar
wget http://image-net.org/small/train_32x32.tar
wget http://image-net.org/small/valid_32x32.tar
tar -xvf train_32x32.tar
tar -xvf valid_32x32.tar
"""
import numpy as np
import scipy.ndimage
import os
from os import listdir
from os.path import isfile, join
import sys
from tqdm import tqdm
def convert_path_to_npy(*, path='~/train_64x64', outfile='~/train_64x64.npy'):
assert isinstance(path, str), "Expected a string input for the path"
assert os.path.exists(path), "Input path doesn't exist"
files = [f for f in listdir(path) if isfile(join(path, f))]
print('Number of valid images is:', len(files))
imgs = []
for i in tqdm(range(len(files))):
img = scipy.ndimage.imread(join(path, files[i]))
img = img.astype('uint8')
assert img.shape == (64, 64, 3)
assert np.max(img) <= 255
assert np.min(img) >= 0
assert img.dtype == 'uint8'
assert isinstance(img, np.ndarray)
imgs.append(img)
resolution_x, resolution_y = img.shape[0], img.shape[1]
imgs = np.asarray(imgs).astype('uint8')
assert imgs.shape[1:] == (resolution_x, resolution_y, 3)
assert np.max(imgs) <= 255
assert np.min(imgs) >= 0
print('Total number of images is:', imgs.shape[0])
print('All assertions done, dumping into npy file')
np.save(outfile, imgs)
if __name__ == '__main__':
convert_path_to_npy(path='~/train_64x64', outfile='~/train_64x64.npy')
convert_path_to_npy(path='~/valid_64x64', outfile='~/valid_64x64.npy')
convert_path_to_npy(path='~/train_32x32', outfile='~/train_32x32.npy')
convert_path_to_npy(path='~/valid_32x32', outfile='~/valid_32x32.npy')