From cb1fc816044f510eb7811d09f3ee377f96c69e77 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Tue, 10 Oct 2023 18:22:39 -0400 Subject: [PATCH] Add option for using a custom noise kernel prior --- gpax/models/hskgp.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/gpax/models/hskgp.py b/gpax/models/hskgp.py index 40d7f54..de3c464 100644 --- a/gpax/models/hskgp.py +++ b/gpax/models/hskgp.py @@ -56,6 +56,7 @@ def __init__( mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, + noise_kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, lengthscale_prior_dist: Optional[dist.Distribution] = None, noise_mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, noise_mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, @@ -68,6 +69,7 @@ def __init__( self.noise_mean_fn = noise_mean_fn self.noise_mean_fn_prior = noise_mean_fn_prior + self.noise_kernel_prior = noise_kernel_prior self.noise_lengthscale_prior_dist = noise_lengthscale_prior_dist def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None: @@ -77,7 +79,10 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None: noise_f_loc = jnp.zeros(X.shape[0]) # Sample noise kernel parameters - noise_kernel_params = self._sample_noise_kernel_params() + if self.noise_kernel_prior: + noise_kernel_params = self.noise_kernel_prior() + else: + noise_kernel_params = self._sample_noise_kernel_params() # Add noise prior mean function (if any) if self.noise_mean_fn is not None: args = [X]