From c2a56dc2b24310cb12fd5673aac663ed850df201 Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Fri, 5 Jan 2024 14:14:12 -0800 Subject: [PATCH] Take the array range check out of the model --- gpax/models/uigp.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/gpax/models/uigp.py b/gpax/models/uigp.py index dc6e528..5a9e021 100644 --- a/gpax/models/uigp.py +++ b/gpax/models/uigp.py @@ -66,15 +66,6 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None: """ Gaussian process model for uncertain (stochastic) inputs """ - if not (X.max() == 1 and X.min() == 0): - warnings.warn( - "The default `sigma_x` prior for uncertain (stochastic) inputs assumes data is " - "normalized to (0, 1), which is not be the case for your data. Therefore, the default prior " - "may not be optimal for your case. Consider passing custom prior for sigma_x. For example, " - "`sigma_x_prior_dist=numpyro.distributions.HalfNormal(scale)` if using NumPyro directly " - "or `sigma_x_prior_dist=gpax.utils.halfnormal_dist(scale)` if using a GPax wrapper", - UserWarning, - ) # Initialize mean function at zeros f_loc = jnp.zeros(X.shape[0]) @@ -160,6 +151,21 @@ def _predict( y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,)) return y_mean, y_sampled + def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None) -> Union[Tuple[jnp.ndarray], jnp.ndarray]: + X = X if X.ndim > 1 else X[:, None] + if y is not None: + if not (X.max() == 1 and X.min() == 0): + warnings.warn( + "The default `sigma_x` prior for uncertain (stochastic) inputs assumes data is " + "normalized to (0, 1), which is not be the case for your data. Therefore, the default prior " + "may not be optimal for your case. Consider passing custom prior for sigma_x. For example, " + "`sigma_x_prior_dist=numpyro.distributions.HalfNormal(scale)` if using NumPyro directly " + "or `sigma_x_prior_dist=gpax.utils.halfnormal_dist(scale)` if using a GPax wrapper", + UserWarning, + ) + return X, y.squeeze() + return X + def _print_summary(self): samples = self.get_samples(1) numpyro.diagnostics.print_summary({k: v for (k, v) in samples.items() if 'X_prime' not in k})