diff --git a/README.md b/README.md index 79841c9..b518d7f 100644 --- a/README.md +++ b/README.md @@ -41,5 +41,5 @@ python run.py lightning_datamodule=stp lightning_module=wav2vec2_for_stp lightni - **Test** [ECAPA2](https://huggingface.co/Jenthe/ECAPA2) for Speaker Verification ``` -python run.py lightning_datamodule=spkv lightning_module=ecapa2 ++trainer.limit_train_batches=0 +python run.py lightning_datamodule=spkv lightning_module=ecapa2 logging=csv ++trainer.limit_train_batches=0 ++trainer.limit_val_batches=0 ``` diff --git a/configs/logging/csv.yaml b/configs/logging/csv.yaml new file mode 100644 index 0000000..0459aa9 --- /dev/null +++ b/configs/logging/csv.yaml @@ -0,0 +1,5 @@ +log_every_n_steps: 100 + +logger: + _target_: lightning.pytorch.loggers.csv_logs.CSVLogger + save_dir: "csv/" diff --git a/run.py b/run.py index 321084b..ba0e2a3 100644 --- a/run.py +++ b/run.py @@ -16,7 +16,8 @@ seed_everything, Trainer, ) -from lightning.pytorch.loggers.tensorboard import TensorBoardLogger + +from lightning.pytorch.loggers import Logger from omegaconf import DictConfig from torchmetrics import MetricCollection @@ -51,7 +52,7 @@ def main(cfg: DictConfig): # Instantiate Trainer callbacks: List[Callback] = list(hydra.utils.instantiate(cfg.callbacks).values()) - logger: TensorBoardLogger = hydra.utils.instantiate(cfg.logging.logger) + logger: Logger = hydra.utils.instantiate(cfg.logging.logger) trainer: Trainer = hydra.utils.instantiate( cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial" ) diff --git a/vibravox/lightning_modules/ecapa2.py b/vibravox/lightning_modules/ecapa2.py index 8a2e7d2..6154957 100644 --- a/vibravox/lightning_modules/ecapa2.py +++ b/vibravox/lightning_modules/ecapa2.py @@ -128,9 +128,6 @@ def on_test_start(self) -> None: # Check DataModule parameters self.check_datamodule_parameter() - # Log description - self.logger.experiment.add_text(tag="description", text_string=self.description) - def on_test_batch_end( self, outputs: Dict[str, torch.Tensor], @@ -188,12 +185,12 @@ def on_test_epoch_end(self): Called at the end of the test epoch. - Triggers the computation of the metrics. - - Logs the metrics to tensorboard. + - Logs the metrics to the logger (preference: csv for extracting results as there is no training curves to log). """ # Get the metrics as a dict metrics_to_log = self.metrics.compute() - # Log in tensorboard + # Log in the logger self.log_dict(dictionary=metrics_to_log, sync_dist=True, prog_bar=True) def check_datamodule_parameter(self) -> None: