-
Notifications
You must be signed in to change notification settings - Fork 252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Included callable layer selector support for IntegratedGradients
.
#894
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #894 +/- ##
==========================================
+ Coverage 85.32% 85.34% +0.01%
==========================================
Files 74 74
Lines 8779 8801 +22
==========================================
+ Hits 7491 7511 +20
- Misses 1288 1290 +2
Flags with carried forward coverage won't be shown. Click here to find out more.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice one! Left a few comments for some improvements.
try: | ||
layer_num = model.layers.index(layer) | ||
layer_meta = model.layers.index(layer) | ||
except ValueError: | ||
logger.info("Layer not in the list of model.layers") | ||
layer_num = None | ||
layer_meta = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use this case to also disallow serialization as currently it leads to a bad outcome where the explainer serializes successfully but is actually unusable/cannot be deserialized. I think this involves modifying the save logic depending on the value of layer_meta
.
layer_num: Optional[int] = 0 | ||
else: | ||
self.layer = None | ||
layer_meta: Optional[Union[int, Callable[[tf.keras.Model], tf.keras.layers.Layer]]] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this is a good time to define layer_meta
better instead of using sentinel values like 0
and None
which aren't very descriptive?
Perhaps a string Enum could be used here with 2 states corresonding to these two states + a callable corresponding to the 3rd state? Then the config would save either a string (if one of the first 2 states is true) or the lambda callable otherwise to layer_meta
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice one, looks great!
This PR includes callable layer selector support for
IntegratedGradients
. Previously, the reloading functionality was using the index of the layer within themodel.layers
list. Such an approach fails when the model is constructed base on nested layers. With the new functionality, the user can pass a function which select the desired model.