Skip to content

Commit

Permalink
fix: only convert dtype when target dtype is not integer (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber authored Apr 1, 2024
1 parent 4a1a4cc commit 98b0c06
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 23 deletions.
28 changes: 14 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ conformal_predictor.fit(X_train, y_train)

When the input data is a pandas DataFrame, the output is also a pandas DataFrame. For example, printing the head of `ŷ_test_quantiles` yields:

| house_id | 0.025 | 0.05 | 0.1 | 0.9 | 0.95 | 0.975 |
|-----------:|--------:|-------:|-------:|-------:|-------:|--------:|
| 1357 | 114784 | 120894 | 131618 | 175761 | 188052 | 205449 |
| 2367 | 67417 | 80074 | 86754 | 117854 | 127583 | 142322 |
| 2822 | 119423 | 132048 | 138725 | 178526 | 197246 | 214206 |
| 2126 | 94031 | 99850 | 110891 | 150249 | 164703 | 182528 |
| 1544 | 68996 | 81516 | 88232 | 121774 | 132425 | 147110 |
| house_id | 0.025 | 0.05 | 0.1 | 0.9 | 0.95 | 0.975 |
|-----------:|---------:|---------:|---------:|-------:|-------:|--------:|
| 1357 | 114784 | 120894 | 131618 | 175761 | 188052 | 205449 |
| 2367 | 67416.6 | 80073.7 | 86754 | 117854 | 127583 | 142322 |
| 2822 | 119423 | 132048 | 138725 | 178526 | 197246 | 214206 |
| 2126 | 94030.6 | 99850 | 110891 | 150249 | 164703 | 182528 |
| 1544 | 68996.2 | 81516.3 | 88231.6 | 121774 | 132425 | 147110 |

Let's visualize the predicted quantiles on the test set:

Expand Down Expand Up @@ -114,13 +114,13 @@ print(coverage) # 96.6%

When the input data is a pandas DataFrame, the output is also a pandas DataFrame. For example, printing the head of `ŷ_test_interval` yields:

| house_id | 0.025 | 0.975 |
|-----------:|--------:|--------:|
| 1357 | 107203 | 206290 |
| 2367 | 66665 | 146005 |
| 2822 | 115592 | 220315 |
| 2126 | 85288 | 183038 |
| 1544 | 67890 | 150646 |
| house_id | 0.025 | 0.975 |
|-----------:|---------:|--------:|
| 1357 | 107203 | 206290 |
| 2367 | 66665.1 | 146005 |
| 2822 | 115592 | 220315 |
| 2126 | 85288.1 | 183038 |
| 1544 | 67889.9 | 150646 |

## Contributing

Expand Down
20 changes: 11 additions & 9 deletions src/conformal_tights/_conformal_coherent_quantile_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,10 @@ def fit(
y = np.ravel(np.asarray(y))
self.n_features_in_: int = X.shape[1]
self.y_dtype_: npt.DTypeLike = y.dtype # Used to cast predictions to the correct dtype.
self.y_is_integer_: bool = bool(np.all(y.astype(np.intp) == y))
y = y.astype(np.float64) # To support datetime64[ns] and timedelta64[ns].
if sample_weight is not None:
check_consistent_length(y, sample_weight)
sample_weight = np.ravel(np.asarray(sample_weight).astype(np.float64))
sample_weight = np.ravel(np.asarray(sample_weight).astype(y.dtype))
# Use the smallest of the relative and absolute calibration sizes.
calib_size = min(
int(self.conformal_calibration_size[0] * X.shape[0]), self.conformal_calibration_size[1]
Expand Down Expand Up @@ -274,9 +273,8 @@ def predict_quantiles(
np.arange(Δŷ_quantiles.shape[0]), :, np.argmin(dispersion, axis=-1)
]
ŷ_quantiles: FloatMatrix[F] = ŷ[:, np.newaxis] + Δŷ_quantiles
if self.y_is_integer_ and np.issubdtype(self.y_dtype_, np.integer):
ŷ_quantiles = np.round(ŷ_quantiles)
ŷ_quantiles = ŷ_quantiles.astype(self.y_dtype_)
if not np.issubdtype(self.y_dtype_, np.integer):
ŷ_quantiles = ŷ_quantiles.astype(self.y_dtype_)
# Convert ŷ_quantiles to a pandas DataFrame if X is a pandas DataFrame.
if hasattr(X, "dtypes") and hasattr(X, "index"):
try:
Expand Down Expand Up @@ -354,9 +352,8 @@ def predict(
ŷ_quantiles = self.predict_quantiles(X, quantiles=quantiles)
return ŷ_quantiles
ŷ = self.estimator_.predict(X)
if self.y_is_integer_:
ŷ = np.round(ŷ)
ŷ = ŷ.astype(self.y_dtype_)
if not np.issubdtype(self.y_dtype_, np.integer):
ŷ = ŷ.astype(self.y_dtype_)
if hasattr(X, "dtypes") and hasattr(X, "index"):
try:
import pandas as pd
Expand All @@ -369,4 +366,9 @@ def predict(

def _more_tags(self) -> dict[str, Any]:
"""Return more tags for the estimator."""
return {"allow_nan": True, "_xfail_checks": {"check_regressors_int": "Incompatible check"}}
return {
"allow_nan": True,
"_xfail_checks": {
"check_sample_weights_invariance": "Conformal calibration not invariant to removing zero-weight examples"
},
}

0 comments on commit 98b0c06

Please sign in to comment.