Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added import of quantized modules so they get registered #542

Merged
merged 8 commits into from
Dec 19, 2022
29 changes: 28 additions & 1 deletion src/super_gradients/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .repvgg_block import RepVGGBlock
from .se_blocks import SEBlock, EffectiveSEBlock
from .skip_connections import Residual, SkipConnection, CrossModelSkipConnection, BackboneInternalSkipConnection, HeadInternalSkipConnection

from super_gradients.common.abstractions.abstract_logger import get_logger

__all__ = [
"ConvBNAct",
Expand All @@ -17,3 +17,30 @@
"BackboneInternalSkipConnection",
"HeadInternalSkipConnection",
]

logger = get_logger(__name__)
try:
# flake8 respects only the first occurence of __all__ defined in the module's root
from .quantization import QuantBottleneck # noqa: F401
from .quantization import QuantResidual # noqa: F401
from .quantization import QuantSkipConnection # noqa: F401
from .quantization import QuantCrossModelSkipConnection # noqa: F401
from .quantization import QuantBackboneInternalSkipConnection # noqa: F401
from .quantization import QuantHeadInternalSkipConnection # noqa: F401

quant_extensions = [
"QuantBottleneck",
"QuantResidual",
"QuantSkipConnection",
"QuantCrossModelSkipConnection",
"QuantBackboneInternalSkipConnection",
"QuantHeadInternalSkipConnection",
]

except (ImportError, NameError, ModuleNotFoundError) as import_err:
logger.warning(f"Failed to import pytorch_quantization: {import_err}")
spsancti marked this conversation as resolved.
Show resolved Hide resolved
quant_extensions = None


if quant_extensions is not None:
__all__.extend(quant_extensions)