diff --git a/tensorflow_probability/examples/BUILD b/tensorflow_probability/examples/BUILD index 92225273e3..95397b4068 100644 --- a/tensorflow_probability/examples/BUILD +++ b/tensorflow_probability/examples/BUILD @@ -120,46 +120,6 @@ py_test( ], ) -py_binary( - name = "latent_dirichlet_allocation_distributions", - srcs = ["latent_dirichlet_allocation_distributions.py"], - deps = [ - ":latent_dirichlet_allocation_distributions_lib", - ], -) - -py_library( - name = "latent_dirichlet_allocation_distributions_lib", - srcs = ["latent_dirichlet_allocation_distributions.py"], - deps = [ - # absl/flags dep, - # absl/logging dep, - # numpy dep, - # scipy dep, - # six dep, - # tensorflow dep, - "//tensorflow_probability", - "//tensorflow_probability/python/distributions", - ], -) - -py_test( - name = "latent_dirichlet_allocation_distributions_test", - size = "small", - srcs = ["latent_dirichlet_allocation_distributions.py"], - args = [ - "--fake_data", - "--max_steps=5", - "--delete_existing", - "--viz_steps=5", - "--learning_rate=1e-7", - ], - main = "latent_dirichlet_allocation_distributions.py", - deps = [ - ":latent_dirichlet_allocation_distributions_lib", - ], -) - py_binary( name = "logistic_regression", srcs = ["logistic_regression.py"], @@ -207,44 +167,6 @@ py_library( ], ) -py_binary( - name = "vae", - srcs = ["vae.py"], - deps = [ - ":vae_lib", - ], -) - -py_library( - name = "vae_lib", - srcs = ["vae.py"], - deps = [ - # absl/flags dep, - # numpy dep, - # six dep, - # tensorflow dep, - "//tensorflow_probability", - "//tensorflow_probability/python/distributions", - ], -) - -py_test( - name = "vae_test", - size = "medium", - srcs = ["vae.py"], - args = [ - "--fake_data", - "--max_steps=5", - "--delete_existing", - "--viz_steps=5", - "--learning_rate=1e-7", - ], - main = "vae.py", - deps = [ - ":vae_lib", - ], -) - py_binary( name = "vq_vae", srcs = ["vq_vae.py"], diff --git a/tensorflow_probability/examples/latent_dirichlet_allocation_distributions.py b/tensorflow_probability/examples/latent_dirichlet_allocation_distributions.py deleted file mode 100644 index a081f68773..0000000000 --- a/tensorflow_probability/examples/latent_dirichlet_allocation_distributions.py +++ /dev/null @@ -1,545 +0,0 @@ -# Copyright 2018 The TensorFlow Probability Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Trains a Latent Dirichlet Allocation (LDA) model on 20 Newsgroups. - -LDA [1] is a topic model for documents represented as bag-of-words -(word counts). It attempts to find a set of topics so that every document from -the corpus is well-described by a few topics. - -Suppose that there are `V` words in the vocabulary and we want to learn `K` -topics. For each document, let `w` be its `V`-dimensional vector of word counts -and `theta` be its `K`-dimensional vector of topics. Let `Beta` be a `KxN` -matrix in which each row is a discrete distribution over words in the -corresponding topic (in other words, belong to a unit simplex). Also, let -`alpha` be the `K`-dimensional vector of prior distribution parameters -(prior topic weights). - -The model we consider here is obtained from the standard LDA by collapsing -the (non-reparameterizable) Categorical distribution over the topics -[1, Sec. 3.2; 3]. Then, the prior distribution is -`p(theta) = Dirichlet(theta | alpha)`, and the likelihood is -`p(w | theta, Beta) = OneHotCategorical(w | theta Beta)`. This means that we -sample the words from a Categorical distribution that is a weighted average -of topics, with the weights specified by `theta`. The number of samples (words) -in the document is assumed to be known, and the words are sampled independently. -We follow [2] and perform amortized variational inference similarly to -Variational Autoencoders. We use a neural network encoder to -parameterize a Dirichlet variational posterior distribution `q(theta | w)`. -Then, an evidence lower bound (ELBO) is maximized with respect to -`alpha`, `Beta` and the parameters of the variational posterior distribution. - -We use the preprocessed version of 20 newsgroups dataset from [3]. -This implementation uses the hyperparameters of [2] and reproduces the reported -results (test perplexity ~875). - -Example output for the final iteration: - -```none -elbo --567.829 - -loss -567.883 - -global_step -180000 - -reconstruction --562.065 - -topics -index=8 alpha=0.46 write article get think one like know say go make -index=21 alpha=0.29 use get thanks one write know anyone car please like -index=0 alpha=0.09 file use key program window image available information -index=43 alpha=0.08 drive use card disk system problem windows driver mac run -index=6 alpha=0.07 god one say christian jesus believe people bible think man -index=5 alpha=0.07 space year new program use research launch university nasa -index=33 alpha=0.07 government gun law people state use right weapon crime -index=36 alpha=0.05 game team play player year win season hockey league score -index=42 alpha=0.05 go say get know come one think people see tell -index=49 alpha=0.04 bike article write post get ride dod car one go - -kl -5.76408 - -perplexity -873.206 -``` - -#### References - -[1]: David M. Blei, Andrew Y. Ng, Michael I. Jordan. Latent Dirichlet - Allocation. In _Journal of Machine Learning Research_, 2003. - http://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf -[2]: Michael Figurnov, Shakir Mohamed, Andriy Mnih. Implicit Reparameterization - Gradients, 2018 - https://arxiv.org/abs/1805.08498 -[3]: Akash Srivastava, Charles Sutton. Autoencoding Variational Inference For - Topic Models. In _International Conference on Learning Representations_, - 2017. - https://arxiv.org/abs/1703.01488 -""" - -import functools -import os - -# Dependency imports -from absl import flags -from absl import logging -import numpy as np -import scipy.sparse -from six.moves import cPickle as pickle -from six.moves import urllib -import tensorflow.compat.v1 as tf1 -import tensorflow.compat.v2 as tf -import tensorflow_probability as tfp - -tfb = tfp.bijectors -tfd = tfp.distributions - - -flags.DEFINE_float( - "learning_rate", default=3e-4, help="Learning rate.") -flags.DEFINE_integer( - "max_steps", default=180000, help="Number of training steps to run.") -flags.DEFINE_integer( - "num_topics", - default=50, - help="The number of topics.") -flags.DEFINE_list( - "layer_sizes", - default=["300", "300", "300"], - help="Comma-separated list denoting hidden units per layer in the encoder.") -flags.DEFINE_string( - "activation", - default="relu", - help="Activation function for all hidden layers.") -flags.DEFINE_integer( - "batch_size", - default=32, - help="Batch size.") -flags.DEFINE_float( - "prior_initial_value", default=0.7, help="The initial value for prior.") -flags.DEFINE_integer( - "prior_burn_in_steps", - default=120000, - help="The number of training steps with fixed prior.") -flags.DEFINE_string( - "data_dir", - default=os.path.join(os.getenv("TEST_TMPDIR", "/tmp"), "lda/data"), - help="Directory where data is stored (if using real data).") -flags.DEFINE_string( - "model_dir", - default=os.path.join(os.getenv("TEST_TMPDIR", "/tmp"), "lda/"), - help="Directory to put the model's fit.") -flags.DEFINE_integer( - "viz_steps", default=10000, help="Frequency at which save visualizations.") -flags.DEFINE_bool("fake_data", default=False, help="If true, uses fake data.") -flags.DEFINE_bool( - "delete_existing", - default=False, - help="If true, deletes existing directory.") - -FLAGS = flags.FLAGS - - -def _clip_dirichlet_parameters(x): - """Clips Dirichlet param for numerically stable KL and nonzero samples.""" - return tf.clip_by_value(x, .1, 1e3) - - -def make_encoder(activation, num_topics, layer_sizes): - """Create the encoder function. - - Args: - activation: Activation function to use. - num_topics: The number of topics. - layer_sizes: The number of hidden units per layer in the encoder. - - Returns: - encoder: A `callable` mapping a bag-of-words `Tensor` to a - `tfd.Distribution` instance over topics. - """ - encoder_net = tf.keras.Sequential() - for num_hidden_units in layer_sizes: - encoder_net.add( - tf.keras.layers.Dense( - num_hidden_units, - activation=activation, - kernel_initializer=tf.initializers.GlorotNormal())) - encoder_net.add( - tf.keras.layers.Dense( - num_topics, - activation=lambda x: _clip_dirichlet_parameters(tf.nn.softplus(x)), - kernel_initializer=tf.initializers.GlorotNormal())) - - def encoder(bag_of_words): - with tf.name_scope("encoder"): - return tfd.Dirichlet(concentration=encoder_net(bag_of_words), - name="topics_posterior") - - return encoder - - -def make_decoder(num_topics, num_words): - """Create the decoder function. - - Args: - num_topics: The number of topics. - num_words: The number of words. - - Returns: - decoder: A `callable` mapping a `Tensor` of encodings to a - `tfd.Distribution` instance over words. - """ - topics_words = tfp.util.TransformedVariable( - tf.nn.softmax(tf.initializers.GlorotNormal()([num_topics, num_words])), - tfb.SoftmaxCentered(), - name="topics_words") - - def decoder(topics): - word_probs = tf.matmul(topics, topics_words) - # The observations are bag of words and therefore not one-hot. However, - # log_prob of OneHotCategorical computes the probability correctly in - # this case. - return tfd.OneHotCategorical(probs=word_probs, name="bag_of_words") - - return decoder, topics_words - - -def make_prior(num_topics, initial_value): - """Create the prior distribution. - - Args: - num_topics: Number of topics. - initial_value: The starting value for the prior parameters. - - Returns: - prior: A `callable` that returns a `tf.distribution.Distribution` - instance, the prior distribution. - """ - concentration = tfp.util.TransformedVariable( - tf.fill([1, num_topics], initial_value), - tfb.Softplus(), - name="concentration") - - return tfd.Dirichlet( - concentration=tfp.util.DeferredTensor( - concentration, _clip_dirichlet_parameters), - name="topics_prior") - - -def model_fn(features, labels, mode, params, config): - """Build the model function for use in an estimator. - - Args: - features: The input features for the estimator. - labels: The labels, unused here. - mode: Signifies whether it is train or test or predict. - params: Some hyperparameters as a dictionary. - config: The RunConfig, unused here. - Returns: - EstimatorSpec: A tf.estimator.EstimatorSpec instance. - """ - del labels, config - - encoder = make_encoder(params["activation"], - params["num_topics"], - params["layer_sizes"]) - decoder, topics_words = make_decoder(params["num_topics"], - features.shape[1]) - topics_prior = make_prior(params["num_topics"], - params["prior_initial_value"]) - - alpha = topics_prior.concentration - - topics_posterior = encoder(features) - topics = topics_posterior.sample(seed=234) - random_reconstruction = decoder(topics) - - reconstruction = random_reconstruction.log_prob(features) - tf1.summary.scalar("reconstruction", tf.reduce_mean(reconstruction)) - - # Compute the KL-divergence between two Dirichlets analytically. - # The sampled KL does not work well for "sparse" distributions - # (see Appendix D of [2]). - kl = tfd.kl_divergence(topics_posterior, topics_prior) - tf1.summary.scalar("kl", tf.reduce_mean(kl)) - - # Ensure that the KL is non-negative (up to a very small slack). - # Negative KL can happen due to numerical instability. - with tf.control_dependencies( - [tf.debugging.assert_greater(kl, -1e-3, message="kl")]): - kl = tf.identity(kl) - - elbo = reconstruction - kl - avg_elbo = tf.reduce_mean(elbo) - tf1.summary.scalar("elbo", avg_elbo) - loss = -avg_elbo - - # Perform variational inference by minimizing the -ELBO. - global_step = tf1.train.get_or_create_global_step() - optimizer = tf1.train.AdamOptimizer(params["learning_rate"]) - - # This implements the "burn-in" for prior parameters (see Appendix D of [2]). - # For the first prior_burn_in_steps steps they are fixed, and then trained - # jointly with the other parameters. - grads_and_vars = optimizer.compute_gradients(loss) - grads_and_vars_except_prior = [ - x for x in grads_and_vars if x[1] not in topics_prior.variables] - - def train_op_except_prior(): - return optimizer.apply_gradients( - grads_and_vars_except_prior, - global_step=global_step) - - def train_op_all(): - return optimizer.apply_gradients( - grads_and_vars, - global_step=global_step) - - train_op = tf.cond( - pred=global_step < params["prior_burn_in_steps"], - true_fn=train_op_except_prior, - false_fn=train_op_all) - - # The perplexity is an exponent of the average negative ELBO per word. - words_per_document = tf.reduce_sum(features, axis=1) - log_perplexity = -elbo / words_per_document - tf1.summary.scalar("perplexity", tf.exp(tf.reduce_mean(log_perplexity))) - (log_perplexity_tensor, - log_perplexity_update) = tf1.metrics.mean(log_perplexity) - perplexity_tensor = tf.exp(log_perplexity_tensor) - - # Obtain the topics summary. Implemented as a py_func for simplicity. - topics = tf1.py_func( - functools.partial(get_topics_strings, vocabulary=params["vocabulary"]), - [topics_words, alpha], - tf.string, - stateful=False) - tf1.summary.text("topics", topics) - - return tf1.estimator.EstimatorSpec( - mode=mode, - loss=loss, - train_op=train_op, - eval_metric_ops={ - "elbo": tf1.metrics.mean(elbo), - "reconstruction": tf1.metrics.mean(reconstruction), - "kl": tf1.metrics.mean(kl), - "perplexity": (perplexity_tensor, log_perplexity_update), - "topics": (topics, tf.no_op()), - }, - ) - - -def get_topics_strings(topics_words, alpha, vocabulary, - topics_to_print=10, words_per_topic=10): - """Returns the summary of the learned topics. - - Args: - topics_words: KxV tensor with topics as rows and words as columns. - alpha: 1xK tensor of prior Dirichlet concentrations for the - topics. - vocabulary: A mapping of word's integer index to the corresponding string. - topics_to_print: The number of topics with highest prior weight to - summarize. - words_per_topic: Number of wodrs per topic to return. - Returns: - summary: A np.array with strings. - """ - alpha = np.squeeze(alpha, axis=0) - # Use a stable sorting algorithm so that when alpha is fixed - # we always get the same topics. - highest_weight_topics = np.argsort(-alpha, kind="mergesort") - top_words = np.argsort(-topics_words, axis=1) - - res = [] - for topic_idx in highest_weight_topics[:topics_to_print]: - l = ["index={} alpha={:.2f}".format(topic_idx, alpha[topic_idx])] - l += [vocabulary[word] for word in top_words[topic_idx, :words_per_topic]] - res.append(" ".join(l)) - - return np.array(res) - - -ROOT_PATH = "https://github.com/akashgit/autoencoding_vi_for_topic_models/raw/9db556361409ecb3a732f99b4ef207aeb8516f83/data/20news_clean" -FILE_TEMPLATE = "{split}.txt.npy" - - -def download(directory, filename): - """Download a file.""" - filepath = os.path.join(directory, filename) - if tf.io.gfile.exists(filepath): - return filepath - if not tf.io.gfile.exists(directory): - tf.io.gfile.makedirs(directory) - url = os.path.join(ROOT_PATH, filename) - print("Downloading %s to %s" % (url, filepath)) - urllib.request.urlretrieve(url, filepath) - return filepath - - -def newsgroups_dataset(directory, split_name, num_words, shuffle_and_repeat): - """Return 20 newsgroups tf.data.Dataset.""" - data = np.load(download(directory, FILE_TEMPLATE.format(split=split_name)), - allow_pickle=True, encoding="latin1") - # The last row is empty in both train and test. - data = data[:-1] - - # Each row is a list of word ids in the document. We first convert this to - # sparse COO matrix (which automatically sums the repeating words). Then, - # we convert this COO matrix to CSR format which allows for fast querying of - # documents. - num_documents = data.shape[0] - indices = np.array([(row_idx, column_idx) - for row_idx, row in enumerate(data) - for column_idx in row]) - sparse_matrix = scipy.sparse.coo_matrix( - (np.ones(indices.shape[0]), (indices[:, 0], indices[:, 1])), - shape=(num_documents, num_words), - dtype=np.float32) - sparse_matrix = sparse_matrix.tocsr() - - dataset = tf.data.Dataset.range(num_documents) - - # For training, we shuffle each epoch and repeat the epochs. - if shuffle_and_repeat: - dataset = dataset.shuffle(num_documents).repeat() - - # Returns a single document as a dense TensorFlow tensor. The dataset is - # stored as a sparse matrix outside of the graph. - def get_row_py_func(idx): - def get_row_python(idx_py): - return np.squeeze(np.array(sparse_matrix[idx_py].todense()), axis=0) - - py_func = tf1.py_func( - get_row_python, [idx], tf.float32, stateful=False) - py_func.set_shape((num_words,)) - return py_func - - dataset = dataset.map(get_row_py_func) - return dataset - - -def build_fake_input_fns(batch_size): - """Build fake data for unit testing.""" - num_words = 1000 - vocabulary = [str(i) for i in range(num_words)] - - random_sample = np.random.randint( - 10, size=(batch_size, num_words)).astype(np.float32) - - def train_input_fn(): - dataset = tf.data.Dataset.from_tensor_slices(random_sample) - dataset = dataset.batch(batch_size) - return tf1.data.make_one_shot_iterator(dataset.repeat()).get_next() - - def eval_input_fn(): - dataset = tf.data.Dataset.from_tensor_slices(random_sample) - dataset = dataset.batch(batch_size) - return tf1.data.make_one_shot_iterator(dataset).get_next() - - return train_input_fn, eval_input_fn, vocabulary - - -def build_input_fns(data_dir, batch_size): - """Builds iterators for train and evaluation data. - - Each object is represented as a bag-of-words vector. - - Args: - data_dir: Folder in which to store the data. - batch_size: Batch size for both train and evaluation. - Returns: - train_input_fn: A function that returns an iterator over the training data. - eval_input_fn: A function that returns an iterator over the evaluation data. - vocabulary: A mapping of word's integer index to the corresponding string. - """ - - with open(download(data_dir, "vocab.pkl"), "rb") as f: - words_to_idx = pickle.load(f) - num_words = len(words_to_idx) - - vocabulary = [None] * num_words - for word, idx in words_to_idx.items(): - vocabulary[idx] = word - - # Build an iterator over training batches. - def train_input_fn(): - dataset = newsgroups_dataset( - data_dir, "train", num_words, shuffle_and_repeat=True) - # Prefetching makes training about 1.5x faster. - dataset = dataset.batch(batch_size).prefetch(32) - return tf1.data.make_one_shot_iterator(dataset).get_next() - - # Build an iterator over the heldout set. - def eval_input_fn(): - dataset = newsgroups_dataset( - data_dir, "test", num_words, shuffle_and_repeat=False) - dataset = dataset.batch(batch_size) - return tf1.data.make_one_shot_iterator(dataset).get_next() - - return train_input_fn, eval_input_fn, vocabulary - - -def main(argv): - del argv # unused - - params = FLAGS.flag_values_dict() - params["layer_sizes"] = [int(units) for units in params["layer_sizes"]] - params["activation"] = getattr(tf.nn, params["activation"]) - if FLAGS.delete_existing and tf.io.gfile.exists(FLAGS.model_dir): - logging.warn("Deleting old log directory at %s", FLAGS.model_dir) - tf.io.gfile.rmtree(FLAGS.model_dir) - tf.io.gfile.makedirs(FLAGS.model_dir) - - if FLAGS.fake_data: - train_input_fn, eval_input_fn, vocabulary = build_fake_input_fns( - FLAGS.batch_size) - else: - train_input_fn, eval_input_fn, vocabulary = build_input_fns( - FLAGS.data_dir, FLAGS.batch_size) - params["vocabulary"] = vocabulary - - estimator = tf.estimator.Estimator( - model_fn, - params=params, - config=tf.estimator.RunConfig( - model_dir=FLAGS.model_dir, - save_checkpoints_steps=FLAGS.viz_steps, - ), - ) - - tf.random.set_seed(123) - for _ in range(FLAGS.max_steps // FLAGS.viz_steps): - estimator.train(train_input_fn, steps=FLAGS.viz_steps) - eval_results = estimator.evaluate(eval_input_fn) - # Print the evaluation results. The keys are strings specified in - # eval_metric_ops, and the values are NumPy scalars/arrays. - for key, value in eval_results.items(): - print(key) - if key == "topics": - # Topics description is a np.array which prints better row-by-row. - for s in value: - print(s) - else: - print(str(value)) - print("") - print("") - - -if __name__ == "__main__": - tf1.app.run() diff --git a/tensorflow_probability/examples/vae.py b/tensorflow_probability/examples/vae.py deleted file mode 100644 index 166d82355c..0000000000 --- a/tensorflow_probability/examples/vae.py +++ /dev/null @@ -1,526 +0,0 @@ -# Copyright 2018 The TensorFlow Probability Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Trains a variational auto-encoder (VAE) on binarized MNIST. - -The VAE defines a generative model in which a latent code `Z` is sampled from a -prior `p(Z)`, then used to generate an observation `X` by way of a decoder -`p(X|Z)`. The full reconstruction follows - -```none - X ~ p(X) # A random image from some dataset. - Z ~ q(Z | X) # A random encoding of the original image ("encoder"). -Xhat ~ p(Xhat | Z) # A random reconstruction of the original image - # ("decoder"). -``` - -To fit the VAE, we assume an approximate representation of the posterior in the -form of an encoder `q(Z|X)`. We minimize the KL divergence between `q(Z|X)` and -the true posterior `p(Z|X)`: this is equivalent to maximizing the evidence lower -bound (ELBO), - -```none --log p(x) -= -log int dz p(x|z) p(z) -= -log int dz q(z|x) p(x|z) p(z) / q(z|x) -<= int dz q(z|x) (-log[ p(x|z) p(z) / q(z|x) ]) # Jensen's Inequality -=: KL[q(Z|x) || p(x|Z)p(Z)] -= -E_{Z~q(Z|x)}[log p(x|Z)] + KL[q(Z|x) || p(Z)] -``` - --or- - -```none --log p(x) -= KL[q(Z|x) || p(x|Z)p(Z)] - KL[q(Z|x) || p(Z|x)] -<= KL[q(Z|x) || p(x|Z)p(Z) # Positivity of KL -= -E_{Z~q(Z|x)}[log p(x|Z)] + KL[q(Z|x) || p(Z)] -``` - -The `-E_{Z~q(Z|x)}[log p(x|Z)]` term is an expected reconstruction loss and -`KL[q(Z|x) || p(Z)]` is a kind of distributional regularizer. See -[Kingma and Welling (2014)][1] for more details. - -This script supports both a (learned) mixture of Gaussians prior as well as a -fixed standard normal prior. You can enable the fixed standard normal prior by -setting `mixture_components` to 1. Note that fixing the parameters of the prior -(as opposed to fitting them with the rest of the model) incurs no loss in -generality when using only a single Gaussian. The reasoning for this is -two-fold: - - * On the generative side, the parameters from the prior can simply be absorbed - into the first linear layer of the generative net. If `z ~ N(mu, Sigma)` and - the first layer of the generative net is given by `x = Wz + b`, this can be - rewritten, - - s ~ N(0, I) - x = Wz + b - = W (As + mu) + b - = (WA) s + (W mu + b) - - where Sigma has been decomposed into A A^T = Sigma. In other words, the log - likelihood of the model (E_{Z~q(Z|x)}[log p(x|Z)]) is independent of whether - or not we learn mu and Sigma. - - * On the inference side, we can adjust any posterior approximation - q(z | x) ~ N(mu[q], Sigma[q]), with - - new_mu[p] := 0 - new_Sigma[p] := eye(d) - new_mu[q] := inv(chol(Sigma[p])) @ (mu[p] - mu[q]) - new_Sigma[q] := inv(Sigma[q]) @ Sigma[p] - - A bit of algebra on the KL divergence term `KL[q(Z|x) || p(Z)]` reveals that - it is also invariant to the prior parameters as long as Sigma[p] and - Sigma[q] are invertible. - -This script also supports using the analytic KL (KL[q(Z|x) || p(Z)]) with the -`analytic_kl` flag. Using the analytic KL is only supported when -`mixture_components` is set to 1 since otherwise no analytic form is known. - -Here we also compute tighter bounds, the IWAE [Burda et. al. (2015)][2]. - -These as well as image summaries can be seen in Tensorboard. For help using -Tensorboard see -https://www.tensorflow.org/guide/summaries_and_tensorboard -which can be run with - `python -m tensorboard.main --logdir=MODEL_DIR` - -#### References - -[1]: Diederik Kingma and Max Welling. Auto-Encoding Variational Bayes. In - _International Conference on Learning Representations_, 2014. - https://arxiv.org/abs/1312.6114 -[2]: Yuri Burda, Roger Grosse, Ruslan Salakhutdinov. Importance Weighted - Autoencoders. In _International Conference on Learning Representations_, - 2015. - https://arxiv.org/abs/1509.00519 -""" - -import functools -import os - -# Dependency imports -from absl import flags -import numpy as np -from six.moves import urllib -import tensorflow.compat.v1 as tf -import tensorflow_probability as tfp - -tfd = tfp.distributions - -IMAGE_SHAPE = [28, 28, 1] - -flags.DEFINE_float( - "learning_rate", default=0.001, help="Initial learning rate.") -flags.DEFINE_integer( - "max_steps", default=5001, help="Number of training steps to run.") -flags.DEFINE_integer( - "latent_size", - default=16, - help="Number of dimensions in the latent code (z).") -flags.DEFINE_integer("base_depth", default=32, help="Base depth for layers.") -flags.DEFINE_string( - "activation", - default="leaky_relu", - help="Activation function for all hidden layers.") -flags.DEFINE_integer( - "batch_size", - default=32, - help="Batch size.") -flags.DEFINE_integer( - "n_samples", default=16, help="Number of samples to use in encoding.") -flags.DEFINE_integer( - "mixture_components", - default=100, - help="Number of mixture components to use in the prior. Each component is " - "a diagonal normal distribution. The parameters of the components are " - "intialized randomly, and then learned along with the rest of the " - "parameters. If `analytic_kl` is True, `mixture_components` must be " - "set to `1`.") -flags.DEFINE_bool( - "analytic_kl", - default=False, - help="Whether or not to use the analytic version of the KL. When set to " - "False the E_{Z~q(Z|X)}[log p(Z)p(X|Z) - log q(Z|X)] form of the ELBO " - "will be used. Otherwise the -KL(q(Z|X) || p(Z)) + " - "E_{Z~q(Z|X)}[log p(X|Z)] form will be used. If analytic_kl is True, " - "then you must also specify `mixture_components=1`.") -flags.DEFINE_string( - "data_dir", - default=os.path.join(os.getenv("TEST_TMPDIR", "/tmp"), "vae/data"), - help="Directory where data is stored (if using real data).") -flags.DEFINE_string( - "model_dir", - default=os.path.join(os.getenv("TEST_TMPDIR", "/tmp"), "vae/"), - help="Directory to put the model's fit.") -flags.DEFINE_integer( - "viz_steps", default=500, help="Frequency at which to save visualizations.") -flags.DEFINE_bool( - "fake_data", - default=False, - help="If true, uses fake data instead of MNIST.") -flags.DEFINE_bool( - "delete_existing", - default=False, - help="If true, deletes existing `model_dir` directory.") - -FLAGS = flags.FLAGS - - -def _softplus_inverse(x): - """Helper which computes the function inverse of `tf.nn.softplus`.""" - return tf.math.log(tf.math.expm1(x)) - - -def make_encoder(activation, latent_size, base_depth): - """Creates the encoder function. - - Args: - activation: Activation function in hidden layers. - latent_size: The dimensionality of the encoding. - base_depth: The lowest depth for a layer. - - Returns: - encoder: A `callable` mapping a `Tensor` of images to a - `tfd.Distribution` instance over encodings. - """ - conv = functools.partial( - tf.keras.layers.Conv2D, padding="SAME", activation=activation) - - encoder_net = tf.keras.Sequential([ - conv(base_depth, 5, 1), - conv(base_depth, 5, 2), - conv(2 * base_depth, 5, 1), - conv(2 * base_depth, 5, 2), - conv(4 * latent_size, 7, padding="VALID"), - tf.keras.layers.Flatten(), - tf.keras.layers.Dense(2 * latent_size, activation=None), - ]) - - def encoder(images): - images = 2 * tf.cast(images, dtype=tf.float32) - 1 - net = encoder_net(images) - return tfd.MultivariateNormalDiag( - loc=net[..., :latent_size], - scale_diag=tf.nn.softplus(net[..., latent_size:] + - _softplus_inverse(1.0)), - name="code") - - return encoder - - -def make_decoder(activation, latent_size, output_shape, base_depth): - """Creates the decoder function. - - Args: - activation: Activation function in hidden layers. - latent_size: Dimensionality of the encoding. - output_shape: The output image shape. - base_depth: Smallest depth for a layer. - - Returns: - decoder: A `callable` mapping a `Tensor` of encodings to a - `tfd.Distribution` instance over images. - """ - deconv = functools.partial( - tf.keras.layers.Conv2DTranspose, padding="SAME", activation=activation) - conv = functools.partial( - tf.keras.layers.Conv2D, padding="SAME", activation=activation) - - decoder_net = tf.keras.Sequential([ - deconv(2 * base_depth, 7, padding="VALID"), - deconv(2 * base_depth, 5), - deconv(2 * base_depth, 5, 2), - deconv(base_depth, 5), - deconv(base_depth, 5, 2), - deconv(base_depth, 5), - conv(output_shape[-1], 5, activation=None), - ]) - - def decoder(codes): - original_shape = tf.shape(input=codes) - # Collapse the sample and batch dimension and convert to rank-4 tensor for - # use with a convolutional decoder network. - codes = tf.reshape(codes, (-1, 1, 1, latent_size)) - logits = decoder_net(codes) - logits = tf.reshape( - logits, shape=tf.concat([original_shape[:-1], output_shape], axis=0)) - return tfd.Independent(tfd.Bernoulli(logits=logits), - reinterpreted_batch_ndims=len(output_shape), - name="image") - - return decoder - - -def make_mixture_prior(latent_size, mixture_components): - """Creates the mixture of Gaussians prior distribution. - - Args: - latent_size: The dimensionality of the latent representation. - mixture_components: Number of elements of the mixture. - - Returns: - random_prior: A `tfd.Distribution` instance representing the distribution - over encodings in the absence of any evidence. - """ - if mixture_components == 1: - # See the module docstring for why we don't learn the parameters here. - return tfd.MultivariateNormalDiag( - loc=tf.zeros([latent_size]), - scale_identity_multiplier=1.0) - - loc = tf.compat.v1.get_variable( - name="loc", shape=[mixture_components, latent_size]) - raw_scale_diag = tf.compat.v1.get_variable( - name="raw_scale_diag", shape=[mixture_components, latent_size]) - mixture_logits = tf.compat.v1.get_variable( - name="mixture_logits", shape=[mixture_components]) - - return tfd.MixtureSameFamily( - components_distribution=tfd.MultivariateNormalDiag( - loc=loc, - scale_diag=tf.nn.softplus(raw_scale_diag)), - mixture_distribution=tfd.Categorical(logits=mixture_logits), - name="prior") - - -def pack_images(images, rows, cols): - """Helper utility to make a field of images.""" - shape = tf.shape(input=images) - width = shape[-3] - height = shape[-2] - depth = shape[-1] - images = tf.reshape(images, (-1, width, height, depth)) - batch = tf.shape(input=images)[0] - rows = tf.minimum(rows, batch) - cols = tf.minimum(batch // rows, cols) - images = images[:rows * cols] - images = tf.reshape(images, (rows, cols, width, height, depth)) - images = tf.transpose(a=images, perm=[0, 2, 1, 3, 4]) - images = tf.reshape(images, [1, rows * width, cols * height, depth]) - return images - - -def image_tile_summary(name, tensor, rows=8, cols=8): - tf.compat.v1.summary.image( - name, pack_images(tensor, rows, cols), max_outputs=1) - - -def model_fn(features, labels, mode, params, config): - """Builds the model function for use in an estimator. - - Args: - features: The input features for the estimator. - labels: The labels, unused here. - mode: Signifies whether it is train or test or predict. - params: Some hyperparameters as a dictionary. - config: The RunConfig, unused here. - - Returns: - EstimatorSpec: A tf.estimator.EstimatorSpec instance. - """ - del labels, config - - if params["analytic_kl"] and params["mixture_components"] != 1: - raise NotImplementedError( - "Using `analytic_kl` is only supported when `mixture_components = 1` " - "since there's no closed form otherwise.") - - encoder = make_encoder(params["activation"], - params["latent_size"], - params["base_depth"]) - decoder = make_decoder(params["activation"], - params["latent_size"], - IMAGE_SHAPE, - params["base_depth"]) - latent_prior = make_mixture_prior(params["latent_size"], - params["mixture_components"]) - - image_tile_summary( - "input", tf.cast(features, dtype=tf.float32), rows=1, cols=16) - - approx_posterior = encoder(features) - approx_posterior_sample = approx_posterior.sample(params["n_samples"]) - decoder_likelihood = decoder(approx_posterior_sample) - image_tile_summary( - "recon/sample", - tf.cast(decoder_likelihood.sample()[:3, :16], dtype=tf.float32), - rows=3, - cols=16) - image_tile_summary( - "recon/mean", - decoder_likelihood.mean()[:3, :16], - rows=3, - cols=16) - - # `distortion` is just the negative log likelihood. - distortion = -decoder_likelihood.log_prob(features) - avg_distortion = tf.reduce_mean(input_tensor=distortion) - tf.compat.v1.summary.scalar("distortion", avg_distortion) - - if params["analytic_kl"]: - rate = tfd.kl_divergence(approx_posterior, latent_prior) - else: - rate = (approx_posterior.log_prob(approx_posterior_sample) - - latent_prior.log_prob(approx_posterior_sample)) - avg_rate = tf.reduce_mean(input_tensor=rate) - tf.compat.v1.summary.scalar("rate", avg_rate) - - elbo_local = -(rate + distortion) - - elbo = tf.reduce_mean(input_tensor=elbo_local) - loss = -elbo - tf.compat.v1.summary.scalar("elbo", elbo) - - importance_weighted_elbo = tf.reduce_mean( - input_tensor=tf.reduce_logsumexp(input_tensor=elbo_local, axis=0) - - tf.math.log(tf.cast(params["n_samples"], dtype=tf.float32))) - tf.compat.v1.summary.scalar("elbo/importance_weighted", - importance_weighted_elbo) - - # Decode samples from the prior for visualization. - random_image = decoder(latent_prior.sample(16)) - image_tile_summary( - "random/sample", - tf.cast(random_image.sample(), dtype=tf.float32), - rows=4, - cols=4) - image_tile_summary("random/mean", random_image.mean(), rows=4, cols=4) - - # Perform variational inference by minimizing the -ELBO. - global_step = tf.compat.v1.train.get_or_create_global_step() - learning_rate = tf.compat.v1.train.cosine_decay( - params["learning_rate"], global_step, params["max_steps"]) - tf.compat.v1.summary.scalar("learning_rate", learning_rate) - optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate) - train_op = optimizer.minimize(loss, global_step=global_step) - - return tf.estimator.EstimatorSpec( - mode=mode, - loss=loss, - train_op=train_op, - eval_metric_ops={ - "elbo": - tf.compat.v1.metrics.mean(elbo), - "elbo/importance_weighted": - tf.compat.v1.metrics.mean(importance_weighted_elbo), - "rate": - tf.compat.v1.metrics.mean(avg_rate), - "distortion": - tf.compat.v1.metrics.mean(avg_distortion), - }, - ) - - -ROOT_PATH = "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/" -FILE_TEMPLATE = "binarized_mnist_{split}.amat" - - -def download(directory, filename): - """Downloads a file.""" - filepath = os.path.join(directory, filename) - if tf.io.gfile.exists(filepath): - return filepath - if not tf.io.gfile.exists(directory): - tf.io.gfile.makedirs(directory) - url = os.path.join(ROOT_PATH, filename) - print("Downloading %s to %s" % (url, filepath)) - urllib.request.urlretrieve(url, filepath) - return filepath - - -def static_mnist_dataset(directory, split_name): - """Returns binary static MNIST tf.data.Dataset.""" - amat_file = download(directory, FILE_TEMPLATE.format(split=split_name)) - dataset = tf.data.TextLineDataset(amat_file) - str_to_arr = lambda string: np.array([c == b"1" for c in string.split()]) - - def _parser(s): - booltensor = tf.compat.v1.py_func(str_to_arr, [s], tf.bool) - reshaped = tf.reshape(booltensor, [28, 28, 1]) - return tf.cast(reshaped, dtype=tf.float32), tf.constant(0, tf.int32) - - return dataset.map(_parser) - - -def build_fake_input_fns(batch_size): - """Builds fake MNIST-style data for unit testing.""" - random_sample = np.random.rand(batch_size, *IMAGE_SHAPE).astype("float32") - - def train_input_fn(): - dataset = tf.data.Dataset.from_tensor_slices( - random_sample).map(lambda row: (row, 0)).batch(batch_size).repeat() - return tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() - - def eval_input_fn(): - dataset = tf.data.Dataset.from_tensor_slices( - random_sample).map(lambda row: (row, 0)).batch(batch_size) - return tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() - - return train_input_fn, eval_input_fn - - -def build_input_fns(data_dir, batch_size): - """Builds an Iterator switching between train and heldout data.""" - - # Build an iterator over training batches. - def train_input_fn(): - dataset = static_mnist_dataset(data_dir, "train") - dataset = dataset.shuffle(50000).repeat().batch(batch_size) - return tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() - - # Build an iterator over the heldout set. - def eval_input_fn(): - eval_dataset = static_mnist_dataset(data_dir, "valid") - eval_dataset = eval_dataset.batch(batch_size) - return tf.compat.v1.data.make_one_shot_iterator(eval_dataset).get_next() - - return train_input_fn, eval_input_fn - - -def main(argv): - del argv # unused - - params = FLAGS.flag_values_dict() - params["activation"] = getattr(tf.nn, params["activation"]) - if FLAGS.delete_existing and tf.io.gfile.exists(FLAGS.model_dir): - tf.compat.v1.logging.warn("Deleting old log directory at {}".format( - FLAGS.model_dir)) - tf.io.gfile.rmtree(FLAGS.model_dir) - tf.io.gfile.makedirs(FLAGS.model_dir) - - if FLAGS.fake_data: - train_input_fn, eval_input_fn = build_fake_input_fns(FLAGS.batch_size) - else: - train_input_fn, eval_input_fn = build_input_fns(FLAGS.data_dir, - FLAGS.batch_size) - - estimator = tf.estimator.Estimator( - model_fn, - params=params, - config=tf.estimator.RunConfig( - model_dir=FLAGS.model_dir, - save_checkpoints_steps=FLAGS.viz_steps, - ), - ) - - for _ in range(FLAGS.max_steps // FLAGS.viz_steps): - estimator.train(train_input_fn, steps=FLAGS.viz_steps) - eval_results = estimator.evaluate(eval_input_fn) - print("Evaluation_results:\n\t%s\n" % eval_results) - - -if __name__ == "__main__": - tf.compat.v1.app.run()