Skip to content

Commit

Permalink
added the ability to download and import external code for models fro… (
Browse files Browse the repository at this point in the history
#420)

* added the ability to download and import external code for models from the deci-lab
fixed a bug in all_architectures.py (model names were missing or changed)

* fix todo

* merge

* fix redundant param

* lint

Co-authored-by: Shay Aharon <[email protected]>
  • Loading branch information
ofrimasad and shaydeci authored Oct 11, 2022
1 parent 747b891 commit a73f9df
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 7 deletions.
42 changes: 42 additions & 0 deletions src/super_gradients/common/plugins/deci_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import json
import sys
from zipfile import ZipFile

import hydra

import importlib.util

import os
import pkg_resources
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig
Expand Down Expand Up @@ -81,3 +87,39 @@ def get_model_weights(self, model_name: str) -> str:
return None

return self._get_file(model_name=model_name, file_name=AutoNACFileName.WEIGHTS_PTH)

def download_and_load_model_additional_code(self, model_name: str, target_path: str, package_name: str = "deci_model_code") -> None:
"""
try to download code files for this model.
if found, code files will be placed in the target_path/package_name and imported dynamically
"""

file = self._get_file(model_name=model_name, file_name=AutoNACFileName.CODE_ZIP)

package_path = os.path.join(target_path, package_name)
if file is not None:
# crete the directory
os.makedirs(package_path, exist_ok=True)

# extract code files
with ZipFile(file) as zipfile:
zipfile.extractall(package_path)

# add an init file that imports all code files
with open(os.path.join(package_path, '__init__.py'), 'w') as init_file:
all_str = '\n\n__all__ = ['
for code_file in os.listdir(path=package_path):
if code_file.endswith(".py") and not code_file.startswith("__init__"):
init_file.write(f'import {code_file.replace(".py", "")}\n')
all_str += f'"{code_file.replace(".py", "")}", '

all_str += "]\n\n"
init_file.write(all_str)

# include in path and import
sys.path.insert(1, package_path)
importlib.import_module(package_name)

logger.info(f'*** IMPORTANT ***: files required for the model {model_name} were downloaded and added to your code in:\n{package_path}\n'
f'These files will be downloaded to the same location each time the model is fetched from the deci-client.\n'
f'you can override this by passing models.get(... download_required_code=False) and importing the files yourself')
4 changes: 3 additions & 1 deletion src/super_gradients/training/models/all_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,15 @@ class ModelNames:
REGNETY600 = "regnetY600"
REGNETY800 = "regnetY800"
CUSTOM_REGNET = "custom_regnet"
CUSTOM_ANYNET = "custom_anynet"
NAS_REGNET = "nas_regnet"
YOLOX_N = "yolox_n"
YOLOX_T = "yolox_t"
YOLOX_S = "yolox_s"
YOLOX_M = "yolox_m"
YOLOX_L = "yolox_l"
YOLOX_X = "yolox_x"
CUSTOM_YOLO_X = "CustomYoloX"
CUSTOM_YOLO_X = "custom_yolox"
SSD_MOBILENET_V1 = "ssd_mobilenet_v1"
SSD_LITE_MOBILENET_V2 = "ssd_lite_mobilenet_v2"
REPVGG_A0 = "repvgg_a0"
Expand Down Expand Up @@ -222,6 +223,7 @@ class ModelNames:
ModelNames.PP_LITE_B_SEG: PPLiteSegB,
ModelNames.PP_LITE_B_SEG50: PPLiteSegB,
ModelNames.PP_LITE_B_SEG75: PPLiteSegB,
ModelNames.CUSTOM_ANYNET: regnet.CustomAnyNet,
}

KD_ARCHITECTURES = {
Expand Down
26 changes: 20 additions & 6 deletions src/super_gradients/training/models/model_factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from typing import Optional, Tuple
from typing import Type

Expand All @@ -23,13 +24,16 @@
logger = get_logger(__name__)


def get_architecture(model_name: str, arch_params: HpmStruct, pretrained_weights: str) -> Tuple[Type[torch.nn.Module], HpmStruct, str, bool]:
def get_architecture(model_name: str, arch_params: HpmStruct, pretrained_weights: str,
download_required_code: bool = True) -> Tuple[Type[torch.nn.Module], HpmStruct, str, bool]:
"""
Get the corresponding architecture class.
:param model_name: Define the model's architecture from models/ALL_ARCHITECTURES
:param arch_params: Architecture hyper parameters. e.g.: block, num_blocks, etc.
:param pretrained_weights: Describe the dataset of the pretrained weights (for example "imagenent")
:param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False
will prevent additional code from being downloaded. This affects only models from remote client.
:return:
- architecture_cls: Class of the model
Expand All @@ -44,17 +48,23 @@ def get_architecture(model_name: str, arch_params: HpmStruct, pretrained_weights
logger.info(f'Required model {model_name} not found in local SuperGradients. Trying to load a model from remote deci lab')
deci_client = DeciClient()
_arch_params = deci_client.get_model_arch_params(model_name)
if download_required_code:
deci_client.download_and_load_model_additional_code(model_name, Path.cwd())

if _arch_params is None:
raise ValueError("Unsupported model name " + str(model_name) + ", see docs or all_architectures.py for supported nets.")
_arch_params = hydra.utils.instantiate(_arch_params)
model_name = _arch_params['model_name']
del _arch_params['model_name']
_arch_params = HpmStruct(**_arch_params)
_arch_params.override(**arch_params.to_dict())
model_name, arch_params, is_remote = _arch_params["model_name"], _arch_params, True
arch_params, is_remote = _arch_params, True
pretrained_weights = deci_client.get_model_weights(model_name)
return ARCHITECTURES[model_name], arch_params, pretrained_weights, is_remote


def instantiate_model(model_name: str, arch_params: dict, num_classes: int, pretrained_weights: str = None) -> torch.nn.Module:
def instantiate_model(model_name: str, arch_params: dict, num_classes: int,
pretrained_weights: str = None, download_required_code: bool = True) -> torch.nn.Module:
"""
Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required
module manipulation (i.e head replacement).
Expand All @@ -64,14 +74,16 @@ def instantiate_model(model_name: str, arch_params: dict, num_classes: int, pret
:param num_classes: Number of classes (defines the net's structure).
If None is given, will try to derrive from pretrained_weight's corresponding dataset.
:param pretrained_weights: Describe the dataset of the pretrained weights (for example "imagenent")
:param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False
will prevent additional code from being downloaded. This affects only models from remote client.
:return: Instantiated model i.e torch.nn.Module, architecture_class (will be none when architecture is not str)
"""
if arch_params is None:
arch_params = {}
arch_params = core_utils.HpmStruct(**arch_params)

architecture_cls, arch_params, pretrained_weights, is_remote = get_architecture(model_name, arch_params, pretrained_weights)
architecture_cls, arch_params, pretrained_weights, is_remote = get_architecture(model_name, arch_params, pretrained_weights, download_required_code)

if not issubclass(architecture_cls, SgModule):
net = architecture_cls(**arch_params.to_dict(include_schema=False))
Expand Down Expand Up @@ -110,7 +122,7 @@ def instantiate_model(model_name: str, arch_params: dict, num_classes: int, pret

def get(model_name: str, arch_params: Optional[dict] = None, num_classes: int = None,
strict_load: StrictLoad = StrictLoad.NO_KEY_MATCHING, checkpoint_path: str = None,
pretrained_weights: str = None, load_backbone: bool = False) -> SgModule:
pretrained_weights: str = None, load_backbone: bool = False, download_required_code: bool = True) -> SgModule:
"""
:param model_name: Defines the model's architecture from models/ALL_ARCHITECTURES
:param arch_params: Architecture hyper parameters. e.g.: block, num_blocks, etc.
Expand All @@ -122,11 +134,13 @@ def get(model_name: str, arch_params: Optional[dict] = None, num_classes: int =
If provided, will automatically attempt to load the checkpoint.
:param pretrained_weights: Describe the dataset of the pretrained weights (for example "imagenent").
:param load_backbone: Load the provided checkpoint to model.backbone instead of model.
:param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False
will prevent additional code from being downloaded. This affects only models from remote client.
NOTE: Passing pretrained_weights and checkpoint_path is ill-defined and will raise an error.
"""

net = instantiate_model(model_name, arch_params, num_classes, pretrained_weights)
net = instantiate_model(model_name, arch_params, num_classes, pretrained_weights, download_required_code)

if load_backbone and not checkpoint_path:
raise ValueError("Please set checkpoint_path when load_backbone=True")
Expand Down

0 comments on commit a73f9df

Please sign in to comment.