Skip to content

Commit

Permalink
add label smoothing parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Ely-S committed Jun 10, 2020
1 parent 1159c1b commit cc79eb4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
35 changes: 17 additions & 18 deletions efficientdet/det_model_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,18 +204,6 @@ def focal_loss(logits, targets, alpha, gamma, normalizer):
return weighted_loss


def _classification_loss(cls_outputs,
cls_targets,
num_positives,
alpha=0.25,
gamma=2.0):
"""Computes classification loss."""
normalizer = num_positives
classification_loss = focal_loss(cls_outputs, cls_targets, alpha, gamma,
normalizer)
return classification_loss


def _box_loss(box_outputs, box_targets, num_positives, delta=0.1):
"""Computes box regression loss."""
# delta is typically around the mean value of regression target.
Expand Down Expand Up @@ -280,6 +268,15 @@ class and box losses from all levels.
# Onehot encoding for classification labels.
cls_targets_at_level = tf.one_hot(labels['cls_targets_%d' % level],
params['num_classes'])

if params['label_smoothing'] > 0:
# see https://arxiv.org/pdf/1512.00567.pdf p7 for a discussion of label_smoothing
assert 1 > params['label_smoothing'] > 0
smooth_positives = tf.cast(1.0 - params['label_smoothing'], tf.float32)
smooth_negatives = tf.cast(params['label_smoothing'] / params['num_classes'],
tf.float32)
cls_targets_at_level = cls_targets_at_level * smooth_positives + smooth_negatives

if params['data_format'] == 'channels_first':
bs, _, width, height, _ = cls_targets_at_level.get_shape().as_list()
cls_targets_at_level = tf.reshape(cls_targets_at_level,
Expand All @@ -289,12 +286,14 @@ class and box losses from all levels.
cls_targets_at_level = tf.reshape(cls_targets_at_level,
[bs, width, height, -1])
box_targets_at_level = labels['box_targets_%d' % level]
cls_loss = _classification_loss(
cls_outputs[level],
cls_targets_at_level,
num_positives_sum,
alpha=params['alpha'],
gamma=params['gamma'])

cls_loss = focal_loss(
cls_outputs[level],
cls_targets_at_level,
params['alpha'],
params['gamma'],
num_positives_sum)

if params['data_format'] == 'channels_first':
cls_loss = tf.reshape(cls_loss,
[bs, -1, width, height, params['num_classes']])
Expand Down
1 change: 1 addition & 0 deletions efficientdet/hparams_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def default_detection_configs():
# classification loss
h.alpha = 0.25
h.gamma = 1.5
h.label_smoothing = 0.0 # 0.1 is a good default
# localization loss
h.delta = 0.1
h.box_loss_weight = 50.0
Expand Down

0 comments on commit cc79eb4

Please sign in to comment.