-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
a capsule cnn on cifar-10 #9193
Conversation
Thanks for the PR. There seems to be various style issues. Please work on fixing the code style. |
@fchollet please check it again. I am sorry that I am not very clear what style issues means. |
You can start by getting your code to pass the checks run by a PEP8 linter: https://pypi.python.org/pypi/pep8 |
@fchollet thank you for introducting me such a useful tool. I have change the code with the guide of pep8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice addition! Please address or discuss the few modifications suggested.
examples/cifar10_cnn_capsule.py
Outdated
return ex / K.sum(ex, axis=axis, keepdims=True) | ||
|
||
|
||
''' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please put this comment on top of your file with the description of the example.
examples/cifar10_cnn_capsule.py
Outdated
|
||
class Capsule(Layer): | ||
def __init__(self, | ||
num_capsule, # the number of output capsules |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please put the documentation in the docstring and follow Keras' documentation style. You can find the style in the docstring of any layer.
examples/cifar10_cnn_capsule.py
Outdated
model = Model(inputs=input_image, outputs=output) | ||
|
||
# we use a margin loss | ||
model.compile(loss=lambda y_true, y_pred: y_true * K.relu(0.9 - y_pred)**2 + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a good idea to implement your loss outside of compile so that your code is readable.
examples/cifar10_cnn_capsule.py
Outdated
self.num_capsule, | ||
self.dim_capsule)) | ||
u_hat_vecs = K.permute_dimensions(u_hat_vecs, (0, 2, 1, 3)) | ||
# final u_hat_vecs.shape = [None, num_capsule, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are those comments needed?
examples/cifar10_cnn_capsule.py
Outdated
# final u_hat_vecs.shape = [None, num_capsule, | ||
# input_num_capsule, dim_capsule] | ||
|
||
b = K.zeros_like(u_hat_vecs[:, :, :, 0]) # shape = [None, num_capsule, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put your comment on top of the operation please.
examples/cifar10_cnn_capsule.py
Outdated
return (None, self.num_capsule, self.dim_capsule) | ||
|
||
|
||
# some parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarify or remove.
examples/cifar10_cnn_capsule.py
Outdated
metrics=['accuracy']) | ||
model.summary() | ||
|
||
# we can compare the perfermace with or without data augmentation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: perfermace -> performance
print('Using real-time data augmentation.') | ||
# This will do preprocessing and realtime data augmentation: | ||
datagen = ImageDataGenerator( | ||
featurewise_center=False, # set input mean to 0 over the dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything in the ImageDataGenerator
is already documented, I'm not 100% sure but I would remove the inline comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is copied verbatim from another example. If comments are removed, it'll be inconsistent with other examples, where this is pervasive throughout.
@bojone any news? Your example is a nice addition and you only need to change a few things 👍 . |
@tboquet thanks your useful suggestions! I will fix it as soon as possible. |
examples/cifar10_cnn_capsule.py
Outdated
super(Capsule, self).build(input_shape) | ||
input_dim_capsule = input_shape[-1] | ||
if self.share_weights: | ||
self.W = self.add_weight( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call this kernel
for consistency with other Keras layers
examples/cifar10_cnn_capsule.py
Outdated
self.dim_capsule = dim_capsule | ||
self.routings = routings | ||
self.share_weights = share_weights | ||
if activation == 'default': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why default? What does that mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default means we use squashing function as activation. I rename it now.
examples/cifar10_cnn_capsule.py
Outdated
if activation == 'default': | ||
self.activation = squash | ||
else: | ||
self.activation = Activation(activation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be activations.get(activation)
where activations
is from keras import activations
examples/cifar10_cnn_capsule.py
Outdated
self.activation = Activation(activation) | ||
|
||
def build(self, input_shape): | ||
super(Capsule, self).build(input_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No point in calling super here
examples/cifar10_cnn_capsule.py
Outdated
|
||
# A common Conv2D model | ||
input_image = Input(shape=(None, None, 3)) | ||
cnn = Conv2D(64, (3, 3), activation='relu')(input_image) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cnn
sounds like a model instance, but it's a tensor. Call it x
examples/cifar10_cnn_capsule.py
Outdated
trainable=True) | ||
|
||
def call(self, u_vecs): | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Format your docstrings like other docstrings in the codebase
examples/cifar10_cnn_capsule.py
Outdated
|
||
class Capsule(Layer): | ||
""" | ||
A Capsule Implement with Pure Keras |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Format your docstrings like other docstrings in the codebase
@fchollet my latest submit has a "SSLError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:661)" error. what is the reason? |
examples/cifar10_cnn_capsule.py
Outdated
|
||
Without Data Augmentation: | ||
It gets to 75% validation accuracy in 10 epochs, | ||
and 79% after 15 epochs, and overfitting after 20 epcohs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
epochs
(typo)
examples/cifar10_cnn_capsule.py
Outdated
|
||
|
||
# a squashing function. but it has litte difference from the Hinton's paper. | ||
# it seems that this form of squashing performs better. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explain what the difference is and what motivates it
return scale * x | ||
|
||
|
||
# define our own softmax function instead of K.softmax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why though?
|
||
|
||
# define our own softmax function instead of K.softmax | ||
# because K.softmax can not specify axis. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's something we should fix in the Keras backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wasn't it implemented in #8841?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
* 'master' of github.com:fchollet/keras: (57 commits) Minor README edit Speed up Travis tests (keras-team#9386) fix typo (keras-team#9391) Fix style issue in docstring Prepare 2.1.4 release. Fix activity regularizer + model composition test Corrected copyright years (keras-team#9375) Change default interpolation from nearest to bilinear. (keras-team#8849) a capsule cnn on cifar-10 (keras-team#9193) Enable us to use sklearn to do cv for functional api (keras-team#9320) Add support for stateful metrics. (keras-team#9253) The type of list keys was float (keras-team#9324) Fix mnist sklearn wrapper example (keras-team#9317) keras-team#9287 Fix most of the file-handle resource leaks. (keras-team#9309) Pass current learning rate to schedule() in LearningRateScheduler (keras-team#8865) Simplify with from six.moves import input (keras-team#9216) fixed RemoteMonitor: Json to handle np.float32 and np.int32 types (keras-team#9261) Update tweet length from 140 to 280 in docs Add `depthconv_conv2d` tests (keras-team#9225) Remove `force` option in progbar ...
* a capsule cnn on cifar-10 * Update cifar10_cnn_capsule.py * update the style * Update cifar10_cnn_capsule.py * Update cifar10_cnn_capsule.py * Update cifar10_cnn_capsule.py * Update cifar10_cnn_capsule.py * pass pep8 verify * Update cifar10_cnn_capsule.py * Update cifar10_cnn_capsule.py
this is a fast capsule implement. and got the better performance than https://github.com/keras-team/keras/blob/master/examples/cifar10_cnn.py.