Skip to content

Commit

Permalink
Merge pull request #1569 from ryanrussell:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 453218168
  • Loading branch information
tensorflower-gardener committed Jun 6, 2022
2 parents 6685d04 + 1166fe9 commit 5172cc7
Show file tree
Hide file tree
Showing 11 changed files with 20 additions and 19 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ example invocation (presumed to run from the root of the TFP repo:

To run the unit tests, you'll need several packages installed (again, we
strongly recommend you work in a virtualenv). We include a script to do this for
you, which also does some sanity checks on the environtment:
you, which also does some sanity checks on the environment:

```shell
./testing/install_test_dependencies.sh
Expand Down
4 changes: 2 additions & 2 deletions discussion/technical_note_on_unrolled_nuts.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ two reasons:

To accomodate these concerns our implementation makes the following
novel observations:
- We *offline enumerate* the recursion possibilties and note all read/write
- We *offline enumerate* the recursion possibilities and note all read/write
operations.
- We pre-allocate requisite memory (for what would otherwise be the recursion
stack).
Expand Down Expand Up @@ -258,7 +258,7 @@ step 1(0): x0 ==> U([x0], [1]) ==> x1 --> MH([x',x1], 1/1) --> x''
## Performance Optimization

Using a memory slot of the size 2^max_tree_depth like above is quite
convenient for both sampling and u turn check, as we have the whole history avaiable
convenient for both sampling and u turn check, as we have the whole history available
and can easily index to it. In practice, while it works well for small
problem, users could quickly ran into memory problem with large batch size (i.e.,
number of chains), large latent size (i.e., dimension of the free parameters),
Expand Down
10 changes: 5 additions & 5 deletions discussion/turnkey_inference_candidate/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ This directory contains proposals and design documents for turnkey inference.

Goal: user specifies how many MCMC samples (or effective samples) they want, and
the sampling method takes care of the rest. This includes the definition of
`target_log_prob_fn`, inital states, and choosing the optimal
(paramterization of) the `TransitionKernel`.
`target_log_prob_fn`, initial states, and choosing the optimal
(parameterization of) the `TransitionKernel`.

### An expanding window tuning for HMC/NUTS

Expand All @@ -24,14 +24,14 @@ posterior.
Currently, the TFP NUTS implementation has a speed bottleneck of waiting for the
slowest chain/batch (due to the SIMD nature), and it could seriously hinder
performance, especially when the (initial) step size is poorly chosen. Thus,
our strategy here is to run very few chains in the inital warm up (1 or 2).
our strategy here is to run very few chains in the initial warm up (1 or 2).
Moreover, by analogy to Stan's expanding memoryless windows (stage II of Stan's
automatic parameter tuning), we implmented an expanding batch, fixed step count
automatic parameter tuning), we implemented an expanding batch, fixed step count
method.

It is worth noting that, in TFP HMC step sizes are defined per dimension of the
target_log_prob_fn. To separate the tuning of the step size (a scalar) and the
mass matrix (a vector for diagnoal mass matrix), we apply an inner transform
mass matrix (a vector for diagonal mass matrix), we apply an inner transform
transition kernel (recall that the covariance matrix Σ acts as a Euclidean
metric to rotate and scale the target_log_prob_fn) using a shift and scale
bijector.
Original file line number Diff line number Diff line change
Expand Up @@ -365,13 +365,14 @@ def window_tune_nuts_sampling(target_log_prob,
(possibly unnormalized) log-density under the target distribution.
prior_samples: Nested structure of `Tensor`s, each of shape `[batches,
latent_part_event_shape]` and should be sample from the prior. They are
used to generate an inital chain position if `init_state` is not supplied.
used to generate an initial chain position if `init_state` is not
supplied.
constraining_bijectors: `tfp.distributions.Bijector` or list of
`tfp.distributions.Bijector`s. These bijectors use `forward` to map the
state on the real space to the constrained state expected by
`target_log_prob`.
init_state: (Optional) `Tensor` or Python `list` of `Tensor`s representing
the inital state(s) of the Markov chain(s).
the initial state(s) of the Markov chain(s).
num_samples: Integer number of the Markov chain draws after tuning.
nchains: Integer number of the Markov chains after tuning.
init_nchains: Integer number of the Markov chains in the first phase of
Expand All @@ -380,7 +381,7 @@ def window_tune_nuts_sampling(target_log_prob,
probability for step size adaptation.
max_tree_depth: Maximum depth of the tree implicitly built by NUTS. See
`tfp.mcmc.NoUTurnSampler` for more details
use_scaled_init: Boolean. If `True`, generate inital state within [-1, 1]
use_scaled_init: Boolean. If `True`, generate initial state within [-1, 1]
scaled by prior sample standard deviation in the unconstrained real space.
This kwarg is ignored if `init_state` is not None
tuning_window_schedule: List-like sequence of integers that specify the
Expand Down
2 changes: 1 addition & 1 deletion spinoffs/inference_gym/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Check out the [tutorial].

```bash
pip install tfp-nightly inference_gym
# Install at least one the folowing
# Install at least one the following
pip install tf-nightly # For the TensorFlow backend.
pip install jax jaxlib # For the JAX backend.
# Install to support external datasets
Expand Down
4 changes: 2 additions & 2 deletions spinoffs/inference_gym/model_contract.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ others for future refinement.
The primary use case of a model is to be able to run an inference algorithm on
it. The secondary goal is to be able to verify the accuracy of the algorithm.
There are other finer points of usability which also matter, but the overarching
princple of the contract for models is that it's better to have a model usable
principle of the contract for models is that it's better to have a model usable
for its primary use case without all the nice-to-haves, rather than not have
the model available at all.

Expand Down Expand Up @@ -82,7 +82,7 @@ argument for inclusion of the model.
example, regression models should support computing held-out negative
log-likelihood. Rationale: This is similar to having a standard
parameterization. In this case, there are certain transformations which are
natural to look at when analyizing a model.
natural to look at when analyzing a model.

4. If the model has analytic ground truth values, they should be filled in.
Rationale: Ground truth values enable one way of measuring the bias of an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7623,7 +7623,7 @@
" exposed[..., WUHAN_IDX] = wuhan_exposed\n",
" undocumented_infectious[..., WUHAN_IDX] = wuhan_undocumented_infectious\n",
"\n",
" # Following Li et al, we do not remove the inital exposed and infectious\n",
" # Following Li et al, we do not remove the initial exposed and infectious\n",
" # persons from the susceptible population.\n",
" return SEIRComponents(\n",
" susceptible=tf.constant(susceptible),\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def retrying_cholesky(
Args:
matrix: A batch of symmetric square matrices, with shape `[..., n, n]`.
jitter: Initial jitter to add to the diagnoal. Default: 1e-6, unless
jitter: Initial jitter to add to the diagonal. Default: 1e-6, unless
`matrix.dtype` is float64, in which case the default is 1e-10.
max_iters: Maximum number of times to retry the Cholesky decomposition
with larger diagonal jitter. Default: 5.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def _build_sub_tree(self,
name=None):
with tf.name_scope('build_sub_tree'):
batch_shape = ps.shape(current_step_meta_info.init_energy)
# We never want to select the inital state
# We never want to select the initial state
if MULTINOMIAL_SAMPLE:
init_weight = tf.fill(
batch_shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _left_doubling_increments(batch_shape, max_doublings, step_size, seed=None,
widths = width_multipliers * step_size

# Take the cumulative sum of the left side increments in slice width to give
# the resulting distance from the inital lower bound.
# the resulting distance from the initial lower bound.
left_increments = tf.cumsum(widths * expand_left, exclusive=True, axis=0)
return left_increments, widths

Expand Down
2 changes: 1 addition & 1 deletion tensorflow_probability/python/mcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def _build_sub_tree(self,
name=None):
with tf.name_scope('build_sub_tree'):
batch_shape = ps.shape(current_step_meta_info.init_energy)
# We never want to select the inital state
# We never want to select the initial state
if MULTINOMIAL_SAMPLE:
init_weight = tf.fill(
batch_shape,
Expand Down

0 comments on commit 5172cc7

Please sign in to comment.