Skip to content

Commit

Permalink
Added import of quantized modules so they get registered (#542)
Browse files Browse the repository at this point in the history
* Added import of quantized modules so they get registered

* Guarded exposure of quantized modules when there is no pytorch_quantization

* Added comment on why

* Changed logging to debug

Co-authored-by: Ofri Masad <[email protected]>
  • Loading branch information
spsancti and ofrimasad authored Dec 19, 2022
1 parent f7f0229 commit c455a52
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion src/super_gradients/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .repvgg_block import RepVGGBlock
from .se_blocks import SEBlock, EffectiveSEBlock
from .skip_connections import Residual, SkipConnection, CrossModelSkipConnection, BackboneInternalSkipConnection, HeadInternalSkipConnection

from super_gradients.common.abstractions.abstract_logger import get_logger

__all__ = [
"ConvBNAct",
Expand All @@ -17,3 +17,30 @@
"BackboneInternalSkipConnection",
"HeadInternalSkipConnection",
]

logger = get_logger(__name__)
try:
# flake8 respects only the first occurence of __all__ defined in the module's root
from .quantization import QuantBottleneck # noqa: F401
from .quantization import QuantResidual # noqa: F401
from .quantization import QuantSkipConnection # noqa: F401
from .quantization import QuantCrossModelSkipConnection # noqa: F401
from .quantization import QuantBackboneInternalSkipConnection # noqa: F401
from .quantization import QuantHeadInternalSkipConnection # noqa: F401

quant_extensions = [
"QuantBottleneck",
"QuantResidual",
"QuantSkipConnection",
"QuantCrossModelSkipConnection",
"QuantBackboneInternalSkipConnection",
"QuantHeadInternalSkipConnection",
]

except (ImportError, NameError, ModuleNotFoundError) as import_err:
logger.debug(f"Failed to import pytorch_quantization: {import_err}")
quant_extensions = None


if quant_extensions is not None:
__all__.extend(quant_extensions)

0 comments on commit c455a52

Please sign in to comment.