Skip to content

Commit

Permalink
Document default trainer metrics (microsoft#1914)
Browse files Browse the repository at this point in the history
* Expand metrics documentation

* typo

* Update documentation following robmarkcole suggestion

* Update summary to be consistent with other trainers

* Move to unordered lists, fix indentation, use note

* Update configure metrics for other trainers

* typo

* Update torchgeo/trainers/classification.py

Co-authored-by: Adam J. Stewart <[email protected]>

* Add detail on wanted values, reword macro note.

* Remove redundant paragraph

* Add acronyms, clarify regression metrics

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
jdilger and adamjstewart authored Mar 2, 2024
1 parent fe33e6c commit 7241b0f
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 6 deletions.
34 changes: 32 additions & 2 deletions torchgeo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,23 @@ def configure_losses(self) -> None:
raise ValueError(f"Loss type '{loss}' is not valid.")

def configure_metrics(self) -> None:
"""Initialize the performance metrics."""
"""Initialize the performance metrics.
* Multiclass Overall Accuracy (OA): Ratio of correctly classified pixels.
Uses 'micro' averaging. Higher values are better.
* Multiclass Average Accuracy (AA): Ratio of correctly classified classes.
Uses 'macro' averaging. Higher values are better.
* Multiclass Jaccard Index (IoU): Per-class overlap between predicted and
actual classes. Uses 'macro' averaging. Higher valuers are better.
* Multiclass F1 Score: The harmonic mean of precision and recall.
Uses 'micro' averaging. Higher values are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect
minority class accuracy.
* 'Macro' averaging gives equal weight to each class, and is useful for
balanced performance assessment across imbalanced classes.
"""
metrics = MetricCollection(
{
"OverallAccuracy": MulticlassAccuracy(
Expand Down Expand Up @@ -252,7 +268,21 @@ class MultiLabelClassificationTask(ClassificationTask):
"""Multi-label image classification."""

def configure_metrics(self) -> None:
"""Initialize the performance metrics."""
"""Initialize the performance metrics.
* Multiclass Overall Accuracy (OA): Ratio of correctly classified pixels.
Uses 'micro' averaging. Higher values are better.
* Multiclass Average Accuracy (AA): Ratio of correctly classified classes.
Uses 'macro' averaging. Higher values are better.
* Multiclass F1 Score: The harmonic mean of precision and recall.
Uses 'micro' averaging. Higher values are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not
reflect minority class accuracy.
* 'Macro' averaging gives equal weight to each class, and is useful for
balanced performance assessment across imbalanced classes.
"""
metrics = MetricCollection(
{
"OverallAccuracy": MultilabelAccuracy(
Expand Down
14 changes: 13 additions & 1 deletion torchgeo/trainers/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,19 @@ def configure_models(self) -> None:
raise ValueError(f"Model type '{model}' is not valid.")

def configure_metrics(self) -> None:
"""Initialize the performance metrics."""
"""Initialize the performance metrics.
* Mean Average Precision (mAP): Computes the Mean-Average-Precision (mAP) and
Mean-Average-Recall (mAR) for object detection. Prediction is based on the
intersection over union (IoU) between the predicted bounding boxes and the
ground truth bounding boxes. Uses 'macro' averaging. Higher values are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not
reflect minority class accuracy.
* 'Macro' averaging gives equal weight to each class, and is useful for
balanced performance assessment across imbalanced classes.
"""
metrics = MetricCollection([MeanAveragePrecision()])
self.val_metrics = metrics.clone(prefix="val_")
self.test_metrics = metrics.clone(prefix="test_")
Expand Down
16 changes: 15 additions & 1 deletion torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,21 @@ def configure_losses(self) -> None:
)

def configure_metrics(self) -> None:
"""Initialize the performance metrics."""
"""Initialize the performance metrics.
* Root Mean Squared Error (RMSE): The square root of the average of the squared
differences between the predicted and actual values. Lower values are better.
* Mean Squared Error (MSE): The average of the squared differences between the
predicted and actual values. Lower values are better.
* Mean Absolute Error (MAE): The average of the absolute differences between the
predicted and actual values. Lower values are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect
minority class accuracy.
* 'Macro' averaging gives equal weight to each class, and is useful for
balanced performance assessment across imbalanced classes.
"""
metrics = MetricCollection(
{
"RMSE": MeanSquaredError(squared=False),
Expand Down
16 changes: 14 additions & 2 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
freeze_backbone: bool = False,
freeze_decoder: bool = False,
) -> None:
"""Inititalize a new SemanticSegmentationTask instance.
"""Initialize a new SemanticSegmentationTask instance.
Args:
model: Name of the
Expand Down Expand Up @@ -122,7 +122,19 @@ def configure_losses(self) -> None:
)

def configure_metrics(self) -> None:
"""Initialize the performance metrics."""
"""Initialize the performance metrics.
* Multiclass Pixel Accuracy: Ratio of correctly classified pixels.
Uses 'micro' averaging. Higher values are better.
* Multiclass Jaccard Index (IoU): Per-pixel overlap between predicted and
actual segments. Uses 'macro' averaging. Higher values are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect
minority class accuracy.
* 'Macro' averaging, not used here, gives equal weight to each class, useful
for balanced performance assessment across imbalanced classes.
"""
num_classes: int = self.hparams["num_classes"]
ignore_index: Optional[int] = self.hparams["ignore_index"]
metrics = MetricCollection(
Expand Down

0 comments on commit 7241b0f

Please sign in to comment.