-
Notifications
You must be signed in to change notification settings - Fork 136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix RISE algorithm for explain function #1263
Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## develop #1263 +/- ##
===========================================
+ Coverage 80.54% 80.61% +0.06%
===========================================
Files 271 271
Lines 30438 30382 -56
Branches 5930 5909 -21
===========================================
- Hits 24517 24491 -26
+ Misses 4532 4507 -25
+ Partials 1389 1384 -5
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. I left some minor comments.
normalized_saliency = np.empty_like(saliency) | ||
for idx, sal in enumerate(saliency): | ||
normalized_saliency[idx, ...] = (sal - np.min(sal)) / (np.max(sal) - np.min(sal)) | ||
return normalized_saliency |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
normalized_saliency = np.empty_like(saliency) | |
for idx, sal in enumerate(saliency): | |
normalized_saliency[idx, ...] = (sal - np.min(sal)) / (np.max(sal) - np.min(sal)) | |
return normalized_saliency | |
return (saliency - np.min(saliency)) / (np.max(saliency) - np.min(saliency)) |
I think this could be replaced like this through vectorized operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you for the feedback, but the original code is to normalize each saliency map per class_idx, while your suggestion is to normalize each saliency map from global max and min. So I expect the results will be different.
logit = pred.get(self.LOGIT_KEY) | ||
if logit is None: | ||
raise DatumaroError(f'"{self.LOGIT_KEY}" key should exist in the model prediction.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logit = pred.get(self.LOGIT_KEY) | |
if logit is None: | |
raise DatumaroError(f'"{self.LOGIT_KEY}" key should exist in the model prediction.') | |
logit = pred.get(self.LOGIT_KEY, []) | |
if not logit: | |
raise DatumaroError(f'"{self.LOGIT_KEY}" key should exist in the model prediction.') |
It seems better to use default values for safer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you. let me update this later.
feature_vector = pred.get(self.FEAT_KEY) | ||
if feature_vector is None: | ||
raise DatumaroError(f'"{self.FEAT_KEY}" key should exist in the model prediction.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feature_vector = pred.get(self.FEAT_KEY) | |
if feature_vector is None: | |
raise DatumaroError(f'"{self.FEAT_KEY}" key should exist in the model prediction.') | |
feature_vector = pred.get(self.FEAT_KEY, []) | |
if not feature_vector: | |
raise DatumaroError(f'"{self.FEAT_KEY}" key should exist in the model prediction.') |
Summary
How to test
Checklist
License
Feel free to contact the maintainers if that's a concern.