diff --git a/code2seq/utils/train.py b/code2seq/utils/train.py index d5af862..0213533 100644 --- a/code2seq/utils/train.py +++ b/code2seq/utils/train.py @@ -1,10 +1,10 @@ from os.path import basename import torch -from commode_utils.callback import UploadCheckpointCallback, PrintEpochResultCallback +from commode_utils.callback import PrintEpochResultCallback, ModelCheckpointWithUpload from omegaconf import DictConfig from pytorch_lightning import seed_everything, Trainer, LightningModule, LightningDataModule -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor +from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor from pytorch_lightning.loggers import WandbLogger @@ -18,7 +18,7 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict wandb_logger = WandbLogger(project=f"{model_name} -- {dataset_name}", log_model=False, offline=config.log_offline) # define model checkpoint callback - checkpoint_callback = ModelCheckpoint( + checkpoint_callback = ModelCheckpointWithUpload( dirpath=wandb_logger.experiment.dir, filename="{epoch:02d}-val_loss={val/loss:.4f}", monitor="val/loss", @@ -26,7 +26,6 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict save_top_k=-1, auto_insert_metric_name=False, ) - upload_checkpoint_callback = UploadCheckpointCallback(wandb_logger.experiment.dir) # define early stopping callback early_stopping_callback = EarlyStopping(patience=params.patience, monitor="val/loss", verbose=True, mode="min") # define callback for printing intermediate result @@ -48,7 +47,6 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict lr_logger, early_stopping_callback, checkpoint_callback, - upload_checkpoint_callback, print_epoch_result_callback, ], resume_from_checkpoint=config.get("checkpoint", None), diff --git a/requirements.txt b/requirements.txt index cf8c480..11a5bee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,10 +2,10 @@ black==21.7b0 mypy==0.910 torch==1.9.0 -pytorch-lightning==1.4.2 -torchmetrics==0.5.0 +pytorch-lightning==1.4.7 +torchmetrics==0.5.1 -tqdm==4.62.1 -wandb==0.12.0 +tqdm==4.62.2 +wandb==0.12.2 omegaconf==2.1.1 -commode-utils==0.3.8 +commode-utils==0.3.9 diff --git a/setup.py b/setup.py index b3b979b..1023b4e 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, find_packages -VERSION = "1.0.1" +VERSION = "1.0.2" with open("README.md") as readme_file: readme = readme_file.read()