diff --git a/src/super_gradients/training/models/model_factory.py b/src/super_gradients/training/models/model_factory.py index 615a389fd6..df68993da4 100644 --- a/src/super_gradients/training/models/model_factory.py +++ b/src/super_gradients/training/models/model_factory.py @@ -153,6 +153,13 @@ def instantiate_model( net = architecture_cls(arch_params=arch_params) if pretrained_weights: + # The logic is follows - first we initialize the preprocessing params using default hard-coded params + # If pretrained checkpoint contains preprocessing params, new params will be loaded and override the ones from + # this step in load_pretrained_weights_local/load_pretrained_weights + if isinstance(net, HasPredict): + processing_params = get_pretrained_processing_params(model_name, pretrained_weights) + net.set_dataset_processing_params(**processing_params) + if is_remote and pretrained_weights_path: load_pretrained_weights_local(net, model_name, pretrained_weights_path) else: @@ -162,11 +169,6 @@ def instantiate_model( net.replace_head(new_num_classes=num_classes_new_head) arch_params.num_classes = num_classes_new_head - # STILL NEED TO GET PREPROCESSING PARAMS IN CASE CHECKPOINT HAS NO RECIPE - if isinstance(net, HasPredict): - processing_params = get_pretrained_processing_params(model_name, pretrained_weights) - net.set_dataset_processing_params(**processing_params) - _add_model_name_attribute(net, model_name) return net diff --git a/src/super_gradients/training/utils/checkpoint_utils.py b/src/super_gradients/training/utils/checkpoint_utils.py index 8638752a3e..46de06fb58 100644 --- a/src/super_gradients/training/utils/checkpoint_utils.py +++ b/src/super_gradients/training/utils/checkpoint_utils.py @@ -1,12 +1,11 @@ import collections import os import tempfile -from typing import Union, Mapping, Dict +from typing import Union, Mapping import pkg_resources import torch from torch import nn, Tensor -from torch.optim.lr_scheduler import CyclicLR from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces @@ -23,6 +22,7 @@ except (ModuleNotFoundError, ImportError, NameError): from torch.hub import _download_url_to_file as download_url_to_file + logger = get_logger(__name__) @@ -1517,16 +1517,7 @@ def load_checkpoint_to_model( message_model = "model" if not load_backbone else "model's backbone" logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix) - if (isinstance(net, HasPredict)) and load_processing_params: - if "processing_params" not in checkpoint.keys(): - raise ValueError("Can't load processing params - could not find any stored in checkpoint file.") - try: - net.set_dataset_processing_params(**checkpoint["processing_params"]) - except Exception as e: - logger.warning( - f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling" - "predict make sure to call set_dataset_processing_params." - ) + _maybe_load_preprocessing_params(net, checkpoint) if load_weights_only or load_backbone: # DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS @@ -1549,10 +1540,12 @@ def __init__(self, desc): def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str): """ Loads pretrained weights from the MODEL_URLS dictionary to model - :param architecture: name of the model's architecture - :param model: model to load pretrinaed weights for - :param pretrained_weights: name for the pretrianed weights (i.e imagenet) - :return: None + + :param architecture: name of the model's architecture + :param model: model to load pretrinaed weights for + :param pretrained_weights: name for the pretrianed weights (i.e imagenet) + + :return: None """ from super_gradients.common.object_names import Models @@ -1569,22 +1562,23 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine "By downloading the pre-trained weight files you agree to comply with these terms." ) - unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_") - map_location = torch.device("cpu") - with wait_for_the_master(get_local_rank()): - pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename) - _load_weights(architecture, model, pretrained_state_dict) - + # Basically this check allows settings pretrained weights from local path using file:///path/to/weights scheme + # which is a valid URI scheme for local files + # Supporting local files and file URI allows us modification of pretrained weights dics in unit tests + if url.startswith("file://") or os.path.exists(url): + pretrained_state_dict = torch.load(url.replace("file://", ""), map_location="cpu") + else: + unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_") + map_location = torch.device("cpu") + with wait_for_the_master(get_local_rank()): + pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename) -def _load_weights(architecture, model, pretrained_state_dict): - if "ema_net" in pretrained_state_dict.keys(): - pretrained_state_dict["net"] = pretrained_state_dict["ema_net"] - solver = YoloXCheckpointSolver() if "yolox" in architecture else DefaultCheckpointSolver() - adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver) - logger.info(f"Successfully loaded pretrained weights for architecture {architecture}") + _load_weights(architecture, model, pretrained_state_dict) + _maybe_load_preprocessing_params(model, pretrained_state_dict) def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str): + """ Loads pretrained weights from the MODEL_URLS dictionary to model :param architecture: name of the model's architecture @@ -1597,18 +1591,41 @@ def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pre pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location) _load_weights(architecture, model, pretrained_state_dict) + _maybe_load_preprocessing_params(model, pretrained_state_dict) + + +def _load_weights(architecture, model, pretrained_state_dict): + if "ema_net" in pretrained_state_dict.keys(): + pretrained_state_dict["net"] = pretrained_state_dict["ema_net"] + solver = YoloXCheckpointSolver() if "yolox" in architecture else DefaultCheckpointSolver() + adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver) + logger.info(f"Successfully loaded pretrained weights for architecture {architecture}") -def get_scheduler_state(scheduler) -> Dict: +def _maybe_load_preprocessing_params(model: Union[nn.Module, HasPredict], checkpoint: Mapping[str, Tensor]) -> bool: """ - Wrapper for getting a torch lr scheduler state dict, resolving some issues with CyclicLR - (see https://github.com/pytorch/pytorch/pull/91400) - :param scheduler: torch.optim.lr_scheduler._LRScheduler, the scheduler - :return: the scheduler's state_dict + Tries to load preprocessing params from the checkpoint to the model. + The function does not crash, and raises a warning if the loading fails. + :param model: Instance of nn.Module + :param checkpoint: Entire checkpoint dict (not state_dict with model weights) + :return: True if the loading was successful, False otherwise. """ - from super_gradients.training.utils import torch_version_is_greater_or_equal - - state = scheduler.state_dict() - if isinstance(scheduler, CyclicLR) and not torch_version_is_greater_or_equal(2, 0): - del state["_scale_fn_ref"] - return state + model = unwrap_model(model) + checkpoint_has_preprocessing_params = "processing_params" in checkpoint.keys() + model_has_predict = isinstance(model, HasPredict) + logger.debug( + f"Trying to load preprocessing params from checkpoint. Preprocessing params in checkpoint: {checkpoint_has_preprocessing_params}. " + f"Model {model.__class__.__name__} inherit HasPredict: {model_has_predict}" + ) + + if model_has_predict and checkpoint_has_preprocessing_params: + try: + model.set_dataset_processing_params(**checkpoint["processing_params"]) + logger.debug(f"Successfully loaded preprocessing params from checkpoint {checkpoint['processing_params']}") + return True + except Exception as e: + logger.warning( + f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling" + "predict make sure to call set_dataset_processing_params." + ) + return False diff --git a/tests/unit_tests/pretrained_models_unit_test.py b/tests/unit_tests/pretrained_models_unit_test.py index bb6a503c9b..6ed2d52db1 100644 --- a/tests/unit_tests/pretrained_models_unit_test.py +++ b/tests/unit_tests/pretrained_models_unit_test.py @@ -1,12 +1,19 @@ +import os +import shutil +import tempfile import unittest + +import numpy as np +import torch + import super_gradients from super_gradients.common.object_names import Models -from super_gradients.training import models from super_gradients.training import Trainer +from super_gradients.training import models from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader from super_gradients.training.metrics import Accuracy -import os -import shutil +from super_gradients.training.pretrained_models import MODEL_URLS, PRETRAINED_NUM_CLASSES +from super_gradients.training.processing.processing import default_yolo_nas_coco_processing_params class PretrainedModelsUnitTest(unittest.TestCase): @@ -29,6 +36,22 @@ def test_pretrained_repvgg_a0_imagenet(self): model = models.get(Models.REPVGG_A0, pretrained_weights="imagenet", arch_params={"build_residual_branches": True}) trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True) + def test_pretrained_models_load_preprocessing_params(self): + """ + Test that checks whether preprocessing params from pretrained model load correctly. + """ + state = {"net": models.get(Models.YOLO_NAS_S, num_classes=80).state_dict(), "processing_params": default_yolo_nas_coco_processing_params()} + with tempfile.TemporaryDirectory() as td: + checkpoint_path = os.path.join(td, "yolo_nas_s_coco.pth") + torch.save(state, checkpoint_path) + + MODEL_URLS[Models.YOLO_NAS_S + "_test"] = checkpoint_path + PRETRAINED_NUM_CLASSES["test"] = 80 + + model = models.get(Models.YOLO_NAS_S, pretrained_weights="test") + # .predict() would fail it model has no preprocessing params + self.assertIsNotNone(model.predict(np.zeros(shape=(512, 512, 3), dtype=np.uint8))) + def tearDown(self) -> None: if os.path.exists("~/.cache/torch/hub/"): shutil.rmtree("~/.cache/torch/hub/")