Skip to content

Commit

Permalink
Merge pull request #4986 from shelhamer/sigce-ignore
Browse files Browse the repository at this point in the history
Sigmoid Cross-Entropy Loss: ignore selected targets by `ignore_label`
  • Loading branch information
shelhamer authored Nov 17, 2016
2 parents 4a158a8 + 3d62e3c commit 28c135c
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 17 deletions.
16 changes: 16 additions & 0 deletions include/caffe/layers/sigmoid_cross_entropy_loss_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ class SigmoidCrossEntropyLossLayer : public LossLayer<Dtype> {
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

/// Read the normalization mode parameter and compute the normalizer based
/// on the blob size. If normalization_mode is VALID, the count of valid
/// outputs will be read from valid_count, unless it is -1 in which case
/// all outputs are assumed to be valid.
virtual Dtype get_normalizer(
LossParameter_NormalizationMode normalization_mode, int valid_count);

/// The internal SigmoidLayer used to map predictions to probabilities.
shared_ptr<SigmoidLayer<Dtype> > sigmoid_layer_;
/// sigmoid_output stores the output of the SigmoidLayer.
Expand All @@ -105,6 +112,15 @@ class SigmoidCrossEntropyLossLayer : public LossLayer<Dtype> {
vector<Blob<Dtype>*> sigmoid_bottom_vec_;
/// top vector holder to call the underlying SigmoidLayer::Forward
vector<Blob<Dtype>*> sigmoid_top_vec_;

/// Whether to ignore instances with a certain label.
bool has_ignore_label_;
/// The label indicating that an instance should be ignored.
int ignore_label_;
/// How to normalize the loss.
LossParameter_NormalizationMode normalization_;
Dtype normalizer_;
int outer_num_, inner_num_;
};

} // namespace caffe
Expand Down
77 changes: 70 additions & 7 deletions src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <algorithm>
#include <vector>

#include "caffe/layers/sigmoid_cross_entropy_loss_layer.hpp"
Expand All @@ -14,35 +15,89 @@ void SigmoidCrossEntropyLossLayer<Dtype>::LayerSetUp(
sigmoid_top_vec_.clear();
sigmoid_top_vec_.push_back(sigmoid_output_.get());
sigmoid_layer_->SetUp(sigmoid_bottom_vec_, sigmoid_top_vec_);

has_ignore_label_ =
this->layer_param_.loss_param().has_ignore_label();
if (has_ignore_label_) {
ignore_label_ = this->layer_param_.loss_param().ignore_label();
}
if (this->layer_param_.loss_param().has_normalization()) {
normalization_ = this->layer_param_.loss_param().normalization();
} else if (this->layer_param_.loss_param().has_normalize()) {
normalization_ = this->layer_param_.loss_param().normalize() ?
LossParameter_NormalizationMode_VALID :
LossParameter_NormalizationMode_BATCH_SIZE;
} else {
normalization_ = LossParameter_NormalizationMode_BATCH_SIZE;
}
}

template <typename Dtype>
void SigmoidCrossEntropyLossLayer<Dtype>::Reshape(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
LossLayer<Dtype>::Reshape(bottom, top);
outer_num_ = bottom[0]->shape(0); // batch size
inner_num_ = bottom[0]->count(1); // instance size: |output| == |target|
CHECK_EQ(bottom[0]->count(), bottom[1]->count()) <<
"SIGMOID_CROSS_ENTROPY_LOSS layer inputs must have the same count.";
sigmoid_layer_->Reshape(sigmoid_bottom_vec_, sigmoid_top_vec_);
}

// TODO(shelhamer) loss normalization should be pulled up into LossLayer,
// instead of duplicated here and in SoftMaxWithLossLayer
template <typename Dtype>
Dtype SigmoidCrossEntropyLossLayer<Dtype>::get_normalizer(
LossParameter_NormalizationMode normalization_mode, int valid_count) {
Dtype normalizer;
switch (normalization_mode) {
case LossParameter_NormalizationMode_FULL:
normalizer = Dtype(outer_num_ * inner_num_);
break;
case LossParameter_NormalizationMode_VALID:
if (valid_count == -1) {
normalizer = Dtype(outer_num_ * inner_num_);
} else {
normalizer = Dtype(valid_count);
}
break;
case LossParameter_NormalizationMode_BATCH_SIZE:
normalizer = Dtype(outer_num_);
break;
case LossParameter_NormalizationMode_NONE:
normalizer = Dtype(1);
break;
default:
LOG(FATAL) << "Unknown normalization mode: "
<< LossParameter_NormalizationMode_Name(normalization_mode);
}
// Some users will have no labels for some examples in order to 'turn off' a
// particular loss in a multi-task setup. The max prevents NaNs in that case.
return std::max(Dtype(1.0), normalizer);
}

template <typename Dtype>
void SigmoidCrossEntropyLossLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
// The forward pass computes the sigmoid outputs.
sigmoid_bottom_vec_[0] = bottom[0];
sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_);
// Compute the loss (negative log likelihood)
const int count = bottom[0]->count();
const int num = bottom[0]->num();
// Stable version of loss computation from input data
const Dtype* input_data = bottom[0]->cpu_data();
const Dtype* target = bottom[1]->cpu_data();
int valid_count = 0;
Dtype loss = 0;
for (int i = 0; i < count; ++i) {
for (int i = 0; i < bottom[0]->count(); ++i) {
const int target_value = static_cast<int>(target[i]);
if (has_ignore_label_ && target_value == ignore_label_) {
continue;
}
loss -= input_data[i] * (target[i] - (input_data[i] >= 0)) -
log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0)));
++valid_count;
}
top[0]->mutable_cpu_data()[0] = loss / num;
normalizer_ = get_normalizer(normalization_, valid_count);
top[0]->mutable_cpu_data()[0] = loss / normalizer_;
}

template <typename Dtype>
Expand All @@ -56,14 +111,22 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Backward_cpu(
if (propagate_down[0]) {
// First, compute the diff
const int count = bottom[0]->count();
const int num = bottom[0]->num();
const Dtype* sigmoid_output_data = sigmoid_output_->cpu_data();
const Dtype* target = bottom[1]->cpu_data();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
caffe_sub(count, sigmoid_output_data, target, bottom_diff);
// Zero out gradient of ignored targets.
if (has_ignore_label_) {
for (int i = 0; i < count; ++i) {
const int target_value = static_cast<int>(target[i]);
if (target_value == ignore_label_) {
bottom_diff[i] = 0;
}
}
}
// Scale down gradient
const Dtype loss_weight = top[0]->cpu_diff()[0];
caffe_scal(count, loss_weight / num, bottom_diff);
Dtype loss_weight = top[0]->cpu_diff()[0] / normalizer_;
caffe_scal(count, loss_weight, bottom_diff);
}
}

Expand Down
56 changes: 47 additions & 9 deletions src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,38 @@

namespace caffe {


template <typename Dtype>
__global__ void SigmoidCrossEntropyLossForwardGPU(const int nthreads,
const Dtype* input_data, const Dtype* target, Dtype* loss) {
const Dtype* input_data, const Dtype* target, Dtype* loss,
const bool has_ignore_label_, const int ignore_label_,
Dtype* counts) {
CUDA_KERNEL_LOOP(i, nthreads) {
loss[i] = input_data[i] * (target[i] - (input_data[i] >= 0)) -
log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0)));
const int target_value = static_cast<int>(target[i]);
if (has_ignore_label_ && target_value == ignore_label_) {
loss[i] = 0;
counts[i] = 0;
} else {
loss[i] = input_data[i] * (target[i] - (input_data[i] >= 0)) -
log(1 + exp(input_data[i] - 2 * input_data[i] *
(input_data[i] >= 0)));
counts[i] = 1;
}
}
}

template <typename Dtype>
__global__ void SigmoidCrossEntropyLossIgnoreDiffGPU(const int count,
const int ignore_label, const Dtype* target, Dtype* diff) {
CUDA_KERNEL_LOOP(i, count) {
const int target_value = static_cast<int>(target[i]);
if (target_value == ignore_label) {
diff[i] = 0;
}
}
}


template <typename Dtype>
void SigmoidCrossEntropyLossLayer<Dtype>::Forward_gpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
Expand All @@ -22,20 +45,30 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Forward_gpu(
sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_);
// Compute the loss (negative log likelihood)
const int count = bottom[0]->count();
const int num = bottom[0]->num();
// Stable version of loss computation from input data
const Dtype* input_data = bottom[0]->gpu_data();
const Dtype* target = bottom[1]->gpu_data();
// Since this memory is not used for anything until it is overwritten
// on the backward pass, we use it here to avoid having to allocate new GPU
// memory to accumulate intermediate results in the kernel.
Dtype* loss_data = bottom[0]->mutable_gpu_diff();
Dtype* count_data = bottom[1]->mutable_gpu_diff();
Dtype valid_count;
// NOLINT_NEXT_LINE(whitespace/operators)
SigmoidCrossEntropyLossForwardGPU<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(count, input_data, target, loss_data);
CAFFE_CUDA_NUM_THREADS>>>(count, input_data, target, loss_data,
has_ignore_label_, ignore_label_, count_data);
// Only launch another CUDA kernel if we actually need the valid count.
if (normalization_ == LossParameter_NormalizationMode_VALID &&
has_ignore_label_) {
caffe_gpu_asum(count, count_data, &valid_count);
} else {
valid_count = count;
}
Dtype loss;
caffe_gpu_asum(count, loss_data, &loss);
top[0]->mutable_cpu_data()[0] = loss / num;
normalizer_ = get_normalizer(normalization_, valid_count);
top[0]->mutable_cpu_data()[0] = loss / normalizer_;
}

template <typename Dtype>
Expand All @@ -49,15 +82,20 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Backward_gpu(
if (propagate_down[0]) {
// First, compute the diff
const int count = bottom[0]->count();
const int num = bottom[0]->num();
const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data();
const Dtype* target = bottom[1]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
caffe_copy(count, sigmoid_output_data, bottom_diff);
caffe_gpu_axpy(count, Dtype(-1), target, bottom_diff);
// Zero out gradient of ignored targets.
if (has_ignore_label_) {
// NOLINT_NEXT_LINE(whitespace/operators)
SigmoidCrossEntropyLossIgnoreDiffGPU<Dtype><<<CAFFE_GET_BLOCKS(count),
CAFFE_CUDA_NUM_THREADS>>>(count, ignore_label_, target, bottom_diff);
}
// Scale down gradient
const Dtype loss_weight = top[0]->cpu_diff()[0];
caffe_gpu_scal(count, loss_weight / num, bottom_diff);
Dtype loss_weight = top[0]->cpu_diff()[0] / normalizer_;
caffe_gpu_scal(count, loss_weight, bottom_diff);
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ message LossParameter {
optional int32 ignore_label = 1;
// How to normalize the loss for loss layers that aggregate across batches,
// spatial dimensions, or other dimensions. Currently only implemented in
// SoftmaxWithLoss layer.
// SoftmaxWithLoss and SigmoidCrossEntropyLoss layers.
enum NormalizationMode {
// Divide by the number of examples in the batch times spatial dimensions.
// Outputs that receive the ignore label will NOT be ignored in computing
Expand All @@ -448,6 +448,8 @@ message LossParameter {
// Do not normalize the loss.
NONE = 3;
}
// For historical reasons, the default normalization for
// SigmoidCrossEntropyLoss is BATCH_SIZE and *not* VALID.
optional NormalizationMode normalization = 3 [default = VALID];
// Deprecated. Ignored if normalization is specified. If normalization
// is not specified, then setting this to false will be equivalent to
Expand Down
28 changes: 28 additions & 0 deletions src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,33 @@ TYPED_TEST(SigmoidCrossEntropyLossLayerTest, TestGradient) {
this->blob_top_vec_, 0);
}

TYPED_TEST(SigmoidCrossEntropyLossLayerTest, TestIgnoreGradient) {
typedef typename TypeParam::Dtype Dtype;
FillerParameter data_filler_param;
data_filler_param.set_std(1);
GaussianFiller<Dtype> data_filler(data_filler_param);
data_filler.Fill(this->blob_bottom_data_);
LayerParameter layer_param;
LossParameter* loss_param = layer_param.mutable_loss_param();
loss_param->set_ignore_label(-1);
Dtype* target = this->blob_bottom_targets_->mutable_cpu_data();
const int count = this->blob_bottom_targets_->count();
// Ignore half of targets, then check that diff of this half is zero,
// while the other half is nonzero.
caffe_set(count / 2, Dtype(-1), target);
SigmoidCrossEntropyLossLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
vector<bool> propagate_down(2);
propagate_down[0] = true;
propagate_down[1] = false;
layer.Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_);
const Dtype* diff = this->blob_bottom_data_->cpu_diff();
for (int i = 0; i < count / 2; ++i) {
EXPECT_FLOAT_EQ(diff[i], 0.);
EXPECT_NE(diff[i + count / 2], 0.);
}
}


} // namespace caffe

0 comments on commit 28c135c

Please sign in to comment.