Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added cycling the cifar and flowers datasets #11640

Merged
merged 1 commit into from
Jun 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions python/paddle/v2/dataset/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'


def reader_creator(filename, sub_name):
def reader_creator(filename, sub_name, cycle=False):
def read_batch(batch):
data = batch['data']
labels = batch.get('labels', batch.get('fine_labels', None))
Expand All @@ -56,10 +56,13 @@ def reader():
names = (each_item.name for each_item in f
if sub_name in each_item.name)

for name in names:
batch = cPickle.load(f.extractfile(name))
for item in read_batch(batch):
yield item
while True:
for name in names:
batch = cPickle.load(f.extractfile(name))
for item in read_batch(batch):
yield item
if not cycle:
break

return reader

Expand Down Expand Up @@ -94,34 +97,40 @@ def test100():
'test')


def train10():
def train10(cycle=False):
"""
CIFAR-10 training set creator.

It returns a reader creator, each sample in the reader is image pixels in
[0, 1] and label in [0, 9].

:param cycle: whether to cycle through the dataset
:type cycle: bool
:return: Training reader creator
:rtype: callable
"""
return reader_creator(
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'data_batch')
'data_batch',
cycle=cycle)


def test10():
def test10(cycle=False):
"""
CIFAR-10 test set creator.

It returns a reader creator, each sample in the reader is image pixels in
[0, 1] and label in [0, 9].

:param cycle: whether to cycle through the dataset
:type cycle: bool
:return: Test reader creator.
:rtype: callable
"""
return reader_creator(
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'test_batch')
'test_batch',
cycle=cycle)


def fetch():
Expand Down
50 changes: 34 additions & 16 deletions python/paddle/v2/dataset/flowers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def reader_creator(data_file,
dataset_name,
mapper,
buffered_size=1024,
use_xmap=True):
use_xmap=True,
cycle=False):
'''
1. read images from tar file and
merge images into batch files in 102flowers.tgz_batch/
Expand All @@ -96,6 +97,8 @@ def reader_creator(data_file,
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:param cycle: whether to cycle through the dataset
:type cycle: bool
:return: data reader
:rtype: callable
'''
Expand All @@ -108,15 +111,18 @@ def reader_creator(data_file,
file_list = batch_images_from_tar(data_file, dataset_name, img2label)

def reader():
for file in open(file_list):
file = file.strip()
batch = None
with open(file, 'r') as f:
batch = cPickle.load(f)
data = batch['data']
labels = batch['label']
for sample, label in itertools.izip(data, batch['label']):
yield sample, int(label) - 1
while True:
for file in open(file_list):
file = file.strip()
batch = None
with open(file, 'r') as f:
batch = cPickle.load(f)
data = batch['data']
labels = batch['label']
for sample, label in itertools.izip(data, batch['label']):
yield sample, int(label) - 1
if not cycle:
break

if use_xmap:
cpu_num = int(os.environ.get('CPU_NUM', cpu_count()))
Expand All @@ -125,7 +131,7 @@ def reader():
return map_readers(mapper, reader)


def train(mapper=train_mapper, buffered_size=1024, use_xmap=True):
def train(mapper=train_mapper, buffered_size=1024, use_xmap=True, cycle=False):
'''
Create flowers training set reader.
It returns a reader, each sample in the reader is
Expand All @@ -138,17 +144,23 @@ def train(mapper=train_mapper, buffered_size=1024, use_xmap=True):
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:param cycle: whether to cycle through the dataset
:type cycle: bool
:return: train data reader
:rtype: callable
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), TRAIN_FLAG, mapper,
buffered_size, use_xmap)
download(SETID_URL, 'flowers', SETID_MD5),
TRAIN_FLAG,
mapper,
buffered_size,
use_xmap,
cycle=cycle)


def test(mapper=test_mapper, buffered_size=1024, use_xmap=True):
def test(mapper=test_mapper, buffered_size=1024, use_xmap=True, cycle=False):
'''
Create flowers test set reader.
It returns a reader, each sample in the reader is
Expand All @@ -161,14 +173,20 @@ def test(mapper=test_mapper, buffered_size=1024, use_xmap=True):
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:param cycle: whether to cycle through the dataset
:type cycle: bool
:return: test data reader
:rtype: callable
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), TEST_FLAG, mapper,
buffered_size, use_xmap)
download(SETID_URL, 'flowers', SETID_MD5),
TEST_FLAG,
mapper,
buffered_size,
use_xmap,
cycle=cycle)


def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True):
Expand Down