diff --git a/examples/poyo/configs/base.yaml b/examples/poyo/configs/base.yaml index 163282b..8b634ef 100644 --- a/examples/poyo/configs/base.yaml +++ b/examples/poyo/configs/base.yaml @@ -6,6 +6,7 @@ log_dir: ./logs sequence_length: 1.0 # in seconds latent_step: 0.125 # in seconds readout_modality_name: cursor_velocity_2d +readout_metric_name: r2 epochs: 1000 eval_epochs: 1 # frequency for doing validation diff --git a/examples/poyo/train.py b/examples/poyo/train.py index 03bfcc1..3380ec4 100644 --- a/examples/poyo/train.py +++ b/examples/poyo/train.py @@ -1,9 +1,10 @@ import logging -from typing import Callable, Dict +from typing import Callable, Dict, List import copy import hydra import lightning as L +import pandas as pd import torch import torch.nn as nn from torch.utils.data import DataLoader @@ -20,7 +21,7 @@ from torch_brain.models.poyo import POYOTokenizer, poyo_mp from torch_brain.utils import callbacks as tbrain_callbacks from torch_brain.utils import seed_everything -from torch_brain.utils.stitcher import DecodingStitchEvaluator +from torch_brain.utils.stitcher import Stitcher from torch_brain.data import Dataset, collate from torch_brain.nn import compute_loss_or_metric from torch_brain.data.sampler import ( @@ -39,12 +40,14 @@ def __init__( cfg: DictConfig, model: nn.Module, modality_spec: ModalitySpec, + session_ids: List[str], ): super().__init__() self.cfg = cfg self.model = model self.modality_spec = modality_spec + self.stitchers = {k: Stitcher() for k in session_ids} self.save_hyperparameters(OmegaConf.to_container(cfg)) def configure_optimizers(self): @@ -123,18 +126,56 @@ def validation_step(self, batch, batch_idx): # forward pass output_values = self.model(**batch) - # add removed elements back to batch - batch["target_values"] = target_values - batch["absolute_start"] = absolute_starts - batch["session_id"] = session_ids - batch["output_subtask_index"] = output_subtask_index - batch["output_mask"] = output_mask + for i in range(len(output_values)): + mask = output_mask[i] + self.stitchers[session_ids[i]].update( + timestamps=batch["output_timestamps"][i][mask] + absolute_starts[i], + preds=output_values[i][mask], + target=target_values[i][mask], + ) + + def on_validation_epoch_end(self, prefix="val"): + metrics = {} + for sess_id, stitcher in self.stitchers.items(): + stitched_preds, stitched_target = stitcher.compute() + stitcher.reset() + metrics[sess_id] = compute_loss_or_metric( + loss_or_metric=self.cfg.readout_metric_name, + output_type=self.modality_spec.type, + output=stitched_preds, + target=stitched_target, + ) + + metrics[f"avg_{prefix}_metric"] = torch.tensor(list(metrics.values())).mean() + + # logging + self.log_dict(metrics) + metrics_df = pd.DataFrame( + [{"metric": k, "value": v.item()} for k, v in metrics.items()] + ) + if self.trainer.is_global_zero: + from rich import print as rprint + + rprint(metrics_df) - return output_values + for logger in self.trainer.loggers: + if isinstance(logger, L.pytorch.loggers.TensorBoardLogger): + logger.experiment.add_text( + f"{prefix}_metrics", metrics_df.to_markdown() + ) + if isinstance(logger, L.pytorch.loggers.WandbLogger): + import wandb + + logger.experiment.log( + {f"{prefix}_metrics": wandb.Table(dataframe=metrics_df)} + ) def test_step(self, batch, batch_idx): return self.validation_step(batch, batch_idx) + def on_test_epoch_end(self): + return self.on_validation_epoch_end(prefix="test") + class DataModule(L.LightningDataModule): def __init__(self, cfg: DictConfig, tokenizer: Callable[[Data], Dict]): @@ -311,15 +352,10 @@ def main(cfg: DictConfig): cfg=cfg, model=model, modality_spec=readout_spec, - ) - - stitch_evaluator = DecodingStitchEvaluator( session_ids=data_module.get_session_ids(), - modality_spec=readout_spec, ) callbacks = [ - stitch_evaluator, ModelSummary(max_depth=2), # Displays the number of parameters in the model. ModelCheckpoint( save_last=True, diff --git a/torch_brain/nn/loss.py b/torch_brain/nn/loss.py index 8a13881..1908681 100644 --- a/torch_brain/nn/loss.py +++ b/torch_brain/nn/loss.py @@ -1,3 +1,4 @@ +from typing import Optional import torch import torch.nn.functional as F from torchmetrics import R2Score @@ -10,7 +11,7 @@ def compute_loss_or_metric( output_type: DataType, output: torch.Tensor, target: torch.Tensor, - weights: torch.Tensor, + weights: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""Helper function to compute various losses or metrics for a given output type. diff --git a/torch_brain/utils/stitcher.py b/torch_brain/utils/stitcher.py index 2708bbb..263f9cd 100644 --- a/torch_brain/utils/stitcher.py +++ b/torch_brain/utils/stitcher.py @@ -7,8 +7,10 @@ import pandas as pd from rich import print as rprint import torch +from torch import Tensor import lightning as L import torchmetrics +from torchmetrics.utilities import dim_zero_cat import wandb import torch_brain @@ -387,3 +389,31 @@ def on_test_batch_end(self, *args, **kwargs): def on_test_epoch_end(self, *args, **kwargs): self.on_validation_epoch_end(*args, **kwargs, prefix="test") + + +class Stitcher(torchmetrics.Metric): + r"""A simple prediction stitcher. Use this if your model output has associated + timestamps and your sampling strategy involves overlapping time windows, requiring + stitching to coalesce the pridiction and targets before computing the evaluation + metric. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.add_state("timestamps", default=[], dist_reduce_fx="cat") + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, timestamps: Tensor, preds: Tensor, target: Tensor) -> None: + self.timestamps.append(timestamps) + self.preds.append(preds) + self.target.append(target) + + def compute(self): + timestamps = dim_zero_cat(self.timestamps) + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + + stitched_preds = stitch(timestamps, preds) + stitched_target = stitch(timestamps, target) + return stitched_preds, stitched_target