Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added class_weights for cross entropy loss to segmentation.py #1221

Merged

Conversation

nsutezo
Copy link
Contributor

@nsutezo nsutezo commented Apr 5, 2023

Adding ability to pass in class weights into cross entropy loss.

@github-actions github-actions bot added the trainers PyTorch Lightning trainers label Apr 5, 2023
calebrob6
calebrob6 previously approved these changes Apr 5, 2023
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also add this to the classification and detection trainers as well?

torchgeo/trainers/segmentation.py Show resolved Hide resolved
@adamjstewart adamjstewart added this to the 0.5.0 milestone Apr 5, 2023
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Only remaining question is if we want to do this for classification/detection trainers too or save that for a different PR?

@calebrob6
Copy link
Member

Different PR

torchgeo/trainers/segmentation.py Outdated Show resolved Hide resolved
@adamjstewart adamjstewart merged commit 8a2e9b4 into microsoft:main Apr 25, 2023
@adamjstewart
Copy link
Collaborator

Thanks for the PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants