From 4f5f88b9bab58e363b95ddc1931f2036d13d14e6 Mon Sep 17 00:00:00 2001 From: Ben Cook Date: Wed, 13 Apr 2016 20:00:25 -0500 Subject: [PATCH] Expose max_q_size and other generator_queue args (#2300) * [#2287] expose generator_queue args * [#2287] only expose max_q_size --- keras/engine/training.py | 20 ++++++++++++-------- keras/models.py | 20 ++++++++++++-------- tests/keras/test_sequential_model.py | 5 +++-- 3 files changed, 27 insertions(+), 18 deletions(-) diff --git a/keras/engine/training.py b/keras/engine/training.py index 4ffbcf793d5..4fff7da875d 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -385,7 +385,7 @@ def standardize_weights(y, sample_weight=None, class_weight=None, def generator_queue(generator, max_q_size=10, wait_time=0.05, nb_worker=1): '''Builds a threading queue out of a data generator. - Used in `fit_generator`, `evaluate_generator`. + Used in `fit_generator`, `evaluate_generator`, `predict_generator`. ''' q = queue.Queue() _stop = threading.Event() @@ -1203,7 +1203,7 @@ def predict_on_batch(self, x): def fit_generator(self, generator, samples_per_epoch, nb_epoch, verbose=1, callbacks=[], validation_data=None, nb_val_samples=None, - class_weight={}): + class_weight={}, max_q_size=10): '''Fits the model on data generated batch-by-batch by a Python generator. The generator is run in parallel to the model, for efficiency. @@ -1233,6 +1233,7 @@ def fit_generator(self, generator, samples_per_epoch, nb_epoch, at the end of every epoch. class_weight: dictionary mapping class indices to a weight for the class. + max_q_size: maximum size for the generator queue # Returns A `History` object. @@ -1311,7 +1312,7 @@ def generate_arrays_from_file(path): self.validation_data = None # start generator thread storing batches into a queue - data_gen_queue, _stop = generator_queue(generator) + data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size) callback_model.stop_training = False while epoch < nb_epoch: @@ -1384,7 +1385,8 @@ def generate_arrays_from_file(path): if samples_seen >= samples_per_epoch and do_validation: if val_gen: val_outs = self.evaluate_generator(validation_data, - nb_val_samples) + nb_val_samples, + max_q_size=max_q_size) else: # no need for try/except because # data has already been validated @@ -1406,7 +1408,7 @@ def generate_arrays_from_file(path): callbacks.on_train_end() return self.history - def evaluate_generator(self, generator, val_samples): + def evaluate_generator(self, generator, val_samples, max_q_size=10): '''Evaluates the model on a data generator. The generator should return the same kind of data as accepted by `test_on_batch`. @@ -1417,6 +1419,7 @@ def evaluate_generator(self, generator, val_samples): val_samples: total number of samples to generate from `generator` before returning. + max_q_size: maximum size for the generator queue # Returns Scalar test loss (if the model has a single output and no metrics) @@ -1430,7 +1433,7 @@ def evaluate_generator(self, generator, val_samples): wait_time = 0.01 all_outs = [] weights = [] - data_gen_queue, _stop = generator_queue(generator) + data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size) while processed_samples < val_samples: generator_output = None @@ -1484,7 +1487,7 @@ def evaluate_generator(self, generator, val_samples): weights=weights)) return averages - def predict_generator(self, generator, val_samples): + def predict_generator(self, generator, val_samples, max_q_size=10): '''Generates predictions for the input samples from a data generator. The generator should return the same kind of data as accepted by `predict_on_batch`. @@ -1493,6 +1496,7 @@ def predict_generator(self, generator, val_samples): generator: generator yielding batches of input samples. val_samples: total number of samples to generate from `generator` before returning. + max_q_size: maximum size for the generator queue # Returns Numpy array(s) of predictions. @@ -1502,7 +1506,7 @@ def predict_generator(self, generator, val_samples): processed_samples = 0 wait_time = 0.01 all_outs = [] - data_gen_queue, _stop = generator_queue(generator) + data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size) while processed_samples < val_samples: generator_output = None diff --git a/keras/models.py b/keras/models.py index b42de4db315..5ff5c645516 100644 --- a/keras/models.py +++ b/keras/models.py @@ -565,8 +565,7 @@ def predict_classes(self, x, batch_size=32, verbose=1): def fit_generator(self, generator, samples_per_epoch, nb_epoch, verbose=1, callbacks=[], validation_data=None, nb_val_samples=None, - class_weight=None, - **kwargs): + class_weight=None, max_q_size=10, **kwargs): '''Fits the model on data generated batch-by-batch by a Python generator. The generator is run in parallel to the model, for efficiency. @@ -596,6 +595,7 @@ def fit_generator(self, generator, samples_per_epoch, nb_epoch, at the end of every epoch. class_weight: dictionary mapping class indices to a weight for the class. + max_q_size: maximum size for the generator queue # Returns A `History` object. @@ -644,10 +644,10 @@ def generate_arrays_from_file(path): callbacks=callbacks, validation_data=validation_data, nb_val_samples=nb_val_samples, - class_weight=class_weight) + class_weight=class_weight, + max_q_size=max_q_size) - def evaluate_generator(self, generator, val_samples, - **kwargs): + def evaluate_generator(self, generator, val_samples, max_q_size=10, **kwargs): '''Evaluates the model on a data generator. The generator should return the same kind of data as accepted by `test_on_batch`. @@ -658,6 +658,7 @@ def evaluate_generator(self, generator, val_samples, val_samples: total number of samples to generate from `generator` before returning. + max_q_size: maximum size for the generator queue ''' if self.model is None: raise Exception('The model needs to be compiled before being used.') @@ -675,9 +676,10 @@ def evaluate_generator(self, generator, val_samples, raise Exception('Received unknown keyword arguments: ' + str(kwargs)) return self.model.evaluate_generator(generator, - val_samples) + val_samples, + max_q_size=max_q_size) - def predict_generator(self, generator, val_samples): + def predict_generator(self, generator, val_samples, max_q_size=10): '''Generates predictions for the input samples from a data generator. The generator should return the same kind of data as accepted by `predict_on_batch`. @@ -686,13 +688,15 @@ def predict_generator(self, generator, val_samples): generator: generator yielding batches of input samples. val_samples: total number of samples to generate from `generator` before returning. + max_q_size: maximum size for the generator queue # Returns A Numpy array of predictions. ''' if self.model is None: raise Exception('The model needs to be compiled before being used.') - return self.model.predict_generator(generator, val_samples) + return self.model.predict_generator(generator, val_samples, + max_q_size=max_q_size) def get_config(self): '''Returns the model configuration diff --git a/tests/keras/test_sequential_model.py b/tests/keras/test_sequential_model.py index 589871799e5..9116fc540a4 100644 --- a/tests/keras/test_sequential_model.py +++ b/tests/keras/test_sequential_model.py @@ -66,6 +66,7 @@ def data_generator(train): model.fit_generator(data_generator(True), len(X_train), nb_epoch, validation_data=(X_test, y_test)) model.fit_generator(data_generator(True), len(X_train), nb_epoch, validation_data=data_generator(False), nb_val_samples=batch_size * 3) + model.fit_generator(data_generator(True), len(X_train), nb_epoch, max_q_size=2) loss = model.evaluate(X_train, y_train) @@ -100,8 +101,8 @@ def data_generator(x, y, batch_size=50): loss = model.evaluate(X_test, y_test) - prediction = model.predict_generator(data_generator(X_test, y_test), X_test.shape[0]) - gen_loss = model.evaluate_generator(data_generator(X_test, y_test, 50), X_test.shape[0]) + prediction = model.predict_generator(data_generator(X_test, y_test), X_test.shape[0], max_q_size=2) + gen_loss = model.evaluate_generator(data_generator(X_test, y_test, 50), X_test.shape[0], max_q_size=2) pred_loss = K.eval(K.mean(objectives.get(model.loss)(K.variable(y_test), K.variable(prediction)))) assert(np.isclose(pred_loss, loss))