Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose max_q_size and other generator_queue args #2300

Merged
merged 2 commits into from
Apr 14, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def standardize_weights(y, sample_weight=None, class_weight=None,
def generator_queue(generator, max_q_size=10,
wait_time=0.05, nb_worker=1):
'''Builds a threading queue out of a data generator.
Used in `fit_generator`, `evaluate_generator`.
Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
'''
q = queue.Queue()
_stop = threading.Event()
Expand Down Expand Up @@ -1184,7 +1184,7 @@ def predict_on_batch(self, x):
def fit_generator(self, generator, samples_per_epoch, nb_epoch,
verbose=1, callbacks=[],
validation_data=None, nb_val_samples=None,
class_weight={}):
class_weight={}, max_q_size=10):
'''Fits the model on data generated batch-by-batch by
a Python generator.
The generator is run in parallel to the model, for efficiency.
Expand Down Expand Up @@ -1214,6 +1214,7 @@ def fit_generator(self, generator, samples_per_epoch, nb_epoch,
at the end of every epoch.
class_weight: dictionary mapping class indices to a weight
for the class.
max_q_size: maximum size for the generator queue

# Returns
A `History` object.
Expand Down Expand Up @@ -1287,7 +1288,7 @@ def generate_arrays_from_file(path):
self.validation_data = None

# start generator thread storing batches into a queue
data_gen_queue, _stop = generator_queue(generator)
data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size)

self.stop_training = False
while epoch < nb_epoch:
Expand Down Expand Up @@ -1358,7 +1359,8 @@ def generate_arrays_from_file(path):
if samples_seen >= samples_per_epoch and do_validation:
if val_gen:
val_outs = self.evaluate_generator(validation_data,
nb_val_samples)
nb_val_samples,
max_q_size=max_q_size)
else:
# no need for try/except because
# data has already been validated
Expand All @@ -1380,7 +1382,7 @@ def generate_arrays_from_file(path):
callbacks.on_train_end()
return self.history

def evaluate_generator(self, generator, val_samples):
def evaluate_generator(self, generator, val_samples, max_q_size=10):
'''Evaluates the model on a data generator. The generator should
return the same kind of data as accepted by `test_on_batch`.

Expand All @@ -1391,6 +1393,7 @@ def evaluate_generator(self, generator, val_samples):
val_samples:
total number of samples to generate from `generator`
before returning.
max_q_size: maximum size for the generator queue

# Returns
Scalar test loss (if the model has a single output and no metrics)
Expand All @@ -1404,7 +1407,7 @@ def evaluate_generator(self, generator, val_samples):
wait_time = 0.01
all_outs = []
weights = []
data_gen_queue, _stop = generator_queue(generator)
data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size)

while processed_samples < val_samples:
generator_output = None
Expand Down Expand Up @@ -1456,7 +1459,7 @@ def evaluate_generator(self, generator, val_samples):
weights=weights))
return averages

def predict_generator(self, generator, val_samples):
def predict_generator(self, generator, val_samples, max_q_size=10):
'''Generates predictions for the input samples from a data generator.
The generator should return the same kind of data as accepted by
`predict_on_batch`.
Expand All @@ -1465,6 +1468,7 @@ def predict_generator(self, generator, val_samples):
generator: generator yielding batches of input samples.
val_samples: total number of samples to generate from `generator`
before returning.
max_q_size: maximum size for the generator queue

# Returns
Numpy array(s) of predictions.
Expand All @@ -1474,7 +1478,7 @@ def predict_generator(self, generator, val_samples):
processed_samples = 0
wait_time = 0.01
all_outs = []
data_gen_queue, _stop = generator_queue(generator)
data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size)

while processed_samples < val_samples:
generator_output = None
Expand Down
20 changes: 12 additions & 8 deletions keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,7 @@ def predict_classes(self, x, batch_size=32, verbose=1):
def fit_generator(self, generator, samples_per_epoch, nb_epoch,
verbose=1, callbacks=[],
validation_data=None, nb_val_samples=None,
class_weight=None,
**kwargs):
class_weight=None, max_q_size=10, **kwargs):
'''Fits the model on data generated batch-by-batch by
a Python generator.
The generator is run in parallel to the model, for efficiency.
Expand Down Expand Up @@ -593,6 +592,7 @@ def fit_generator(self, generator, samples_per_epoch, nb_epoch,
at the end of every epoch.
class_weight: dictionary mapping class indices to a weight
for the class.
max_q_size: maximum size for the generator queue

# Returns
A `History` object.
Expand Down Expand Up @@ -641,10 +641,10 @@ def generate_arrays_from_file(path):
callbacks=callbacks,
validation_data=validation_data,
nb_val_samples=nb_val_samples,
class_weight=class_weight)
class_weight=class_weight,
max_q_size=max_q_size)

def evaluate_generator(self, generator, val_samples,
**kwargs):
def evaluate_generator(self, generator, val_samples, max_q_size=10, **kwargs):
'''Evaluates the model on a data generator. The generator should
return the same kind of data as accepted by `test_on_batch`.

Expand All @@ -655,6 +655,7 @@ def evaluate_generator(self, generator, val_samples,
val_samples:
total number of samples to generate from `generator`
before returning.
max_q_size: maximum size for the generator queue
'''
if self.model is None:
raise Exception('The model needs to be compiled before being used.')
Expand All @@ -672,9 +673,10 @@ def evaluate_generator(self, generator, val_samples,
raise Exception('Received unknown keyword arguments: ' +
str(kwargs))
return self.model.evaluate_generator(generator,
val_samples)
val_samples,
max_q_size=max_q_size)

def predict_generator(self, generator, val_samples):
def predict_generator(self, generator, val_samples, max_q_size=10):
'''Generates predictions for the input samples from a data generator.
The generator should return the same kind of data as accepted by
`predict_on_batch`.
Expand All @@ -683,13 +685,15 @@ def predict_generator(self, generator, val_samples):
generator: generator yielding batches of input samples.
val_samples: total number of samples to generate from `generator`
before returning.
max_q_size: maximum size for the generator queue

# Returns
A Numpy array of predictions.
'''
if self.model is None:
raise Exception('The model needs to be compiled before being used.')
return self.model.predict_generator(generator, val_samples)
return self.model.predict_generator(generator, val_samples,
max_q_size=max_q_size)

def get_config(self):
'''Returns the model configuration
Expand Down
5 changes: 3 additions & 2 deletions tests/keras/test_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def data_generator(train):
model.fit_generator(data_generator(True), len(X_train), nb_epoch, validation_data=(X_test, y_test))
model.fit_generator(data_generator(True), len(X_train), nb_epoch,
validation_data=data_generator(False), nb_val_samples=batch_size * 3)
model.fit_generator(data_generator(True), len(X_train), nb_epoch, max_q_size=2)

loss = model.evaluate(X_train, y_train)

Expand Down Expand Up @@ -100,8 +101,8 @@ def data_generator(x, y, batch_size=50):

loss = model.evaluate(X_test, y_test)

prediction = model.predict_generator(data_generator(X_test, y_test), X_test.shape[0])
gen_loss = model.evaluate_generator(data_generator(X_test, y_test, 50), X_test.shape[0])
prediction = model.predict_generator(data_generator(X_test, y_test), X_test.shape[0], max_q_size=2)
gen_loss = model.evaluate_generator(data_generator(X_test, y_test, 50), X_test.shape[0], max_q_size=2)
pred_loss = K.eval(K.mean(objectives.get(model.loss)(K.variable(y_test), K.variable(prediction))))

assert(np.isclose(pred_loss, loss))
Expand Down