diff --git a/src/super_gradients/training/models/detection_models/yolo_base.py b/src/super_gradients/training/models/detection_models/yolo_base.py index 0549a81443..955bcd62a9 100755 --- a/src/super_gradients/training/models/detection_models/yolo_base.py +++ b/src/super_gradients/training/models/detection_models/yolo_base.py @@ -1,3 +1,4 @@ +import collections import math import warnings from typing import Union, Type, List, Tuple, Optional @@ -590,6 +591,15 @@ def forward(self, x): def load_state_dict(self, state_dict, strict=True): try: + keys_dropped_in_sg_320 = { + "stride", + "_head.anchors._stride", + "_head.anchors._anchors", + "_head.anchors._anchor_grid", + "_head._modules_list.14.stride", + } + state_dict = collections.OrderedDict([(k, v) for k, v in state_dict.items() if k not in keys_dropped_in_sg_320]) + super().load_state_dict(state_dict, strict) except RuntimeError as e: raise RuntimeError(