Skip to content

Commit

Permalink
Fast and Faster R-CNN
Browse files Browse the repository at this point in the history
This commit is a port from the following [fork](https://github.com/rbgirshick/caffe-fast-rcnn/tree/0dcd397b29507b8314e252e850518c5695efbb83)

It adds :
 - smooth l1 loss layer
 - roi pooling layer
 - dropout scaling at test time (needed for MSRA-trained ZF network)

LICENSE :
Faster R-CNN

The MIT License (MIT)

Copyright (c) 2015 Microsoft Corporation

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
  • Loading branch information
Austriker committed May 19, 2016
1 parent bb0c1a5 commit 234fcbe
Show file tree
Hide file tree
Showing 15 changed files with 879 additions and 12 deletions.
2 changes: 2 additions & 0 deletions include/caffe/layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,8 @@ class Layer {
param_propagate_down_[param_id] = value;
}

inline Phase phase() { return phase_; }


protected:
/** The protobuf that stores the layer parameters */
Expand Down
1 change: 1 addition & 0 deletions include/caffe/layers/dropout_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class DropoutLayer : public NeuronLayer<Dtype> {
/// the scale for undropped inputs at train time @f$ 1 / (1 - p) @f$
Dtype scale_;
unsigned int uint_thres_;
bool scale_train_;
};

} // namespace caffe
Expand Down
58 changes: 58 additions & 0 deletions include/caffe/layers/roi_pooling_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#ifndef CAFFE_ROI_POOLING_LAYER_HPP_
#define CAFFE_ROI_POOLING_LAYER_HPP_

#include <vector>

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"

namespace caffe {

/**
* @brief ROIPoolingLayer - Region of Interest Pooling Layer.
*
* Fast R-CNN
* Written by Ross Girshick
*/

template <typename Dtype>
class ROIPoolingLayer : public Layer<Dtype> {
public:
explicit ROIPoolingLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual inline const char* type() const { return "ROIPooling"; }

virtual inline int MinBottomBlobs() const { return 2; }
virtual inline int MaxBottomBlobs() const { return 2; }
virtual inline int MinTopBlobs() const { return 1; }
virtual inline int MaxTopBlobs() const { return 1; }

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

int channels_;
int height_;
int width_;
int pooled_height_;
int pooled_width_;
Dtype spatial_scale_;
Blob<int> max_idx_;
};

} // namespace caffe

#endif // CAFFE_ROI_POOLING_LAYER_HPP_
65 changes: 65 additions & 0 deletions include/caffe/layers/smooth_l1_loss_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#ifndef CAFFE_SMOOTH_L1_LOSS_LAYER_HPP_
#define CAFFE_SMOOTH_L1_LOSS_LAYER_HPP_

#include <vector>

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"

#include "caffe/layers/loss_layer.hpp"

namespace caffe {

/**
* @brief SmoothL1LossLayer
*
* Fast R-CNN
* Written by Ross Girshick
*/
template <typename Dtype>
class SmoothL1LossLayer : public LossLayer<Dtype> {
public:
explicit SmoothL1LossLayer(const LayerParameter& param)
: LossLayer<Dtype>(param), diff_() {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual inline const char* type() const { return "SmoothL1Loss"; }

virtual inline int ExactNumBottomBlobs() const { return -1; }
virtual inline int MinBottomBlobs() const { return 2; }
virtual inline int MaxBottomBlobs() const { return 4; }

/**
* Unlike most loss layers, in the SmoothL1LossLayer we can backpropagate
* to both inputs -- override to return true and always allow force_backward.
*/
virtual inline bool AllowForceBackward(const int bottom_index) const {
return true;
}

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

Blob<Dtype> diff_;
Blob<Dtype> errors_;
Blob<Dtype> ones_;
bool has_weights_;
Dtype sigma2_;
};

} // namespace caffe

#endif // CAFFE_SMOOTH_L1_LOSS_LAYER_HPP_
2 changes: 1 addition & 1 deletion python/caffe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .pycaffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, RMSPropSolver, AdaDeltaSolver, AdamSolver
from ._caffe import set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver, layer_type_list
from ._caffe import set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver, layer_type_list, set_random_seed
from ._caffe import __version__
from .proto.caffe_pb2 import TRAIN, TEST
from .classifier import Classifier
Expand Down
1 change: 1 addition & 0 deletions python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ BOOST_PYTHON_MODULE(_caffe) {
bp::def("set_mode_cpu", &set_mode_cpu);
bp::def("set_mode_gpu", &set_mode_gpu);
bp::def("set_device", &Caffe::SetDevice);
bp::def("set_random_seed", &Caffe::set_random_seed);

bp::def("layer_type_list", &LayerRegistry<Dtype>::LayerTypeList);

Expand Down
27 changes: 23 additions & 4 deletions src/caffe/layers/dropout_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void DropoutLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
DCHECK(threshold_ < 1.);
scale_ = 1. / (1. - threshold_);
uint_thres_ = static_cast<unsigned int>(UINT_MAX * threshold_);
scale_train_ = this->layer_param_.dropout_param().scale_train();
}

template <typename Dtype>
Expand All @@ -37,11 +38,20 @@ void DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
if (this->phase_ == TRAIN) {
// Create random numbers
caffe_rng_bernoulli(count, 1. - threshold_, mask);
for (int i = 0; i < count; ++i) {
top_data[i] = bottom_data[i] * mask[i] * scale_;
if (scale_train_) {
for (int i = 0; i < count; ++i) {
top_data[i] = bottom_data[i] * mask[i] * scale_;
}
} else {
for (int i = 0; i < count; ++i) {
top_data[i] = bottom_data[i] * mask[i];
}
}
} else {
caffe_copy(bottom[0]->count(), bottom_data, top_data);
if (!scale_train_) {
caffe_scal<Dtype>(count, 1. / scale_, top_data);
}
}
}

Expand All @@ -55,11 +65,20 @@ void DropoutLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
if (this->phase_ == TRAIN) {
const unsigned int* mask = rand_vec_.cpu_data();
const int count = bottom[0]->count();
for (int i = 0; i < count; ++i) {
bottom_diff[i] = top_diff[i] * mask[i] * scale_;
if (scale_train_) {
for (int i = 0; i < count; ++i) {
bottom_diff[i] = top_diff[i] * mask[i] * scale_;
}
} else {
for (int i = 0; i < count; ++i) {
bottom_diff[i] = top_diff[i] * mask[i];
}
}
} else {
caffe_copy(top[0]->count(), top_diff, bottom_diff);
if (!scale_train_) {
caffe_scal<Dtype>(top[0]->count(), 1. / scale_, bottom_diff);
}
}
}
}
Expand Down
35 changes: 28 additions & 7 deletions src/caffe/layers/dropout_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,23 @@ void DropoutLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
static_cast<unsigned int*>(rand_vec_.mutable_gpu_data());
caffe_gpu_rng_uniform(count, mask);
// set thresholds
// NOLINT_NEXT_LINE(whitespace/operators)
DropoutForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, mask, uint_thres_, scale_, top_data);
if (scale_train_) {
// NOLINT_NEXT_LINE(whitespace/operators)
DropoutForward<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, mask, uint_thres_, scale_, top_data);
} else {
// NOLINT_NEXT_LINE(whitespace/operators)
DropoutForward<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, mask, uint_thres_, 1.f, top_data);
}
CUDA_POST_KERNEL_CHECK;
} else {
caffe_copy(count, bottom_data, top_data);
if (!scale_train_) {
caffe_gpu_scal<Dtype>(count, 1. / scale_, top_data);
}
}
}

Expand All @@ -54,13 +65,23 @@ void DropoutLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const unsigned int* mask =
static_cast<const unsigned int*>(rand_vec_.gpu_data());
const int count = bottom[0]->count();
// NOLINT_NEXT_LINE(whitespace/operators)
DropoutBackward<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, mask, uint_thres_, scale_, bottom_diff);
if (scale_train_) {
// NOLINT_NEXT_LINE(whitespace/operators)
DropoutBackward<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, mask, uint_thres_, scale_, bottom_diff);
} else {
// NOLINT_NEXT_LINE(whitespace/operators)
DropoutBackward<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, mask, uint_thres_, 1.f, bottom_diff);
}
CUDA_POST_KERNEL_CHECK;
} else {
caffe_copy(top[0]->count(), top_diff, bottom_diff);
if (!scale_train_) {
caffe_gpu_scal<Dtype>(top[0]->count(), 1. / scale_, bottom_diff);
}
}
}
}
Expand Down
Loading

0 comments on commit 234fcbe

Please sign in to comment.