Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added class_weights for cross entropy loss to segmentation.py #1221

Merged
14 changes: 14 additions & 0 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def config_task(self) -> None:
if self.hyperparams["loss"] == "ce":
ignore_value = -1000 if self.ignore_index is None else self.ignore_index
self.loss = nn.CrossEntropyLoss(ignore_index=ignore_value)
calebrob6 marked this conversation as resolved.
Show resolved Hide resolved

class_weights = (
torch.FloatTensor(self.class_weights) if self.class_weights else None
)
self.loss = nn.CrossEntropyLoss(
ignore_index=ignore_value, weight=class_weights
)
elif self.hyperparams["loss"] == "jaccard":
self.loss = smp.losses.JaccardLoss(
mode="multiclass", classes=self.hyperparams["num_classes"]
Expand All @@ -84,6 +91,8 @@ def __init__(self, **kwargs: Any) -> None:
the backbone
in_channels: Number of channels in input image
num_classes: Number of semantic classes to predict
class_weights: Optional rescaling weight given to each
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
class and used with 'ce' loss
loss: Name of the loss function, currently supports
'ce', 'jaccard' or 'focal' loss
ignore_index: Optional integer class index to ignore in the loss and metrics
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -115,6 +124,11 @@ def __init__(self, **kwargs: Any) -> None:
UserWarning,
)
self.ignore_index = kwargs["ignore_index"]

self.class_weights = kwargs.get("class_weights", None)
if not isinstance(self.class_weights, (list, type(None))):
raise ValueError("class_weights must be a List or None")
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

self.config_task()

self.train_metrics = MetricCollection(
Expand Down