Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added the ability to download and import external code for models fro… #420

Merged
merged 11 commits into from
Oct 11, 2022
42 changes: 42 additions & 0 deletions src/super_gradients/common/plugins/deci_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import json
import sys
from zipfile import ZipFile

import hydra

import importlib.util

import os
import pkg_resources
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig
Expand Down Expand Up @@ -81,3 +87,39 @@ def get_model_weights(self, model_name: str) -> str:
return None

return self._get_file(model_name=model_name, file_name=AutoNACFileName.WEIGHTS_PTH)

def download_and_load_model_additional_code(self, model_name: str, target_path: str, package_name: str = "deci_model_code") -> None:
"""
try to download code files for this model.
if found, code files will be placed in the target_path/package_name and imported dynamically
"""

file = self._get_file(model_name=model_name, file_name=AutoNACFileName.CODE_ZIP)

package_path = os.path.join(target_path, package_name)
if file is not None:
# crete the directory
os.makedirs(package_path, exist_ok=True)

# extract code files
with ZipFile(file) as zipfile:
zipfile.extractall(package_path)

# add an init file that imports all code files
with open(os.path.join(package_path, '__init__.py'), 'w') as init_file:
all_str = '\n\n__all__ = ['
for code_file in os.listdir(path=package_path):
if code_file.endswith(".py") and not code_file.startswith("__init__"):
init_file.write(f'import {code_file.replace(".py", "")}\n')
all_str += f'"{code_file.replace(".py", "")}", '

all_str += "]\n\n"
init_file.write(all_str)

# include in path and import
sys.path.insert(1, package_path)
importlib.import_module(package_name)

logger.info(f'*** IMPORTANT ***: files required for the model {model_name} were downloaded and added to your code in:\n{package_path}\n'
f'These files will be downloaded to the same location each time the model is fetched from the deci-client.\n'
f'you can override this by passing models.get(... download_required_code=False) and importing the files yourself')
4 changes: 3 additions & 1 deletion src/super_gradients/training/models/all_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,15 @@ class ModelNames:
REGNETY600 = "regnetY600"
REGNETY800 = "regnetY800"
CUSTOM_REGNET = "custom_regnet"
CUSTOM_ANYNET = "custom_anynet"
NAS_REGNET = "nas_regnet"
YOLOX_N = "yolox_n"
YOLOX_T = "yolox_t"
YOLOX_S = "yolox_s"
YOLOX_M = "yolox_m"
YOLOX_L = "yolox_l"
YOLOX_X = "yolox_x"
CUSTOM_YOLO_X = "CustomYoloX"
CUSTOM_YOLO_X = "custom_yolox"
SSD_MOBILENET_V1 = "ssd_mobilenet_v1"
SSD_LITE_MOBILENET_V2 = "ssd_lite_mobilenet_v2"
REPVGG_A0 = "repvgg_a0"
Expand Down Expand Up @@ -220,6 +221,7 @@ class ModelNames:
ModelNames.PP_LITE_B_SEG: PPLiteSegB,
ModelNames.PP_LITE_B_SEG50: PPLiteSegB,
ModelNames.PP_LITE_B_SEG75: PPLiteSegB,
ModelNames.CUSTOM_ANYNET: regnet.CustomAnyNet,
}

KD_ARCHITECTURES = {
Expand Down
13 changes: 10 additions & 3 deletions src/super_gradients/training/models/model_factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from typing import Optional

import hydra
Expand All @@ -20,14 +21,16 @@
logger = get_logger(__name__)


def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = None) -> SgModule:
def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = None, download_required_code: bool = True) -> SgModule:
"""
Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required
module manipulation (i.e head replacement).

:param name: Defines the model's architecture from models/ALL_ARCHITECTURES
:param arch_params: Architecture's parameters passed to models c'tor.
:param pretrained_weights: string describing the dataset of the pretrained weights (for example "imagenent")
:param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False
will prevent additional code from being downloaded. This affects only models from remote client.

:return: instantiated model i.e torch.nn.Module, architecture_class (will be none when architecture is not str)

Expand All @@ -49,6 +52,8 @@ def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = No
logger.info(f'Required model {name} not found in local SuperGradients. Trying to load a model from remote deci lab')
deci_client = DeciClient()
_arch_params = deci_client.get_model_arch_params(name)
if download_required_code:
deci_client.download_and_load_model_additional_code(name, Path.cwd())

if _arch_params is not None:
_arch_params = hydra.utils.instantiate(_arch_params)
Expand Down Expand Up @@ -78,7 +83,7 @@ def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = No

def get(model_name: str, arch_params: Optional[dict] = None, num_classes: int = None,
strict_load: StrictLoad = StrictLoad.NO_KEY_MATCHING, checkpoint_path: str = None,
pretrained_weights: str = None, load_backbone: bool = False) -> SgModule:
pretrained_weights: str = None, load_backbone: bool = False, download_required_code: bool = True) -> SgModule:
"""
:param model_name: Defines the model's architecture from models/ALL_ARCHITECTURES
:param num_classes: Number of classes (defines the net's structure). If None is given, will try to derrive from
Expand All @@ -92,6 +97,8 @@ def get(model_name: str, arch_params: Optional[dict] = None, num_classes: int =
(ie: path/to/checkpoint.pth). If provided, will automatically attempt to
load the checkpoint.
:param pretrained_weights: a string describing the dataset of the pretrained weights (for example "imagenent").
:param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False
will prevent additional code from being downloaded. This affects only models from remote client.

NOTE: Passing pretrained_weights and checkpoint_path is ill-defined and will raise an error.

Expand All @@ -111,7 +118,7 @@ def get(model_name: str, arch_params: Optional[dict] = None, num_classes: int =
arch_params["num_classes"] = num_classes

arch_params = core_utils.HpmStruct(**arch_params)
net = instantiate_model(model_name, arch_params, pretrained_weights)
net = instantiate_model(model_name, arch_params, pretrained_weights, download_required_code)

if load_backbone and not checkpoint_path:
raise ValueError("Please set checkpoint_path when load_backbone=True")
Expand Down