Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specifying a FCN based on VGG16 #3540

Closed
psalvaggio opened this issue Aug 22, 2016 · 20 comments
Closed

Specifying a FCN based on VGG16 #3540

psalvaggio opened this issue Aug 22, 2016 · 20 comments

Comments

@psalvaggio
Copy link

Hi, I'm new to the area of deep learning, so please forgive if I get some terminology wrong.

I'm trying to train a Fully Convolution Network (FCN) to perform semantic segmentation. I am attempting to use the VGG16 network as the first part of my network.

def DefineVGG16(img_width, img_height):
  # build the VGG16 network
  model = Sequential()
  model.add(ZeroPadding2D((1, 1),
            batch_input_shape=(1, 3, img_width, img_height)))
  first_layer = model.layers[-1]

  # this is a placeholder tensor that will contain our generated images
  input_img = first_layer.input

  # build the rest of the network
  model.add(Convolution2D(64, 3, 3, activation='relu', name='conv1_1'))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(64, 3, 3, activation='relu', name='conv1_2'))
  model.add(MaxPooling2D((2, 2), strides=(2, 2)))

  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(128, 3, 3, activation='relu', name='conv2_1'))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(128, 3, 3, activation='relu', name='conv2_2'))
  model.add(MaxPooling2D((2, 2), strides=(2, 2)))

  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(256, 3, 3, activation='relu', name='conv3_1'))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(256, 3, 3, activation='relu', name='conv3_2'))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(256, 3, 3, activation='relu', name='conv3_3'))
  model.add(MaxPooling2D((2, 2), strides=(2, 2)))

  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(512, 3, 3, activation='relu', name='conv4_1'))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(512, 3, 3, activation='relu', name='conv4_2'))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(512, 3, 3, activation='relu', name='conv4_3'))
  model.add(MaxPooling2D((2, 2), strides=(2, 2)))

  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_1'))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_2'))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_3'))
  model.add(MaxPooling2D((2, 2), strides=(2, 2)))

  return model

My issue is that I want to perform analysis on much bigger images than the 224x224 size the network was trained on. I can tile down to something reasonable like 512x512, but 224x224 is too small. I'm a bit confused on how to specify the middle and back half of the network. After the last VGG16 MaxPooling2D layer, my output size is (1, 512, 16, 16). I am assuming I need to insert another layer of convolutions to get it down to (1, 512, 8, 8), so I can keep the number of parameters under control. At that point, I insert the fully-connected layers as

model.add(Flatten())
model.add(Dense(4096, activation='relu', name='fc6'))
model.add(Dropout(0.5))
model.add(Dense(4096, activation='relu', name='fc7'))
model.add(Dropout(0.5))

Here's where I'm lost. I know there is a Deconvolution2D, but there's no documentation on the Keras site on how to use it to build up the deconvolution side of the network. I think I need to reshape the output of the 'fc7' layer back to 3D and then use a combination of Deconvolution2D and Upsampling2D to mirror the front half of the network, but I don't know how that would look in the code.

Any help would be greatly appreciated.

@lukovkin
Copy link
Contributor

There's an example of usage of the Deconvolution2D in the one of examples - https://github.com/fchollet/keras/blob/master/examples/variational_autoencoder_deconv.py#L53-L55
The main difference of the implementation of the deconvolution is that you have to pass output shape explicitly as a tuple to the layer constructor (see usage in example for the correct order).
So you have to calculate output shape rows and cols manually. The formula for it could be taken from here, for example - http://deeplearning.net/software/theano_versions/dev/tutorial/conv_arithmetic.html#transposed-convolution-arithmetic or from this source - https://arxiv.org/pdf/1603.07285v1.pdf.
I'll put it here for the convenience:
o = s(i - 1) + a + k - 2_p_, a is from [0, s-1]
where:
o - output size,
i - input size,
k - kernel (filter) size,
s - stride (subsample in the terms of Keras),
p - padding size,
a - user-specified quantity used to distinguish between the s different possible output sizes.

If stride is 1, the formula simplifies to the: o = s(i - 1) + k - 2_p_.
a is not required in this case.

Please keep in mind that it could be problematic to calculate the shape by the hand if you deal with variable-sized inputs (which is feasible in the Fully Convolutional implementation).

I've tried to implement a simplified shape autoinference here - https://github.com/lukovkin/keras/blob/master/keras/layers/convolutional.py#L414-L436, it seems to be working for TF backend but with Theano I've ran into the specific issues and haven't resolved them.

@dolaameng
Copy link
Contributor

Thanks @lukovkin for the update on this long-debated topic. The discussions were scattered in many threads, which makes it a little hard to track. Additional questions:

  1. There is an UpSampling2D implementation as demostrated in Francois's blog post, which plays a similar role as the UnPooling in this paper. But their implementations are quite different - one based on Repeating and the other based on Switch Variables. What will be the difference of these two in practice, e.g., Unpooling may potentially enable deeper deconvolutional network without losing too much spatial information?
  2. There are discussions on using Convolution layer to "simulate" the Deconvolution layer, although the later is supposed to be the transpose. There are also implementations of AtrousConvolution2D. Again what are their main differences?

Appreciate your opinions because I didn't find clear answers by reading the related discussions, e.g., #3122, #2822, #2087, #378 and etc.

@psalvaggio
Copy link
Author

psalvaggio commented Aug 23, 2016

Thanks for the great responses and the explanations and examples for Deconvolution2D! I was going off of the same paper that @dolaameng mentioned. I saw the difference in the upsampling operation and I would also be interested in the effects of the differences, although I am hoping that the system I am trying to build will not be too sensitive. Here's the solution I came up with:

# Start the network with the VGG16 network
model = DefineVGG16(512, 512)
LoadVGG16Weights('vgg16_weights.h5', model)

DefineVGG16() is in the original post. Then I added additional convolutional layers to get down to around 8x8 for larger images.

# Add additional convolutional layers to get the size down to at most 8x8
additional_layers = 0
while model.layers[-1].output_shape[2] > 8:
  additional_layers += 1
  layer = 5 + additional_layers
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(512, 3, 3, activation='relu',
                          name='conv%d_1' % layer))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(512, 3, 3, activation='relu',
                          name='conv%d_2' % layer))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(512, 3, 3, activation='relu',
                          name='conv%d_3' % layer))
  model.add(MaxPooling2D((2, 2), strides=(2, 2)))

Then I added the fully connected layers in the middle. This was my biggest original issue, I was using Dense instead of Convolution2D. Probably should have raised an alarm when I was using a non-convolutional layer in a "fully-convolutional" network.

layer_size = model.layers[-1].output_shape[2]

# Add the fully convolutional layers
model.add(Convolution2D(4096, layer_size, layer_size, activation='relu',
                        name='fc6'))
model.add(Dropout(0.5))
model.add(Convolution2D(4096, 1, 1, activation='relu', name='fc7'))
model.add(Dropout(0.5))
model.add(Deconvolution2D(512, layer_size, layer_size,
                          (1, 512, layer_size, layer_size),
                          subsample=(2, 2), name='deconv-fc6'))

Then I started building up the deconvolutional side of the network. I started by adding deconvolutional layers to mirror any additional convolutional layers that I had added to handle larger images.

# Add deconvolutional layers for any additional convolutional layers
# that we added to get down to 8x8
for i in range(additional_layers, 0, -1):
  layer = 5 + i
  output_size = 2 * model.layers[-1].output_shape[2]
  model.add(UpSampling2D((2, 2)))
  model.add(Deconvolution2D(512, 3, 3,
                            (1, 512, output_size, output_size),
                            name='deconv%d_1' % layer,
                            border_mode='same'))
  model.add(Deconvolution2D(512, 3, 3,
                            (1, 512, output_size, output_size),
                            name='deconv%d_2' % layer,
                            border_mode='same'))
  model.add(Deconvolution2D(512, 3, 3,
                            (1, 512, output_size, output_size),
                            name='deconv%d_3' % layer,
                            border_mode='same'))

Then, I created the mirrored version of the VGG16 network. In my case the formula for output size that @lukovkin provided simplified down to 2*input_size, so I just referenced the output size of the UpSampling2D layer.

model.add(UpSampling2D((2, 2)))
output_size = model.layers[-1].output_shape[2]
model.add(Deconvolution2D(512, 3, 3, (1, 512, output_size, output_size),
                          border_mode='same', name='deconv5_1'))
model.add(Deconvolution2D(512, 3, 3, (1, 512, output_size, output_size),
                          border_mode='same', name='deconv5_2'))
model.add(Deconvolution2D(512, 3, 3, (1, 512, output_size, output_size),
                          border_mode='same', name='deconv5_3'))

model.add(UpSampling2D((2, 2)))
output_size = model.layers[-1].output_shape[2]
model.add(Deconvolution2D(512, 3, 3, (1, 512, output_size, output_size),
                          border_mode='same', name='deconv4_1'))
model.add(Deconvolution2D(512, 3, 3, (1, 512, output_size, output_size),
                          border_mode='same', name='deconv4_2'))
model.add(Deconvolution2D(256, 3, 3, (1, 256, output_size, output_size),
                          border_mode='same', name='deconv4_3'))

model.add(UpSampling2D((2, 2)))
output_size = model.layers[-1].output_shape[2]
model.add(Deconvolution2D(256, 3, 3, (1, 256, output_size, output_size),
                          border_mode='same', name='deconv3_1'))
model.add(Deconvolution2D(256, 3, 3, (1, 256, output_size, output_size),
                          border_mode='same', name='deconv3_2'))
model.add(Deconvolution2D(128, 3, 3, (1, 128, output_size, output_size),
                          border_mode='same', name='deconv3_3'))

model.add(UpSampling2D((2, 2)))
output_size = model.layers[-1].output_shape[2]
model.add(Deconvolution2D(128, 3, 3, (1, 128, output_size, output_size),
                          border_mode='same', name='deconv2_1'))
model.add(Deconvolution2D(64, 3, 3, (1, 64, output_size, output_size),
                          border_mode='same', name='deconv2_2'))

model.add(UpSampling2D((2, 2)))
output_size = model.layers[-1].output_shape[2]
model.add(Deconvolution2D(64, 3, 3, (1, 64, output_size, output_size),
                          border_mode='same', name='deconv1_1'))
model.add(Deconvolution2D(64, 3, 3, (1, 64, output_size, output_size),
                          border_mode='same', name='deconv1_2'))

Finally, I specified the output layer as

model.add(Convolution2D(num_classes, 1, 1, activation='relu', name='output'))

Where num_classes is the number of classes in the specific problem (2 in my case). For 224x224 images, the total number of parameters matches the 252M value that is given in the paper. Now on to training!

fchollet pushed a commit that referenced this issue Aug 24, 2016
* Fix exception message for Deconvolution2D

* Docs update for Deconvolution2D layer (#3540)

* Corrections to Deconvolution2D docs

* References formatted as Markdown links
* Blank lines added
@dolaameng
Copy link
Contributor

dolaameng commented Aug 25, 2016

Hi @psalvaggio thanks for sharing the code. Any progress on your experiment? I am specially interested in knowing the differences of using Convolution for Deconvolution (aka Transposed Convolution).

It seems that the current Deconvolution implementation in Keras is the backprop operation wrt inputs. But in terms of learning capability, do you think a convolution layer will be enough to give similar results, e.g. for semantic segmentation?

However, I think there will be certain differences because of the way that they share the weights - see the animations. Do you have any visualizations on the Deconvolution layers? Thanks!

@HarisIqbal88
Copy link

Hi @psalvaggio , How did you initialize your network? Can you share your full code?

@psalvaggio
Copy link
Author

@dolaameng Not much yet. I'm still working on the training code and according to that paper, I'm looking at around a week of training time. This is actually my first deep learning project, so I don't have a ton of intuition as to the effects of the implementation details. I was previously in the optical modeling field and deconvolution was pretty much at the center of my research, but deconvolution here doesn't seem to have much in common. My intuition says that since the shift variables from that paper aren't present, it will lead to some distortions in the boundaries of objects, but I don't have anything to back that up yet. I will definitely be producing visualizations once I have trained the network, so I can start to understand what these "deconvolutional" layers are doing.

@HarisIqbal88 I initialized the first part of the network with the VGG16 weights, as described in this repo: https://gist.github.com/baraldilorenzo/07d7802847aaad0a35d3, I haven't gotten the training running yet, so I don't have a good answer for the rest of the network.

@lukovkin
Copy link
Contributor

@dolaameng Not at all.

Some comments:

  1. Unfortunately I cannot give any practical advise on it. I mostly tried to avoid downsampling - and the following upsampling/unpooling (excluding autoencoders case). I think that there definitely be a difference, but what exactly?
  2. As far as I understand, Deconvolution (or transposed convolution as a better term) could be emulated by the direct convolution by adding zeros directly to the input, but it's not efficient.
    I general, I understand those 3 types in the following way (it would be great if someone would correct me or add something):
  • Convolution - we 'generalize' input data, trying to extract common features at different resolution levels. It also could be considered as a low-pass filter or some sort of smart averaging of the input.
  • Transposed convolution (deconvolution) - kind of reverse operation, when we want to somehow restore details from the 'generalized'/averaged data, effectively it's a backward pass of convolution put instead of forward pass. It could be considered as a high-pass filter or some sort of differencing operation.
  • 'Atrous' convolution (dilated convolution) - same as direct convolution, but we skip input values at rate that is defined as a parameter. If it's 1, we take each 1st value - each value as a direct convolution. If it's 2 - we take each 2nd value, if 4 - each 4th value, etc. I see it as just a 'cheap' way to get large receptive window with a conservative amount of weights. I could be wrong, but it seems to me that Dilated convolution is good in the Fully Convolutional case, but in the case of the large amount of small input samples and Fully connected blocks it will not be so applicable.

@dolaameng
Copy link
Contributor

Thank you @psalvaggio for your update - look forward to your results.
Thank you @lukovkin for the helpful comments.

I was trying to track back how the terms "deconvolution" and "unpooling" were used. If I am not mistaken, they were actually used as "convolutional sparse encoding" in papers 1 and 2 as a way to learn filters and infer feature maps - quite similar to the idea of a convolutional version of autoencoder. And then people started to use it more in a way of inverse-mapping from a feature map to the original image space (projection vs reconstruction) in 3, 4, 5 and 6.

Since then it seemed to be agreeable that the implementation of deconvolution can be the 'transpose' of convolution, which is the gradient w.r.t. inputs now in theano, tensorflow and thus keras, even though in Matthew Zeiler's original paper 3 it was suggested to be implemented as a horizontal and vertical flipping of filters from convolutional layers - to inverse-map a single activation once a time. So based on this, it seems that using a deconvolution as a transposed convolution is definitely recommended to map back the middle-layer activations to original image spaces, e.g., for visualization of activations, mapping from class labels to pixels for semantic segmentation. If you only care about "learning a representation" of an image, like in the context of learning a convolutonal auto-encoders, it's reasonable to just use convolution layers for deconvolution.

As for Unpooling, again Matthew's paper 2 (in section 4.6) discussed that an implementation based on switch variable is preferable than "repeating" in reconstructing images. And it seems that tensorflow and Caffe have already had implementations based on that. But I am not aware of any similar implementations in Keras yet.

At the end, I found the torch documentation gives clear and complete definitions of all these operators. I hope all these discussions here would help people who were confused like me to find these resources more easily.

@lukovkin
Copy link
Contributor

@dolaameng Jeez, that was a pretty extensive review, thank you very much! I'll take time to browse through the links and may be come back with something later.

@HarisIqbal88
Copy link

Hi @psalvaggio , were you successful in training FCN?
I was reading the caffe code for FCN. I could not understand why they used lr_mult = 0 in their deconvolutional layers. Also, how did they initialized this layer. Do you have any idea?

@psalvaggio
Copy link
Author

@HarisIqbal88 I got pulled away from this project almost immediately after this thread. I will be getting back to it shortly, however. I'm not sure which caffe model you are referencing, but is it possible that they preloaded the deconv layer to perform bilinear upsampling and then had some regular conv layers after that that were learned?

@HarisIqbal88
Copy link

@psalvaggio I am now writing FCN in Keras but got into an interesting problem. I used the Deconvolutional layer as you used. However, The input image size is not fixed in FCN(same is the case for my dataset and I cannot reshape them into same size). This carries the 'None' argument from Input() to Deconvolution2D layer which does not accept it. Any idea about how to implement FCN without fixing input image size?

About the earlier discussion, I converged to the same conclusion.

@psalvaggio
Copy link
Author

@HarisIqbal88 Right, the network I proposed here does not work for variable size input. One of the papers I was looking at (https://arxiv.org/abs/1606.02585) does make a "no-downsampling FCN" which can work for variable size input. In Keras, you do have to specify the size of the input image, but the number of parameters for that network is independent of the input size, so there would be no retraining, just some manipulation on the size of the input layer.

@ahundt
Copy link
Contributor

ahundt commented Dec 18, 2016

FYI there is an implementation here: https://github.com/guojiyao/deep_fcn

@mzaradzki
Copy link

mzaradzki commented Apr 4, 2017

Thanks @lukovkin for your comment, it was very useful to me in debugging my implementation

I found out implementing FCN UpSampling bit was quite fun in the end (a bit frustrating in the course of it) so I posted a Medium to explain it and show the impact of various settings :

https://medium.com/@m.zaradzki/image-segmentation-with-neural-net-d5094d571b1e

Here is the corresponding repo for FCN32s, FCN16s and FCN8s (handling variable image size) :
https://github.com/mzaradzki/neuralnets/tree/master/vgg_segmentation_keras

Hope this helps !

PS : Related to this topic I just found this link with DeepMask implementation in Keras :
http://www.gitxiv.com/posts/pmb5ESpGRyNcewLmk/learning-to-segment-object-candidates

@ahundt
Copy link
Contributor

ahundt commented Apr 4, 2017

DensNetFCN is now available in keras-contrib.

@ahundt
Copy link
Contributor

ahundt commented Apr 19, 2017

These items & repositories are also relevant to FCN:

@stale stale bot added the stale label Jul 18, 2017
@stale
Copy link

stale bot commented Jul 18, 2017

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

@milk-bottle-liyu
Copy link

I got a problem. Why the learning rate of deconvolutional layer is set to be 0 ?

@stale stale bot removed the stale label Aug 6, 2017
@stale
Copy link

stale bot commented Nov 4, 2017

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

@stale stale bot added the stale label Nov 4, 2017
@stale stale bot closed this as completed Dec 4, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants