Skip to content

Commit

Permalink
Expose max_q_size and other generator_queue args (#2300)
Browse files Browse the repository at this point in the history
* [#2287] expose generator_queue args

* [#2287] only expose max_q_size
  • Loading branch information
Ben Cook authored and fchollet committed Apr 14, 2016
1 parent c1c2b33 commit 4f5f88b
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 18 deletions.
20 changes: 12 additions & 8 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def standardize_weights(y, sample_weight=None, class_weight=None,
def generator_queue(generator, max_q_size=10,
wait_time=0.05, nb_worker=1):
'''Builds a threading queue out of a data generator.
Used in `fit_generator`, `evaluate_generator`.
Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
'''
q = queue.Queue()
_stop = threading.Event()
Expand Down Expand Up @@ -1203,7 +1203,7 @@ def predict_on_batch(self, x):
def fit_generator(self, generator, samples_per_epoch, nb_epoch,
verbose=1, callbacks=[],
validation_data=None, nb_val_samples=None,
class_weight={}):
class_weight={}, max_q_size=10):
'''Fits the model on data generated batch-by-batch by
a Python generator.
The generator is run in parallel to the model, for efficiency.
Expand Down Expand Up @@ -1233,6 +1233,7 @@ def fit_generator(self, generator, samples_per_epoch, nb_epoch,
at the end of every epoch.
class_weight: dictionary mapping class indices to a weight
for the class.
max_q_size: maximum size for the generator queue
# Returns
A `History` object.
Expand Down Expand Up @@ -1311,7 +1312,7 @@ def generate_arrays_from_file(path):
self.validation_data = None

# start generator thread storing batches into a queue
data_gen_queue, _stop = generator_queue(generator)
data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size)

callback_model.stop_training = False
while epoch < nb_epoch:
Expand Down Expand Up @@ -1384,7 +1385,8 @@ def generate_arrays_from_file(path):
if samples_seen >= samples_per_epoch and do_validation:
if val_gen:
val_outs = self.evaluate_generator(validation_data,
nb_val_samples)
nb_val_samples,
max_q_size=max_q_size)
else:
# no need for try/except because
# data has already been validated
Expand All @@ -1406,7 +1408,7 @@ def generate_arrays_from_file(path):
callbacks.on_train_end()
return self.history

def evaluate_generator(self, generator, val_samples):
def evaluate_generator(self, generator, val_samples, max_q_size=10):
'''Evaluates the model on a data generator. The generator should
return the same kind of data as accepted by `test_on_batch`.
Expand All @@ -1417,6 +1419,7 @@ def evaluate_generator(self, generator, val_samples):
val_samples:
total number of samples to generate from `generator`
before returning.
max_q_size: maximum size for the generator queue
# Returns
Scalar test loss (if the model has a single output and no metrics)
Expand All @@ -1430,7 +1433,7 @@ def evaluate_generator(self, generator, val_samples):
wait_time = 0.01
all_outs = []
weights = []
data_gen_queue, _stop = generator_queue(generator)
data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size)

while processed_samples < val_samples:
generator_output = None
Expand Down Expand Up @@ -1484,7 +1487,7 @@ def evaluate_generator(self, generator, val_samples):
weights=weights))
return averages

def predict_generator(self, generator, val_samples):
def predict_generator(self, generator, val_samples, max_q_size=10):
'''Generates predictions for the input samples from a data generator.
The generator should return the same kind of data as accepted by
`predict_on_batch`.
Expand All @@ -1493,6 +1496,7 @@ def predict_generator(self, generator, val_samples):
generator: generator yielding batches of input samples.
val_samples: total number of samples to generate from `generator`
before returning.
max_q_size: maximum size for the generator queue
# Returns
Numpy array(s) of predictions.
Expand All @@ -1502,7 +1506,7 @@ def predict_generator(self, generator, val_samples):
processed_samples = 0
wait_time = 0.01
all_outs = []
data_gen_queue, _stop = generator_queue(generator)
data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size)

while processed_samples < val_samples:
generator_output = None
Expand Down
20 changes: 12 additions & 8 deletions keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,7 @@ def predict_classes(self, x, batch_size=32, verbose=1):
def fit_generator(self, generator, samples_per_epoch, nb_epoch,
verbose=1, callbacks=[],
validation_data=None, nb_val_samples=None,
class_weight=None,
**kwargs):
class_weight=None, max_q_size=10, **kwargs):
'''Fits the model on data generated batch-by-batch by
a Python generator.
The generator is run in parallel to the model, for efficiency.
Expand Down Expand Up @@ -596,6 +595,7 @@ def fit_generator(self, generator, samples_per_epoch, nb_epoch,
at the end of every epoch.
class_weight: dictionary mapping class indices to a weight
for the class.
max_q_size: maximum size for the generator queue
# Returns
A `History` object.
Expand Down Expand Up @@ -644,10 +644,10 @@ def generate_arrays_from_file(path):
callbacks=callbacks,
validation_data=validation_data,
nb_val_samples=nb_val_samples,
class_weight=class_weight)
class_weight=class_weight,
max_q_size=max_q_size)

def evaluate_generator(self, generator, val_samples,
**kwargs):
def evaluate_generator(self, generator, val_samples, max_q_size=10, **kwargs):
'''Evaluates the model on a data generator. The generator should
return the same kind of data as accepted by `test_on_batch`.
Expand All @@ -658,6 +658,7 @@ def evaluate_generator(self, generator, val_samples,
val_samples:
total number of samples to generate from `generator`
before returning.
max_q_size: maximum size for the generator queue
'''
if self.model is None:
raise Exception('The model needs to be compiled before being used.')
Expand All @@ -675,9 +676,10 @@ def evaluate_generator(self, generator, val_samples,
raise Exception('Received unknown keyword arguments: ' +
str(kwargs))
return self.model.evaluate_generator(generator,
val_samples)
val_samples,
max_q_size=max_q_size)

def predict_generator(self, generator, val_samples):
def predict_generator(self, generator, val_samples, max_q_size=10):
'''Generates predictions for the input samples from a data generator.
The generator should return the same kind of data as accepted by
`predict_on_batch`.
Expand All @@ -686,13 +688,15 @@ def predict_generator(self, generator, val_samples):
generator: generator yielding batches of input samples.
val_samples: total number of samples to generate from `generator`
before returning.
max_q_size: maximum size for the generator queue
# Returns
A Numpy array of predictions.
'''
if self.model is None:
raise Exception('The model needs to be compiled before being used.')
return self.model.predict_generator(generator, val_samples)
return self.model.predict_generator(generator, val_samples,
max_q_size=max_q_size)

def get_config(self):
'''Returns the model configuration
Expand Down
5 changes: 3 additions & 2 deletions tests/keras/test_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def data_generator(train):
model.fit_generator(data_generator(True), len(X_train), nb_epoch, validation_data=(X_test, y_test))
model.fit_generator(data_generator(True), len(X_train), nb_epoch,
validation_data=data_generator(False), nb_val_samples=batch_size * 3)
model.fit_generator(data_generator(True), len(X_train), nb_epoch, max_q_size=2)

loss = model.evaluate(X_train, y_train)

Expand Down Expand Up @@ -100,8 +101,8 @@ def data_generator(x, y, batch_size=50):

loss = model.evaluate(X_test, y_test)

prediction = model.predict_generator(data_generator(X_test, y_test), X_test.shape[0])
gen_loss = model.evaluate_generator(data_generator(X_test, y_test, 50), X_test.shape[0])
prediction = model.predict_generator(data_generator(X_test, y_test), X_test.shape[0], max_q_size=2)
gen_loss = model.evaluate_generator(data_generator(X_test, y_test, 50), X_test.shape[0], max_q_size=2)
pred_loss = K.eval(K.mean(objectives.get(model.loss)(K.variable(y_test), K.variable(prediction))))

assert(np.isclose(pred_loss, loss))
Expand Down

0 comments on commit 4f5f88b

Please sign in to comment.