Skip to content

Commit

Permalink
Add utility to simplify assignment of normal prior
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Oct 7, 2023
1 parent c056600 commit 0b2627a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
9 changes: 9 additions & 0 deletions gpax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,12 @@ def preprocess_sparse_image(sparse_image):
# Generate indices for the entire image
full_indices = onp.array(onp.meshgrid(*[onp.arange(dim) for dim in sparse_image.shape])).T.reshape(-1, sparse_image.ndim)
return gp_input, targets, full_indices


def normal_prior(param_name, loc=0, scale=1):
"""
Samples a value from a normal distribution with the specified mean (loc) and standard deviation (scale),
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions
in structured Gaussian processes.
"""
return numpyro.sample(param_name, numpyro.distributions.Normal(loc, scale))
19 changes: 18 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import numpy as onp
import jax.numpy as jnp
import jax.random as jra
import numpyro
from numpy.testing import assert_equal, assert_, assert_array_equal

sys.path.insert(0, "../gpax/")

from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys
from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys, normal_prior


def test_sparse_img_processing():
Expand Down Expand Up @@ -104,3 +105,19 @@ def test_get_keys_different_seeds():
key1a, key2a = get_keys(42)
assert_(not onp.array_equal(key1, key1a))
assert_(not onp.array_equal(key2, key2a))


def test_normal_prior():
with numpyro.handlers.seed(rng_seed=1):
sample = normal_prior("a")
assert_(isinstance(sample, jnp.ndarray))


def test_normal_prior_params():
with numpyro.handlers.seed(rng_seed=1):
with numpyro.handlers.trace() as tr:
normal_prior("a", loc=0.5, scale=0.1)
site = tr["a"]
assert_(isinstance(site['fn'], numpyro.distributions.Normal))
assert_equal(site['fn'].loc, 0.5)
assert_equal(site['fn'].scale, 0.1)

0 comments on commit 0b2627a

Please sign in to comment.