diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 3396ff7b5f7..e6cd938e775 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -59,7 +59,13 @@ def config_task(self) -> None: if self.hyperparams["loss"] == "ce": ignore_value = -1000 if self.ignore_index is None else self.ignore_index - self.loss = nn.CrossEntropyLoss(ignore_index=ignore_value) + + class_weights = ( + torch.FloatTensor(self.class_weights) if self.class_weights else None + ) + self.loss = nn.CrossEntropyLoss( + ignore_index=ignore_value, weight=class_weights + ) elif self.hyperparams["loss"] == "jaccard": self.loss = smp.losses.JaccardLoss( mode="multiclass", classes=self.hyperparams["num_classes"] @@ -86,6 +92,8 @@ def __init__(self, **kwargs: Any) -> None: num_classes: Number of semantic classes to predict loss: Name of the loss function, currently supports 'ce', 'jaccard' or 'focal' loss + class_weights: Optional rescaling weight given to each + class and used with 'ce' loss ignore_index: Optional integer class index to ignore in the loss and metrics learning_rate: Learning rate for optimizer learning_rate_schedule_patience: Patience for learning rate scheduler @@ -100,6 +108,9 @@ def __init__(self, **kwargs: Any) -> None: The *segmentation_model* parameter was renamed to *model*, *encoder_name* renamed to *backbone*, and *encoder_weights* to *weights*. + + .. versionadded: 0.5 + The *class_weights* parameter. """ super().__init__() @@ -115,6 +126,8 @@ def __init__(self, **kwargs: Any) -> None: UserWarning, ) self.ignore_index = kwargs["ignore_index"] + self.class_weights = kwargs.get("class_weights", None) + self.config_task() self.train_metrics = MetricCollection(