From a331486e7eefac96046b74e7e5a7a99e38374f7b Mon Sep 17 00:00:00 2001 From: ofri masad Date: Sun, 9 Oct 2022 13:52:12 +0300 Subject: [PATCH 1/5] added the ability to download and import external code for models from the deci-lab fixed a bug in all_architectures.py (model names were missing or changed) --- .../common/plugins/deci_client.py | 42 +++++++++++++++++++ .../training/models/all_architectures.py | 4 +- .../training/models/model_factory.py | 13 ++++-- 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/src/super_gradients/common/plugins/deci_client.py b/src/super_gradients/common/plugins/deci_client.py index 77aa1f68e7..d2e4fcbcb2 100644 --- a/src/super_gradients/common/plugins/deci_client.py +++ b/src/super_gradients/common/plugins/deci_client.py @@ -1,6 +1,12 @@ import json +import sys +from zipfile import ZipFile import hydra + +import importlib.util + +import os import pkg_resources from hydra.core.global_hydra import GlobalHydra from omegaconf import DictConfig @@ -81,3 +87,39 @@ def get_model_weights(self, model_name: str) -> str: return None return self._get_file(model_name=model_name, file_name=AutoNACFileName.WEIGHTS_PTH) + + def download_and_load_model_additional_code(self, model_name: str, target_path: str, package_name: str = "deci_model_code") -> None: + """ + try to download code files for this model. + if found, code files will be placed in the target_path/package_name and imported dynamically + """ + + file = self._get_file(model_name=model_name, file_name="code.zip") # TODO fix after lab client new version + + package_path = os.path.join(target_path, package_name) + if file is not None: + # crete the directory + os.makedirs(package_path, exist_ok=True) + + # extract code files + with ZipFile(file) as zipfile: + zipfile.extractall(package_path) + + # add an init file that imports all code files + with open(os.path.join(package_path, '__init__.py'), 'w') as init_file: + all_str = '\n\n__all__ = [' + for code_file in os.listdir(path=package_path): + if code_file.endswith(".py") and not code_file.startswith("__init__"): + init_file.write(f'import {code_file.replace(".py", "")}\n') + all_str += f'"{code_file.replace(".py", "")}", ' + + all_str += "]\n\n" + init_file.write(all_str) + + # include in path and import + sys.path.insert(1, package_path) + importlib.import_module(package_name) + + logger.info(f'*** IMPORTANT ***: files required for the model {model_name} were downloaded and added to your code in:\n{package_path}\n' + f'These files will be downloaded to the same location each time the model is fetched from the deci-client.\n' + f'you can override this by passing models.get(... download_required_code=False) and importing the files yourself') diff --git a/src/super_gradients/training/models/all_architectures.py b/src/super_gradients/training/models/all_architectures.py index ac4444cdf0..cf2b6cfef6 100755 --- a/src/super_gradients/training/models/all_architectures.py +++ b/src/super_gradients/training/models/all_architectures.py @@ -79,6 +79,7 @@ class ModelNames: REGNETY600 = "regnetY600" REGNETY800 = "regnetY800" CUSTOM_REGNET = "custom_regnet" + CUSTOM_ANYNET = "custom_anynet" NAS_REGNET = "nas_regnet" YOLOX_N = "yolox_n" YOLOX_T = "yolox_t" @@ -86,7 +87,7 @@ class ModelNames: YOLOX_M = "yolox_m" YOLOX_L = "yolox_l" YOLOX_X = "yolox_x" - CUSTOM_YOLO_X = "CustomYoloX" + CUSTOM_YOLO_X = "custom_yolox" SSD_MOBILENET_V1 = "ssd_mobilenet_v1" SSD_LITE_MOBILENET_V2 = "ssd_lite_mobilenet_v2" REPVGG_A0 = "repvgg_a0" @@ -220,6 +221,7 @@ class ModelNames: ModelNames.PP_LITE_B_SEG: PPLiteSegB, ModelNames.PP_LITE_B_SEG50: PPLiteSegB, ModelNames.PP_LITE_B_SEG75: PPLiteSegB, + ModelNames.CUSTOM_ANYNET: regnet.CustomAnyNet, } KD_ARCHITECTURES = { diff --git a/src/super_gradients/training/models/model_factory.py b/src/super_gradients/training/models/model_factory.py index d67b6369c6..b8c9c68576 100644 --- a/src/super_gradients/training/models/model_factory.py +++ b/src/super_gradients/training/models/model_factory.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import Optional import hydra @@ -20,7 +21,7 @@ logger = get_logger(__name__) -def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = None) -> SgModule: +def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = None, download_required_code: bool = True) -> SgModule: """ Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required module manipulation (i.e head replacement). @@ -28,6 +29,8 @@ def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = No :param name: Defines the model's architecture from models/ALL_ARCHITECTURES :param arch_params: Architecture's parameters passed to models c'tor. :param pretrained_weights: string describing the dataset of the pretrained weights (for example "imagenent") + :param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False + will prevent additional code from being downloaded. This affects only models from remote client. :return: instantiated model i.e torch.nn.Module, architecture_class (will be none when architecture is not str) @@ -49,6 +52,8 @@ def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = No logger.info(f'Required model {name} not found in local SuperGradients. Trying to load a model from remote deci lab') deci_client = DeciClient() _arch_params = deci_client.get_model_arch_params(name) + if download_required_code: + deci_client.download_and_load_model_additional_code(name, Path.cwd()) if _arch_params is not None: _arch_params = hydra.utils.instantiate(_arch_params) @@ -78,7 +83,7 @@ def instantiate_model(name: str, arch_params: dict, pretrained_weights: str = No def get(model_name: str, arch_params: Optional[dict] = None, num_classes: int = None, strict_load: StrictLoad = StrictLoad.NO_KEY_MATCHING, checkpoint_path: str = None, - pretrained_weights: str = None, load_backbone: bool = False) -> SgModule: + pretrained_weights: str = None, load_backbone: bool = False, download_required_code: bool = True) -> SgModule: """ :param model_name: Defines the model's architecture from models/ALL_ARCHITECTURES :param num_classes: Number of classes (defines the net's structure). If None is given, will try to derrive from @@ -92,6 +97,8 @@ def get(model_name: str, arch_params: Optional[dict] = None, num_classes: int = (ie: path/to/checkpoint.pth). If provided, will automatically attempt to load the checkpoint. :param pretrained_weights: a string describing the dataset of the pretrained weights (for example "imagenent"). + :param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False + will prevent additional code from being downloaded. This affects only models from remote client. NOTE: Passing pretrained_weights and checkpoint_path is ill-defined and will raise an error. @@ -111,7 +118,7 @@ def get(model_name: str, arch_params: Optional[dict] = None, num_classes: int = arch_params["num_classes"] = num_classes arch_params = core_utils.HpmStruct(**arch_params) - net = instantiate_model(model_name, arch_params, pretrained_weights) + net = instantiate_model(model_name, arch_params, pretrained_weights, download_required_code) if load_backbone and not checkpoint_path: raise ValueError("Please set checkpoint_path when load_backbone=True") From c15da518f0e471463ac5dffb192545d9584c021d Mon Sep 17 00:00:00 2001 From: ofri masad Date: Sun, 9 Oct 2022 14:56:51 +0300 Subject: [PATCH 2/5] fix todo --- src/super_gradients/common/plugins/deci_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/super_gradients/common/plugins/deci_client.py b/src/super_gradients/common/plugins/deci_client.py index d2e4fcbcb2..12d5a09fb4 100644 --- a/src/super_gradients/common/plugins/deci_client.py +++ b/src/super_gradients/common/plugins/deci_client.py @@ -94,7 +94,7 @@ def download_and_load_model_additional_code(self, model_name: str, target_path: if found, code files will be placed in the target_path/package_name and imported dynamically """ - file = self._get_file(model_name=model_name, file_name="code.zip") # TODO fix after lab client new version + file = self._get_file(model_name=model_name, file_name=AutoNACFileName.CODE_ZIP) package_path = os.path.join(target_path, package_name) if file is not None: From 370a63e3dab4f4b00cf754a47755428a531cb34f Mon Sep 17 00:00:00 2001 From: ofri masad Date: Tue, 11 Oct 2022 14:27:19 +0300 Subject: [PATCH 3/5] merge --- src/super_gradients/training/models/model_factory.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/super_gradients/training/models/model_factory.py b/src/super_gradients/training/models/model_factory.py index a4b6913561..62a587545a 100644 --- a/src/super_gradients/training/models/model_factory.py +++ b/src/super_gradients/training/models/model_factory.py @@ -56,12 +56,12 @@ def get_architecture(model_name: str, arch_params: HpmStruct, pretrained_weights _arch_params = hydra.utils.instantiate(_arch_params) _arch_params = HpmStruct(**_arch_params) _arch_params.override(**arch_params.to_dict()) - model_name, arch_params, is_remote = _arch_params["model_name"], _arch_params, True + model_name, arch_params, is_remote = _arch_params.model_name, _arch_params, True pretrained_weights = deci_client.get_model_weights(model_name) return ARCHITECTURES[model_name], arch_params, pretrained_weights, is_remote -def instantiate_model(model_name: str, arch_params: dict, num_classes: int, pretrained_weights: str = None) -> torch.nn.Module: +def instantiate_model(model_name: str, arch_params: dict, num_classes: int, pretrained_weights: str = None, download_required_code: bool = True) -> torch.nn.Module: """ Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required module manipulation (i.e head replacement). @@ -71,6 +71,8 @@ def instantiate_model(model_name: str, arch_params: dict, num_classes: int, pret :param num_classes: Number of classes (defines the net's structure). If None is given, will try to derrive from pretrained_weight's corresponding dataset. :param pretrained_weights: Describe the dataset of the pretrained weights (for example "imagenent") + :param download_required_code: if model is not found in SG and is downloaded from a remote client, overriding this parameter with False + will prevent additional code from being downloaded. This affects only models from remote client. :return: Instantiated model i.e torch.nn.Module, architecture_class (will be none when architecture is not str) """ @@ -78,7 +80,7 @@ def instantiate_model(model_name: str, arch_params: dict, num_classes: int, pret arch_params = {} arch_params = core_utils.HpmStruct(**arch_params) - architecture_cls, arch_params, pretrained_weights, is_remote = get_architecture(model_name, arch_params, pretrained_weights) + architecture_cls, arch_params, pretrained_weights, is_remote = get_architecture(model_name, arch_params, pretrained_weights, download_required_code) if not issubclass(architecture_cls, SgModule): net = architecture_cls(**arch_params.to_dict(include_schema=False)) From 8d9c232fb566700d47b3844106f6120cbe7f09ab Mon Sep 17 00:00:00 2001 From: ofri masad Date: Tue, 11 Oct 2022 14:45:18 +0300 Subject: [PATCH 4/5] fix redundant param --- src/super_gradients/training/models/model_factory.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/super_gradients/training/models/model_factory.py b/src/super_gradients/training/models/model_factory.py index 62a587545a..0abe8642f7 100644 --- a/src/super_gradients/training/models/model_factory.py +++ b/src/super_gradients/training/models/model_factory.py @@ -54,9 +54,11 @@ def get_architecture(model_name: str, arch_params: HpmStruct, pretrained_weights if _arch_params is None: raise ValueError("Unsupported model name " + str(model_name) + ", see docs or all_architectures.py for supported nets.") _arch_params = hydra.utils.instantiate(_arch_params) + model_name = _arch_params['model_name'] + del _arch_params['model_name'] _arch_params = HpmStruct(**_arch_params) _arch_params.override(**arch_params.to_dict()) - model_name, arch_params, is_remote = _arch_params.model_name, _arch_params, True + arch_params, is_remote = _arch_params, True pretrained_weights = deci_client.get_model_weights(model_name) return ARCHITECTURES[model_name], arch_params, pretrained_weights, is_remote From 0d47f155816c41f899ec84f80f954d49bb7e9ed2 Mon Sep 17 00:00:00 2001 From: ofri masad Date: Tue, 11 Oct 2022 14:46:44 +0300 Subject: [PATCH 5/5] lint --- src/super_gradients/training/models/model_factory.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/super_gradients/training/models/model_factory.py b/src/super_gradients/training/models/model_factory.py index 0abe8642f7..e49434e234 100644 --- a/src/super_gradients/training/models/model_factory.py +++ b/src/super_gradients/training/models/model_factory.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Optional from typing import Optional, Tuple from typing import Type @@ -25,7 +24,8 @@ logger = get_logger(__name__) -def get_architecture(model_name: str, arch_params: HpmStruct, pretrained_weights: str, download_required_code: bool = True) -> Tuple[Type[torch.nn.Module], HpmStruct, str, bool]: +def get_architecture(model_name: str, arch_params: HpmStruct, pretrained_weights: str, + download_required_code: bool = True) -> Tuple[Type[torch.nn.Module], HpmStruct, str, bool]: """ Get the corresponding architecture class. @@ -63,7 +63,8 @@ def get_architecture(model_name: str, arch_params: HpmStruct, pretrained_weights return ARCHITECTURES[model_name], arch_params, pretrained_weights, is_remote -def instantiate_model(model_name: str, arch_params: dict, num_classes: int, pretrained_weights: str = None, download_required_code: bool = True) -> torch.nn.Module: +def instantiate_model(model_name: str, arch_params: dict, num_classes: int, + pretrained_weights: str = None, download_required_code: bool = True) -> torch.nn.Module: """ Instantiates nn.Module according to architecture and arch_params, and handles pretrained weights and the required module manipulation (i.e head replacement).