Skip to content

Commit

Permalink
Csv logger for spkv (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
jhauret authored May 30, 2024
2 parents d8a56de + e7fb534 commit 5ba57dc
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ python run.py lightning_datamodule=stp lightning_module=wav2vec2_for_stp lightni

- **Test** [ECAPA2](https://huggingface.co/Jenthe/ECAPA2) for Speaker Verification
```
python run.py lightning_datamodule=spkv lightning_module=ecapa2 ++trainer.limit_train_batches=0
python run.py lightning_datamodule=spkv lightning_module=ecapa2 logging=csv ++trainer.limit_train_batches=0 ++trainer.limit_val_batches=0
```
5 changes: 5 additions & 0 deletions configs/logging/csv.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
log_every_n_steps: 100

logger:
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
save_dir: "csv/"
5 changes: 3 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
seed_everything,
Trainer,
)
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger

from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig

from torchmetrics import MetricCollection
Expand Down Expand Up @@ -51,7 +52,7 @@ def main(cfg: DictConfig):

# Instantiate Trainer
callbacks: List[Callback] = list(hydra.utils.instantiate(cfg.callbacks).values())
logger: TensorBoardLogger = hydra.utils.instantiate(cfg.logging.logger)
logger: Logger = hydra.utils.instantiate(cfg.logging.logger)
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
)
Expand Down
7 changes: 2 additions & 5 deletions vibravox/lightning_modules/ecapa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,6 @@ def on_test_start(self) -> None:
# Check DataModule parameters
self.check_datamodule_parameter()

# Log description
self.logger.experiment.add_text(tag="description", text_string=self.description)

def on_test_batch_end(
self,
outputs: Dict[str, torch.Tensor],
Expand Down Expand Up @@ -188,12 +185,12 @@ def on_test_epoch_end(self):
Called at the end of the test epoch.
- Triggers the computation of the metrics.
- Logs the metrics to tensorboard.
- Logs the metrics to the logger (preference: csv for extracting results as there is no training curves to log).
"""
# Get the metrics as a dict
metrics_to_log = self.metrics.compute()

# Log in tensorboard
# Log in the logger
self.log_dict(dictionary=metrics_to_log, sync_dist=True, prog_bar=True)

def check_datamodule_parameter(self) -> None:
Expand Down

0 comments on commit 5ba57dc

Please sign in to comment.