Skip to content

Commit

Permalink
Merge pull request #7002 from qingqing01/imdb_data
Browse files Browse the repository at this point in the history
 Speed data reader for IMDB dataset.
  • Loading branch information
qingqing01 authored Dec 26, 2017
2 parents f839154 + eb8edeb commit c3fd2c2
Showing 1 changed file with 13 additions and 40 deletions.
53 changes: 13 additions & 40 deletions python/paddle/v2/dataset/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@
import paddle.v2.dataset.common
import collections
import tarfile
import Queue
import re
import string
import threading
import random

__all__ = ['build_dict', 'train', 'test', 'convert']

Expand Down Expand Up @@ -74,47 +73,21 @@ def build_dict(pattern, cutoff):
return word_idx


def reader_creator(pos_pattern, neg_pattern, word_idx, buffer_size):
def reader_creator(pos_pattern, neg_pattern, word_idx):
UNK = word_idx['<unk>']
INS = []

qs = [Queue.Queue(maxsize=buffer_size), Queue.Queue(maxsize=buffer_size)]

def load(pattern, queue):
def load(pattern, out, label):
for doc in tokenize(pattern):
queue.put(doc)
queue.put(None)
out.append(([word_idx.get(w, UNK) for w in doc], label))

load(pos_pattern, INS, 0)
load(neg_pattern, INS, 1)
random.shuffle(INS)

def reader():
# Creates two threads that loads positive and negative samples
# into qs.
t0 = threading.Thread(
target=load, args=(
pos_pattern,
qs[0], ))
t0.daemon = True
t0.start()

t1 = threading.Thread(
target=load, args=(
neg_pattern,
qs[1], ))
t1.daemon = True
t1.start()

# Read alternatively from qs[0] and qs[1].
i = 0
doc = qs[i].get()
while doc != None:
yield [word_idx.get(w, UNK) for w in doc], i % 2
i += 1
doc = qs[i % 2].get()

# If any queue is empty, reads from the other queue.
i += 1
doc = qs[i % 2].get()
while doc != None:
yield [word_idx.get(w, UNK) for w in doc], i % 2
doc = qs[i % 2].get()
for doc, label in INS:
yield doc, label

return reader

Expand All @@ -133,7 +106,7 @@ def train(word_idx):
"""
return reader_creator(
re.compile("aclImdb/train/pos/.*\.txt$"),
re.compile("aclImdb/train/neg/.*\.txt$"), word_idx, 1000)
re.compile("aclImdb/train/neg/.*\.txt$"), word_idx)


def test(word_idx):
Expand All @@ -150,7 +123,7 @@ def test(word_idx):
"""
return reader_creator(
re.compile("aclImdb/test/pos/.*\.txt$"),
re.compile("aclImdb/test/neg/.*\.txt$"), word_idx, 1000)
re.compile("aclImdb/test/neg/.*\.txt$"), word_idx)


def word_dict():
Expand Down

0 comments on commit c3fd2c2

Please sign in to comment.