Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the ordering bugs when using pickle_safe=True #6891

Merged
merged 29 commits into from
Jun 20, 2017

Conversation

Dref360
Copy link
Contributor

@Dref360 Dref360 commented Jun 7, 2017

This PR is to fix the multiples problems with GeneratorEnqueuer. When pickle_safe=True, the order is not preserved which can be annoying for predict_generator. The structure is like Pytorch's Dataset http://pytorch.org/docs/data.html.

This PR guarantee that the order will be preserved at no cost.

While GeneratorEnqueuer is still supported, it should be deprecated in favour of this new feature.

I would really appreciate your thoughts on this.

Work to do :
[ ] - Default Dataset for folder, hdf5
[ ] - Validate Windows behaviour


import multiprocessing

"""Get the uid for the default graph.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will clean that up

is_dataset = isinstance(generator, Dataset)
if not is_dataset:
warnings.warn(
"""Using a generator for `generator` is now deprecated.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confusing error message, but I don't know want to introduce breaking changes.

@Dref360
Copy link
Contributor Author

Dref360 commented Jun 7, 2017

Nevermind, this only works for python 3. Otherwise we need to add a new dependency.

Should I just forget this idea? Otherwise, we need the futures package for the python 2 version.

@fchollet
Copy link
Collaborator

fchollet commented Jun 7, 2017

It would be good to:

  • Preserve order when using predict_generator with multiple processes, at no performance cost
  • Offer a way to easily add enqueuing capabilities to any source of data. Currently we kind of have that by requiring people to write generators. But the experience of writing these could be improved.

This should be achieved in a fully backwards compatible way. Do you see a solution?

@Dref360
Copy link
Contributor Author

Dref360 commented Jun 7, 2017

Using a manager or pipes, processes could communicate which input (or index) they have processed. This is rather ugly, but this would work.

This could cause memory issues, but if you're using predict_generator, I guess you're not outputting images. ( In this case, switching predict_generator to output a generator would help people in the segmentation domain like myself. Quick suggestion) This is why I recommend to add futures to the dependencies for Py2 since it's only a
backport from the real module in Py3.

Adding helper functions to handle many types of data would have to be done. Most of them would be easy to do (Numpy, Directory, HDF5), but handling generators would be really difficult. Of course, the community would have to adapt.

@Dref360
Copy link
Contributor Author

Dref360 commented Jun 8, 2017

So Theano test are timing out so my thing works, it just adds a dependency.

@fchollet
Copy link
Collaborator

fchollet commented Jun 8, 2017

Can you list the user-facing changes? Provide some examples of the user experience? (e.g. what code would the user be writing, what problems would the user potentially be facing).

@ahundt
Copy link
Contributor

ahundt commented Jun 8, 2017

@PavlosMelissinos may want to comment on this since he is working on a dataset API as well & trying to consider segmentation https://github.com/PavlosMelissinos/enet-keras/blob/master/src/data/datasets.py#L14

@Dref360
Copy link
Contributor Author

Dref360 commented Jun 9, 2017

@fchollet I'll be working on a blog post about the changes and how to update generators to Dataset. Should be really easy to do the transition. Previously, people were doing

for x,y in X_train,Y_train:
    yield resize(imread(x),(200,200)),y

now it would be

class CIFAR10Dataset(Dataset):
    def __init__(self,x_set,y_set):
        self.X,self.y = x_set,y_set
    def __len__(self):
        return len(self.X)
    def __getitem__(self,idx):
        return resize(imread(self.X[idx]),(200,200)),self.y[idx]

Where X_train would be a list of filename containing images for example and Y_train their classes.
So a little more of boilerplate code, but most people were already using objects to encapsulate their generator.

Problems that I can think of:

  • ImageDataGenerator will need to be remade in a future PR
  • People who were doing funky stuff with generators like making then process-safe won't need that anymore. May be disturbing.
  • Non-OO people will not like this change.

For real, there should not be any major problem with Dataset, and generators will still be around for a while.

@fchollet
Copy link
Collaborator

fchollet commented Jun 9, 2017

As long as passing generators as usual keeps working, and as long as all changes to the built-in generators (ImageDataGenerator and the like) are backwards compatible, then the addition of Dataset seems fine to me.

The dependence on futures is problematic though. This Dataset API would have to move to the Keras version in TF, and TF does not have that dependency. Can you make it work without it?

@fchollet
Copy link
Collaborator

fchollet commented Jun 9, 2017

Orthogonally, we seriously need a way to make Theano tests less slow. They take 5x longer than other backends (we now have several other backends!), even though there is generally less tests being run than on TF. Maybe changing compilation mode would do it?

@Dref360
Copy link
Contributor Author

Dref360 commented Jun 9, 2017

I'm really not experienced with Theano sorry. Is there a Theano master in the community?

For the deps on futures, it could be an optional dependency. So the default behaviour would still be GeneratorEnqueuer, but if you give a Dataset and you're using Python 2, you need to install future. The docs would need to reflect this thing as well. Just like h5py, pydot-ng is not a required deps.

@Dref360
Copy link
Contributor Author

Dref360 commented Jun 9, 2017

We could disable a lot of tests for Theano, Since models, topology,engine are already tested in TF, we can fairly assume that if the backend is well tested, there should not be any problem with Theano. This way, we only need to test the backend in Theano.

indexes = range(len(self.dataset))
if self.scheduling is not 'sequential':
random.shuffle(indexes)
indexes = itertools.cycle(indexes)
Copy link

@jonilaserson jonilaserson Jun 9, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it make more sense to reshuffle the index order after each epoch? (rather than cycling through the same shuffled order).

@jonilaserson
Copy link

jonilaserson commented Jun 9, 2017

@Dref360 I'm excited about the Datasets addition!

I think generators should not be deprecated, because Datasets introduce additional constraints that generators don't have:

  1. It seems that they have to have a finite number of samples (len must be implemented).
  2. The index ordering is determined from the start (must be either sequential or random).

For example, you can imagine having a predictor listening to HTTP requests and applying the model on whatever indexes or queries are being sent to it. This logic could be implemented by generators but not a Dataset. What do you think?

@Dref360
Copy link
Contributor Author

Dref360 commented Jun 9, 2017

@jonilaserson You're right, Generators won't go away anytime soon. (Will remove the warnings and change the docs). Mostly because of the added dep. If @fchollet approve the optional dependency, I'll remove the warning and make GeneratorEnqueuer the default.

(I'll try to make it work with a Pool from multiprocessing this would remove the dep)
EDIT: Dep on future has been removed!

@fchollet
Copy link
Collaborator

fchollet commented Jun 9, 2017

Just like h5py, pydot-ng is not a required deps.

I'm not a fan of this pattern in general, it has caused a bunch of problems and a lot of user frustration. Can we avoid it? Can Python 2.7 really not support something like this?

@Dref360
Copy link
Contributor Author

Dref360 commented Jun 9, 2017

@fchollet I fixed it. We do not need the future package anymore.
It's now using multiprocessing.Pool which return AsyncResult instead of future.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should put them in a new directory. Maybe in data_utils?

@@ -0,0 +1,11 @@
# TODO add some default like hdf5 dataset, directory, etc
class Dataset():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should inherit from object

# TODO add some default like hdf5 dataset, directory, etc
class Dataset():
"""
Base object for every dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Standardize docstring format and add a more complete description of what datasets are

return ds[i]


class DatasetHandler():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should inherit from object



class DatasetHandler():
"""Base class to enqueue datasets."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a description of how the whole setup works

@@ -232,7 +235,7 @@ def _check_array_lengths(inputs, targets, weights):
raise ValueError('Input arrays should have '
'the same number of samples as target arrays. '
'Found ' + str(list(set_x)[0]) + ' input samples '
'and ' + str(list(set_y)[0]) + ' target samples.')
'and ' + str(list(set_y)[0]) + ' target samples.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

@@ -453,6 +456,7 @@ def weighted(y_true, y_pred, weights, mask=None):
score_array *= weights
score_array /= K.mean(K.cast(K.not_equal(weights, 0), K.floatx()))
return K.mean(score_array)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?


"""
Purpose of this file is to have multiples workers working on the same generator to speed up training/testing
Thanks to http://anandology.com for the base code. Handle Python 2 and 3.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put these as private functions in the same file as the classes that use them

setup.py Outdated
@@ -1,7 +1,6 @@
from setuptools import setup
from setuptools import find_packages


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor

@ahundt ahundt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read through everything and I think a dataset API will make for a very valuable addition! As @fchollet mentioned, there isn't yet a complete description in the docstrings, so please forgive me if I misunderstood something.

Since this is a substantial API expansion I have a couple of questions, alternative design suggestions, and additions that may be worth considering:

  1. What if Datasets were layers and a bit more similar to the functional API?
    • This could enable backends to supply the data in addition to pure python, and preprocessing layers could become possible.
    • Regarding performance: This pull request appears to have similiarities to tf.FIFOQueue and tf.train.queue_runner, which underutilize GPUs according to TensorFlow's detailed guide for high performance models. I expect the design in this pull request may have the same problem for similar reasons, and a layer/backend based design might solve it.
  2. Is this a design that can be sensibly extended to supply more than one GPU device, or would that require a redesign/rewrite?
  3. Does this design account for a mechanism to insert preprocessing steps?
  4. Can preprocessing be added to both data and labels?
  5. The new tensorflow dataset API may also have some good design elements and discuss gotchas worth considering.
  6. Could the following function, or an adaptation of it which is a subclass of the Dataset class be added?
def generate_samples_from_filesystem(sample_sets, callbacks=load_image,  batch_size=1, data_dirs=None):
    """Generate numpy arrays from files on disk in groups, such as single images or pairs of images.
    # Arguments
    sample_sets: A list of lists, each containing the data's filenames such as [['img1.jpg', 'img2.jpg'], ['label1.png', 'label2.png']].
        Also supports a list of txt files, each containing the list of filenames in each set such as ['images.txt', 'labels.txt'].
        If None, all images in the folders specified in data_dirs are loaded in lexicographic order.
    callbacks: One callback that loads data from the specified file path into a numpy array, `load_image` by default. 
       Either a single callback should be specified or a callback must be provided for each sample set, and must be the same length as sample_sets. 
    data_dirs: Directory or list of directories to load. 
        Default None means each entry in sample_sets contains the full path to each file.
        Specifying a directory means filenames sample_sets can be found in that directory.
        Specifying a list of directories means each sample set is in that separate directory, and must be the same length as sample_sets.
    batch_size: number of samples in a batch 
    # Returns
      Yields batch_size data points in each list provided.
    """

The list of lists permitted for each parameter above is so multiple inputs & label types can be processed cleanly.

Thanks again for working on such a great addition!


def __getitem__(self, index):
raise NotImplementedError

Copy link
Contributor

@ahundt ahundt Jun 9, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding this function or something similar:

    def download(self, data_dir=None, subset='all'):
        '''Downloads a dataset and extracts the files

        # Arguments

            data_dir: Location to save the dataset.
                Downloads to `~/.keras/datasets/dataset_name` by default.
            dataset: The name of the subset to download, downloads all by default.
                Most useful for large datasets with useful subsets.

        # Returns

           list of paths to the downloaded files

        '''
        raise NotImplementedError

subset='all' might instead be subset=None, but this parameter is worth considering so large dataset downloads can be supported easily without inadvertantly filling people's HDD.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is highly dataset-specific. I don't want to be forced to implement something that I may not need. If my dataset is already on disk, why should I implement something like this? This can still be done by the user. The same way that actual generators are used, there is no 'mandatory' thing to do.

raise NotImplemented

@abstractmethod
def get(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this design allow a tensor, op, or an equivalent on another backend to provide the data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a limitation of Keras itself since fit_generator is calling sess.run with a feed_dict.
I don't think we want to go this deep :P Or maybe for a future PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Btw feed_dict no longer has a memory copy overhead. It's fast these days. Modern times!

@Dref360
Copy link
Contributor Author

Dref360 commented Jun 10, 2017

Just a copy of Slack for diligent book keeping.
The most challenging point of your comment is the first one. You could chain Dataset I guess.
like

ds = FileLoaderDataset(filenames) #generic file loading
ds = RandomTranslationDataset(ds)
ds = RandomShearDataset(ds)

In their __getitem__ they would have to do something like

def RandomTranslationDataset(Dataset):
    def __init__(ds):
        self.ds = ds
    def __getitem__(self,idx):
        return self.random_translation(self.ds[idx])

In my blog post (dref360.github.io), I found that the thing that sucks about tf.FIFOQueue is that is cannot run on different processes. This can.

In think this design is flexible enough to do pretty much whatever you want. We only want the input for frame idx, what you do to give the inputs is pretty much what you want.

This PR is pretty much WIP, I wanted some feedbacks, So the design is still pretty flexible

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍

# and `y_set` are the associated classes.

class CIFAR10Sequence(Sequence):
class CIFAR10Sequence(Sequence):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant

Every `Sequence` must implements the `__getitem__` and the `__len__` methods.

# Examples
```python
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add line break above



class Sequence(object):
""" Base object for fitting to a sequence of data, such as a dataset.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove leading space

is_sequence = isinstance(generator, Sequence)
if not is_sequence and pickle_safe:
warnings.warn(
"""Using a generator with `pickle_safe=True` may duplicate your data.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use ' as the string delimiter here (which would be consistent with other warning/error messages elsewhere)

@@ -2060,6 +1980,7 @@ def predict_generator(self, generator, steps,

# Arguments
generator: Generator yielding batches of input samples.
Sequence object to avoid duplicate data.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rephrase a bit to make it more user friendly, e.g., or instance of Sequence(fromkeras.utils.Sequence), which you can use in order to...

"""Create a generator to extract data from the queue. Skip the data if it's None.

#Returns
A generator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

with pytest.warns(Warning) as w:
out = model.fit_generator(gen_data(4), steps_per_epoch=10, pickle_safe=True, workers=2)
assert any(['Sequence' in str(w_.message) for w_ in w]), \
"No warning raised when using generator with processes."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use ' everywhere for consistency. Do not break lines with \

import six

from keras.utils.data_utils import Sequence
from keras.utils.data_utils import GeneratorEnqueuer
from keras.utils.data_utils import OrderedEnqueuer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add these 3 classes to utils/__init__.py so they can be imported from utils by users (internally it doesn't matter)


@abstractmethod
def __getitem__(self, index):
"""Get batch at position `index`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gets


return np.array([resize(imread(file_name), (200,200))
for file_name in batch_x]),
np.array(batch_y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty hard to read, put the resize on a separate line, not on the return line

@Dref360
Copy link
Contributor Author

Dref360 commented Jun 19, 2017

I renamed pickle_safe and max_q_size. They are now in legacy.

@Dref360
Copy link
Contributor Author

Dref360 commented Jun 20, 2017

All good? Can be merged?

Copy link
Contributor

@ahundt ahundt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@fchollet
Copy link
Collaborator

LGTM

@fchollet fchollet merged commit ab6b82c into keras-team:master Jun 20, 2017
@gokceneraslan
Copy link
Contributor

Thanks for the new API. One minor comment, start() method of *Enqueuer classes shouldn't really start if it's already started, i.e. if self.is_running(): return.

@fchollet
Copy link
Collaborator

Following this PR, one test has started failing: the commented-out lines here. Please take a look.

@Dref360
Copy link
Contributor Author

Dref360 commented Jun 20, 2017

Yeah @ahundt fixed it, but seems like my PR overrided it.
This test doesn't make sense. Since generators are copied, there is no way you would get it to pass with only good_batches + 1 steps. It will trigger for sure for n_workers * good_batches + 1

@gokceneraslan
Copy link
Contributor

gokceneraslan commented Jun 20, 2017

@Dref360, more comments:

  • OrderedEnqueuer::stop() should call self.executor.join() after self.executor.close(), otherwise worker processes turn into zombies.

  • thread.terminate() call in GeneratorEnqueuer::stop() might corrupt the queue, why not thread.join()?

@ahundt
Copy link
Contributor

ahundt commented Jun 20, 2017

@gokceneraslan This is already merged, for the first comment could you add a PR? For the second comment see #6891 (comment)

@joeyearsley
Copy link
Contributor

joeyearsley commented Jun 24, 2017

If I have another fix for generators which allows ordering to be kept in the file reader, should I make a PR?

I believe people should move to the sequence class anyway, but that will be more lines written than my solution I've been using locally (~8 lines of code in the Iterator class).

@Dref360
Copy link
Contributor Author

Dref360 commented Jun 24, 2017

Please submit a PR! I'll review it asap

@joeyearsley
Copy link
Contributor

Provided another solution in #7118 , this sets up shared variables if multiprocessing is used, that allows the multiprocessing lock to actually take effect. Which should save people converting current generators to sequences.

However, it would be nice going forward to see sequences phased in.

@ahundt
Copy link
Contributor

ahundt commented Jun 26, 2017

If this actually solves the problem, are there still benefits provided by sequences?

@joeyearsley
Copy link
Contributor

It will do, however I've ran into a heisenbug with multiprocessing now.

The following stochastically happens:
Start evaluation_generator for validation, all processes go to sleep and never hit the next call (for some reason) then because they never return from next the Q is always empty. Hence the entire process is comatose.

By never hit the next call I've enabled logging and it never gets called when the process enters this weird state.

if self._use_multiprocessing:
# Reset random seed else all children processes
# share the same seed
np.random.seed(self.random_seed)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You seed np.random but you shuffle with the Python random module. Is this a bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seeding is for the children processes. The shuffling is for the order of the batches. For now, there is no way to control the shuffling seed. Could be added.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants