Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Binary cross entropy with logit numerically stable for high logit values #2562

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

BeneSim
Copy link

@BeneSim BeneSim commented Oct 14, 2024

See #2561 for details. I added detailed commentary to the commits.

The documentation of binary_cross_entropy_with_logit says that it
expects the target to be of type usize which is wrong and yields an
error at runtime due to dtype mismatch in the multiplication step.
In the current implementation of binary_cross_entropy_with_logit the
loss will actually be NaN due to taking the log(0) which occurs for high
logits passing through a sigmoid and an affine transformation:

inp.affine(-1., 1.)?.log()?
^      ^              ^
|      |              |
1.0    |              |
       0.0            |
                      NaN

The proposed implementation is actually taken more or less directly from
pytorch
https://github.com/pytorch/pytorch/blob/41977a05314bbf537e1c5d6cf5916a368d1907d9/aten/src/ATen/native/Loss.cpp#L362
@BeneSim
Copy link
Author

BeneSim commented Oct 14, 2024

Just a quick update, I also encountered the case where the logit was so small that the sigmoid returned 0. So I guess we need a better way. Maybe the method tensorflow uses might make sense, they basically calculate the log_sigmoid by using softplus:

See log_sigmoid and softplus. This could be implemented like

let log_sigmoid_input = (inp.neg()?.exp()? + 1.)?.log()?.neg()?;

EDIT: This however will also quickyl overflow for small inp ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant