-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SegGPT] Fix loss calculation #30421
Conversation
3b78bdb
to
5b001b8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this!
Have we verified the loss value with that of the original model?
Yeap, also added some tests to make sure this won't be an issue again 🙂. Regarding |
This should be done in a separate PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
What does this PR do?
This PR fixes #30419 and ensures that the loss is being correctly calculated.
While working on this PR I not only noticed that
SegGptLoss
was broken, but it was being incorrectly calculated (shame onSegGpt
contributor 😞). Proposed solution include:labels
toSegGptModel
to correctly perform theforward
pass when training inIn-Context Painting
styleSegGptLoss
forward
method and its docstrings accordingly so that the output is the same as the one obtained in the original implementation.Note
While running
test_modeling_seggpt
withis_training = True
I found thatgradient_checkpointing
is also not working due totype_token_semantic
parameter that is not used in the forward pass and is controlled by theembedding_type
the model'sforward
and by default we use thetype_token_instance
just like the original implementation. Hence, we could probably move theembedding_type
to config to allowgradient_checkpointing
or remove it entirely as in the original implementation is not clear what is the use case fortype_token_semantic
c.c. @amyeroberts