diff --git a/README.md b/README.md index 965436c..2107b88 100644 --- a/README.md +++ b/README.md @@ -49,13 +49,13 @@ conformal_predictor.fit(X_train, y_train) When the input data is a pandas DataFrame, the output is also a pandas DataFrame. For example, printing the head of `ŷ_test_quantiles` yields: -| house_id | 0.025 | 0.05 | 0.1 | 0.9 | 0.95 | 0.975 | -|-----------:|--------:|-------:|-------:|-------:|-------:|--------:| -| 1357 | 114784 | 120894 | 131618 | 175761 | 188052 | 205449 | -| 2367 | 67417 | 80074 | 86754 | 117854 | 127583 | 142322 | -| 2822 | 119423 | 132048 | 138725 | 178526 | 197246 | 214206 | -| 2126 | 94031 | 99850 | 110891 | 150249 | 164703 | 182528 | -| 1544 | 68996 | 81516 | 88232 | 121774 | 132425 | 147110 | +| house_id | 0.025 | 0.05 | 0.1 | 0.9 | 0.95 | 0.975 | +|-----------:|---------:|---------:|---------:|-------:|-------:|--------:| +| 1357 | 114784 | 120894 | 131618 | 175761 | 188052 | 205449 | +| 2367 | 67416.6 | 80073.7 | 86754 | 117854 | 127583 | 142322 | +| 2822 | 119423 | 132048 | 138725 | 178526 | 197246 | 214206 | +| 2126 | 94030.6 | 99850 | 110891 | 150249 | 164703 | 182528 | +| 1544 | 68996.2 | 81516.3 | 88231.6 | 121774 | 132425 | 147110 | Let's visualize the predicted quantiles on the test set: @@ -114,13 +114,13 @@ print(coverage) # 96.6% When the input data is a pandas DataFrame, the output is also a pandas DataFrame. For example, printing the head of `ŷ_test_interval` yields: -| house_id | 0.025 | 0.975 | -|-----------:|--------:|--------:| -| 1357 | 107203 | 206290 | -| 2367 | 66665 | 146005 | -| 2822 | 115592 | 220315 | -| 2126 | 85288 | 183038 | -| 1544 | 67890 | 150646 | +| house_id | 0.025 | 0.975 | +|-----------:|---------:|--------:| +| 1357 | 107203 | 206290 | +| 2367 | 66665.1 | 146005 | +| 2822 | 115592 | 220315 | +| 2126 | 85288.1 | 183038 | +| 1544 | 67889.9 | 150646 | ## Contributing diff --git a/src/conformal_tights/_conformal_coherent_quantile_regressor.py b/src/conformal_tights/_conformal_coherent_quantile_regressor.py index 4894d0a..0c945a7 100644 --- a/src/conformal_tights/_conformal_coherent_quantile_regressor.py +++ b/src/conformal_tights/_conformal_coherent_quantile_regressor.py @@ -101,11 +101,10 @@ def fit( y = np.ravel(np.asarray(y)) self.n_features_in_: int = X.shape[1] self.y_dtype_: npt.DTypeLike = y.dtype # Used to cast predictions to the correct dtype. - self.y_is_integer_: bool = bool(np.all(y.astype(np.intp) == y)) y = y.astype(np.float64) # To support datetime64[ns] and timedelta64[ns]. if sample_weight is not None: check_consistent_length(y, sample_weight) - sample_weight = np.ravel(np.asarray(sample_weight).astype(np.float64)) + sample_weight = np.ravel(np.asarray(sample_weight).astype(y.dtype)) # Use the smallest of the relative and absolute calibration sizes. calib_size = min( int(self.conformal_calibration_size[0] * X.shape[0]), self.conformal_calibration_size[1] @@ -274,9 +273,8 @@ def predict_quantiles( np.arange(Δŷ_quantiles.shape[0]), :, np.argmin(dispersion, axis=-1) ] ŷ_quantiles: FloatMatrix[F] = ŷ[:, np.newaxis] + Δŷ_quantiles - if self.y_is_integer_ and np.issubdtype(self.y_dtype_, np.integer): - ŷ_quantiles = np.round(ŷ_quantiles) - ŷ_quantiles = ŷ_quantiles.astype(self.y_dtype_) + if not np.issubdtype(self.y_dtype_, np.integer): + ŷ_quantiles = ŷ_quantiles.astype(self.y_dtype_) # Convert ŷ_quantiles to a pandas DataFrame if X is a pandas DataFrame. if hasattr(X, "dtypes") and hasattr(X, "index"): try: @@ -354,9 +352,8 @@ def predict( ŷ_quantiles = self.predict_quantiles(X, quantiles=quantiles) return ŷ_quantiles ŷ = self.estimator_.predict(X) - if self.y_is_integer_: - ŷ = np.round(ŷ) - ŷ = ŷ.astype(self.y_dtype_) + if not np.issubdtype(self.y_dtype_, np.integer): + ŷ = ŷ.astype(self.y_dtype_) if hasattr(X, "dtypes") and hasattr(X, "index"): try: import pandas as pd @@ -369,4 +366,9 @@ def predict( def _more_tags(self) -> dict[str, Any]: """Return more tags for the estimator.""" - return {"allow_nan": True, "_xfail_checks": {"check_regressors_int": "Incompatible check"}} + return { + "allow_nan": True, + "_xfail_checks": { + "check_sample_weights_invariance": "Conformal calibration not invariant to removing zero-weight examples" + }, + }