Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a capsule cnn on cifar-10 #9193

Merged
merged 10 commits into from
Feb 12, 2018
Merged

a capsule cnn on cifar-10 #9193

merged 10 commits into from
Feb 12, 2018

Conversation

bojone
Copy link
Contributor

@bojone bojone commented Jan 26, 2018

this is a fast capsule implement. and got the better performance than https://github.com/keras-team/keras/blob/master/examples/cifar10_cnn.py.

@fchollet
Copy link
Collaborator

Thanks for the PR. There seems to be various style issues. Please work on fixing the code style.

@bojone
Copy link
Contributor Author

bojone commented Jan 27, 2018

@fchollet please check it again. I am sorry that I am not very clear what style issues means.

@fchollet
Copy link
Collaborator

You can start by getting your code to pass the checks run by a PEP8 linter: https://pypi.python.org/pypi/pep8

@bojone
Copy link
Contributor Author

bojone commented Jan 27, 2018

@fchollet thank you for introducting me such a useful tool. I have change the code with the guide of pep8

Copy link
Contributor

@tboquet tboquet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition! Please address or discuss the few modifications suggested.

return ex / K.sum(ex, axis=axis, keepdims=True)


'''
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put this comment on top of your file with the description of the example.


class Capsule(Layer):
def __init__(self,
num_capsule, # the number of output capsules
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put the documentation in the docstring and follow Keras' documentation style. You can find the style in the docstring of any layer.

model = Model(inputs=input_image, outputs=output)

# we use a margin loss
model.compile(loss=lambda y_true, y_pred: y_true * K.relu(0.9 - y_pred)**2 +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a good idea to implement your loss outside of compile so that your code is readable.

self.num_capsule,
self.dim_capsule))
u_hat_vecs = K.permute_dimensions(u_hat_vecs, (0, 2, 1, 3))
# final u_hat_vecs.shape = [None, num_capsule,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are those comments needed?

# final u_hat_vecs.shape = [None, num_capsule,
# input_num_capsule, dim_capsule]

b = K.zeros_like(u_hat_vecs[:, :, :, 0]) # shape = [None, num_capsule,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put your comment on top of the operation please.

return (None, self.num_capsule, self.dim_capsule)


# some parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarify or remove.

metrics=['accuracy'])
model.summary()

# we can compare the perfermace with or without data augmentation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: perfermace -> performance

print('Using real-time data augmentation.')
# This will do preprocessing and realtime data augmentation:
datagen = ImageDataGenerator(
featurewise_center=False, # set input mean to 0 over the dataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything in the ImageDataGenerator is already documented, I'm not 100% sure but I would remove the inline comments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied verbatim from another example. If comments are removed, it'll be inconsistent with other examples, where this is pervasive throughout.

@tboquet
Copy link
Contributor

tboquet commented Feb 9, 2018

@bojone any news? Your example is a nice addition and you only need to change a few things 👍 .

@bojone
Copy link
Contributor Author

bojone commented Feb 10, 2018

@tboquet thanks your useful suggestions! I will fix it as soon as possible.

super(Capsule, self).build(input_shape)
input_dim_capsule = input_shape[-1]
if self.share_weights:
self.W = self.add_weight(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call this kernel for consistency with other Keras layers

self.dim_capsule = dim_capsule
self.routings = routings
self.share_weights = share_weights
if activation == 'default':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why default? What does that mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default means we use squashing function as activation. I rename it now.

if activation == 'default':
self.activation = squash
else:
self.activation = Activation(activation)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be activations.get(activation) where activations is from keras import activations

self.activation = Activation(activation)

def build(self, input_shape):
super(Capsule, self).build(input_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No point in calling super here


# A common Conv2D model
input_image = Input(shape=(None, None, 3))
cnn = Conv2D(64, (3, 3), activation='relu')(input_image)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cnn sounds like a model instance, but it's a tensor. Call it x

trainable=True)

def call(self, u_vecs):
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Format your docstrings like other docstrings in the codebase


class Capsule(Layer):
"""
A Capsule Implement with Pure Keras
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Format your docstrings like other docstrings in the codebase

@bojone
Copy link
Contributor Author

bojone commented Feb 11, 2018

@fchollet my latest submit has a "SSLError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:661)" error. what is the reason?


Without Data Augmentation:
It gets to 75% validation accuracy in 10 epochs,
and 79% after 15 epochs, and overfitting after 20 epcohs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

epochs (typo)



# a squashing function. but it has litte difference from the Hinton's paper.
# it seems that this form of squashing performs better.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain what the difference is and what motivates it

return scale * x


# define our own softmax function instead of K.softmax
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why though?



# define our own softmax function instead of K.softmax
# because K.softmax can not specify axis.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's something we should fix in the Keras backend.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't it implemented in #8841?

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@fchollet fchollet merged commit f6958ab into keras-team:master Feb 12, 2018
ahundt added a commit to ahundt/keras that referenced this pull request Feb 16, 2018
* 'master' of github.com:fchollet/keras: (57 commits)
  Minor README edit
  Speed up Travis tests (keras-team#9386)
  fix typo (keras-team#9391)
  Fix style issue in docstring
  Prepare 2.1.4 release.
  Fix activity regularizer + model composition test
  Corrected copyright years (keras-team#9375)
  Change default interpolation from nearest to bilinear. (keras-team#8849)
  a capsule cnn on cifar-10 (keras-team#9193)
  Enable us to use sklearn to do cv for functional api (keras-team#9320)
  Add support for stateful metrics. (keras-team#9253)
  The type of list keys was float (keras-team#9324)
  Fix mnist sklearn wrapper example (keras-team#9317)
  keras-team#9287 Fix most of the file-handle resource leaks. (keras-team#9309)
  Pass current learning rate to schedule() in LearningRateScheduler (keras-team#8865)
  Simplify with from six.moves import input (keras-team#9216)
  fixed RemoteMonitor: Json to handle np.float32 and np.int32 types (keras-team#9261)
  Update tweet length from 140 to 280 in docs
  Add `depthconv_conv2d` tests (keras-team#9225)
  Remove `force` option in progbar
  ...
lenjoy pushed a commit to lenjoy/keras that referenced this pull request Feb 22, 2018
* a capsule cnn on cifar-10

* Update cifar10_cnn_capsule.py

* update the style

* Update cifar10_cnn_capsule.py

* Update cifar10_cnn_capsule.py

* Update cifar10_cnn_capsule.py

* Update cifar10_cnn_capsule.py

* pass pep8 verify

* Update cifar10_cnn_capsule.py

* Update cifar10_cnn_capsule.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants