-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Our own Partial Dependence Implementation #2834
Conversation
457a2e3
to
91b7a18
Compare
Codecov Report
@@ Coverage Diff @@
## main #2834 +/- ##
=======================================
- Coverage 99.8% 99.8% -0.0%
=======================================
Files 302 303 +1
Lines 28148 28226 +78
=======================================
+ Hits 28070 28145 +75
- Misses 78 81 +3
Continue to review full report at Codecov.
|
if not isinstance(feature_range, (np.ndarray, pd.Series)): | ||
feature_range = np.array(feature_range) | ||
if feature_range.ndim != 1: | ||
raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm ok if this isn't covered. It's impossible to trigger this as a user because custom_range
is not a public parameter. But I'd like to keep this check in case we refactor this in the future. Helped catch a couple of bugs during development.
@@ -653,6 +652,11 @@ def partial_dependence( | |||
is_datetime = [_is_feature_of_type(features, X, ww.logical_types.Datetime)] | |||
|
|||
if isinstance(features, (list, tuple)): | |||
if any(is_datetime) and len(features) > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There used to be two if isinstance(features, (list, tuple))
checks. Consolidating into one now.
1f024ea
to
b6ad171
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is epic 👏!
I left some nitpicky comments but nothing blocking. Great work @freddyaboulton!
Also, the speedups are a cherry on top :)
pl, X, features=("amount", "provider"), grid_resolution=5 | ||
) | ||
assert not dep2way.isna().any().any() | ||
# Minus 1 in the columns because there is `class_label` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, not minus?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!!
) | ||
assert not dep2way.isna().any().any() | ||
# Minus 1 in the columns because there is `class_label` | ||
assert dep2way.shape == (5, X["provider"].dropna().nunique() + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Omega nitpick, but I think it'd be a good idea to set grid_resolution as a variable and use it above / here, whereassert dep2way.shape == (grid_resolution_variable,...)
? Just so its more clear where this 5 value is coming from :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I completely agree!
arrays = [np.asarray(x) for x in arrays] | ||
shape = (len(x) for x in arrays) | ||
|
||
ix = np.indices(shape) | ||
ix = ix.reshape(len(arrays), -1).T | ||
|
||
out = pd.DataFrame() | ||
|
||
for n, arr in enumerate(arrays): | ||
out[n] = arrays[n][ix[:, n]] | ||
|
||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be the same as https://github.com/scikit-learn/scikit-learn/blob/844b4be24d20fc42cc13b957374c718956a0db39/sklearn/utils/extmath.py#L655 except we return a dataframe, and since it's a public sklearn method, we could just import it--whatcha think? Also totally down to take their impl 😂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great idea. I know as we first tredged through partial dependence that we borrowed a lot. Perhaps a bit more from some private methods than I would like, but it was necessary. If we can refactor to use their public methods, that's great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@angela97lin Great point. Originally I wanted to use their method but the problem is that numpy arrays cannot handle mixed-types very well. So if we want to have a grid of categoricals and datetimes, the conversion storing it in a numpy array won't really work.
There may be a way around it I'm not seeing (maybe this) but IMO that's a nice to have as opposed to a requirement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow @freddyaboulton , this is amazing. I am really impressed. I feel like you cleaned up the code substantially, improve performance and enhanced functionality. This is a great PR. I had a question about the handling of the times, but that isn't blocking.
pd.Series: Range of dates between percentiles. | ||
""" | ||
timestamps = np.array( | ||
[X_dt - pd.Timestamp("1970-01-01")] // np.timedelta64(1, "s") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know why, but I fixated on this. I think probably because I come from a natural science background...but is it worth us leaving the reference date and the quantum of time as variable? I don't think any of our common or current use cases extend to people doing time series modeling on like a chemical reaction timescale (~milli/microseconds). But I can definitely see pharma customers being interested in it.
Let me know what you think. I don't think we necessarily have to do the work here, but it might be nice to at least talk about it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic point @chukarsten ! I think what this is getting at is making our custom_range
internal parameter public. I think there can be value in letting users specify how the grid for their features is computed!
I will file a separate issue for tracking that.
prediction_method = pipeline.predict_proba | ||
|
||
for _, new_values in grid.iterrows(): | ||
X_eval = X.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to copy this each time? Does it make more sense to just rebuild the new dataframe with a concat or something at the end? If it's just as performant, then whatever, this makes sense and is clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great point. I think we can move it out the loop. Will test it out!
@chukarsten @angela97lin Thank you so much for the reviews! I didn't think this would get into the coming release. Kicked off perf tests out of paranoia to make sure none of the datasets error out on partial dependence. Will merge if those look good. |
fa4c109
to
0756eb7
Compare
Perf tests here and they look good to me! |
Pull Request Description
Fixes #2502
Fixes #2475
Same run-time as
main
for model understanding tests:main: 9m 59s
this branch: 9m 14s
Plots match between this branch and main
This branch
Main
This branch
Main
After creating the pull request: in order to pass the release_notes_updated check you will need to update the "Future Release" section of
docs/source/release_notes.rst
to include this pull request by adding :pr:123
.