Skip to content

Commit

Permalink
fixed black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
nsutezo committed Apr 5, 2023
1 parent b6069c9 commit 83c3038
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def __init__(self, **kwargs: Any) -> None:
the backbone
in_channels: Number of channels in input image
num_classes: Number of semantic classes to predict
class_weights: Optional rescaling weight given to each class and used with 'ce' loss
class_weights: Optional rescaling weight given to each
class and used with 'ce' loss
loss: Name of the loss function, currently supports
'ce', 'jaccard' or 'focal' loss
ignore_index: Optional integer class index to ignore in the loss and metrics
Expand Down Expand Up @@ -123,11 +124,11 @@ def __init__(self, **kwargs: Any) -> None:
UserWarning,
)
self.ignore_index = kwargs["ignore_index"]

self.class_weights = kwargs.get("class_weights", None)
if not isinstance(self.class_weights, (list, type(None))):
raise ValueError("class_weights must be a List or None")
raise ValueError("class_weights must be a List or None")

self.config_task()

self.train_metrics = MetricCollection(
Expand Down

0 comments on commit 83c3038

Please sign in to comment.