diff --git a/src/sparseml/yolov8/trainers.py b/src/sparseml/yolov8/trainers.py index eeeb8576a7d..fe816605985 100644 --- a/src/sparseml/yolov8/trainers.py +++ b/src/sparseml/yolov8/trainers.py @@ -205,22 +205,13 @@ def setup_model(self): LOGGER.info("Loaded previous weights from sparseml checkpoint") return ckpt - def _modify_arch_for_quantization(self): - layer_map = {"Bottleneck": Bottleneck, "Conv": Conv} - for name, layer in self.model.named_modules(): - cls_name = layer.__class__.__name__ - if cls_name in layer_map: - submodule_path = name.split(".") - parent_module = _get_submodule(self.model, submodule_path[:-1]) - setattr(parent_module, submodule_path[-1], layer_map[cls_name](layer)) - def _build_managers(self, ckpt: Optional[dict]): if self.args.recipe is not None: self.manager = ScheduledModifierManager.from_yaml( self.args.recipe, recipe_variables=self.args.recipe_args ) if self.manager.quantization_modifiers: - self._modify_arch_for_quantization() + _modify_arch_for_quantization(self.model) if ckpt is None: return @@ -557,6 +548,8 @@ def _load(self, weights: str): "Applying structure from sparseml checkpoint " f"at epoch {self.ckpt['epoch']}" ) + if manager.quantization_modifiers: + _modify_arch_for_quantization(self.model) manager.apply_structure(self.model, epoch=epoch) else: LOGGER.info("No recipe from in sparseml checkpoint") @@ -775,3 +768,13 @@ def _get_submodule(module: torch.nn.Module, path: List[str]) -> torch.nn.Module: if not path: return module return _get_submodule(getattr(module, path[0]), path[1:]) + + +def _modify_arch_for_quantization(model): + layer_map = {"Bottleneck": Bottleneck, "Conv": Conv} + for name, layer in model.named_modules(): + cls_name = layer.__class__.__name__ + if cls_name in layer_map: + submodule_path = name.split(".") + parent_module = _get_submodule(model, submodule_path[:-1]) + setattr(parent_module, submodule_path[-1], layer_map[cls_name](layer))