Skip to content

Commit

Permalink
fix regression tests (#1408)
Browse files Browse the repository at this point in the history
* fix regression tests

* fix

* fix prepare_conversion_cfgs

* fix

---------

Co-authored-by: Louis Dupont <[email protected]>
  • Loading branch information
ofrimasad and Louis-Dupont authored Aug 30, 2023
1 parent 3e1019f commit dc52a0d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
16 changes: 7 additions & 9 deletions src/super_gradients/training/models/conversion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import pathlib
from pathlib import Path

import hydra
import numpy as np
Expand All @@ -17,6 +16,7 @@
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.training import models
from super_gradients.training.utils.sg_trainer_utils import parse_args
from super_gradients.training.utils.utils import get_param

logger = get_logger(__name__)

Expand Down Expand Up @@ -228,22 +228,20 @@ def prepare_conversion_cfgs(cfg: DictConfig):
# CREATE THE EXPERIMENT CFG

# Load the latest experiment config
# TODO: check if we can load the cfg from run
experiment_cfg = load_experiment_cfg(ckpt_root_dir=cfg.ckpt_root_dir, experiment_name=cfg.experiment_name)
run_id = get_param(cfg, "run_id")
if run_id is None:
run_id = get_latest_run_id(experiment_name=cfg.experiment_name, checkpoints_root_dir=cfg.ckpt_root_dir)
experiment_cfg = load_experiment_cfg(ckpt_root_dir=cfg.ckpt_root_dir, experiment_name=cfg.experiment_name, run_id=run_id)

hydra.utils.instantiate(experiment_cfg)
if cfg.checkpoint_path is None:
logger.info(
"checkpoint_params.checkpoint_path was not provided, so the model will be converted using weights from "
"checkpoints_dir/training_hyperparams.ckpt_name "
)
if cfg.run_id is None:
checkpoints_dir = Path(get_latest_run_id(experiment_name=cfg.experiment_name, checkpoints_root_dir=cfg.ckpt_root_dir))
else:
checkpoints_dir = Path(get_checkpoints_dir_path(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir))
checkpoints_dir = os.path.join(checkpoints_dir, cfg.run_id)

checkpoints_dir = get_checkpoints_dir_path(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir, run_id=run_id)
cfg.checkpoint_path = os.path.join(checkpoints_dir, cfg.ckpt_name)

cfg.out_path = cfg.out_path or cfg.checkpoint_path.replace(".pth", ".onnx")
logger.info(f"Exporting checkpoint: {cfg.checkpoint_path} to ONNX.")
return cfg, experiment_cfg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import torch

from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path
from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path, get_latest_run_id


class ShortenedRecipesAccuracyTests(unittest.TestCase):
Expand All @@ -20,7 +20,9 @@ def test_shortened_cifar10_resnet_accuracy(self):
self.assertTrue(self._reached_goal_metric(experiment_name="shortened_cifar10_resnet_accuracy_test", metric_value=0.9167, delta=0.05))

def test_convert_shortened_cifar10_resnet(self):
ckpt_dir = get_checkpoints_dir_path(experiment_name="shortened_cifar10_resnet_accuracy_test")
experiment_name = "shortened_cifar10_resnet_accuracy_test"
run_id = get_latest_run_id(experiment_name=experiment_name)
ckpt_dir = get_checkpoints_dir_path(experiment_name=experiment_name, run_id=run_id)
self.assertTrue(os.path.exists(os.path.join(ckpt_dir, "ckpt_best.onnx")))

def test_shortened_coco2017_yolox_n_map(self):
Expand All @@ -34,7 +36,8 @@ def test_shortened_coco_dekr_32_ap_test(self):

@classmethod
def _reached_goal_metric(cls, experiment_name: str, metric_value: float, delta: float):
checkpoints_dir_path = get_checkpoints_dir_path(experiment_name=experiment_name)
run_id = get_latest_run_id(experiment_name=experiment_name)
checkpoints_dir_path = get_checkpoints_dir_path(experiment_name=experiment_name, run_id=run_id)
sd = torch.load(os.path.join(checkpoints_dir_path, "ckpt_best.pth"))
metric_val_reached = sd["acc"].cpu().item()
diff = abs(metric_val_reached - metric_value)
Expand Down

0 comments on commit dc52a0d

Please sign in to comment.