Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into bugfix/SG-000_fix_c…
Browse files Browse the repository at this point in the history
…yclic_lr_state_dict

# Conflicts:
#	src/super_gradients/training/utils/checkpoint_utils.py
  • Loading branch information
shaydeci committed Sep 21, 2023
2 parents b677a89 + 96df027 commit 6fa1bf9
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 47 deletions.
12 changes: 7 additions & 5 deletions src/super_gradients/training/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def instantiate_model(
net = architecture_cls(arch_params=arch_params)

if pretrained_weights:
# The logic is follows - first we initialize the preprocessing params using default hard-coded params
# If pretrained checkpoint contains preprocessing params, new params will be loaded and override the ones from
# this step in load_pretrained_weights_local/load_pretrained_weights
if isinstance(net, HasPredict):
processing_params = get_pretrained_processing_params(model_name, pretrained_weights)
net.set_dataset_processing_params(**processing_params)

if is_remote and pretrained_weights_path:
load_pretrained_weights_local(net, model_name, pretrained_weights_path)
else:
Expand All @@ -162,11 +169,6 @@ def instantiate_model(
net.replace_head(new_num_classes=num_classes_new_head)
arch_params.num_classes = num_classes_new_head

# STILL NEED TO GET PREPROCESSING PARAMS IN CASE CHECKPOINT HAS NO RECIPE
if isinstance(net, HasPredict):
processing_params = get_pretrained_processing_params(model_name, pretrained_weights)
net.set_dataset_processing_params(**processing_params)

_add_model_name_attribute(net, model_name)

return net
Expand Down
95 changes: 56 additions & 39 deletions src/super_gradients/training/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import collections
import os
import tempfile
from typing import Union, Mapping, Dict
from typing import Union, Mapping

import pkg_resources
import torch
from torch import nn, Tensor
from torch.optim.lr_scheduler import CyclicLR

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces
Expand All @@ -23,6 +22,7 @@
except (ModuleNotFoundError, ImportError, NameError):
from torch.hub import _download_url_to_file as download_url_to_file


logger = get_logger(__name__)


Expand Down Expand Up @@ -1517,16 +1517,7 @@ def load_checkpoint_to_model(
message_model = "model" if not load_backbone else "model's backbone"
logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix)

if (isinstance(net, HasPredict)) and load_processing_params:
if "processing_params" not in checkpoint.keys():
raise ValueError("Can't load processing params - could not find any stored in checkpoint file.")
try:
net.set_dataset_processing_params(**checkpoint["processing_params"])
except Exception as e:
logger.warning(
f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling"
"predict make sure to call set_dataset_processing_params."
)
_maybe_load_preprocessing_params(net, checkpoint)

if load_weights_only or load_backbone:
# DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS
Expand All @@ -1549,10 +1540,12 @@ def __init__(self, desc):
def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
"""
Loads pretrained weights from the MODEL_URLS dictionary to model
:param architecture: name of the model's architecture
:param model: model to load pretrinaed weights for
:param pretrained_weights: name for the pretrianed weights (i.e imagenet)
:return: None
:param architecture: name of the model's architecture
:param model: model to load pretrinaed weights for
:param pretrained_weights: name for the pretrianed weights (i.e imagenet)
:return: None
"""
from super_gradients.common.object_names import Models

Expand All @@ -1569,22 +1562,23 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
"By downloading the pre-trained weight files you agree to comply with these terms."
)

unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_")
map_location = torch.device("cpu")
with wait_for_the_master(get_local_rank()):
pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
_load_weights(architecture, model, pretrained_state_dict)

# Basically this check allows settings pretrained weights from local path using file:///path/to/weights scheme
# which is a valid URI scheme for local files
# Supporting local files and file URI allows us modification of pretrained weights dics in unit tests
if url.startswith("file://") or os.path.exists(url):
pretrained_state_dict = torch.load(url.replace("file://", ""), map_location="cpu")
else:
unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_")
map_location = torch.device("cpu")
with wait_for_the_master(get_local_rank()):
pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)

def _load_weights(architecture, model, pretrained_state_dict):
if "ema_net" in pretrained_state_dict.keys():
pretrained_state_dict["net"] = pretrained_state_dict["ema_net"]
solver = YoloXCheckpointSolver() if "yolox" in architecture else DefaultCheckpointSolver()
adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver)
logger.info(f"Successfully loaded pretrained weights for architecture {architecture}")
_load_weights(architecture, model, pretrained_state_dict)
_maybe_load_preprocessing_params(model, pretrained_state_dict)


def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):

"""
Loads pretrained weights from the MODEL_URLS dictionary to model
:param architecture: name of the model's architecture
Expand All @@ -1597,18 +1591,41 @@ def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pre

pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location)
_load_weights(architecture, model, pretrained_state_dict)
_maybe_load_preprocessing_params(model, pretrained_state_dict)


def _load_weights(architecture, model, pretrained_state_dict):
if "ema_net" in pretrained_state_dict.keys():
pretrained_state_dict["net"] = pretrained_state_dict["ema_net"]
solver = YoloXCheckpointSolver() if "yolox" in architecture else DefaultCheckpointSolver()
adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver)
logger.info(f"Successfully loaded pretrained weights for architecture {architecture}")


def get_scheduler_state(scheduler) -> Dict:
def _maybe_load_preprocessing_params(model: Union[nn.Module, HasPredict], checkpoint: Mapping[str, Tensor]) -> bool:
"""
Wrapper for getting a torch lr scheduler state dict, resolving some issues with CyclicLR
(see https://github.com/pytorch/pytorch/pull/91400)
:param scheduler: torch.optim.lr_scheduler._LRScheduler, the scheduler
:return: the scheduler's state_dict
Tries to load preprocessing params from the checkpoint to the model.
The function does not crash, and raises a warning if the loading fails.
:param model: Instance of nn.Module
:param checkpoint: Entire checkpoint dict (not state_dict with model weights)
:return: True if the loading was successful, False otherwise.
"""
from super_gradients.training.utils import torch_version_is_greater_or_equal

state = scheduler.state_dict()
if isinstance(scheduler, CyclicLR) and not torch_version_is_greater_or_equal(2, 0):
del state["_scale_fn_ref"]
return state
model = unwrap_model(model)
checkpoint_has_preprocessing_params = "processing_params" in checkpoint.keys()
model_has_predict = isinstance(model, HasPredict)
logger.debug(
f"Trying to load preprocessing params from checkpoint. Preprocessing params in checkpoint: {checkpoint_has_preprocessing_params}. "
f"Model {model.__class__.__name__} inherit HasPredict: {model_has_predict}"
)

if model_has_predict and checkpoint_has_preprocessing_params:
try:
model.set_dataset_processing_params(**checkpoint["processing_params"])
logger.debug(f"Successfully loaded preprocessing params from checkpoint {checkpoint['processing_params']}")
return True
except Exception as e:
logger.warning(
f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling"
"predict make sure to call set_dataset_processing_params."
)
return False
29 changes: 26 additions & 3 deletions tests/unit_tests/pretrained_models_unit_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import os
import shutil
import tempfile
import unittest

import numpy as np
import torch

import super_gradients
from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.training import Trainer
from super_gradients.training import models
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
from super_gradients.training.metrics import Accuracy
import os
import shutil
from super_gradients.training.pretrained_models import MODEL_URLS, PRETRAINED_NUM_CLASSES
from super_gradients.training.processing.processing import default_yolo_nas_coco_processing_params


class PretrainedModelsUnitTest(unittest.TestCase):
Expand All @@ -29,6 +36,22 @@ def test_pretrained_repvgg_a0_imagenet(self):
model = models.get(Models.REPVGG_A0, pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)

def test_pretrained_models_load_preprocessing_params(self):
"""
Test that checks whether preprocessing params from pretrained model load correctly.
"""
state = {"net": models.get(Models.YOLO_NAS_S, num_classes=80).state_dict(), "processing_params": default_yolo_nas_coco_processing_params()}
with tempfile.TemporaryDirectory() as td:
checkpoint_path = os.path.join(td, "yolo_nas_s_coco.pth")
torch.save(state, checkpoint_path)

MODEL_URLS[Models.YOLO_NAS_S + "_test"] = checkpoint_path
PRETRAINED_NUM_CLASSES["test"] = 80

model = models.get(Models.YOLO_NAS_S, pretrained_weights="test")
# .predict() would fail it model has no preprocessing params
self.assertIsNotNone(model.predict(np.zeros(shape=(512, 512, 3), dtype=np.uint8)))

def tearDown(self) -> None:
if os.path.exists("~/.cache/torch/hub/"):
shutil.rmtree("~/.cache/torch/hub/")
Expand Down

0 comments on commit 6fa1bf9

Please sign in to comment.