Skip to content

Commit

Permalink
fix: xfail sklearn's check_regressors_int (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber authored Apr 1, 2024
1 parent 2b14c2a commit 9fea615
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ warn_unreachable = true

[tool.pytest.ini_options] # https://docs.pytest.org/en/latest/reference/reference.html#ini-options-ref
addopts = "--color=yes --doctest-modules --exitfirst --failed-first --strict-config --strict-markers --verbosity=2 --junitxml=reports/pytest.xml"
filterwarnings = ["error", "ignore::DeprecationWarning"]
filterwarnings = ["error", "ignore::DeprecationWarning", "ignore::sklearn.exceptions.SkipTestWarning"]
testpaths = ["src", "tests"]
xfail_strict = true

Expand Down
10 changes: 4 additions & 6 deletions src/conformal_tights/_conformal_coherent_quantile_regressor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Conformal Coherent Quantile Regressor meta-estimator."""

from typing import TYPE_CHECKING, Literal, TypeVar, overload
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -102,8 +102,6 @@ def fit(
self.n_features_in_: int = X.shape[1]
self.y_dtype_: npt.DTypeLike = y.dtype # Used to cast predictions to the correct dtype.
self.y_is_integer_: bool = bool(np.all(y.astype(np.intp) == y))
if self.y_is_integer_:
self.y_dtype_ = np.intp # To satisfy sklearn's `check_regressors_int`.
y = y.astype(np.float64) # To support datetime64[ns] and timedelta64[ns].
if sample_weight is not None:
check_consistent_length(y, sample_weight)
Expand Down Expand Up @@ -276,7 +274,7 @@ def predict_quantiles(
np.arange(Δŷ_quantiles.shape[0]), :, np.argmin(dispersion, axis=-1)
]
ŷ_quantiles: FloatMatrix[F] = ŷ[:, np.newaxis] + Δŷ_quantiles
if self.y_is_integer_:
if self.y_is_integer_ and np.issubdtype(self.y_dtype_, np.integer):
ŷ_quantiles = np.round(ŷ_quantiles)
ŷ_quantiles = ŷ_quantiles.astype(self.y_dtype_)
# Convert ŷ_quantiles to a pandas DataFrame if X is a pandas DataFrame.
Expand Down Expand Up @@ -369,6 +367,6 @@ def predict(
return ŷ_series
return ŷ

def _more_tags(self) -> dict[str, bool]:
def _more_tags(self) -> dict[str, Any]:
"""Return more tags for the estimator."""
return {"allow_nan": True}
return {"allow_nan": True, "_xfail_checks": {"check_regressors_int": "Incompatible check"}}

0 comments on commit 9fea615

Please sign in to comment.