-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #45 from ziatdinovmax/utils
Utilities for simplifying assignment of priors
- Loading branch information
Showing
2 changed files
with
283 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,8 @@ | |
Created by Maxim Ziatdinov (email: [email protected]) | ||
""" | ||
|
||
from typing import Union, Dict, Type, List | ||
import inspect | ||
from typing import Union, Dict, Type, List, Callable | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
@@ -167,10 +168,161 @@ def preprocess_sparse_image(sparse_image): | |
return gp_input, targets, full_indices | ||
|
||
|
||
def normal_prior(param_name, loc=0, scale=1): | ||
def place_normal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0): | ||
""" | ||
Samples a value from a normal distribution with the specified mean (loc) and standard deviation (scale), | ||
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions | ||
in structured Gaussian processes. | ||
""" | ||
return numpyro.sample(param_name, numpyro.distributions.Normal(loc, scale)) | ||
return numpyro.sample(param_name, normal_dist(loc, scale)) | ||
|
||
|
||
def place_halfnormal_prior(param_name: str, scale: float = 1.0): | ||
""" | ||
Samples a value from a half-normal distribution with the specified standard deviation (scale), | ||
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions | ||
in structured Gaussian processes. | ||
""" | ||
return numpyro.sample(param_name, halfnormal_dist(scale)) | ||
|
||
|
||
def place_uniform_prior(param_name: str, | ||
low: float = None, | ||
high: float = None, | ||
X: jnp.ndarray = None): | ||
""" | ||
Samples a value from a uniform distribution with the specified low and high values, | ||
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions | ||
in structured Gaussian processes. | ||
""" | ||
d = uniform_dist(low, high, X) | ||
return numpyro.sample(param_name, d) | ||
|
||
|
||
def place_gamma_prior(param_name: str, | ||
c: float = None, | ||
r: float = None, | ||
X: jnp.ndarray = None): | ||
""" | ||
Samples a value from a uniform distribution with the specified concentration (c) and rate (r) values, | ||
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions | ||
in structured Gaussian processes. | ||
""" | ||
d = gamma_dist(c, r, X) | ||
return numpyro.sample(param_name, d) | ||
|
||
|
||
def normal_dist(loc: float = None, scale: float = None | ||
) -> numpyro.distributions.Distribution: | ||
""" | ||
Generate a Normal distribution based on provided center (loc) and standard deviation (scale) parameters. | ||
I neithere are provided, uses 0 and 1 by default. | ||
""" | ||
loc = loc if loc is not None else 0.0 | ||
scale = scale if scale is not None else 1.0 | ||
return numpyro.distributions.Normal(loc, scale) | ||
|
||
|
||
def halfnormal_dist(scale: float = None) -> numpyro.distributions.Distribution: | ||
""" | ||
Generate a half-normal distribution based on provided standard deviation (scale). | ||
If none is provided, uses 1.0 by default. | ||
""" | ||
scale = scale if scale is not None else 1.0 | ||
return numpyro.distributions.HalfNormal(scale) | ||
|
||
|
||
def gamma_dist(c: float = None, | ||
r: float = None, | ||
input_vec: jnp.ndarray = None | ||
) -> numpyro.distributions.Distribution: | ||
""" | ||
Generate a Gamma distribution based on provided shape (c) and rate (r) parameters. If the shape (c) is not provided, | ||
it attempts to infer it using the range of the input vector divided by 2. The rate parameter defaults to 1.0 if not provided. | ||
""" | ||
if c is None: | ||
if input_vec is not None: | ||
c = (input_vec.max() - input_vec.min()) / 2 | ||
else: | ||
raise ValueError("Provide either c or an input array") | ||
if r is None: | ||
r = 1.0 | ||
return numpyro.distributions.Gamma(c, r) | ||
|
||
|
||
def uniform_dist(low: float = None, | ||
high: float = None, | ||
input_vec: jnp.ndarray = None | ||
) -> numpyro.distributions.Distribution: | ||
""" | ||
Generate a Uniform distribution based on provided low and high bounds. If one of the bounds is not provided, | ||
it attempts to infer the missing bound(s) using the minimum or maximum value from the input vector. | ||
""" | ||
if (low is None or high is None) and input_vec is None: | ||
raise ValueError( | ||
"If 'low' or 'high' is not provided, an input array must be provided.") | ||
low = low if low is not None else input_vec.min() | ||
high = high if high is not None else input_vec.max() | ||
|
||
return numpyro.distributions.Uniform(low, high) | ||
|
||
|
||
def set_fn(func: Callable) -> Callable: | ||
""" | ||
Transforms the given deterministic function to use a params dictionary | ||
for its parameters, excluding the first one (assumed to be the dependent variable). | ||
Args: | ||
- func (Callable): The deterministic function to be transformed. | ||
Returns: | ||
- Callable: The transformed function where parameters are accessed | ||
from a `params` dictionary. | ||
""" | ||
# Extract parameter names excluding the first one (assumed to be the dependent variable) | ||
params_names = list(inspect.signature(func).parameters.keys())[1:] | ||
|
||
# Create the transformed function definition | ||
transformed_code = f"def {func.__name__}(x, params):\n" | ||
|
||
# Retrieve the source code of the function and indent it to be a valid function body | ||
source = inspect.getsource(func).split("\n", 1)[1] | ||
source = " " + source.replace("\n", "\n ") | ||
|
||
# Replace each parameter name with its dictionary lookup | ||
for name in params_names: | ||
source = source.replace(f" {name}", f' params["{name}"]') | ||
|
||
# Combine to get the full source | ||
transformed_code += source | ||
|
||
# Define the transformed function in the local namespace | ||
local_namespace = {} | ||
exec(transformed_code, globals(), local_namespace) | ||
|
||
# Return the transformed function | ||
return local_namespace[func.__name__] | ||
|
||
|
||
def auto_normal_priors(func: Callable, loc: float = 0.0, scale: float = 1.0) -> Callable: | ||
""" | ||
Generates a function that, when invoked, samples from normal distributions | ||
for each parameter of the given deterministic function, except the first one. | ||
Args: | ||
- func (Callable): The deterministic function for which to set normal priors. | ||
- loc (float, optional): Mean of the normal distribution. Defaults to 0.0. | ||
- scale (float, optional): Standard deviation of the normal distribution. Defaults to 1.0. | ||
Returns: | ||
- Callable: A function that, when invoked, returns a dictionary of sampled values | ||
from normal distributions for each parameter of the original function. | ||
""" | ||
# Get the names of the parameters of the function excluding the first one (dependent variable) | ||
params_names = list(inspect.signature(func).parameters.keys())[1:] | ||
|
||
def sample_priors() -> Dict[str, Union[float, Type[Callable]]]: | ||
# Return a dictionary with normal priors for each parameter | ||
return {name: place_normal_prior(name, loc, scale) for name in params_names} | ||
|
||
return sample_priors |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters