Skip to content
This repository has been archived by the owner on May 25, 2022. It is now read-only.

Commit

Permalink
[keras-team#2287] only expose max_q_size
Browse files Browse the repository at this point in the history
  • Loading branch information
jbencook committed Apr 14, 2016
1 parent 78c11ac commit ac83587
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 46 deletions.
27 changes: 8 additions & 19 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,8 +1184,7 @@ def predict_on_batch(self, x):
def fit_generator(self, generator, samples_per_epoch, nb_epoch,
verbose=1, callbacks=[],
validation_data=None, nb_val_samples=None,
class_weight={}, max_q_size=10, wait_time=0.05,
nb_worker=1):
class_weight={}, max_q_size=10):
'''Fits the model on data generated batch-by-batch by
a Python generator.
The generator is run in parallel to the model, for efficiency.
Expand Down Expand Up @@ -1216,8 +1215,6 @@ def fit_generator(self, generator, samples_per_epoch, nb_epoch,
class_weight: dictionary mapping class indices to a weight
for the class.
max_q_size: maximum size for the generator queue
wait_time: time to sleep before retry when queue is full
nb_worker: number of threads for running generator task
# Returns
A `History` object.
Expand Down Expand Up @@ -1291,8 +1288,7 @@ def generate_arrays_from_file(path):
self.validation_data = None

# start generator thread storing batches into a queue
data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size,
wait_time=wait_time, nb_worker=nb_worker)
data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size)

self.stop_training = False
while epoch < nb_epoch:
Expand Down Expand Up @@ -1363,7 +1359,8 @@ def generate_arrays_from_file(path):
if samples_seen >= samples_per_epoch and do_validation:
if val_gen:
val_outs = self.evaluate_generator(validation_data,
nb_val_samples)
nb_val_samples,
max_q_size=max_q_size)
else:
# no need for try/except because
# data has already been validated
Expand All @@ -1385,8 +1382,7 @@ def generate_arrays_from_file(path):
callbacks.on_train_end()
return self.history

def evaluate_generator(self, generator, val_samples, max_q_size=10,
wait_time=0.05, nb_worker=1):
def evaluate_generator(self, generator, val_samples, max_q_size=10):
'''Evaluates the model on a data generator. The generator should
return the same kind of data as accepted by `test_on_batch`.
Expand All @@ -1398,8 +1394,6 @@ def evaluate_generator(self, generator, val_samples, max_q_size=10,
total number of samples to generate from `generator`
before returning.
max_q_size: maximum size for the generator queue
wait_time: time to sleep before retry when queue is full
nb_worker: number of threads for running generator task
# Returns
Scalar test loss (if the model has a single output and no metrics)
Expand All @@ -1413,8 +1407,7 @@ def evaluate_generator(self, generator, val_samples, max_q_size=10,
wait_time = 0.01
all_outs = []
weights = []
data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size,
wait_time=wait_time, nb_worker=nb_worker)
data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size)

while processed_samples < val_samples:
generator_output = None
Expand Down Expand Up @@ -1466,8 +1459,7 @@ def evaluate_generator(self, generator, val_samples, max_q_size=10,
weights=weights))
return averages

def predict_generator(self, generator, val_samples, max_q_size=10,
wait_time=0.05, nb_worker=1):
def predict_generator(self, generator, val_samples, max_q_size=10):
'''Generates predictions for the input samples from a data generator.
The generator should return the same kind of data as accepted by
`predict_on_batch`.
Expand All @@ -1477,8 +1469,6 @@ def predict_generator(self, generator, val_samples, max_q_size=10,
val_samples: total number of samples to generate from `generator`
before returning.
max_q_size: maximum size for the generator queue
wait_time: time to sleep before retry when queue is full
nb_worker: number of threads for running generator task
# Returns
Numpy array(s) of predictions.
Expand All @@ -1488,8 +1478,7 @@ def predict_generator(self, generator, val_samples, max_q_size=10,
processed_samples = 0
wait_time = 0.01
all_outs = []
data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size,
wait_time=wait_time, nb_worker=nb_worker)
data_gen_queue, _stop = generator_queue(generator, max_q_size=max_q_size)

while processed_samples < val_samples:
generator_output = None
Expand Down
27 changes: 6 additions & 21 deletions keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,7 @@ def predict_classes(self, x, batch_size=32, verbose=1):
def fit_generator(self, generator, samples_per_epoch, nb_epoch,
verbose=1, callbacks=[],
validation_data=None, nb_val_samples=None,
class_weight=None, max_q_size=10,
wait_time=0.05, nb_worker=1, **kwargs):
class_weight=None, max_q_size=10, **kwargs):
'''Fits the model on data generated batch-by-batch by
a Python generator.
The generator is run in parallel to the model, for efficiency.
Expand Down Expand Up @@ -594,8 +593,6 @@ def fit_generator(self, generator, samples_per_epoch, nb_epoch,
class_weight: dictionary mapping class indices to a weight
for the class.
max_q_size: maximum size for the generator queue
wait_time: time to sleep before retry when queue is full
nb_worker: number of threads for running generator task
# Returns
A `History` object.
Expand Down Expand Up @@ -645,12 +642,9 @@ def generate_arrays_from_file(path):
validation_data=validation_data,
nb_val_samples=nb_val_samples,
class_weight=class_weight,
max_q_size=max_q_size,
wait_time=wait_time,
nb_worker=nb_worker)
max_q_size=max_q_size)

def evaluate_generator(self, generator, val_samples, max_q_size=10,
wait_time=0.05, nb_worker=1, **kwargs):
def evaluate_generator(self, generator, val_samples, max_q_size=10, **kwargs):
'''Evaluates the model on a data generator. The generator should
return the same kind of data as accepted by `test_on_batch`.
Expand All @@ -662,8 +656,6 @@ def evaluate_generator(self, generator, val_samples, max_q_size=10,
total number of samples to generate from `generator`
before returning.
max_q_size: maximum size for the generator queue
wait_time: time to sleep before retry when queue is full
nb_worker: number of threads for running generator task
'''
if self.model is None:
raise Exception('The model needs to be compiled before being used.')
Expand All @@ -682,12 +674,9 @@ def evaluate_generator(self, generator, val_samples, max_q_size=10,
str(kwargs))
return self.model.evaluate_generator(generator,
val_samples,
max_q_size=max_q_size,
wait_time=wait_time,
nb_worker=nb_worker)
max_q_size=max_q_size)

def predict_generator(self, generator, val_samples, max_q_size=10,
wait_time=0.05, nb_worker=1):
def predict_generator(self, generator, val_samples, max_q_size=10):
'''Generates predictions for the input samples from a data generator.
The generator should return the same kind of data as accepted by
`predict_on_batch`.
Expand All @@ -697,18 +686,14 @@ def predict_generator(self, generator, val_samples, max_q_size=10,
val_samples: total number of samples to generate from `generator`
before returning.
max_q_size: maximum size for the generator queue
wait_time: time to sleep before retry when queue is full
nb_worker: number of threads for running generator task
# Returns
A Numpy array of predictions.
'''
if self.model is None:
raise Exception('The model needs to be compiled before being used.')
return self.model.predict_generator(generator, val_samples,
max_q_size=max_q_size,
wait_time=wait_time,
nb_worker=nb_worker)
max_q_size=max_q_size)

def get_config(self):
'''Returns the model configuration
Expand Down
9 changes: 3 additions & 6 deletions tests/keras/test_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def data_generator(train):
model.fit_generator(data_generator(True), len(X_train), nb_epoch, validation_data=(X_test, y_test))
model.fit_generator(data_generator(True), len(X_train), nb_epoch,
validation_data=data_generator(False), nb_val_samples=batch_size * 3)
model.fit_generator(data_generator(True), len(X_train), nb_epoch,
max_q_size=2, wait_time=0.1, nb_worker=2)
model.fit_generator(data_generator(True), len(X_train), nb_epoch, max_q_size=2)

loss = model.evaluate(X_train, y_train)

Expand Down Expand Up @@ -102,10 +101,8 @@ def data_generator(x, y, batch_size=50):

loss = model.evaluate(X_test, y_test)

prediction = model.predict_generator(data_generator(X_test, y_test), X_test.shape[0],
max_q_size=2, wait_time=0.1, nb_worker=2)
gen_loss = model.evaluate_generator(data_generator(X_test, y_test, 50), X_test.shape[0],
max_q_size=2, wait_time=0.1, nb_worker=2)
prediction = model.predict_generator(data_generator(X_test, y_test), X_test.shape[0], max_q_size=2)
gen_loss = model.evaluate_generator(data_generator(X_test, y_test, 50), X_test.shape[0], max_q_size=2)
pred_loss = K.eval(K.mean(objectives.get(model.loss)(K.variable(y_test), K.variable(prediction))))

assert(np.isclose(pred_loss, loss))
Expand Down

0 comments on commit ac83587

Please sign in to comment.