From 31d26b664e87edff45874b594951f47de6ead318 Mon Sep 17 00:00:00 2001 From: Louis Dupont Date: Thu, 8 Dec 2022 17:09:40 +0200 Subject: [PATCH] fix --- .../training/sg_trainer/sg_trainer.py | 13 +++++++++---- .../training/utils/checkpoint_utils.py | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index de9f297a40..a3c34689af 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -270,9 +270,14 @@ def evaluate_from_recipe(cls, cfg: DictConfig) -> None: name=cfg.val_dataloader, dataset_params=cfg.dataset_params.val_dataset_params, dataloader_params=cfg.dataset_params.val_dataloader_params ) - checkpoints_dir = Path(get_checkpoints_dir_path(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir)) - checkpoint_path = str(checkpoints_dir / cfg.training_hyperparams.ckpt_name) - logger.info(f"Evaluating checkpoint: {checkpoint_path}") + if cfg.checkpoint_params.checkpoint_path is None: + logger.info( + "checkpoint_params.checkpoint_path was not provided, " "so the recipe will be evaluated using checkpoints_dir/training_hyperparams.ckpt_name" + ) + checkpoints_dir = Path(get_checkpoints_dir_path(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir)) + cfg.checkpoint_params.checkpoint_path = str(checkpoints_dir / cfg.training_hyperparams.ckpt_name) + + logger.info(f"Evaluating checkpoint: {cfg.checkpoint_params.checkpoint_path}") # BUILD NETWORK model = models.get( @@ -280,7 +285,7 @@ def evaluate_from_recipe(cls, cfg: DictConfig) -> None: num_classes=cfg.arch_params.num_classes, arch_params=cfg.arch_params, pretrained_weights=cfg.checkpoint_params.pretrained_weights, - checkpoint_path=checkpoint_path, + checkpoint_path=cfg.checkpoint_params.checkpoint_path, load_backbone=cfg.checkpoint_params.load_backbone, ) diff --git a/src/super_gradients/training/utils/checkpoint_utils.py b/src/super_gradients/training/utils/checkpoint_utils.py index 19c59b6ff5..cc468a8bf8 100644 --- a/src/super_gradients/training/utils/checkpoint_utils.py +++ b/src/super_gradients/training/utils/checkpoint_utils.py @@ -138,7 +138,7 @@ def copy_ckpt_to_local_folder( def read_ckpt_state_dict(ckpt_path: str, device="cpu"): if not os.path.exists(ckpt_path): - raise ValueError("Incorrect Checkpoint path") + raise FileNotFoundError(f"Incorrect Checkpoint path: {ckpt_path} (This should be an absolute path)") if device == "cuda": state_dict = torch.load(ckpt_path)