How to add model metrics like accuracy, recall, precision etc. of segmentation task #1013
-
I'm new to PyTorch and TorchGeo was what made me inerested. I have a segmentation task and a trainer as; task = SemanticSegmentationTask(
segmentation_model="unet",
encoder_name="resnet18",
encoder_weights="imagenet",
in_channels=4,
num_classes=len(class_names),
loss="jaccard",
learning_rate=learning_rate,
learning_rate_schedule_patience=10,
)
trainer = pl.Trainer(
default_root_dir=experiment_dir,
min_epochs=3,
max_epochs=10,
accelerator='gpu',
devices=[gpu_id]
) After calling trainer.fit method, The process only prints the loss of train set to the screen after every batch step. How can I print loss, accuracy or f1-score of both train set and validation set after every epoch? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You'll want to read https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html to understand how to combine TorchMetrics and PyTorch Lightning for logging within TorchGeo. There are actually a few questions here:
Currently, only loss is displayed on the command line. To add other metrics to the progress bar, you'll need to add
At the moment,
The same metrics are recorded for train, val, and test, so you shouldn't need to change anything there.
The aforementioned Hopefully this wasn't information overload, and let me know if you have any questions that the PyTorch Lightning docs didn't answer. I'd be happy to accept PRs that make any of the above modifications for all tasks! |
Beta Was this translation helpful? Give feedback.
You'll want to read https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html to understand how to combine TorchMetrics and PyTorch Lightning for logging within TorchGeo.
There are actually a few questions here:
Currently, only loss is displayed on the command line. To add other metrics to the progress bar, you'll need to add
prog_bar=True
to allself.log(...)
andself.log_dict(...)
calls intorchgeo/trainers/segmentation.py
. However, you'll notice that the progress bar is rather short and you probably won't have enough room to see all metrics printed unless you have an ultrawide monitor. A better solution would be to use a logger like CSVLogger or T…