Skip to content

Commit

Permalink
Merge branch 'main' of github.com:jhauret/vibravox into permanent_dir…
Browse files Browse the repository at this point in the history
…_and_description
  • Loading branch information
jhauret committed May 30, 2024
2 parents 89c91ca + 5ba57dc commit 5b32f8e
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ python run.py lightning_datamodule=stp lightning_module=wav2vec2_for_stp lightni

- Test [ECAPA2](https://huggingface.co/Jenthe/ECAPA2) for Speaker Verification
```
python run.py lightning_datamodule=spkv lightning_module=ecapa2 ++trainer.limit_train_batches=0
python run.py lightning_datamodule=spkv lightning_module=ecapa2 logging=csv ++trainer.limit_train_batches=0 ++trainer.limit_val_batches=0
```
5 changes: 5 additions & 0 deletions configs/logging/csv.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
log_every_n_steps: 100

logger:
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
save_dir: "csv/"
145 changes: 145 additions & 0 deletions configs/slurm_array/spkv.txt

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
seed_everything,
Trainer,
)
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger

from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig

from torchmetrics import MetricCollection
Expand Down Expand Up @@ -51,7 +52,7 @@ def main(cfg: DictConfig):

# Instantiate Trainer
callbacks: List[Callback] = list(hydra.utils.instantiate(cfg.callbacks).values())
logger: TensorBoardLogger = hydra.utils.instantiate(cfg.logging.logger)
logger: Logger = hydra.utils.instantiate(cfg.logging.logger)
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
)
Expand Down
36 changes: 36 additions & 0 deletions scripts/run_spkv_slurm_array_JZ.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/bin/bash

#SBATCH --job-name=spkv_array_job
#SBATCH --output=slurm-%A_%a.out
#SBATCH --error=slurm-%A_%a.err
#SBATCH --constraint=v100-16g
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --gres=gpu:1
#SBATCH --cpus-per-task=4
#SBATCH --time=05:00:00
#SBATCH --qos=qos_gpu-t3
#SBATCH --hint=nomultithread
#SBATCH --account=lbo@v100
#SBATCH --array=1-144

module purge
conda deactivate

module load pytorch-gpu/py3/2.2.0

export HF_DATASETS_CACHE=$WORK/huggingface_cache/datasets
export HF_DATASETS_OFFLINE=1

# Specify the path to the config file
array_config=./configs/slurm_array/spkv.txt

# Extract values of the job
dataset_name=$(awk -v ArrayTaskID=$SLURM_ARRAY_TASK_ID '$1==ArrayTaskID {print $2}' $array_config)
split=$(awk -v ArrayTaskID=$SLURM_ARRAY_TASK_ID '$1==ArrayTaskID {print $3}' $array_config)
sensor_a=$(awk -v ArrayTaskID=$SLURM_ARRAY_TASK_ID '$1==ArrayTaskID {print $4}' $array_config)
sensor_b=$(awk -v ArrayTaskID=$SLURM_ARRAY_TASK_ID '$1==ArrayTaskID {print $5}' $array_config)
pairs=$(awk -v ArrayTaskID=$SLURM_ARRAY_TASK_ID '$1==ArrayTaskID {print $6}' $array_config)

set -x
srun python -u run.py lightning_datamodule=spkv lightning_module=ecapa2 lightning_datamodule.dataset_name="$dataset_name" lightning_datamodule.split="$split" lightning_datamodule.sensor_a="$sensor_a" lightning_datamodule.sensor_b="$sensor_b" lightning_datamodule.pairs="$pairs" ++trainer.limit_train_batches=0
7 changes: 2 additions & 5 deletions vibravox/lightning_modules/ecapa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,6 @@ def on_test_start(self) -> None:
# Check DataModule parameters
self.check_datamodule_parameter()

# Log description
self.logger.experiment.add_text(tag="description", text_string=self.description)

def on_test_batch_end(
self,
outputs: Dict[str, torch.Tensor],
Expand Down Expand Up @@ -188,12 +185,12 @@ def on_test_epoch_end(self):
Called at the end of the test epoch.
- Triggers the computation of the metrics.
- Logs the metrics to tensorboard.
- Logs the metrics to the logger (preference: csv for extracting results as there is no training curves to log).
"""
# Get the metrics as a dict
metrics_to_log = self.metrics.compute()

# Log in tensorboard
# Log in the logger
self.log_dict(dictionary=metrics_to_log, sync_dist=True, prog_bar=True)

def check_datamodule_parameter(self) -> None:
Expand Down

0 comments on commit 5b32f8e

Please sign in to comment.