Skip to content

Commit

Permalink
Fix #406 (#409)
Browse files Browse the repository at this point in the history
* Fix #406

* Add changelog entry.

* Add changelog entry.

* Move check inside _expand_categorical_penalties

* Remove glmnet_python.

* Changelog.

* Remove changelog entry since the bug was introduced after the previous release.

Co-authored-by: lbittarello <[email protected]>
  • Loading branch information
tbenthompson and lbittarello authored Sep 24, 2021
1 parent 4595f84 commit 53acb49
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 19 deletions.
40 changes: 21 additions & 19 deletions src/quantcore/glm/_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,27 +1504,29 @@ def _expand_categorical_penalties(penalty, X):
"""
if isinstance(penalty, str):
return penalty
else:
if np.asarray(penalty).shape[0] == X.shape[1]:
if np.asarray(penalty).ndim == 2:
raise ValueError(
"When the penalty is two dimensional, it has "
"to have the same length as the number of "
"columns of X, after the categoricals "
"have been expanded."
)
return np.array(
list(
chain.from_iterable(
[elmt for _ in dtype.categories]
if pd.api.types.is_categorical_dtype(dtype)
else [elmt]
for elmt, dtype in zip(penalty, X.dtypes)
)
if not sparse.issparse(penalty):
penalty = np.asanyarray(penalty)

if penalty.shape[0] == X.shape[1]:
if penalty.ndim == 2:
raise ValueError(
"When the penalty is two dimensional, it has "
"to have the same length as the number of "
"columns of X, after the categoricals "
"have been expanded."
)
return np.array(
list(
chain.from_iterable(
[elmt for _ in dtype.categories]
if pd.api.types.is_categorical_dtype(dtype)
else [elmt]
for elmt, dtype in zip(penalty, X.dtypes)
)
)
else:
return penalty
)
else:
return penalty

P1 = _expand_categorical_penalties(self.P1, X)
P2 = _expand_categorical_penalties(self.P2, X)
Expand Down
6 changes: 6 additions & 0 deletions tests/glm/test_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,12 @@ def test_P1_P2_expansion_with_categoricals():
mdl2.fit(X, y)
np.testing.assert_allclose(mdl1.coef_, mdl2.coef_)

mdl2 = GeneralizedLinearRegressor(
l1_ratio=0.01, P1=[1, 2], P2=sparse.diags([2, 1, 1, 1, 1, 1])
)
mdl2.fit(X, y)
np.testing.assert_allclose(mdl1.coef_, mdl2.coef_)


@pytest.mark.parametrize(
"estimator", [GeneralizedLinearRegressor, GeneralizedLinearRegressorCV]
Expand Down

0 comments on commit 53acb49

Please sign in to comment.