Skip to content
This repository has been archived by the owner on May 25, 2022. It is now read-only.

Commit

Permalink
[keras-team#2287] expose generator_queue args
Browse files Browse the repository at this point in the history
  • Loading branch information
jbencook committed Apr 13, 2016
1 parent 66ebd2a commit e7b0d4e
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 17 deletions.
26 changes: 19 additions & 7 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def standardize_weights(y, sample_weight=None, class_weight=None,
def generator_queue(generator, max_q_size=10,
wait_time=0.05, nb_worker=1):
'''Builds a threading queue out of a data generator.
Used in `fit_generator`, `evaluate_generator`.
Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
'''
q = queue.Queue()
_stop = threading.Event()
Expand Down Expand Up @@ -1184,7 +1184,8 @@ def predict_on_batch(self, x):
def fit_generator(self, generator, samples_per_epoch, nb_epoch,
verbose=1, callbacks=[],
validation_data=None, nb_val_samples=None,
class_weight={}):
class_weight={}, max_q_size=10, wait_time=0.05,
nb_worker=1):
'''Fits the model on data generated batch-by-batch by
a Python generator.
The generator is run in parallel to the model, for efficiency.
Expand Down Expand Up @@ -1214,6 +1215,9 @@ def fit_generator(self, generator, samples_per_epoch, nb_epoch,
at the end of every epoch.
class_weight: dictionary mapping class indices to a weight
for the class.
max_q_size: maximum size for the generator queue
wait_time: time to sleep before retry when queue is full
nb_worker: number of threads for running generator task
# Returns
A `History` object.
Expand Down Expand Up @@ -1287,7 +1291,7 @@ def generate_arrays_from_file(path):
self.validation_data = None

# start generator thread storing batches into a queue
data_gen_queue, _stop = generator_queue(generator)
data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size)

self.stop_training = False
while epoch < nb_epoch:
Expand Down Expand Up @@ -1380,7 +1384,8 @@ def generate_arrays_from_file(path):
callbacks.on_train_end()
return self.history

def evaluate_generator(self, generator, val_samples):
def evaluate_generator(self, generator, val_samples, max_q_size=10,
wait_time=0.05, nb_worker=1):
'''Evaluates the model on a data generator. The generator should
return the same kind of data as accepted by `test_on_batch`.
Expand All @@ -1391,6 +1396,9 @@ def evaluate_generator(self, generator, val_samples):
val_samples:
total number of samples to generate from `generator`
before returning.
max_q_size: maximum size for the generator queue
wait_time: time to sleep before retry when queue is full
nb_worker: number of threads for running generator task
# Returns
Scalar test loss (if the model has a single output and no metrics)
Expand All @@ -1404,7 +1412,7 @@ def evaluate_generator(self, generator, val_samples):
wait_time = 0.01
all_outs = []
weights = []
data_gen_queue, _stop = generator_queue(generator)
data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size)

while processed_samples < val_samples:
generator_output = None
Expand Down Expand Up @@ -1456,7 +1464,8 @@ def evaluate_generator(self, generator, val_samples):
weights=weights))
return averages

def predict_generator(self, generator, val_samples):
def predict_generator(self, generator, val_samples, max_q_size=10,
wait_time=0.05, nb_worker=1):
'''Generates predictions for the input samples from a data generator.
The generator should return the same kind of data as accepted by
`predict_on_batch`.
Expand All @@ -1465,6 +1474,9 @@ def predict_generator(self, generator, val_samples):
generator: generator yielding batches of input samples.
val_samples: total number of samples to generate from `generator`
before returning.
max_q_size: maximum size for the generator queue
wait_time: time to sleep before retry when queue is full
nb_worker: number of threads for running generator task
# Returns
Numpy array(s) of predictions.
Expand All @@ -1474,7 +1486,7 @@ def predict_generator(self, generator, val_samples):
processed_samples = 0
wait_time = 0.01
all_outs = []
data_gen_queue, _stop = generator_queue(generator)
data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size)

while processed_samples < val_samples:
generator_output = None
Expand Down
35 changes: 27 additions & 8 deletions keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,8 @@ def predict_classes(self, x, batch_size=32, verbose=1):
def fit_generator(self, generator, samples_per_epoch, nb_epoch,
verbose=1, callbacks=[],
validation_data=None, nb_val_samples=None,
class_weight=None,
**kwargs):
class_weight=None, max_q_size=10,
wait_time=0.05, nb_worker=1, **kwargs):
'''Fits the model on data generated batch-by-batch by
a Python generator.
The generator is run in parallel to the model, for efficiency.
Expand Down Expand Up @@ -593,6 +593,9 @@ def fit_generator(self, generator, samples_per_epoch, nb_epoch,
at the end of every epoch.
class_weight: dictionary mapping class indices to a weight
for the class.
max_q_size: maximum size for the generator queue
wait_time: time to sleep before retry when queue is full
nb_worker: number of threads for running generator task
# Returns
A `History` object.
Expand Down Expand Up @@ -641,10 +644,13 @@ def generate_arrays_from_file(path):
callbacks=callbacks,
validation_data=validation_data,
nb_val_samples=nb_val_samples,
class_weight=class_weight)
class_weight=class_weight,
max_q_size=max_q_size,
wait_time=wait_time,
nb_worker=nb_worker)

def evaluate_generator(self, generator, val_samples,
**kwargs):
def evaluate_generator(self, generator, val_samples, max_q_size=10,
wait_time=0.05, nb_worker=1, **kwargs):
'''Evaluates the model on a data generator. The generator should
return the same kind of data as accepted by `test_on_batch`.
Expand All @@ -655,6 +661,9 @@ def evaluate_generator(self, generator, val_samples,
val_samples:
total number of samples to generate from `generator`
before returning.
max_q_size: maximum size for the generator queue
wait_time: time to sleep before retry when queue is full
nb_worker: number of threads for running generator task
'''
if self.model is None:
raise Exception('The model needs to be compiled before being used.')
Expand All @@ -672,9 +681,13 @@ def evaluate_generator(self, generator, val_samples,
raise Exception('Received unknown keyword arguments: ' +
str(kwargs))
return self.model.evaluate_generator(generator,
val_samples)
val_samples,
max_q_size=max_q_size,
wait_time=wait_time,
nb_worker=nb_worker)

def predict_generator(self, generator, val_samples):
def predict_generator(self, generator, val_samples, max_q_size=10,
wait_time=0.05, nb_worker=1):
'''Generates predictions for the input samples from a data generator.
The generator should return the same kind of data as accepted by
`predict_on_batch`.
Expand All @@ -683,13 +696,19 @@ def predict_generator(self, generator, val_samples):
generator: generator yielding batches of input samples.
val_samples: total number of samples to generate from `generator`
before returning.
max_q_size: maximum size for the generator queue
wait_time: time to sleep before retry when queue is full
nb_worker: number of threads for running generator task
# Returns
A Numpy array of predictions.
'''
if self.model is None:
raise Exception('The model needs to be compiled before being used.')
return self.model.predict_generator(generator, val_samples)
return self.model.predict_generator(generator, val_samples,
max_q_size=max_q_size,
wait_time=wait_time,
nb_worker=nb_worker)

def get_config(self):
'''Returns the model configuration
Expand Down
8 changes: 6 additions & 2 deletions tests/keras/test_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def data_generator(train):
model.fit_generator(data_generator(True), len(X_train), nb_epoch, validation_data=(X_test, y_test))
model.fit_generator(data_generator(True), len(X_train), nb_epoch,
validation_data=data_generator(False), nb_val_samples=batch_size * 3)
model.fit_generator(data_generator(True), len(X_train), nb_epoch,
max_q_size=2, wait_time=0.1, nb_worker=2)

loss = model.evaluate(X_train, y_train)

Expand Down Expand Up @@ -100,8 +102,10 @@ def data_generator(x, y, batch_size=50):

loss = model.evaluate(X_test, y_test)

prediction = model.predict_generator(data_generator(X_test, y_test), X_test.shape[0])
gen_loss = model.evaluate_generator(data_generator(X_test, y_test, 50), X_test.shape[0])
prediction = model.predict_generator(data_generator(X_test, y_test), X_test.shape[0],
max_q_size=2, wait_time=0.1, nb_worker=2)
gen_loss = model.evaluate_generator(data_generator(X_test, y_test, 50), X_test.shape[0],
max_q_size=2, wait_time=0.1, nb_worker=2)
pred_loss = K.eval(K.mean(objectives.get(model.loss)(K.variable(y_test), K.variable(prediction))))

assert(np.isclose(pred_loss, loss))
Expand Down

0 comments on commit e7b0d4e

Please sign in to comment.