Skip to content

Commit

Permalink
Add label_from_zero_one argument to LogisticLoss (apache#9265)
Browse files Browse the repository at this point in the history
* add use_zero_one argument to logisticloss

* add comment

* revise name

* update

* update
  • Loading branch information
sxjscience authored and zheng-da committed Jun 28, 2018
1 parent ed4ec89 commit da04371
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions python/mxnet/gluon/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,16 +619,19 @@ class LogisticLoss(Loss):
L = \sum_i \log(1 + \exp(- {pred}_i \cdot {label}_i))
where `pred` is the classifier prediction and `label` is the target tensor
containing values -1 or 1. `pred` and `label` can have arbitrary shape as
long as they have the same number of elements.
containing values -1 or 1 (0 or 1 if `label_format` is binary).
`pred` and `label` can have arbitrary shape as long as they have the same number of elements.
Parameters
----------
weight : float or None
Global scalar weight for loss.
batch_axis : int, default 0
The axis that represents mini-batch.
label_format : str, default 'signed'
Can be either 'signed' or 'binary'. If the label_format is 'signed', all label values should
be either -1 or 1. If the label_format is 'binary', all label values should be either
0 or 1.
Inputs:
- **pred**: prediction tensor with arbitrary shape.
Expand All @@ -643,11 +646,17 @@ class LogisticLoss(Loss):
- **loss**: loss tensor with shape (batch_size,). Dimenions other than
batch_axis are averaged out.
"""
def __init__(self, weight=None, batch_axis=0, **kwargs):
def __init__(self, weight=None, batch_axis=0, label_format='signed', **kwargs):
super(LogisticLoss, self).__init__(weight, batch_axis, **kwargs)
self._label_format = label_format
if self._label_format not in ["signed", "binary"]:
raise ValueError("label_format can only be signed or binary, recieved %s."
% label_format)

def hybrid_forward(self, F, pred, label, sample_weight=None):
label = _reshape_like(F, label, pred)
if self._label_format == 'binary':
label = 2 * label - 1 # Transform label to be either -1 or 1
loss = F.log(1.0 + F.exp(-pred * label))
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)
Expand Down

0 comments on commit da04371

Please sign in to comment.