diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 53119ca5b..75b727872 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -545,6 +545,27 @@ class LKJ(TransformedDistribution): When ``concentration < 1``, the distribution favors samples with small determinent. This is useful when we know a priori that some underlying variables are correlated. + Sample code for using LKJ in the context of multivariate normal sample:: + + def model(y): # y has dimension N x d + d = y.shape[1] + N = y.shape[0] + # Vector of variances for each of the d variables + theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d))) + + concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices + corr_mat = numpyro.sample("corr_mat", dist.LKJ(d, concentration)) + sigma = jnp.sqrt(theta) + # we can also use a faster formula `cov_mat = jnp.outer(theta, theta) * corr_mat` + cov_mat = jnp.matmul(jnp.matmul(jnp.diag(sigma), corr_mat), jnp.diag(sigma)) + + # Vector of expectations + mu = jnp.zeros(d) + + with numpyro.plate("observations", N): + obs = numpyro.sample("obs", dist.MultivariateNormal(mu, covariance_matrix=cov_mat), obs=y) + return obs + :param int dimension: dimension of the matrices :param ndarray concentration: concentration/shape parameter of the distribution (often referred to as eta) @@ -606,6 +627,28 @@ class LKJCholesky(Distribution): (hence small determinent). This is useful when we know a priori that some underlying variables are correlated. + Sample code for using LKJCholesky in the context of multivariate normal sample:: + + def model(y): # y has dimension N x d + d = y.shape[1] + N = y.shape[0] + # Vector of variances for each of the d variables + theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d))) + # Lower cholesky factor of a correlation matrix + concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices + L_omega = numpyro.sample("L_omega", dist.LKJCholesky(d, concentration)) + # Lower cholesky factor of the covariance matrix + sigma = jnp.sqrt(theta) + # we can also use a faster formula `L_Omega = sigma[..., None] * L_omega` + L_Omega = jnp.matmul(jnp.diag(sigma), L_omega) + + # Vector of expectations + mu = jnp.zeros(d) + + with numpyro.plate("observations", N): + obs = numpyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y) + return obs + :param int dimension: dimension of the matrices :param ndarray concentration: concentration/shape parameter of the distribution (often referred to as eta)