diff --git a/gpax/utils.py b/gpax/utils.py index d94b742..f031508 100644 --- a/gpax/utils.py +++ b/gpax/utils.py @@ -165,3 +165,12 @@ def preprocess_sparse_image(sparse_image): # Generate indices for the entire image full_indices = onp.array(onp.meshgrid(*[onp.arange(dim) for dim in sparse_image.shape])).T.reshape(-1, sparse_image.ndim) return gp_input, targets, full_indices + + +def normal_prior(param_name, loc=0, scale=1): + """ + Samples a value from a normal distribution with the specified mean (loc) and standard deviation (scale), + and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions + in structured Gaussian processes. + """ + return numpyro.sample(param_name, numpyro.distributions.Normal(loc, scale)) diff --git a/tests/test_utils.py b/tests/test_utils.py index 306a582..8609bc9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,11 +2,12 @@ import numpy as onp import jax.numpy as jnp import jax.random as jra +import numpyro from numpy.testing import assert_equal, assert_, assert_array_equal sys.path.insert(0, "../gpax/") -from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys +from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys, normal_prior def test_sparse_img_processing(): @@ -104,3 +105,19 @@ def test_get_keys_different_seeds(): key1a, key2a = get_keys(42) assert_(not onp.array_equal(key1, key1a)) assert_(not onp.array_equal(key2, key2a)) + + +def test_normal_prior(): + with numpyro.handlers.seed(rng_seed=1): + sample = normal_prior("a") + assert_(isinstance(sample, jnp.ndarray)) + + +def test_normal_prior_params(): + with numpyro.handlers.seed(rng_seed=1): + with numpyro.handlers.trace() as tr: + normal_prior("a", loc=0.5, scale=0.1) + site = tr["a"] + assert_(isinstance(site['fn'], numpyro.distributions.Normal)) + assert_equal(site['fn'].loc, 0.5) + assert_equal(site['fn'].scale, 0.1)