Skip to content

Commit

Permalink
fix: improve dtype handling in CoherentLinearQuantileRegressor (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber authored Apr 1, 2024
1 parent 9e8058f commit 959eeb0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ jobs:
PYTHON_VERSION=${{ matrix.python-version }} devcontainer up --workspace-folder .
- name: Lint package
run: devcontainer exec --remote-env CI=true --workspace-folder . poe lint
run: devcontainer exec --workspace-folder . poe lint

- name: Test package
run: devcontainer exec --remote-env CI=true --workspace-folder . poe test
run: devcontainer exec --workspace-folder . poe test

- name: Upload coverage
uses: codecov/codecov-action@v4
Expand Down
19 changes: 12 additions & 7 deletions src/conformal_tights/_coherent_linear_quantile_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,13 @@ def fit(
# Validate input.
X, y = check_X_y(X, y, y_numeric=True)
self.n_features_in_: int = X.shape[1]
self.y_dtype_: npt.DTypeLike = y.dtype # Used to cast predictions to the correct dtype.
if np.all(y.astype(np.intp) == y):
self.y_dtype_ = np.intp # To satisfy sklearn's `check_regressors_int`.
y = y.astype(np.float64) # To support datetime64[ns] and timedelta64[ns].
self.y_dtype_: npt.DTypeLike = ( # Used to cast predictions to the correct dtype.
X.dtype if np.issubdtype(y.dtype, np.integer) else y.dtype
)
X, y = X.astype(np.float64), y.astype(np.float64) # To support datetime64 and timedelta64.
if sample_weight is not None:
check_consistent_length(y, sample_weight)
sample_weight = np.asarray(sample_weight).astype(np.float64)
sample_weight = np.asarray(sample_weight).astype(y.dtype)
# Add a constant column to X to allow for a bias in the regression.
if self.fit_intercept:
X = np.hstack([X, np.ones((X.shape[0], 1), dtype=X.dtype)])
Expand All @@ -240,18 +240,23 @@ def predict(self, X: FloatMatrix[F]) -> FloatMatrix[F]:
"""Predict the output on a given dataset."""
# Check input.
check_is_fitted(self)
X = check_array(X)
X = check_array(X, dtype=np.float64)
# Add a constant column to X to allow for a bias in the regression.
if self.fit_intercept:
X = np.hstack([X, np.ones((X.shape[0], 1), dtype=X.dtype)])
# Predict the output.
ŷ: FloatMatrix[F] = X @ self.β_
# Map back to the training target dtype.
ŷ = np.squeeze(ŷ.astype(self.y_dtype_), axis=1 if ŷ.shape[1] == 1 else ())
ŷ = np.squeeze(ŷ, axis=1 if ŷ.shape[1] == 1 else ())
if not np.issubdtype(self.y_dtype_, np.integer):
ŷ.astype(self.y_dtype_)
return ŷ

def intercept_clip(self, X: FloatMatrix[F], y: FloatVector[F]) -> FloatMatrix[F]:
"""Compute a clip for a delta on the intercept that retains quantile coherence."""
check_is_fitted(self)
X, y = check_X_y(X, y, y_numeric=True)
X, y = X.astype(np.float64), y.astype(np.float64)
if self.fit_intercept:
X = np.hstack([X, np.ones((X.shape[0], 1), dtype=X.dtype)])
Q = X @ self.β_full_ - y[:, np.newaxis]
Expand Down

0 comments on commit 959eeb0

Please sign in to comment.