Skip to content

Commit

Permalink
Fixes #368
Browse files Browse the repository at this point in the history
  • Loading branch information
ejolly committed Dec 18, 2022
1 parent 7b062bd commit 738f9c1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
16 changes: 6 additions & 10 deletions nltools/data/brain_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,13 +1209,9 @@ def predict(self, algorithm=None, cv_dict=None, plot=True, **kwargs):
)
# Intercept
if predictor_settings["algorithm"] == "pcr":
output["intercept_xval"].append(
predictor_settings["_regress"].intercept_
)
output["intercept_xval"].append(predictor_cv["regress"].intercept_)
elif predictor_settings["algorithm"] == "lassopcr":
output["intercept_xval"].append(
predictor_settings["_lasso"].intercept_
)
output["intercept_xval"].append(predictor_cv["lasso"].intercept_)
else:
output["intercept_xval"].append(predictor_cv.intercept_)
output["cv_idx"].append((train, test))
Expand All @@ -1233,15 +1229,15 @@ def predict(self, algorithm=None, cv_dict=None, plot=True, **kwargs):
if predictor_settings["algorithm"] == "lassopcr":
wt_map_xval.append(
np.dot(
predictor_settings["_pca"].components_.T,
predictor_settings["_lasso"].coef_,
predictor_cv["pca"].components_.T,
predictor_cv["lasso"].coef_,
)
)
elif predictor_settings["algorithm"] == "pcr":
wt_map_xval.append(
np.dot(
predictor_settings["_pca"].components_.T,
predictor_settings["_regress"].coef_,
predictor_cv["pca"].components_.T,
predictor_cv["regress"].coef_,
)
)
else:
Expand Down
24 changes: 16 additions & 8 deletions nltools/tests/test_brain_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,27 +389,23 @@ def test_bootstrap(sim_brain_data):

def test_predict(sim_brain_data):
holdout = np.array([[x] * 2 for x in range(3)]).flatten()
cv_dict = {"type": "kfolds", "n_folds": 2}
stats = sim_brain_data.predict(
algorithm="svm",
cv_dict={"type": "kfolds", "n_folds": 2},
plot=False,
**{"kernel": "linear"}
algorithm="svm", cv_dict=cv_dict, plot=False, **{"kernel": "linear"}
)

# Support Vector Regression, with 5 fold cross-validation with Platt Scaling
# This will output probabilities of each class
stats = sim_brain_data.predict(
algorithm="svm",
cv_dict=None,
cv_dict=cv_dict,
plot=False,
**{"kernel": "linear", "probability": True}
)
assert isinstance(stats["weight_map"], Brain_Data)

# Logistic classificiation, with 2 fold cross-validation.
stats = sim_brain_data.predict(
algorithm="logistic", cv_dict={"type": "kfolds", "n_folds": 2}, plot=False
)
stats = sim_brain_data.predict(algorithm="logistic", cv_dict=cv_dict, plot=False)
assert isinstance(stats["weight_map"], Brain_Data)

# Ridge classificiation,
Expand All @@ -436,6 +432,18 @@ def test_predict(sim_brain_data):

# PCR
stats = sim_brain_data.predict(algorithm="pcr", cv_dict=None, plot=False)
stats = sim_brain_data.predict(algorithm="pcr", cv_dict=cv_dict, plot=False)

# Issue #368
stats = sim_brain_data.predict(
algorithm="lassopcr",
cv_dict={"type": "kfolds", "n_folds": 2},
plot=False,
**{"kernel": "linear"}
)
assert not np.allclose(
[1.0, 1.0], stats["weight_map"].similarity(stats["weight_map_xval"])
)


def test_predict_multi():
Expand Down

0 comments on commit 738f9c1

Please sign in to comment.