From 96df0276befbeb5195c3a65c99f2c99f3e9b3b1b Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Wed, 20 Sep 2023 19:56:23 +0300 Subject: [PATCH] Fixed loading preprocessing params from pretrained weights (#1473) * Fixed loading preprocessing params from pretrained weights * Added support for file:/// in pretrained weights * Added test to ensure we load preprocessing params --- .../training/models/model_factory.py | 12 +-- .../training/utils/checkpoint_utils.py | 83 +++++++++++++------ .../unit_tests/pretrained_models_unit_test.py | 29 ++++++- 3 files changed, 90 insertions(+), 34 deletions(-) diff --git a/src/super_gradients/training/models/model_factory.py b/src/super_gradients/training/models/model_factory.py index 615a389fd6..df68993da4 100644 --- a/src/super_gradients/training/models/model_factory.py +++ b/src/super_gradients/training/models/model_factory.py @@ -153,6 +153,13 @@ def instantiate_model( net = architecture_cls(arch_params=arch_params) if pretrained_weights: + # The logic is follows - first we initialize the preprocessing params using default hard-coded params + # If pretrained checkpoint contains preprocessing params, new params will be loaded and override the ones from + # this step in load_pretrained_weights_local/load_pretrained_weights + if isinstance(net, HasPredict): + processing_params = get_pretrained_processing_params(model_name, pretrained_weights) + net.set_dataset_processing_params(**processing_params) + if is_remote and pretrained_weights_path: load_pretrained_weights_local(net, model_name, pretrained_weights_path) else: @@ -162,11 +169,6 @@ def instantiate_model( net.replace_head(new_num_classes=num_classes_new_head) arch_params.num_classes = num_classes_new_head - # STILL NEED TO GET PREPROCESSING PARAMS IN CASE CHECKPOINT HAS NO RECIPE - if isinstance(net, HasPredict): - processing_params = get_pretrained_processing_params(model_name, pretrained_weights) - net.set_dataset_processing_params(**processing_params) - _add_model_name_attribute(net, model_name) return net diff --git a/src/super_gradients/training/utils/checkpoint_utils.py b/src/super_gradients/training/utils/checkpoint_utils.py index a16de11d4e..46de06fb58 100644 --- a/src/super_gradients/training/utils/checkpoint_utils.py +++ b/src/super_gradients/training/utils/checkpoint_utils.py @@ -1517,16 +1517,7 @@ def load_checkpoint_to_model( message_model = "model" if not load_backbone else "model's backbone" logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix) - if (isinstance(net, HasPredict)) and load_processing_params: - if "processing_params" not in checkpoint.keys(): - raise ValueError("Can't load processing params - could not find any stored in checkpoint file.") - try: - net.set_dataset_processing_params(**checkpoint["processing_params"]) - except Exception as e: - logger.warning( - f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling" - "predict make sure to call set_dataset_processing_params." - ) + _maybe_load_preprocessing_params(net, checkpoint) if load_weights_only or load_backbone: # DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS @@ -1549,10 +1540,12 @@ def __init__(self, desc): def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str): """ Loads pretrained weights from the MODEL_URLS dictionary to model - :param architecture: name of the model's architecture - :param model: model to load pretrinaed weights for - :param pretrained_weights: name for the pretrianed weights (i.e imagenet) - :return: None + + :param architecture: name of the model's architecture + :param model: model to load pretrinaed weights for + :param pretrained_weights: name for the pretrianed weights (i.e imagenet) + + :return: None """ from super_gradients.common.object_names import Models @@ -1569,19 +1562,19 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine "By downloading the pre-trained weight files you agree to comply with these terms." ) - unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_") - map_location = torch.device("cpu") - with wait_for_the_master(get_local_rank()): - pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename) - _load_weights(architecture, model, pretrained_state_dict) - + # Basically this check allows settings pretrained weights from local path using file:///path/to/weights scheme + # which is a valid URI scheme for local files + # Supporting local files and file URI allows us modification of pretrained weights dics in unit tests + if url.startswith("file://") or os.path.exists(url): + pretrained_state_dict = torch.load(url.replace("file://", ""), map_location="cpu") + else: + unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_") + map_location = torch.device("cpu") + with wait_for_the_master(get_local_rank()): + pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename) -def _load_weights(architecture, model, pretrained_state_dict): - if "ema_net" in pretrained_state_dict.keys(): - pretrained_state_dict["net"] = pretrained_state_dict["ema_net"] - solver = YoloXCheckpointSolver() if "yolox" in architecture else DefaultCheckpointSolver() - adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver) - logger.info(f"Successfully loaded pretrained weights for architecture {architecture}") + _load_weights(architecture, model, pretrained_state_dict) + _maybe_load_preprocessing_params(model, pretrained_state_dict) def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str): @@ -1598,3 +1591,41 @@ def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pre pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location) _load_weights(architecture, model, pretrained_state_dict) + _maybe_load_preprocessing_params(model, pretrained_state_dict) + + +def _load_weights(architecture, model, pretrained_state_dict): + if "ema_net" in pretrained_state_dict.keys(): + pretrained_state_dict["net"] = pretrained_state_dict["ema_net"] + solver = YoloXCheckpointSolver() if "yolox" in architecture else DefaultCheckpointSolver() + adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver) + logger.info(f"Successfully loaded pretrained weights for architecture {architecture}") + + +def _maybe_load_preprocessing_params(model: Union[nn.Module, HasPredict], checkpoint: Mapping[str, Tensor]) -> bool: + """ + Tries to load preprocessing params from the checkpoint to the model. + The function does not crash, and raises a warning if the loading fails. + :param model: Instance of nn.Module + :param checkpoint: Entire checkpoint dict (not state_dict with model weights) + :return: True if the loading was successful, False otherwise. + """ + model = unwrap_model(model) + checkpoint_has_preprocessing_params = "processing_params" in checkpoint.keys() + model_has_predict = isinstance(model, HasPredict) + logger.debug( + f"Trying to load preprocessing params from checkpoint. Preprocessing params in checkpoint: {checkpoint_has_preprocessing_params}. " + f"Model {model.__class__.__name__} inherit HasPredict: {model_has_predict}" + ) + + if model_has_predict and checkpoint_has_preprocessing_params: + try: + model.set_dataset_processing_params(**checkpoint["processing_params"]) + logger.debug(f"Successfully loaded preprocessing params from checkpoint {checkpoint['processing_params']}") + return True + except Exception as e: + logger.warning( + f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling" + "predict make sure to call set_dataset_processing_params." + ) + return False diff --git a/tests/unit_tests/pretrained_models_unit_test.py b/tests/unit_tests/pretrained_models_unit_test.py index bb6a503c9b..6ed2d52db1 100644 --- a/tests/unit_tests/pretrained_models_unit_test.py +++ b/tests/unit_tests/pretrained_models_unit_test.py @@ -1,12 +1,19 @@ +import os +import shutil +import tempfile import unittest + +import numpy as np +import torch + import super_gradients from super_gradients.common.object_names import Models -from super_gradients.training import models from super_gradients.training import Trainer +from super_gradients.training import models from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader from super_gradients.training.metrics import Accuracy -import os -import shutil +from super_gradients.training.pretrained_models import MODEL_URLS, PRETRAINED_NUM_CLASSES +from super_gradients.training.processing.processing import default_yolo_nas_coco_processing_params class PretrainedModelsUnitTest(unittest.TestCase): @@ -29,6 +36,22 @@ def test_pretrained_repvgg_a0_imagenet(self): model = models.get(Models.REPVGG_A0, pretrained_weights="imagenet", arch_params={"build_residual_branches": True}) trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True) + def test_pretrained_models_load_preprocessing_params(self): + """ + Test that checks whether preprocessing params from pretrained model load correctly. + """ + state = {"net": models.get(Models.YOLO_NAS_S, num_classes=80).state_dict(), "processing_params": default_yolo_nas_coco_processing_params()} + with tempfile.TemporaryDirectory() as td: + checkpoint_path = os.path.join(td, "yolo_nas_s_coco.pth") + torch.save(state, checkpoint_path) + + MODEL_URLS[Models.YOLO_NAS_S + "_test"] = checkpoint_path + PRETRAINED_NUM_CLASSES["test"] = 80 + + model = models.get(Models.YOLO_NAS_S, pretrained_weights="test") + # .predict() would fail it model has no preprocessing params + self.assertIsNotNone(model.predict(np.zeros(shape=(512, 512, 3), dtype=np.uint8))) + def tearDown(self) -> None: if os.path.exists("~/.cache/torch/hub/"): shutil.rmtree("~/.cache/torch/hub/")