This repository contains a Tensorflow implementation of the Generative Query Network (GQN) described in 'Neural Scene Representation and Rendering' by Eslami et al. (2018).
Neural Scene Representation and Rendering [PDF] [blog]
S. M. Ali Eslami, Danilo J. Rezende, Frederic Besse, Fabio Viola, Ari S. Morcos, Marta Garnelo, Avraham Ruderman, Andrei A. Rusu, Ivo Danihelka, Karol Gregor, David P. Reichert, Lars Buesing, Theophane Weber, Oriol Vinyals, Dan Rosenbaum, Neil Rabinowitz, Helen King, Chloe Hillier, Matt Botvinick, Daan Wierstra, Koray Kavukcuoglu and Demis Hassabis
If you use this repository, please cite the original publication:
@article{eslami2018neural,
title={Neural scene representation and rendering},
author={Eslami, SM Ali and Rezende, Danilo Jimenez and Besse, Frederic and Viola, Fabio and Morcos, Ari S and Garnelo, Marta and Ruderman, Avraham and Rusu, Andrei A and Danihelka, Ivo and Gregor, Karol and others},
journal={Science},
volume={360},
number={6394},
pages={1204--1210},
year={2018},
publisher={American Association for the Advancement of Science}
}
The major software requirements can be installed on an Ubuntu machine via:
$ sudo apt-get install python3-pip python3-dev virtualenv
The code requires at least Tensorflow 1.12.0. Also, in order to run the models efficiently on GPU, the latest NVIDIA drivers, CUDA and cuDNN frameworks which are compatible with Tensorflow should be installed (see version list).
All Python dependencies should live in their own virtual environment. All runtime requirements can be easily installed via the following commands:
$ virtualenv -p python3 venv
$ source venv/bin/activate
(venv) $ pip3 install -r requirements.txt
The data provider implementation is adapted from https://github.com/deepmind/gqn-datasets and uses the more up-to-date tf.data.dataset
input pipeline approach.
The training datasets can be downloaded from: https://console.cloud.google.com/storage/gqn-dataset
To download the datasets you can use the gsutil cp
command; see also the gsutil
installation instructions.
The training script can be started with the following command, assuming the GQN datasets are located in data/gqn-dataset
:
(venv) $ python3 train_gqn.py \
--data_dir data/gqn-dataset \
--dataset rooms_ring_camera \
--model_dir models/rooms_ring_camera/gqn
For more verbose information (and tensorboard summaries), you can pass the --debug
option to the script as well.
When the --debug
flag is passed to the training script, image summaries will be written to the tensorboard records.
During the training phase of the network, results from the inference network will be shown (target_inference
). These images will resemble the target images relatively quickly but are not indicative of model performance because they are computed with the posterior from the target images which are only available during training phase.
During the evaluation phase of the network, results from the generator network will be shown (target_generation
). These visual results indicate how well the GQN performs when deployed in prediction mode.
Using the tf.estimator API is the most basic form of using the GQN model.
An estimator can be set up by passing the gqn_draw_model_fn
, the model parameters and the path to the model directory with a corresponding snapshot. An example can be found in the training script.
Once the estimator is instantiated, it can be trained further (model.train()
) or used for evaluation or prediction purposes (model.eval()
or model.predict()
).
In evaluation and prediction mode, the generator is used.
We provide a convenience wrapper around the tf.estimator
with the GqnViewPredictor class.
The view predictor can be set up by pointing to a model directory containing a model config (gqn_config.json
) and a corresponding snapshot.
The predictor features APIs to add new context frames and render a query view based on the currently loaded context.
An example application of the view predictor class can be found in the view interpolation notebook.
Model snapshots for the following GQN datasets are available:
In order to use a snapshot, just download the archive and unpack it into the 'models' sub-directory of this repository (which is the default path for all scripts and notebooks to use them).
Each snapshot directory also contains *-runcmd.json
and gqn_config.json
files detailing all training settings and model hyper-parameters. You can also run tensorboard
on the models
directory to display all summaries which have been tracked during the model training runs.
Jupyter notebooks for running examples of the data loader and view predictor can be found under notebooks/ and a jupyter server can be started with:
(venv) $ cd notebooks/
(venv) $ jupyter notebook
The dataset viewer notebook illustrates the use of the gqn_input_fn and can be used to browse through the different GQN datasets.
The view interpolation notebook illustrates the use of a GqnViewPredictor and can be used to render an imagined flight through a scene as shown in DeepMind's blog post.
A few random notes about this implementation:
- We were not able to train the model with the learning rate scheme reported in the original paper (from 5*10e-4 to 5*10e-5 over 200K steps). This always resulted in a local minimum only generating light blue sky and a grey blob of background. We achieved good results by lowering all learning rates by one order of magnitude.
- Currently, our implementation does not share the convolutional cores between the inference and generation LSTMs. With shared cores we observed the KL divergence between posterior and prior collapsing to zero frequently and obtained generally inferior results (which is in line with the results reported in the paper).
- In our tests, we found eight generation steps to be a good trade-off between training stability, training speed and visual quality.
Done during our PhD research at the Oxford Robotics Institute, and the Visual Geometry Group.