diff --git a/gpax/utils.py b/gpax/utils.py index f031508..f7b11c4 100644 --- a/gpax/utils.py +++ b/gpax/utils.py @@ -7,7 +7,8 @@ Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) """ -from typing import Union, Dict, Type, List +import inspect +from typing import Union, Dict, Type, List, Callable import jax import jax.numpy as jnp @@ -167,10 +168,161 @@ def preprocess_sparse_image(sparse_image): return gp_input, targets, full_indices -def normal_prior(param_name, loc=0, scale=1): +def place_normal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0): """ Samples a value from a normal distribution with the specified mean (loc) and standard deviation (scale), and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions in structured Gaussian processes. """ - return numpyro.sample(param_name, numpyro.distributions.Normal(loc, scale)) + return numpyro.sample(param_name, normal_dist(loc, scale)) + + +def place_halfnormal_prior(param_name: str, scale: float = 1.0): + """ + Samples a value from a half-normal distribution with the specified standard deviation (scale), + and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions + in structured Gaussian processes. + """ + return numpyro.sample(param_name, halfnormal_dist(scale)) + + +def place_uniform_prior(param_name: str, + low: float = None, + high: float = None, + X: jnp.ndarray = None): + """ + Samples a value from a uniform distribution with the specified low and high values, + and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions + in structured Gaussian processes. + """ + d = uniform_dist(low, high, X) + return numpyro.sample(param_name, d) + + +def place_gamma_prior(param_name: str, + c: float = None, + r: float = None, + X: jnp.ndarray = None): + """ + Samples a value from a uniform distribution with the specified concentration (c) and rate (r) values, + and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions + in structured Gaussian processes. + """ + d = gamma_dist(c, r, X) + return numpyro.sample(param_name, d) + + +def normal_dist(loc: float = None, scale: float = None + ) -> numpyro.distributions.Distribution: + """ + Generate a Normal distribution based on provided center (loc) and standard deviation (scale) parameters. + I neithere are provided, uses 0 and 1 by default. + """ + loc = loc if loc is not None else 0.0 + scale = scale if scale is not None else 1.0 + return numpyro.distributions.Normal(loc, scale) + + +def halfnormal_dist(scale: float = None) -> numpyro.distributions.Distribution: + """ + Generate a half-normal distribution based on provided standard deviation (scale). + If none is provided, uses 1.0 by default. + """ + scale = scale if scale is not None else 1.0 + return numpyro.distributions.HalfNormal(scale) + + +def gamma_dist(c: float = None, + r: float = None, + input_vec: jnp.ndarray = None + ) -> numpyro.distributions.Distribution: + """ + Generate a Gamma distribution based on provided shape (c) and rate (r) parameters. If the shape (c) is not provided, + it attempts to infer it using the range of the input vector divided by 2. The rate parameter defaults to 1.0 if not provided. + """ + if c is None: + if input_vec is not None: + c = (input_vec.max() - input_vec.min()) / 2 + else: + raise ValueError("Provide either c or an input array") + if r is None: + r = 1.0 + return numpyro.distributions.Gamma(c, r) + + +def uniform_dist(low: float = None, + high: float = None, + input_vec: jnp.ndarray = None + ) -> numpyro.distributions.Distribution: + """ + Generate a Uniform distribution based on provided low and high bounds. If one of the bounds is not provided, + it attempts to infer the missing bound(s) using the minimum or maximum value from the input vector. + """ + if (low is None or high is None) and input_vec is None: + raise ValueError( + "If 'low' or 'high' is not provided, an input array must be provided.") + low = low if low is not None else input_vec.min() + high = high if high is not None else input_vec.max() + + return numpyro.distributions.Uniform(low, high) + + +def set_fn(func: Callable) -> Callable: + """ + Transforms the given deterministic function to use a params dictionary + for its parameters, excluding the first one (assumed to be the dependent variable). + + Args: + - func (Callable): The deterministic function to be transformed. + + Returns: + - Callable: The transformed function where parameters are accessed + from a `params` dictionary. + """ + # Extract parameter names excluding the first one (assumed to be the dependent variable) + params_names = list(inspect.signature(func).parameters.keys())[1:] + + # Create the transformed function definition + transformed_code = f"def {func.__name__}(x, params):\n" + + # Retrieve the source code of the function and indent it to be a valid function body + source = inspect.getsource(func).split("\n", 1)[1] + source = " " + source.replace("\n", "\n ") + + # Replace each parameter name with its dictionary lookup + for name in params_names: + source = source.replace(f" {name}", f' params["{name}"]') + + # Combine to get the full source + transformed_code += source + + # Define the transformed function in the local namespace + local_namespace = {} + exec(transformed_code, globals(), local_namespace) + + # Return the transformed function + return local_namespace[func.__name__] + + +def auto_normal_priors(func: Callable, loc: float = 0.0, scale: float = 1.0) -> Callable: + """ + Generates a function that, when invoked, samples from normal distributions + for each parameter of the given deterministic function, except the first one. + + Args: + - func (Callable): The deterministic function for which to set normal priors. + - loc (float, optional): Mean of the normal distribution. Defaults to 0.0. + - scale (float, optional): Standard deviation of the normal distribution. Defaults to 1.0. + + Returns: + - Callable: A function that, when invoked, returns a dictionary of sampled values + from normal distributions for each parameter of the original function. + """ + # Get the names of the parameters of the function excluding the first one (dependent variable) + params_names = list(inspect.signature(func).parameters.keys())[1:] + + def sample_priors() -> Dict[str, Union[float, Type[Callable]]]: + # Return a dictionary with normal priors for each parameter + return {name: place_normal_prior(name, loc, scale) for name in params_names} + + return sample_priors diff --git a/tests/test_utils.py b/tests/test_utils.py index 8609bc9..62d0025 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ import sys +import pytest import numpy as onp import jax.numpy as jnp import jax.random as jra @@ -7,9 +8,14 @@ sys.path.insert(0, "../gpax/") -from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys, normal_prior +from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys +from gpax.utils import place_normal_prior, place_halfnormal_prior, place_uniform_prior, place_gamma_prior, gamma_dist, uniform_dist, normal_dist, halfnormal_dist +from gpax.utils import set_fn, auto_normal_priors +def sample_function(x, a, b): + return a + b * x + def test_sparse_img_processing(): img = onp.random.randn(16, 16) # Generate random indices @@ -107,17 +113,135 @@ def test_get_keys_different_seeds(): assert_(not onp.array_equal(key2, key2a)) -def test_normal_prior(): +@pytest.mark.parametrize("prior", [place_normal_prior, place_halfnormal_prior]) +def test_normal_prior(prior): + with numpyro.handlers.seed(rng_seed=1): + sample = prior("a") + assert_(isinstance(sample, jnp.ndarray)) + + +def test_uniform_prior(): + with numpyro.handlers.seed(rng_seed=1): + sample = place_uniform_prior("a", 0, 1) + assert_(isinstance(sample, jnp.ndarray)) + + +def test_gamma_prior(): with numpyro.handlers.seed(rng_seed=1): - sample = normal_prior("a") + sample = place_gamma_prior("a", 2, 2) assert_(isinstance(sample, jnp.ndarray)) def test_normal_prior_params(): with numpyro.handlers.seed(rng_seed=1): with numpyro.handlers.trace() as tr: - normal_prior("a", loc=0.5, scale=0.1) + place_normal_prior("a", loc=0.5, scale=0.1) site = tr["a"] assert_(isinstance(site['fn'], numpyro.distributions.Normal)) assert_equal(site['fn'].loc, 0.5) assert_equal(site['fn'].scale, 0.1) + + +def test_halfnormal_prior_params(): + with numpyro.handlers.seed(rng_seed=1): + with numpyro.handlers.trace() as tr: + place_halfnormal_prior("a", 0.1) + site = tr["a"] + assert_(isinstance(site['fn'], numpyro.distributions.HalfNormal)) + assert_equal(site['fn'].scale, 0.1) + + +def test_uniform_prior_params(): + with numpyro.handlers.seed(rng_seed=1): + with numpyro.handlers.trace() as tr: + place_uniform_prior("a", low=0.5, high=1.0) + site = tr["a"] + assert_(isinstance(site['fn'], numpyro.distributions.Uniform)) + assert_equal(site['fn'].low, 0.5) + assert_equal(site['fn'].high, 1.0) + + +def test_gamma_prior_params(): + with numpyro.handlers.seed(rng_seed=1): + with numpyro.handlers.trace() as tr: + place_gamma_prior("a", c=2.0, r=1.0) + site = tr["a"] + assert_(isinstance(site['fn'], numpyro.distributions.Gamma)) + assert_equal(site['fn'].concentration, 2.0) + assert_equal(site['fn'].rate, 1.0) + + +def test_get_uniform_dist(): + uniform_dist_ = uniform_dist(low=1.0, high=5.0) + assert isinstance(uniform_dist_, numpyro.distributions.Uniform) + assert uniform_dist_.low == 1.0 + assert uniform_dist_.high == 5.0 + + +def test_get_uniform_dist_infer_params(): + uniform_dist_ = uniform_dist(input_vec=jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])) + assert uniform_dist_.low == 1.0 + assert uniform_dist_.high == 5.0 + + +def test_get_gamma_dist(): + gamma_dist_ = gamma_dist(c=2.0, r=1.0) + assert isinstance(gamma_dist_, numpyro.distributions.Gamma) + assert gamma_dist_.concentration == 2.0 + assert gamma_dist_.rate == 1.0 + + +def test_get_normal_dist(): + normal_dist_ = normal_dist(loc=2.0, scale=3.0) + assert isinstance(normal_dist_, numpyro.distributions.Normal) + assert normal_dist_.loc == 2.0 + assert normal_dist_.scale == 3.0 + + +def test_get_halfnormal_dist(): + halfnormal_dist_ = halfnormal_dist(scale=1.5) + assert isinstance(halfnormal_dist_, numpyro.distributions.HalfNormal) + assert halfnormal_dist_.scale == 1.5 + + +def test_get_gamma_dist_infer_param(): + gamma_dist_ = gamma_dist(input_vec=jnp.linspace(0, 10, 20)) + assert isinstance(gamma_dist_, numpyro.distributions.Gamma) + assert gamma_dist_.concentration == 5.0 + assert gamma_dist_.rate == 1.0 + + +def test_get_uniform_dist_error(): + with pytest.raises(ValueError): + uniform_dist(low=1.0) # Only low provided without input_vec + with pytest.raises(ValueError): + uniform_dist(high=5.0) # Only high provided without input_vec + with pytest.raises(ValueError): + uniform_dist() # Neither low nor high, and no input_vec + + +def test_get_gamma_dist_error(): + with pytest.raises(ValueError): + uniform_dist() # Neither concentration, nor input_vec + + +def test_set_fn(): + transformed_fn = set_fn(sample_function) + result = transformed_fn(2, {"a": 1, "b": 3}) + assert result == 7 # Expected output: 1 + 3*2 = 7 + + +def test_auto_normal_priors(): + prior_fn = auto_normal_priors(sample_function, loc=2.0, scale=1.0) + with numpyro.handlers.seed(rng_seed=1): + with numpyro.handlers.trace() as tr: + prior_fn() + site1 = tr["a"] + assert_(isinstance(site1['fn'], numpyro.distributions.Normal)) + assert_equal(site1['fn'].loc, 2.0) + assert_equal(site1['fn'].scale, 1.0) + site2 = tr["b"] + assert_(isinstance(site2['fn'], numpyro.distributions.Normal)) + assert_equal(site2['fn'].loc, 2.0) + assert_equal(site2['fn'].scale, 1.0) +