diff --git a/hourglass_tensorflow/losses/__init__.py b/hourglass_tensorflow/losses/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hourglass_tensorflow/losses/sigmoid_cross_entropy.py b/hourglass_tensorflow/losses/sigmoid_cross_entropy.py new file mode 100644 index 0000000..cedbf07 --- /dev/null +++ b/hourglass_tensorflow/losses/sigmoid_cross_entropy.py @@ -0,0 +1,14 @@ +import tensorflow as tf +import keras.losses + + +class SigmoidCrossEntropyLoss(keras.losses.Loss): + def __init__(self, reduction=..., name=None, *args, **kwargs): + super().__init__(reduction, name) + + def call(self, y_true, y_pred): + return tf.nn.sigmoid_cross_entropy_with_logits( + logits=y_pred, + labels=y_true, + name="nn.sigmoid_cross_entropy_with_logits", + )