Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
freddyaboulton committed Sep 28, 2021
1 parent b6ad171 commit fa4c109
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
4 changes: 2 additions & 2 deletions evalml/model_understanding/_partial_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ def _partial_dependence_calculation(pipeline, grid, features, X):
else:
prediction_method = pipeline.predict_proba

X_eval = X.ww.copy()
for _, new_values in grid.iterrows():
X_eval = X.copy()
for i, variable in enumerate(features):
X_eval.loc[:, variable] = new_values[i]
X_eval.ww[variable] = pd.Series([new_values[i]] * X_eval.shape[0])

pred = prediction_method(X_eval)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1534,15 +1534,18 @@ def test_partial_dependence_categorical_nan(fraud_100):
)
pl.fit(X, y)

dep = partial_dependence(pl, X, features="provider", grid_resolution=5)
GRID_RESOLUTION = 5
dep = partial_dependence(
pl, X, features="provider", grid_resolution=GRID_RESOLUTION
)

assert dep.shape[0] == X["provider"].dropna().nunique()
assert not dep["feature_values"].isna().any()
assert not dep["partial_dependence"].isna().any()

dep2way = partial_dependence(
pl, X, features=("amount", "provider"), grid_resolution=5
pl, X, features=("amount", "provider"), grid_resolution=GRID_RESOLUTION
)
assert not dep2way.isna().any().any()
# Minus 1 in the columns because there is `class_label`
assert dep2way.shape == (5, X["provider"].dropna().nunique() + 1)
# Plus 1 in the columns because there is `class_label`
assert dep2way.shape == (GRID_RESOLUTION, X["provider"].dropna().nunique() + 1)

0 comments on commit fa4c109

Please sign in to comment.