Skip to content

Commit

Permalink
Port LKJ example from Pyro (#1065)
Browse files Browse the repository at this point in the history
* port LKJ example from Pyro

* add code samples in docstring for LKJ and LKJCholesky

* delete LKJ example

* change sample name for correlation matrix from L_omega to corr_mat

* incorporate review comments
  • Loading branch information
irustandi authored Jun 29, 2021
1 parent ab18282 commit f8f482a
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,27 @@ class LKJ(TransformedDistribution):
When ``concentration < 1``, the distribution favors samples with small determinent. This is
useful when we know a priori that some underlying variables are correlated.
Sample code for using LKJ in the context of multivariate normal sample::
def model(y): # y has dimension N x d
d = y.shape[1]
N = y.shape[0]
# Vector of variances for each of the d variables
theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d)))
concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices
corr_mat = numpyro.sample("corr_mat", dist.LKJ(d, concentration))
sigma = jnp.sqrt(theta)
# we can also use a faster formula `cov_mat = jnp.outer(theta, theta) * corr_mat`
cov_mat = jnp.matmul(jnp.matmul(jnp.diag(sigma), corr_mat), jnp.diag(sigma))
# Vector of expectations
mu = jnp.zeros(d)
with numpyro.plate("observations", N):
obs = numpyro.sample("obs", dist.MultivariateNormal(mu, covariance_matrix=cov_mat), obs=y)
return obs
:param int dimension: dimension of the matrices
:param ndarray concentration: concentration/shape parameter of the
distribution (often referred to as eta)
Expand Down Expand Up @@ -606,6 +627,28 @@ class LKJCholesky(Distribution):
(hence small determinent). This is useful when we know a priori that some underlying
variables are correlated.
Sample code for using LKJCholesky in the context of multivariate normal sample::
def model(y): # y has dimension N x d
d = y.shape[1]
N = y.shape[0]
# Vector of variances for each of the d variables
theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d)))
# Lower cholesky factor of a correlation matrix
concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices
L_omega = numpyro.sample("L_omega", dist.LKJCholesky(d, concentration))
# Lower cholesky factor of the covariance matrix
sigma = jnp.sqrt(theta)
# we can also use a faster formula `L_Omega = sigma[..., None] * L_omega`
L_Omega = jnp.matmul(jnp.diag(sigma), L_omega)
# Vector of expectations
mu = jnp.zeros(d)
with numpyro.plate("observations", N):
obs = numpyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y)
return obs
:param int dimension: dimension of the matrices
:param ndarray concentration: concentration/shape parameter of the
distribution (often referred to as eta)
Expand Down

0 comments on commit f8f482a

Please sign in to comment.