diff --git a/pyproject.toml b/pyproject.toml index 05ec09b..87a78d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ warn_unreachable = true [tool.pytest.ini_options] # https://docs.pytest.org/en/latest/reference/reference.html#ini-options-ref addopts = "--color=yes --doctest-modules --exitfirst --failed-first --strict-config --strict-markers --verbosity=2 --junitxml=reports/pytest.xml" -filterwarnings = ["error", "ignore::DeprecationWarning"] +filterwarnings = ["error", "ignore::DeprecationWarning", "ignore::sklearn.exceptions.SkipTestWarning"] testpaths = ["src", "tests"] xfail_strict = true diff --git a/src/conformal_tights/_conformal_coherent_quantile_regressor.py b/src/conformal_tights/_conformal_coherent_quantile_regressor.py index 6234952..4894d0a 100644 --- a/src/conformal_tights/_conformal_coherent_quantile_regressor.py +++ b/src/conformal_tights/_conformal_coherent_quantile_regressor.py @@ -1,6 +1,6 @@ """Conformal Coherent Quantile Regressor meta-estimator.""" -from typing import TYPE_CHECKING, Literal, TypeVar, overload +from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload import numpy as np import numpy.typing as npt @@ -102,8 +102,6 @@ def fit( self.n_features_in_: int = X.shape[1] self.y_dtype_: npt.DTypeLike = y.dtype # Used to cast predictions to the correct dtype. self.y_is_integer_: bool = bool(np.all(y.astype(np.intp) == y)) - if self.y_is_integer_: - self.y_dtype_ = np.intp # To satisfy sklearn's `check_regressors_int`. y = y.astype(np.float64) # To support datetime64[ns] and timedelta64[ns]. if sample_weight is not None: check_consistent_length(y, sample_weight) @@ -276,7 +274,7 @@ def predict_quantiles( np.arange(Δŷ_quantiles.shape[0]), :, np.argmin(dispersion, axis=-1) ] ŷ_quantiles: FloatMatrix[F] = ŷ[:, np.newaxis] + Δŷ_quantiles - if self.y_is_integer_: + if self.y_is_integer_ and np.issubdtype(self.y_dtype_, np.integer): ŷ_quantiles = np.round(ŷ_quantiles) ŷ_quantiles = ŷ_quantiles.astype(self.y_dtype_) # Convert ŷ_quantiles to a pandas DataFrame if X is a pandas DataFrame. @@ -369,6 +367,6 @@ def predict( return ŷ_series return ŷ - def _more_tags(self) -> dict[str, bool]: + def _more_tags(self) -> dict[str, Any]: """Return more tags for the estimator.""" - return {"allow_nan": True} + return {"allow_nan": True, "_xfail_checks": {"check_regressors_int": "Incompatible check"}}