TensorBoard images logging by datamodules plot() #1598
-
would be happy to get an explanation on the datamodule plot() that is called and returns a figure logged to TensorBoard. It's a very useful feature for debugging and reviewing the training. it seems that "occasionally" a figure is being logged. we would like to know how this works and when, and possibly to be able to control the logging. for example, can see the plot() is called in the validation step once:
not sure where how this happens in training. Questions:
thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 5 replies
-
also, for the logging, how is it possible to get the coordinates or location of a sample when it is being processed or plotted? |
Beta Was this translation helpful? Give feedback.
-
I would recommend extending GeoDataModule or NonGeoDataModule. That's what they're there for.
It actually isn't called during training, it's only called during validation. But when you run
This would be a cool feature to add, want to open a PR?
At the moment, this isn't possible, as the |
Beta Was this translation helpful? Give feedback.
-
sure. I'll open a PR and start logging the needs that come while working on a task. |
Beta Was this translation helpful? Give feedback.
-
Regarding the datamodule.plot() call during validation: Do you have to implement the plot function in custom DataModules? I haven't had this error in torchgeo before the update. |
Beta Was this translation helpful? Give feedback.
-
Of course, with the following code snippet I was able to reproduce it: import torch
import os
from torchgeo.models import ViTSmall16_Weights
from torchgeo.datasets import NonGeoDataset
from torchgeo.trainers import RegressionTask
from lightning.pytorch import Trainer
from torch.utils.data import DataLoader
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
default_root_dir = os.path.join("experiments")
num_workers = 2
max_epochs = 10
fast_dev_run = False
weights = ViTSmall16_Weights.LANDSAT_ETM_SR_MOCO
class CustomDataset(NonGeoDataset):
def __init__(self):
self.labels = [1]*300
def __len__(self):
return len(self.labels)
def __getitem__(self, index: int):
label = self.labels[index]
return {"image" : torch.rand(6,224,224), "label" : torch.Tensor(label)}
reg_task = RegressionTask(model="vit_small_patch16_224",
weights=weights,
in_channels=6,
num_outputs=1,
loss="mse",
lr=0.001,
patience=5)
trainer = Trainer(
accelerator=accelerator,
default_root_dir=default_root_dir,
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
min_epochs=1,
max_epochs=max_epochs,
)
trainer.fit(
model=reg_task,
train_dataloaders=DataLoader(CustomDataset(), batch_size=16),
val_dataloaders=DataLoader(CustomDataset(), batch_size=16)) |
Beta Was this translation helpful? Give feedback.
I would recommend extending GeoDataModule or NonGeoDataModule. That's what they're there for.
It actually isn't called during training, it's only called during validation. But when you run
trainer.fit
, it runs both the train and validation steps.This would be a cool feature to add, want to open a PR?