Skip to content

Commit

Permalink
Take the array range check out of the model
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Jan 5, 2024
1 parent 216e0c5 commit c2a56dc
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions gpax/models/uigp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,6 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None:
"""
Gaussian process model for uncertain (stochastic) inputs
"""
if not (X.max() == 1 and X.min() == 0):
warnings.warn(
"The default `sigma_x` prior for uncertain (stochastic) inputs assumes data is "
"normalized to (0, 1), which is not be the case for your data. Therefore, the default prior "
"may not be optimal for your case. Consider passing custom prior for sigma_x. For example, "
"`sigma_x_prior_dist=numpyro.distributions.HalfNormal(scale)` if using NumPyro directly "
"or `sigma_x_prior_dist=gpax.utils.halfnormal_dist(scale)` if using a GPax wrapper",
UserWarning,
)
# Initialize mean function at zeros
f_loc = jnp.zeros(X.shape[0])

Expand Down Expand Up @@ -160,6 +151,21 @@ def _predict(
y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,))
return y_mean, y_sampled

def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None) -> Union[Tuple[jnp.ndarray], jnp.ndarray]:
X = X if X.ndim > 1 else X[:, None]
if y is not None:
if not (X.max() == 1 and X.min() == 0):
warnings.warn(
"The default `sigma_x` prior for uncertain (stochastic) inputs assumes data is "
"normalized to (0, 1), which is not be the case for your data. Therefore, the default prior "
"may not be optimal for your case. Consider passing custom prior for sigma_x. For example, "
"`sigma_x_prior_dist=numpyro.distributions.HalfNormal(scale)` if using NumPyro directly "
"or `sigma_x_prior_dist=gpax.utils.halfnormal_dist(scale)` if using a GPax wrapper",
UserWarning,
)
return X, y.squeeze()
return X

def _print_summary(self):
samples = self.get_samples(1)
numpyro.diagnostics.print_summary({k: v for (k, v) in samples.items() if 'X_prime' not in k})

0 comments on commit c2a56dc

Please sign in to comment.